1 // Copyright 2020 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
5 // Package fillreturns defines an Analyzer that will attempt to
6 // automatically fill in a return statement that has missing
7 // values with zero value elements.
20 "golang.org/x/tools/go/analysis"
21 "golang.org/x/tools/go/ast/astutil"
22 "golang.org/x/tools/internal/analysisinternal"
25 const Doc = `suggested fixes for "wrong number of return values (want %d, got %d)"
27 This checker provides suggested fixes for type errors of the
28 type "wrong number of return values (want %d, got %d)". For example:
29 func m() (int, string, *bool, error) {
33 func m() (int, string, *bool, error) {
34 return 0, "", nil, nil
37 This functionality is similar to https://github.com/sqs/goreturns.
40 var Analyzer = &analysis.Analyzer{
43 Requires: []*analysis.Analyzer{},
45 RunDespiteErrors: true,
48 var wrongReturnNumRegex = regexp.MustCompile(`wrong number of return values \(want (\d+), got (\d+)\)`)
50 func run(pass *analysis.Pass) (interface{}, error) {
51 info := pass.TypesInfo
53 return nil, fmt.Errorf("nil TypeInfo")
56 errors := analysisinternal.GetTypeErrors(pass)
58 for _, typeErr := range errors {
59 // Filter out the errors that are not relevant to this analyzer.
60 if !FixesError(typeErr.Msg) {
64 for _, f := range pass.Files {
65 if f.Pos() <= typeErr.Pos && typeErr.Pos <= f.End() {
74 // Get the end position of the error.
76 if err := format.Node(&buf, pass.Fset, file); err != nil {
79 typeErrEndPos := analysisinternal.TypeErrorEndPos(pass.Fset, buf.Bytes(), typeErr.Pos)
81 // Get the path for the relevant range.
82 path, _ := astutil.PathEnclosingInterval(file, typeErr.Pos, typeErrEndPos)
86 // Check to make sure the node of interest is a ReturnStmt.
87 ret, ok := path[0].(*ast.ReturnStmt)
92 // Get the function type that encloses the ReturnStmt.
93 var enclosingFunc *ast.FuncType
94 for _, n := range path {
95 switch node := n.(type) {
97 enclosingFunc = node.Type
99 enclosingFunc = node.Type
101 if enclosingFunc != nil {
105 if enclosingFunc == nil {
109 // Find the function declaration that encloses the ReturnStmt.
110 var outer *ast.FuncDecl
111 for _, p := range path {
112 if p, ok := p.(*ast.FuncDecl); ok {
121 // Skip any return statements that contain function calls with multiple return values.
122 for _, expr := range ret.Results {
123 e, ok := expr.(*ast.CallExpr)
127 if tup, ok := info.TypeOf(e).(*types.Tuple); ok && tup.Len() > 1 {
132 // Duplicate the return values to track which values have been matched.
133 remaining := make([]ast.Expr, len(ret.Results))
134 copy(remaining, ret.Results)
136 fixed := make([]ast.Expr, len(enclosingFunc.Results.List))
138 // For each value in the return function declaration, find the leftmost element
139 // in the return statement that has the desired type. If no such element exits,
140 // fill in the missing value with the appropriate "zero" value.
141 var retTyps []types.Type
142 for _, ret := range enclosingFunc.Results.List {
143 retTyps = append(retTyps, info.TypeOf(ret.Type))
146 analysisinternal.FindMatchingIdents(retTyps, file, ret.Pos(), info, pass.Pkg)
147 for i, retTyp := range retTyps {
150 for j, val := range remaining {
151 if !matchingTypes(info.TypeOf(val), retTyp) {
154 if !analysisinternal.IsZeroValue(val) {
158 // If the current match is a "zero" value, we keep searching in
159 // case we find a non-"zero" value match. If we do not find a
160 // non-"zero" value, we will use the "zero" value.
166 remaining = append(remaining[:idx], remaining[idx+1:]...)
168 idents, ok := matches[retTyp]
170 return nil, fmt.Errorf("invalid return type: %v", retTyp)
172 // Find the identifer whose name is most similar to the return type.
173 // If we do not find any identifer that matches the pattern,
174 // generate a zero value.
175 value := analysisinternal.FindBestMatch(retTyp.String(), idents)
177 value = analysisinternal.ZeroValue(
178 pass.Fset, file, pass.Pkg, retTyp)
187 // Remove any non-matching "zero values" from the leftover values.
188 var nonZeroRemaining []ast.Expr
189 for _, expr := range remaining {
190 if !analysisinternal.IsZeroValue(expr) {
191 nonZeroRemaining = append(nonZeroRemaining, expr)
194 // Append leftover return values to end of new return statement.
195 fixed = append(fixed, nonZeroRemaining...)
197 newRet := &ast.ReturnStmt{
202 // Convert the new return statement AST to text.
203 var newBuf bytes.Buffer
204 if err := format.Node(&newBuf, pass.Fset, newRet); err != nil {
208 pass.Report(analysis.Diagnostic{
211 Message: typeErr.Msg,
212 SuggestedFixes: []analysis.SuggestedFix{{
213 Message: "Fill in return values",
214 TextEdits: []analysis.TextEdit{{
217 NewText: newBuf.Bytes(),
225 func matchingTypes(want, got types.Type) bool {
226 if want == got || types.Identical(want, got) {
229 // Code segment to help check for untyped equality from (golang/go#32146).
230 if rhs, ok := want.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 {
231 if lhs, ok := got.Underlying().(*types.Basic); ok {
232 return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType
235 return types.AssignableTo(want, got) || types.ConvertibleTo(want, got)
238 func FixesError(msg string) bool {
239 matches := wrongReturnNumRegex.FindStringSubmatch(strings.TrimSpace(msg))
240 if len(matches) < 3 {
243 if _, err := strconv.Atoi(matches[1]); err != nil {
246 if _, err := strconv.Atoi(matches[2]); err != nil {