--- /dev/null
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package fillreturns defines an Analyzer that will attempt to
+// automatically fill in a return statement that has missing
+// values with zero value elements.
+package fillreturns
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/format"
+ "go/types"
+ "regexp"
+ "strconv"
+ "strings"
+
+ "golang.org/x/tools/go/analysis"
+ "golang.org/x/tools/go/ast/astutil"
+ "golang.org/x/tools/internal/analysisinternal"
+)
+
+const Doc = `suggested fixes for "wrong number of return values (want %d, got %d)"
+
+This checker provides suggested fixes for type errors of the
+type "wrong number of return values (want %d, got %d)". For example:
+ func m() (int, string, *bool, error) {
+ return
+ }
+will turn into
+ func m() (int, string, *bool, error) {
+ return 0, "", nil, nil
+ }
+
+This functionality is similar to https://github.com/sqs/goreturns.
+`
+
+var Analyzer = &analysis.Analyzer{
+ Name: "fillreturns",
+ Doc: Doc,
+ Requires: []*analysis.Analyzer{},
+ Run: run,
+ RunDespiteErrors: true,
+}
+
+var wrongReturnNumRegex = regexp.MustCompile(`wrong number of return values \(want (\d+), got (\d+)\)`)
+
+func run(pass *analysis.Pass) (interface{}, error) {
+ info := pass.TypesInfo
+ if info == nil {
+ return nil, fmt.Errorf("nil TypeInfo")
+ }
+
+ errors := analysisinternal.GetTypeErrors(pass)
+outer:
+ for _, typeErr := range errors {
+ // Filter out the errors that are not relevant to this analyzer.
+ if !FixesError(typeErr.Msg) {
+ continue
+ }
+ var file *ast.File
+ for _, f := range pass.Files {
+ if f.Pos() <= typeErr.Pos && typeErr.Pos <= f.End() {
+ file = f
+ break
+ }
+ }
+ if file == nil {
+ continue
+ }
+
+ // Get the end position of the error.
+ var buf bytes.Buffer
+ if err := format.Node(&buf, pass.Fset, file); err != nil {
+ continue
+ }
+ typeErrEndPos := analysisinternal.TypeErrorEndPos(pass.Fset, buf.Bytes(), typeErr.Pos)
+
+ // Get the path for the relevant range.
+ path, _ := astutil.PathEnclosingInterval(file, typeErr.Pos, typeErrEndPos)
+ if len(path) == 0 {
+ return nil, nil
+ }
+ // Check to make sure the node of interest is a ReturnStmt.
+ ret, ok := path[0].(*ast.ReturnStmt)
+ if !ok {
+ return nil, nil
+ }
+
+ // Get the function type that encloses the ReturnStmt.
+ var enclosingFunc *ast.FuncType
+ for _, n := range path {
+ switch node := n.(type) {
+ case *ast.FuncLit:
+ enclosingFunc = node.Type
+ case *ast.FuncDecl:
+ enclosingFunc = node.Type
+ }
+ if enclosingFunc != nil {
+ break
+ }
+ }
+ if enclosingFunc == nil {
+ continue
+ }
+
+ // Find the function declaration that encloses the ReturnStmt.
+ var outer *ast.FuncDecl
+ for _, p := range path {
+ if p, ok := p.(*ast.FuncDecl); ok {
+ outer = p
+ break
+ }
+ }
+ if outer == nil {
+ return nil, nil
+ }
+
+ // Skip any return statements that contain function calls with multiple return values.
+ for _, expr := range ret.Results {
+ e, ok := expr.(*ast.CallExpr)
+ if !ok {
+ continue
+ }
+ if tup, ok := info.TypeOf(e).(*types.Tuple); ok && tup.Len() > 1 {
+ continue outer
+ }
+ }
+
+ // Duplicate the return values to track which values have been matched.
+ remaining := make([]ast.Expr, len(ret.Results))
+ copy(remaining, ret.Results)
+
+ fixed := make([]ast.Expr, len(enclosingFunc.Results.List))
+
+ // For each value in the return function declaration, find the leftmost element
+ // in the return statement that has the desired type. If no such element exits,
+ // fill in the missing value with the appropriate "zero" value.
+ var retTyps []types.Type
+ for _, ret := range enclosingFunc.Results.List {
+ retTyps = append(retTyps, info.TypeOf(ret.Type))
+ }
+ matches :=
+ analysisinternal.FindMatchingIdents(retTyps, file, ret.Pos(), info, pass.Pkg)
+ for i, retTyp := range retTyps {
+ var match ast.Expr
+ var idx int
+ for j, val := range remaining {
+ if !matchingTypes(info.TypeOf(val), retTyp) {
+ continue
+ }
+ if !analysisinternal.IsZeroValue(val) {
+ match, idx = val, j
+ break
+ }
+ // If the current match is a "zero" value, we keep searching in
+ // case we find a non-"zero" value match. If we do not find a
+ // non-"zero" value, we will use the "zero" value.
+ match, idx = val, j
+ }
+
+ if match != nil {
+ fixed[i] = match
+ remaining = append(remaining[:idx], remaining[idx+1:]...)
+ } else {
+ idents, ok := matches[retTyp]
+ if !ok {
+ return nil, fmt.Errorf("invalid return type: %v", retTyp)
+ }
+ // Find the identifer whose name is most similar to the return type.
+ // If we do not find any identifer that matches the pattern,
+ // generate a zero value.
+ value := analysisinternal.FindBestMatch(retTyp.String(), idents)
+ if value == nil {
+ value = analysisinternal.ZeroValue(
+ pass.Fset, file, pass.Pkg, retTyp)
+ }
+ if value == nil {
+ return nil, nil
+ }
+ fixed[i] = value
+ }
+ }
+
+ // Remove any non-matching "zero values" from the leftover values.
+ var nonZeroRemaining []ast.Expr
+ for _, expr := range remaining {
+ if !analysisinternal.IsZeroValue(expr) {
+ nonZeroRemaining = append(nonZeroRemaining, expr)
+ }
+ }
+ // Append leftover return values to end of new return statement.
+ fixed = append(fixed, nonZeroRemaining...)
+
+ newRet := &ast.ReturnStmt{
+ Return: ret.Pos(),
+ Results: fixed,
+ }
+
+ // Convert the new return statement AST to text.
+ var newBuf bytes.Buffer
+ if err := format.Node(&newBuf, pass.Fset, newRet); err != nil {
+ return nil, err
+ }
+
+ pass.Report(analysis.Diagnostic{
+ Pos: typeErr.Pos,
+ End: typeErrEndPos,
+ Message: typeErr.Msg,
+ SuggestedFixes: []analysis.SuggestedFix{{
+ Message: "Fill in return values",
+ TextEdits: []analysis.TextEdit{{
+ Pos: ret.Pos(),
+ End: ret.End(),
+ NewText: newBuf.Bytes(),
+ }},
+ }},
+ })
+ }
+ return nil, nil
+}
+
+func matchingTypes(want, got types.Type) bool {
+ if want == got || types.Identical(want, got) {
+ return true
+ }
+ // Code segment to help check for untyped equality from (golang/go#32146).
+ if rhs, ok := want.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 {
+ if lhs, ok := got.Underlying().(*types.Basic); ok {
+ return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType
+ }
+ }
+ return types.AssignableTo(want, got) || types.ConvertibleTo(want, got)
+}
+
+func FixesError(msg string) bool {
+ matches := wrongReturnNumRegex.FindStringSubmatch(strings.TrimSpace(msg))
+ if len(matches) < 3 {
+ return false
+ }
+ if _, err := strconv.Atoi(matches[1]); err != nil {
+ return false
+ }
+ if _, err := strconv.Atoi(matches[2]); err != nil {
+ return false
+ }
+ return true
+}