+++ /dev/null
-// Copyright 2020 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package source
-
-import (
- "bytes"
- "fmt"
- "go/ast"
- "go/format"
- "go/parser"
- "go/token"
- "go/types"
- "strings"
- "unicode"
-
- "golang.org/x/tools/go/analysis"
- "golang.org/x/tools/go/ast/astutil"
- "golang.org/x/tools/internal/analysisinternal"
- "golang.org/x/tools/internal/span"
-)
-
-func extractVariable(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, _ *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
- expr, path, ok, err := canExtractVariable(rng, file)
- if !ok {
- return nil, fmt.Errorf("extractVariable: cannot extract %s: %v", fset.Position(rng.Start), err)
- }
-
- // Create new AST node for extracted code.
- var lhsNames []string
- switch expr := expr.(type) {
- // TODO: stricter rules for selectorExpr.
- case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr,
- *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
- lhsNames = append(lhsNames, generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0))
- case *ast.CallExpr:
- tup, ok := info.TypeOf(expr).(*types.Tuple)
- if !ok {
- // If the call expression only has one return value, we can treat it the
- // same as our standard extract variable case.
- lhsNames = append(lhsNames,
- generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0))
- break
- }
- for i := 0; i < tup.Len(); i++ {
- // Generate a unique variable for each return value.
- lhsNames = append(lhsNames,
- generateAvailableIdentifier(expr.Pos(), file, path, info, "x", i))
- }
- default:
- return nil, fmt.Errorf("cannot extract %T", expr)
- }
-
- insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path)
- if insertBeforeStmt == nil {
- return nil, fmt.Errorf("cannot find location to insert extraction")
- }
- tok := fset.File(expr.Pos())
- if tok == nil {
- return nil, fmt.Errorf("no file for pos %v", fset.Position(file.Pos()))
- }
- newLineIndent := "\n" + calculateIndentation(src, tok, insertBeforeStmt)
-
- lhs := strings.Join(lhsNames, ", ")
- assignStmt := &ast.AssignStmt{
- Lhs: []ast.Expr{ast.NewIdent(lhs)},
- Tok: token.DEFINE,
- Rhs: []ast.Expr{expr},
- }
- var buf bytes.Buffer
- if err := format.Node(&buf, fset, assignStmt); err != nil {
- return nil, err
- }
- assignment := strings.ReplaceAll(buf.String(), "\n", newLineIndent) + newLineIndent
-
- return &analysis.SuggestedFix{
- TextEdits: []analysis.TextEdit{
- {
- Pos: rng.Start,
- End: rng.End,
- NewText: []byte(lhs),
- },
- {
- Pos: insertBeforeStmt.Pos(),
- End: insertBeforeStmt.Pos(),
- NewText: []byte(assignment),
- },
- },
- }, nil
-}
-
-// canExtractVariable reports whether the code in the given range can be
-// extracted to a variable.
-func canExtractVariable(rng span.Range, file *ast.File) (ast.Expr, []ast.Node, bool, error) {
- if rng.Start == rng.End {
- return nil, nil, false, fmt.Errorf("start and end are equal")
- }
- path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
- if len(path) == 0 {
- return nil, nil, false, fmt.Errorf("no path enclosing interval")
- }
- for _, n := range path {
- if _, ok := n.(*ast.ImportSpec); ok {
- return nil, nil, false, fmt.Errorf("cannot extract variable in an import block")
- }
- }
- node := path[0]
- if rng.Start != node.Pos() || rng.End != node.End() {
- return nil, nil, false, fmt.Errorf("range does not map to an AST node")
- }
- expr, ok := node.(ast.Expr)
- if !ok {
- return nil, nil, false, fmt.Errorf("node is not an expression")
- }
- switch expr.(type) {
- case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.CallExpr,
- *ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
- return expr, path, true, nil
- }
- return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr)
-}
-
-// Calculate indentation for insertion.
-// When inserting lines of code, we must ensure that the lines have consistent
-// formatting (i.e. the proper indentation). To do so, we observe the indentation on the
-// line of code on which the insertion occurs.
-func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast.Node) string {
- line := tok.Line(insertBeforeStmt.Pos())
- lineOffset := tok.Offset(tok.LineStart(line))
- stmtOffset := tok.Offset(insertBeforeStmt.Pos())
- return string(content[lineOffset:stmtOffset])
-}
-
-// generateAvailableIdentifier adjusts the new function name until there are no collisons in scope.
-// Possible collisions include other function and variable names.
-func generateAvailableIdentifier(pos token.Pos, file *ast.File, path []ast.Node, info *types.Info, prefix string, idx int) string {
- scopes := CollectScopes(info, path, pos)
- name := prefix + fmt.Sprintf("%d", idx)
- for file.Scope.Lookup(name) != nil || !isValidName(name, scopes) {
- idx++
- name = fmt.Sprintf("%v%d", prefix, idx)
- }
- return name
-}
-
-// isValidName checks for variable collision in scope.
-func isValidName(name string, scopes []*types.Scope) bool {
- for _, scope := range scopes {
- if scope == nil {
- continue
- }
- if scope.Lookup(name) != nil {
- return false
- }
- }
- return true
-}
-
-// returnVariable keeps track of the information we need to properly introduce a new variable
-// that we will return in the extracted function.
-type returnVariable struct {
- // name is the identifier that is used on the left-hand side of the call to
- // the extracted function.
- name ast.Expr
- // decl is the declaration of the variable. It is used in the type signature of the
- // extracted function and for variable declarations.
- decl *ast.Field
- // zeroVal is the "zero value" of the type of the variable. It is used in a return
- // statement in the extracted function.
- zeroVal ast.Expr
-}
-
-// extractFunction refactors the selected block of code into a new function.
-// It also replaces the selected block of code with a call to the extracted
-// function. First, we manually adjust the selection range. We remove trailing
-// and leading whitespace characters to ensure the range is precisely bounded
-// by AST nodes. Next, we determine the variables that will be the paramters
-// and return values of the extracted function. Lastly, we construct the call
-// of the function and insert this call as well as the extracted function into
-// their proper locations.
-func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
- p, ok, err := canExtractFunction(fset, rng, src, file, info)
- if !ok {
- return nil, fmt.Errorf("extractFunction: cannot extract %s: %v",
- fset.Position(rng.Start), err)
- }
- tok, path, rng, outer, start := p.tok, p.path, p.rng, p.outer, p.start
- fileScope := info.Scopes[file]
- if fileScope == nil {
- return nil, fmt.Errorf("extractFunction: file scope is empty")
- }
- pkgScope := fileScope.Parent()
- if pkgScope == nil {
- return nil, fmt.Errorf("extractFunction: package scope is empty")
- }
-
- // TODO: Support non-nested return statements.
- // A return statement is non-nested if its parent node is equal to the parent node
- // of the first node in the selection. These cases must be handled seperately because
- // non-nested return statements are guaranteed to execute. Our control flow does not
- // properly consider these situations yet.
- var retStmts []*ast.ReturnStmt
- var hasNonNestedReturn bool
- startParent := findParent(outer, start)
- ast.Inspect(outer, func(n ast.Node) bool {
- if n == nil {
- return false
- }
- if n.Pos() < rng.Start || n.End() > rng.End {
- return n.Pos() <= rng.End
- }
- ret, ok := n.(*ast.ReturnStmt)
- if !ok {
- return true
- }
- if findParent(outer, n) == startParent {
- hasNonNestedReturn = true
- return false
- }
- retStmts = append(retStmts, ret)
- return false
- })
- if hasNonNestedReturn {
- return nil, fmt.Errorf("extractFunction: selected block contains non-nested return")
- }
- containsReturnStatement := len(retStmts) > 0
-
- // Now that we have determined the correct range for the selection block,
- // we must determine the signature of the extracted function. We will then replace
- // the block with an assignment statement that calls the extracted function with
- // the appropriate parameters and return values.
- variables, err := collectFreeVars(info, file, fileScope, pkgScope, rng, path[0])
- if err != nil {
- return nil, err
- }
-
- var (
- params, returns []ast.Expr // used when calling the extracted function
- paramTypes, returnTypes []*ast.Field // used in the signature of the extracted function
- uninitialized []types.Object // vars we will need to initialize before the call
- )
-
- // Avoid duplicates while traversing vars and uninitialzed.
- seenVars := make(map[types.Object]ast.Expr)
- seenUninitialized := make(map[types.Object]struct{})
-
- // Some variables on the left-hand side of our assignment statement may be free. If our
- // selection begins in the same scope in which the free variable is defined, we can
- // redefine it in our assignment statement. See the following example, where 'b' and
- // 'err' (both free variables) can be redefined in the second funcCall() while maintaing
- // correctness.
- //
- //
- // Not Redefined:
- //
- // a, err := funcCall()
- // var b int
- // b, err = funcCall()
- //
- // Redefined:
- //
- // a, err := funcCall()
- // b, err := funcCall()
- //
- // We track the number of free variables that can be redefined to maintain our preference
- // of using "x, y, z := fn()" style assignment statements.
- var canRedefineCount int
-
- // Each identifier in the selected block must become (1) a parameter to the
- // extracted function, (2) a return value of the extracted function, or (3) a local
- // variable in the extracted function. Determine the outcome(s) for each variable
- // based on whether it is free, altered within the selected block, and used outside
- // of the selected block.
- for _, v := range variables {
- if _, ok := seenVars[v.obj]; ok {
- continue
- }
- typ := analysisinternal.TypeExpr(fset, file, pkg, v.obj.Type())
- if typ == nil {
- return nil, fmt.Errorf("nil AST expression for type: %v", v.obj.Name())
- }
- seenVars[v.obj] = typ
- identifier := ast.NewIdent(v.obj.Name())
- // An identifier must meet three conditions to become a return value of the
- // extracted function. (1) its value must be defined or reassigned within
- // the selection (isAssigned), (2) it must be used at least once after the
- // selection (isUsed), and (3) its first use after the selection
- // cannot be its own reassignment or redefinition (objOverriden).
- if v.obj.Parent() == nil {
- return nil, fmt.Errorf("parent nil")
- }
- isUsed, firstUseAfter := objUsed(info, span.NewRange(fset, rng.End, v.obj.Parent().End()), v.obj)
- if v.assigned && isUsed && !varOverridden(info, firstUseAfter, v.obj, v.free, outer) {
- returnTypes = append(returnTypes, &ast.Field{Type: typ})
- returns = append(returns, identifier)
- if !v.free {
- uninitialized = append(uninitialized, v.obj)
- } else if v.obj.Parent().Pos() == startParent.Pos() {
- canRedefineCount++
- }
- }
- // An identifier must meet two conditions to become a parameter of the
- // extracted function. (1) it must be free (isFree), and (2) its first
- // use within the selection cannot be its own definition (isDefined).
- if v.free && !v.defined {
- params = append(params, identifier)
- paramTypes = append(paramTypes, &ast.Field{
- Names: []*ast.Ident{identifier},
- Type: typ,
- })
- }
- }
-
- // Find the function literal that encloses the selection. The enclosing function literal
- // may not be the enclosing function declaration (i.e. 'outer'). For example, in the
- // following block:
- //
- // func main() {
- // ast.Inspect(node, func(n ast.Node) bool {
- // v := 1 // this line extracted
- // return true
- // })
- // }
- //
- // 'outer' is main(). However, the extracted selection most directly belongs to
- // the anonymous function literal, the second argument of ast.Inspect(). We use the
- // enclosing function literal to determine the proper return types for return statements
- // within the selection. We still need the enclosing function declaration because this is
- // the top-level declaration. We inspect the top-level declaration to look for variables
- // as well as for code replacement.
- enclosing := outer.Type
- for _, p := range path {
- if p == enclosing {
- break
- }
- if fl, ok := p.(*ast.FuncLit); ok {
- enclosing = fl.Type
- break
- }
- }
-
- // We put the selection in a constructed file. We can then traverse and edit
- // the extracted selection without modifying the original AST.
- startOffset := tok.Offset(rng.Start)
- endOffset := tok.Offset(rng.End)
- selection := src[startOffset:endOffset]
- extractedBlock, err := parseBlockStmt(fset, selection)
- if err != nil {
- return nil, err
- }
-
- // We need to account for return statements in the selected block, as they will complicate
- // the logical flow of the extracted function. See the following example, where ** denotes
- // the range to be extracted.
- //
- // Before:
- //
- // func _() int {
- // a := 1
- // b := 2
- // **if a == b {
- // return a
- // }**
- // ...
- // }
- //
- // After:
- //
- // func _() int {
- // a := 1
- // b := 2
- // cond0, ret0 := x0(a, b)
- // if cond0 {
- // return ret0
- // }
- // ...
- // }
- //
- // func x0(a int, b int) (bool, int) {
- // if a == b {
- // return true, a
- // }
- // return false, 0
- // }
- //
- // We handle returns by adding an additional boolean return value to the extracted function.
- // This bool reports whether the original function would have returned. Because the
- // extracted selection contains a return statement, we must also add the types in the
- // return signature of the enclosing function to the return signature of the
- // extracted function. We then add an extra if statement checking this boolean value
- // in the original function. If the condition is met, the original function should
- // return a value, mimicking the functionality of the original return statement(s)
- // in the selection.
-
- var retVars []*returnVariable
- var ifReturn *ast.IfStmt
- if containsReturnStatement {
- // The selected block contained return statements, so we have to modify the
- // signature of the extracted function as described above. Adjust all of
- // the return statements in the extracted function to reflect this change in
- // signature.
- if err := adjustReturnStatements(returnTypes, seenVars, fset, file,
- pkg, extractedBlock); err != nil {
- return nil, err
- }
- // Collect the additional return values and types needed to accomodate return
- // statements in the selection. Update the type signature of the extracted
- // function and construct the if statement that will be inserted in the enclosing
- // function.
- retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, fset, rng.Start)
- if err != nil {
- return nil, err
- }
- }
-
- // Add a return statement to the end of the new function. This return statement must include
- // the values for the types of the original extracted function signature and (if a return
- // statement is present in the selection) enclosing function signature.
- hasReturnValues := len(returns)+len(retVars) > 0
- if hasReturnValues {
- extractedBlock.List = append(extractedBlock.List, &ast.ReturnStmt{
- Results: append(returns, getZeroVals(retVars)...),
- })
- }
-
- // Construct the appropriate call to the extracted function.
- // We must meet two conditions to use ":=" instead of '='. (1) there must be at least
- // one variable on the lhs that is uninitailized (non-free) prior to the assignment.
- // (2) all of the initialized (free) variables on the lhs must be able to be redefined.
- sym := token.ASSIGN
- canDefineCount := len(uninitialized) + canRedefineCount
- canDefine := len(uninitialized)+len(retVars) > 0 && canDefineCount == len(returns)
- if canDefine {
- sym = token.DEFINE
- }
- funName := generateAvailableIdentifier(rng.Start, file, path, info, "fn", 0)
- extractedFunCall := generateFuncCall(hasReturnValues, params,
- append(returns, getNames(retVars)...), funName, sym)
-
- // Build the extracted function.
- newFunc := &ast.FuncDecl{
- Name: ast.NewIdent(funName),
- Type: &ast.FuncType{
- Params: &ast.FieldList{List: paramTypes},
- Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)},
- },
- Body: extractedBlock,
- }
-
- // Create variable declarations for any identifiers that need to be initialized prior to
- // calling the extracted function. We do not manually initialize variables if every return
- // value is unitialized. We can use := to initialize the variables in this situation.
- var declarations []ast.Stmt
- if canDefineCount != len(returns) {
- declarations = initializeVars(uninitialized, retVars, seenUninitialized, seenVars)
- }
-
- var declBuf, replaceBuf, newFuncBuf, ifBuf bytes.Buffer
- if err := format.Node(&declBuf, fset, declarations); err != nil {
- return nil, err
- }
- if err := format.Node(&replaceBuf, fset, extractedFunCall); err != nil {
- return nil, err
- }
- if ifReturn != nil {
- if err := format.Node(&ifBuf, fset, ifReturn); err != nil {
- return nil, err
- }
- }
- if err := format.Node(&newFuncBuf, fset, newFunc); err != nil {
- return nil, err
- }
-
- // We're going to replace the whole enclosing function,
- // so preserve the text before and after the selected block.
- outerStart := tok.Offset(outer.Pos())
- outerEnd := tok.Offset(outer.End())
- before := src[outerStart:startOffset]
- after := src[endOffset:outerEnd]
- newLineIndent := "\n" + calculateIndentation(src, tok, start)
-
- var fullReplacement strings.Builder
- fullReplacement.Write(before)
- if declBuf.Len() > 0 { // add any initializations, if needed
- initializations := strings.ReplaceAll(declBuf.String(), "\n", newLineIndent) +
- newLineIndent
- fullReplacement.WriteString(initializations)
- }
- fullReplacement.Write(replaceBuf.Bytes()) // call the extracted function
- if ifBuf.Len() > 0 { // add the if statement below the function call, if needed
- ifstatement := newLineIndent +
- strings.ReplaceAll(ifBuf.String(), "\n", newLineIndent)
- fullReplacement.WriteString(ifstatement)
- }
- fullReplacement.Write(after)
- fullReplacement.WriteString("\n\n") // add newlines after the enclosing function
- fullReplacement.Write(newFuncBuf.Bytes()) // insert the extracted function
-
- return &analysis.SuggestedFix{
- TextEdits: []analysis.TextEdit{{
- Pos: outer.Pos(),
- End: outer.End(),
- NewText: []byte(fullReplacement.String()),
- }},
- }, nil
-}
-
-// adjustRangeForWhitespace adjusts the given range to exclude unnecessary leading or
-// trailing whitespace characters from selection. In the following example, each line
-// of the if statement is indented once. There are also two extra spaces after the
-// closing bracket before the line break.
-//
-// \tif (true) {
-// \t _ = 1
-// \t} \n
-//
-// By default, a valid range begins at 'if' and ends at the first whitespace character
-// after the '}'. But, users are likely to highlight full lines rather than adjusting
-// their cursors for whitespace. To support this use case, we must manually adjust the
-// ranges to match the correct AST node. In this particular example, we would adjust
-// rng.Start forward by one byte, and rng.End backwards by two bytes.
-func adjustRangeForWhitespace(rng span.Range, tok *token.File, content []byte) span.Range {
- offset := tok.Offset(rng.Start)
- for offset < len(content) {
- if !unicode.IsSpace(rune(content[offset])) {
- break
- }
- // Move forwards one byte to find a non-whitespace character.
- offset += 1
- }
- rng.Start = tok.Pos(offset)
-
- // Move backwards to find a non-whitespace character.
- offset = tok.Offset(rng.End)
- for o := offset - 1; 0 <= o && o < len(content); o-- {
- if !unicode.IsSpace(rune(content[o])) {
- break
- }
- offset = o
- }
- rng.End = tok.Pos(offset)
- return rng
-}
-
-// findParent finds the parent AST node of the given target node, if the target is a
-// descendant of the starting node.
-func findParent(start ast.Node, target ast.Node) ast.Node {
- var parent ast.Node
- analysisinternal.WalkASTWithParent(start, func(n, p ast.Node) bool {
- if n == target {
- parent = p
- return false
- }
- return true
- })
- return parent
-}
-
-// variable describes the status of a variable within a selection.
-type variable struct {
- obj types.Object
-
- // free reports whether the variable is a free variable, meaning it should
- // be a parameter to the extracted function.
- free bool
-
- // assigned reports whether the variable is assigned to in the selection.
- assigned bool
-
- // defined reports whether the variable is defined in the selection.
- defined bool
-}
-
-// collectFreeVars maps each identifier in the given range to whether it is "free."
-// Given a range, a variable in that range is defined as "free" if it is declared
-// outside of the range and neither at the file scope nor package scope. These free
-// variables will be used as arguments in the extracted function. It also returns a
-// list of identifiers that may need to be returned by the extracted function.
-// Some of the code in this function has been adapted from tools/cmd/guru/freevars.go.
-func collectFreeVars(info *types.Info, file *ast.File, fileScope, pkgScope *types.Scope, rng span.Range, node ast.Node) ([]*variable, error) {
- // id returns non-nil if n denotes an object that is referenced by the span
- // and defined either within the span or in the lexical environment. The bool
- // return value acts as an indicator for where it was defined.
- id := func(n *ast.Ident) (types.Object, bool) {
- obj := info.Uses[n]
- if obj == nil {
- return info.Defs[n], false
- }
- if obj.Name() == "_" {
- return nil, false // exclude objects denoting '_'
- }
- if _, ok := obj.(*types.PkgName); ok {
- return nil, false // imported package
- }
- if !(file.Pos() <= obj.Pos() && obj.Pos() <= file.End()) {
- return nil, false // not defined in this file
- }
- scope := obj.Parent()
- if scope == nil {
- return nil, false // e.g. interface method, struct field
- }
- if scope == fileScope || scope == pkgScope {
- return nil, false // defined at file or package scope
- }
- if rng.Start <= obj.Pos() && obj.Pos() <= rng.End {
- return obj, false // defined within selection => not free
- }
- return obj, true
- }
- // sel returns non-nil if n denotes a selection o.x.y that is referenced by the
- // span and defined either within the span or in the lexical environment. The bool
- // return value acts as an indicator for where it was defined.
- var sel func(n *ast.SelectorExpr) (types.Object, bool)
- sel = func(n *ast.SelectorExpr) (types.Object, bool) {
- switch x := astutil.Unparen(n.X).(type) {
- case *ast.SelectorExpr:
- return sel(x)
- case *ast.Ident:
- return id(x)
- }
- return nil, false
- }
- seen := make(map[types.Object]*variable)
- firstUseIn := make(map[types.Object]token.Pos)
- var vars []types.Object
- ast.Inspect(node, func(n ast.Node) bool {
- if n == nil {
- return false
- }
- if rng.Start <= n.Pos() && n.End() <= rng.End {
- var obj types.Object
- var isFree, prune bool
- switch n := n.(type) {
- case *ast.Ident:
- obj, isFree = id(n)
- case *ast.SelectorExpr:
- obj, isFree = sel(n)
- prune = true
- }
- if obj != nil {
- seen[obj] = &variable{
- obj: obj,
- free: isFree,
- }
- vars = append(vars, obj)
- // Find the first time that the object is used in the selection.
- first, ok := firstUseIn[obj]
- if !ok || n.Pos() < first {
- firstUseIn[obj] = n.Pos()
- }
- if prune {
- return false
- }
- }
- }
- return n.Pos() <= rng.End
- })
-
- // Find identifiers that are initialized or whose values are altered at some
- // point in the selected block. For example, in a selected block from lines 2-4,
- // variables x, y, and z are included in assigned. However, in a selected block
- // from lines 3-4, only variables y and z are included in assigned.
- //
- // 1: var a int
- // 2: var x int
- // 3: y := 3
- // 4: z := x + a
- //
- ast.Inspect(node, func(n ast.Node) bool {
- if n == nil {
- return false
- }
- if n.Pos() < rng.Start || n.End() > rng.End {
- return n.Pos() <= rng.End
- }
- switch n := n.(type) {
- case *ast.AssignStmt:
- for _, assignment := range n.Lhs {
- lhs, ok := assignment.(*ast.Ident)
- if !ok {
- continue
- }
- obj, _ := id(lhs)
- if obj == nil {
- continue
- }
- if _, ok := seen[obj]; !ok {
- continue
- }
- seen[obj].assigned = true
- if n.Tok != token.DEFINE {
- continue
- }
- // Find identifiers that are defined prior to being used
- // elsewhere in the selection.
- // TODO: Include identifiers that are assigned prior to being
- // used elsewhere in the selection. Then, change the assignment
- // to a definition in the extracted function.
- if firstUseIn[obj] != lhs.Pos() {
- continue
- }
- // Ensure that the object is not used in its own re-definition.
- // For example:
- // var f float64
- // f, e := math.Frexp(f)
- for _, expr := range n.Rhs {
- if referencesObj(info, expr, obj) {
- continue
- }
- if _, ok := seen[obj]; !ok {
- continue
- }
- seen[obj].defined = true
- break
- }
- }
- return false
- case *ast.DeclStmt:
- gen, ok := n.Decl.(*ast.GenDecl)
- if !ok {
- return false
- }
- for _, spec := range gen.Specs {
- vSpecs, ok := spec.(*ast.ValueSpec)
- if !ok {
- continue
- }
- for _, vSpec := range vSpecs.Names {
- obj, _ := id(vSpec)
- if obj == nil {
- continue
- }
- if _, ok := seen[obj]; !ok {
- continue
- }
- seen[obj].assigned = true
- }
- }
- return false
- case *ast.IncDecStmt:
- if ident, ok := n.X.(*ast.Ident); !ok {
- return false
- } else if obj, _ := id(ident); obj == nil {
- return false
- } else {
- if _, ok := seen[obj]; !ok {
- return false
- }
- seen[obj].assigned = true
- }
- }
- return true
- })
- var variables []*variable
- for _, obj := range vars {
- v, ok := seen[obj]
- if !ok {
- return nil, fmt.Errorf("no seen types.Object for %v", obj)
- }
- variables = append(variables, v)
- }
- return variables, nil
-}
-
-// referencesObj checks whether the given object appears in the given expression.
-func referencesObj(info *types.Info, expr ast.Expr, obj types.Object) bool {
- var hasObj bool
- ast.Inspect(expr, func(n ast.Node) bool {
- if n == nil {
- return false
- }
- ident, ok := n.(*ast.Ident)
- if !ok {
- return true
- }
- objUse := info.Uses[ident]
- if obj == objUse {
- hasObj = true
- return false
- }
- return false
- })
- return hasObj
-}
-
-type fnExtractParams struct {
- tok *token.File
- path []ast.Node
- rng span.Range
- outer *ast.FuncDecl
- start ast.Node
-}
-
-// canExtractFunction reports whether the code in the given range can be
-// extracted to a function.
-func canExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, _ *types.Info) (*fnExtractParams, bool, error) {
- if rng.Start == rng.End {
- return nil, false, fmt.Errorf("start and end are equal")
- }
- tok := fset.File(file.Pos())
- if tok == nil {
- return nil, false, fmt.Errorf("no file for pos %v", fset.Position(file.Pos()))
- }
- rng = adjustRangeForWhitespace(rng, tok, src)
- path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
- if len(path) == 0 {
- return nil, false, fmt.Errorf("no path enclosing interval")
- }
- // Node that encloses the selection must be a statement.
- // TODO: Support function extraction for an expression.
- _, ok := path[0].(ast.Stmt)
- if !ok {
- return nil, false, fmt.Errorf("node is not a statement")
- }
-
- // Find the function declaration that encloses the selection.
- var outer *ast.FuncDecl
- for _, p := range path {
- if p, ok := p.(*ast.FuncDecl); ok {
- outer = p
- break
- }
- }
- if outer == nil {
- return nil, false, fmt.Errorf("no enclosing function")
- }
-
- // Find the nodes at the start and end of the selection.
- var start, end ast.Node
- ast.Inspect(outer, func(n ast.Node) bool {
- if n == nil {
- return false
- }
- // Do not override 'start' with a node that begins at the same location
- // but is nested further from 'outer'.
- if start == nil && n.Pos() == rng.Start && n.End() <= rng.End {
- start = n
- }
- if end == nil && n.End() == rng.End && n.Pos() >= rng.Start {
- end = n
- }
- return n.Pos() <= rng.End
- })
- if start == nil || end == nil {
- return nil, false, fmt.Errorf("range does not map to AST nodes")
- }
- return &fnExtractParams{
- tok: tok,
- path: path,
- rng: rng,
- outer: outer,
- start: start,
- }, true, nil
-}
-
-// objUsed checks if the object is used within the range. It returns the first occurence of
-// the object in the range, if it exists.
-func objUsed(info *types.Info, rng span.Range, obj types.Object) (bool, *ast.Ident) {
- var firstUse *ast.Ident
- for id, objUse := range info.Uses {
- if obj != objUse {
- continue
- }
- if id.Pos() < rng.Start || id.End() > rng.End {
- continue
- }
- if firstUse == nil || id.Pos() < firstUse.Pos() {
- firstUse = id
- }
- }
- return firstUse != nil, firstUse
-}
-
-// varOverridden traverses the given AST node until we find the given identifier. Then, we
-// examine the occurrence of the given identifier and check for (1) whether the identifier
-// is being redefined. If the identifier is free, we also check for (2) whether the identifier
-// is being reassigned. We will not include an identifier in the return statement of the
-// extracted function if it meets one of the above conditions.
-func varOverridden(info *types.Info, firstUse *ast.Ident, obj types.Object, isFree bool, node ast.Node) bool {
- var isOverriden bool
- ast.Inspect(node, func(n ast.Node) bool {
- if n == nil {
- return false
- }
- assignment, ok := n.(*ast.AssignStmt)
- if !ok {
- return true
- }
- // A free variable is initialized prior to the selection. We can always reassign
- // this variable after the selection because it has already been defined.
- // Conversely, a non-free variable is initialized within the selection. Thus, we
- // cannot reassign this variable after the selection unless it is initialized and
- // returned by the extracted function.
- if !isFree && assignment.Tok == token.ASSIGN {
- return false
- }
- for _, assigned := range assignment.Lhs {
- ident, ok := assigned.(*ast.Ident)
- // Check if we found the first use of the identifier.
- if !ok || ident != firstUse {
- continue
- }
- objUse := info.Uses[ident]
- if objUse == nil || objUse != obj {
- continue
- }
- // Ensure that the object is not used in its own definition.
- // For example:
- // var f float64
- // f, e := math.Frexp(f)
- for _, expr := range assignment.Rhs {
- if referencesObj(info, expr, obj) {
- return false
- }
- }
- isOverriden = true
- return false
- }
- return false
- })
- return isOverriden
-}
-
-// parseExtraction generates an AST file from the given text. We then return the portion of the
-// file that represents the text.
-func parseBlockStmt(fset *token.FileSet, src []byte) (*ast.BlockStmt, error) {
- text := "package main\nfunc _() { " + string(src) + " }"
- extract, err := parser.ParseFile(fset, "", text, 0)
- if err != nil {
- return nil, err
- }
- if len(extract.Decls) == 0 {
- return nil, fmt.Errorf("parsed file does not contain any declarations")
- }
- decl, ok := extract.Decls[0].(*ast.FuncDecl)
- if !ok {
- return nil, fmt.Errorf("parsed file does not contain expected function declaration")
- }
- if decl.Body == nil {
- return nil, fmt.Errorf("extracted function has no body")
- }
- return decl.Body, nil
-}
-
-// generateReturnInfo generates the information we need to adjust the return statements and
-// signature of the extracted function. We prepare names, signatures, and "zero values" that
-// represent the new variables. We also use this information to construct the if statement that
-// is inserted below the call to the extracted function.
-func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, fset *token.FileSet, pos token.Pos) ([]*returnVariable, *ast.IfStmt, error) {
- // Generate information for the added bool value.
- cond := &ast.Ident{Name: generateAvailableIdentifier(pos, file, path, info, "cond", 0)}
- retVars := []*returnVariable{
- {
- name: cond,
- decl: &ast.Field{Type: ast.NewIdent("bool")},
- zeroVal: ast.NewIdent("false"),
- },
- }
- // Generate information for the values in the return signature of the enclosing function.
- if enclosing.Results != nil {
- for i, field := range enclosing.Results.List {
- typ := info.TypeOf(field.Type)
- if typ == nil {
- return nil, nil, fmt.Errorf(
- "failed type conversion, AST expression: %T", field.Type)
- }
- expr := analysisinternal.TypeExpr(fset, file, pkg, typ)
- if expr == nil {
- return nil, nil, fmt.Errorf("nil AST expression")
- }
- retVars = append(retVars, &returnVariable{
- name: ast.NewIdent(generateAvailableIdentifier(pos, file,
- path, info, "ret", i)),
- decl: &ast.Field{Type: expr},
- zeroVal: analysisinternal.ZeroValue(
- fset, file, pkg, typ),
- })
- }
- }
- // Create the return statement for the enclosing function. We must exclude the variable
- // for the condition of the if statement (cond) from the return statement.
- ifReturn := &ast.IfStmt{
- Cond: cond,
- Body: &ast.BlockStmt{
- List: []ast.Stmt{&ast.ReturnStmt{Results: getNames(retVars)[1:]}},
- },
- }
- return retVars, ifReturn, nil
-}
-
-// adjustReturnStatements adds "zero values" of the given types to each return statement
-// in the given AST node.
-func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object]ast.Expr, fset *token.FileSet, file *ast.File, pkg *types.Package, extractedBlock *ast.BlockStmt) error {
- var zeroVals []ast.Expr
- // Create "zero values" for each type.
- for _, returnType := range returnTypes {
- var val ast.Expr
- for obj, typ := range seenVars {
- if typ != returnType.Type {
- continue
- }
- val = analysisinternal.ZeroValue(fset, file, pkg, obj.Type())
- break
- }
- if val == nil {
- return fmt.Errorf(
- "could not find matching AST expression for %T", returnType.Type)
- }
- zeroVals = append(zeroVals, val)
- }
- // Add "zero values" to each return statement.
- // The bool reports whether the enclosing function should return after calling the
- // extracted function. We set the bool to 'true' because, if these return statements
- // execute, the extracted function terminates early, and the enclosing function must
- // return as well.
- zeroVals = append(zeroVals, ast.NewIdent("true"))
- ast.Inspect(extractedBlock, func(n ast.Node) bool {
- if n == nil {
- return false
- }
- if n, ok := n.(*ast.ReturnStmt); ok {
- n.Results = append(zeroVals, n.Results...)
- return false
- }
- return true
- })
- return nil
-}
-
-// generateFuncCall constructs a call expression for the extracted function, described by the
-// given parameters and return variables.
-func generateFuncCall(hasReturnVals bool, params, returns []ast.Expr, name string, token token.Token) ast.Node {
- var replace ast.Node
- if hasReturnVals {
- callExpr := &ast.CallExpr{
- Fun: ast.NewIdent(name),
- Args: params,
- }
- replace = &ast.AssignStmt{
- Lhs: returns,
- Tok: token,
- Rhs: []ast.Expr{callExpr},
- }
- } else {
- replace = &ast.CallExpr{
- Fun: ast.NewIdent(name),
- Args: params,
- }
- }
- return replace
-}
-
-// initializeVars creates variable declarations, if needed.
-// Our preference is to replace the selected block with an "x, y, z := fn()" style
-// assignment statement. We can use this style when all of the variables in the
-// extracted function's return statement are either not defined prior to the extracted block
-// or can be safely redefined. However, for example, if z is already defined
-// in a different scope, we replace the selected block with:
-//
-// var x int
-// var y string
-// x, y, z = fn()
-func initializeVars(uninitialized []types.Object, retVars []*returnVariable, seenUninitialized map[types.Object]struct{}, seenVars map[types.Object]ast.Expr) []ast.Stmt {
- var declarations []ast.Stmt
- for _, obj := range uninitialized {
- if _, ok := seenUninitialized[obj]; ok {
- continue
- }
- seenUninitialized[obj] = struct{}{}
- valSpec := &ast.ValueSpec{
- Names: []*ast.Ident{ast.NewIdent(obj.Name())},
- Type: seenVars[obj],
- }
- genDecl := &ast.GenDecl{
- Tok: token.VAR,
- Specs: []ast.Spec{valSpec},
- }
- declarations = append(declarations, &ast.DeclStmt{Decl: genDecl})
- }
- // Each variable added from a return statement in the selection
- // must be initialized.
- for i, retVar := range retVars {
- n := retVar.name.(*ast.Ident)
- valSpec := &ast.ValueSpec{
- Names: []*ast.Ident{n},
- Type: retVars[i].decl.Type,
- }
- genDecl := &ast.GenDecl{
- Tok: token.VAR,
- Specs: []ast.Spec{valSpec},
- }
- declarations = append(declarations, &ast.DeclStmt{Decl: genDecl})
- }
- return declarations
-}
-
-// getNames returns the names from the given list of returnVariable.
-func getNames(retVars []*returnVariable) []ast.Expr {
- var names []ast.Expr
- for _, retVar := range retVars {
- names = append(names, retVar.name)
- }
- return names
-}
-
-// getZeroVals returns the "zero values" from the given list of returnVariable.
-func getZeroVals(retVars []*returnVariable) []ast.Expr {
- var zvs []ast.Expr
- for _, retVar := range retVars {
- zvs = append(zvs, retVar.zeroVal)
- }
- return zvs
-}
-
-// getDecls returns the declarations from the given list of returnVariable.
-func getDecls(retVars []*returnVariable) []*ast.Field {
- var decls []*ast.Field
- for _, retVar := range retVars {
- decls = append(decls, retVar.decl)
- }
- return decls
-}