+++ /dev/null
-// Copyright 2020 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Package fillreturns defines an Analyzer that will attempt to
-// automatically fill in a return statement that has missing
-// values with zero value elements.
-package fillreturns
-
-import (
- "bytes"
- "fmt"
- "go/ast"
- "go/format"
- "go/types"
- "regexp"
- "strconv"
- "strings"
-
- "golang.org/x/tools/go/analysis"
- "golang.org/x/tools/go/ast/astutil"
- "golang.org/x/tools/internal/analysisinternal"
-)
-
-const Doc = `suggested fixes for "wrong number of return values (want %d, got %d)"
-
-This checker provides suggested fixes for type errors of the
-type "wrong number of return values (want %d, got %d)". For example:
- func m() (int, string, *bool, error) {
- return
- }
-will turn into
- func m() (int, string, *bool, error) {
- return 0, "", nil, nil
- }
-
-This functionality is similar to https://github.com/sqs/goreturns.
-`
-
-var Analyzer = &analysis.Analyzer{
- Name: "fillreturns",
- Doc: Doc,
- Requires: []*analysis.Analyzer{},
- Run: run,
- RunDespiteErrors: true,
-}
-
-var wrongReturnNumRegex = regexp.MustCompile(`wrong number of return values \(want (\d+), got (\d+)\)`)
-
-func run(pass *analysis.Pass) (interface{}, error) {
- info := pass.TypesInfo
- if info == nil {
- return nil, fmt.Errorf("nil TypeInfo")
- }
-
- errors := analysisinternal.GetTypeErrors(pass)
-outer:
- for _, typeErr := range errors {
- // Filter out the errors that are not relevant to this analyzer.
- if !FixesError(typeErr.Msg) {
- continue
- }
- var file *ast.File
- for _, f := range pass.Files {
- if f.Pos() <= typeErr.Pos && typeErr.Pos <= f.End() {
- file = f
- break
- }
- }
- if file == nil {
- continue
- }
-
- // Get the end position of the error.
- var buf bytes.Buffer
- if err := format.Node(&buf, pass.Fset, file); err != nil {
- continue
- }
- typeErrEndPos := analysisinternal.TypeErrorEndPos(pass.Fset, buf.Bytes(), typeErr.Pos)
-
- // Get the path for the relevant range.
- path, _ := astutil.PathEnclosingInterval(file, typeErr.Pos, typeErrEndPos)
- if len(path) == 0 {
- return nil, nil
- }
- // Check to make sure the node of interest is a ReturnStmt.
- ret, ok := path[0].(*ast.ReturnStmt)
- if !ok {
- return nil, nil
- }
-
- // Get the function type that encloses the ReturnStmt.
- var enclosingFunc *ast.FuncType
- for _, n := range path {
- switch node := n.(type) {
- case *ast.FuncLit:
- enclosingFunc = node.Type
- case *ast.FuncDecl:
- enclosingFunc = node.Type
- }
- if enclosingFunc != nil {
- break
- }
- }
- if enclosingFunc == nil {
- continue
- }
-
- // Find the function declaration that encloses the ReturnStmt.
- var outer *ast.FuncDecl
- for _, p := range path {
- if p, ok := p.(*ast.FuncDecl); ok {
- outer = p
- break
- }
- }
- if outer == nil {
- return nil, nil
- }
-
- // Skip any return statements that contain function calls with multiple return values.
- for _, expr := range ret.Results {
- e, ok := expr.(*ast.CallExpr)
- if !ok {
- continue
- }
- if tup, ok := info.TypeOf(e).(*types.Tuple); ok && tup.Len() > 1 {
- continue outer
- }
- }
-
- // Duplicate the return values to track which values have been matched.
- remaining := make([]ast.Expr, len(ret.Results))
- copy(remaining, ret.Results)
-
- fixed := make([]ast.Expr, len(enclosingFunc.Results.List))
-
- // For each value in the return function declaration, find the leftmost element
- // in the return statement that has the desired type. If no such element exits,
- // fill in the missing value with the appropriate "zero" value.
- var retTyps []types.Type
- for _, ret := range enclosingFunc.Results.List {
- retTyps = append(retTyps, info.TypeOf(ret.Type))
- }
- matches :=
- analysisinternal.FindMatchingIdents(retTyps, file, ret.Pos(), info, pass.Pkg)
- for i, retTyp := range retTyps {
- var match ast.Expr
- var idx int
- for j, val := range remaining {
- if !matchingTypes(info.TypeOf(val), retTyp) {
- continue
- }
- if !analysisinternal.IsZeroValue(val) {
- match, idx = val, j
- break
- }
- // If the current match is a "zero" value, we keep searching in
- // case we find a non-"zero" value match. If we do not find a
- // non-"zero" value, we will use the "zero" value.
- match, idx = val, j
- }
-
- if match != nil {
- fixed[i] = match
- remaining = append(remaining[:idx], remaining[idx+1:]...)
- } else {
- idents, ok := matches[retTyp]
- if !ok {
- return nil, fmt.Errorf("invalid return type: %v", retTyp)
- }
- // Find the identifer whose name is most similar to the return type.
- // If we do not find any identifer that matches the pattern,
- // generate a zero value.
- value := analysisinternal.FindBestMatch(retTyp.String(), idents)
- if value == nil {
- value = analysisinternal.ZeroValue(
- pass.Fset, file, pass.Pkg, retTyp)
- }
- if value == nil {
- return nil, nil
- }
- fixed[i] = value
- }
- }
-
- // Remove any non-matching "zero values" from the leftover values.
- var nonZeroRemaining []ast.Expr
- for _, expr := range remaining {
- if !analysisinternal.IsZeroValue(expr) {
- nonZeroRemaining = append(nonZeroRemaining, expr)
- }
- }
- // Append leftover return values to end of new return statement.
- fixed = append(fixed, nonZeroRemaining...)
-
- newRet := &ast.ReturnStmt{
- Return: ret.Pos(),
- Results: fixed,
- }
-
- // Convert the new return statement AST to text.
- var newBuf bytes.Buffer
- if err := format.Node(&newBuf, pass.Fset, newRet); err != nil {
- return nil, err
- }
-
- pass.Report(analysis.Diagnostic{
- Pos: typeErr.Pos,
- End: typeErrEndPos,
- Message: typeErr.Msg,
- SuggestedFixes: []analysis.SuggestedFix{{
- Message: "Fill in return values",
- TextEdits: []analysis.TextEdit{{
- Pos: ret.Pos(),
- End: ret.End(),
- NewText: newBuf.Bytes(),
- }},
- }},
- })
- }
- return nil, nil
-}
-
-func matchingTypes(want, got types.Type) bool {
- if want == got || types.Identical(want, got) {
- return true
- }
- // Code segment to help check for untyped equality from (golang/go#32146).
- if rhs, ok := want.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 {
- if lhs, ok := got.Underlying().(*types.Basic); ok {
- return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType
- }
- }
- return types.AssignableTo(want, got) || types.ConvertibleTo(want, got)
-}
-
-func FixesError(msg string) bool {
- matches := wrongReturnNumRegex.FindStringSubmatch(strings.TrimSpace(msg))
- if len(matches) < 3 {
- return false
- }
- if _, err := strconv.Atoi(matches[1]); err != nil {
- return false
- }
- if _, err := strconv.Atoi(matches[2]); err != nil {
- return false
- }
- return true
-}