1 // Copyright 2020 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
18 "golang.org/x/tools/go/analysis"
19 "golang.org/x/tools/go/ast/astutil"
20 "golang.org/x/tools/internal/analysisinternal"
21 "golang.org/x/tools/internal/span"
24 func extractVariable(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, _ *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
25 expr, path, ok, err := canExtractVariable(rng, file)
27 return nil, fmt.Errorf("extractVariable: cannot extract %s: %v", fset.Position(rng.Start), err)
30 // Create new AST node for extracted code.
32 switch expr := expr.(type) {
33 // TODO: stricter rules for selectorExpr.
34 case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr,
35 *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
36 lhsNames = append(lhsNames, generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0))
38 tup, ok := info.TypeOf(expr).(*types.Tuple)
40 // If the call expression only has one return value, we can treat it the
41 // same as our standard extract variable case.
42 lhsNames = append(lhsNames,
43 generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0))
46 for i := 0; i < tup.Len(); i++ {
47 // Generate a unique variable for each return value.
48 lhsNames = append(lhsNames,
49 generateAvailableIdentifier(expr.Pos(), file, path, info, "x", i))
52 return nil, fmt.Errorf("cannot extract %T", expr)
55 insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path)
56 if insertBeforeStmt == nil {
57 return nil, fmt.Errorf("cannot find location to insert extraction")
59 tok := fset.File(expr.Pos())
61 return nil, fmt.Errorf("no file for pos %v", fset.Position(file.Pos()))
63 newLineIndent := "\n" + calculateIndentation(src, tok, insertBeforeStmt)
65 lhs := strings.Join(lhsNames, ", ")
66 assignStmt := &ast.AssignStmt{
67 Lhs: []ast.Expr{ast.NewIdent(lhs)},
69 Rhs: []ast.Expr{expr},
72 if err := format.Node(&buf, fset, assignStmt); err != nil {
75 assignment := strings.ReplaceAll(buf.String(), "\n", newLineIndent) + newLineIndent
77 return &analysis.SuggestedFix{
78 TextEdits: []analysis.TextEdit{
85 Pos: insertBeforeStmt.Pos(),
86 End: insertBeforeStmt.Pos(),
87 NewText: []byte(assignment),
93 // canExtractVariable reports whether the code in the given range can be
94 // extracted to a variable.
95 func canExtractVariable(rng span.Range, file *ast.File) (ast.Expr, []ast.Node, bool, error) {
96 if rng.Start == rng.End {
97 return nil, nil, false, fmt.Errorf("start and end are equal")
99 path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
101 return nil, nil, false, fmt.Errorf("no path enclosing interval")
103 for _, n := range path {
104 if _, ok := n.(*ast.ImportSpec); ok {
105 return nil, nil, false, fmt.Errorf("cannot extract variable in an import block")
109 if rng.Start != node.Pos() || rng.End != node.End() {
110 return nil, nil, false, fmt.Errorf("range does not map to an AST node")
112 expr, ok := node.(ast.Expr)
114 return nil, nil, false, fmt.Errorf("node is not an expression")
117 case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.CallExpr,
118 *ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
119 return expr, path, true, nil
121 return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr)
124 // Calculate indentation for insertion.
125 // When inserting lines of code, we must ensure that the lines have consistent
126 // formatting (i.e. the proper indentation). To do so, we observe the indentation on the
127 // line of code on which the insertion occurs.
128 func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast.Node) string {
129 line := tok.Line(insertBeforeStmt.Pos())
130 lineOffset := tok.Offset(tok.LineStart(line))
131 stmtOffset := tok.Offset(insertBeforeStmt.Pos())
132 return string(content[lineOffset:stmtOffset])
135 // generateAvailableIdentifier adjusts the new function name until there are no collisons in scope.
136 // Possible collisions include other function and variable names.
137 func generateAvailableIdentifier(pos token.Pos, file *ast.File, path []ast.Node, info *types.Info, prefix string, idx int) string {
138 scopes := CollectScopes(info, path, pos)
139 name := prefix + fmt.Sprintf("%d", idx)
140 for file.Scope.Lookup(name) != nil || !isValidName(name, scopes) {
142 name = fmt.Sprintf("%v%d", prefix, idx)
147 // isValidName checks for variable collision in scope.
148 func isValidName(name string, scopes []*types.Scope) bool {
149 for _, scope := range scopes {
153 if scope.Lookup(name) != nil {
160 // returnVariable keeps track of the information we need to properly introduce a new variable
161 // that we will return in the extracted function.
162 type returnVariable struct {
163 // name is the identifier that is used on the left-hand side of the call to
164 // the extracted function.
166 // decl is the declaration of the variable. It is used in the type signature of the
167 // extracted function and for variable declarations.
169 // zeroVal is the "zero value" of the type of the variable. It is used in a return
170 // statement in the extracted function.
174 // extractFunction refactors the selected block of code into a new function.
175 // It also replaces the selected block of code with a call to the extracted
176 // function. First, we manually adjust the selection range. We remove trailing
177 // and leading whitespace characters to ensure the range is precisely bounded
178 // by AST nodes. Next, we determine the variables that will be the paramters
179 // and return values of the extracted function. Lastly, we construct the call
180 // of the function and insert this call as well as the extracted function into
181 // their proper locations.
182 func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
183 p, ok, err := canExtractFunction(fset, rng, src, file, info)
185 return nil, fmt.Errorf("extractFunction: cannot extract %s: %v",
186 fset.Position(rng.Start), err)
188 tok, path, rng, outer, start := p.tok, p.path, p.rng, p.outer, p.start
189 fileScope := info.Scopes[file]
190 if fileScope == nil {
191 return nil, fmt.Errorf("extractFunction: file scope is empty")
193 pkgScope := fileScope.Parent()
195 return nil, fmt.Errorf("extractFunction: package scope is empty")
198 // TODO: Support non-nested return statements.
199 // A return statement is non-nested if its parent node is equal to the parent node
200 // of the first node in the selection. These cases must be handled seperately because
201 // non-nested return statements are guaranteed to execute. Our control flow does not
202 // properly consider these situations yet.
203 var retStmts []*ast.ReturnStmt
204 var hasNonNestedReturn bool
205 startParent := findParent(outer, start)
206 ast.Inspect(outer, func(n ast.Node) bool {
210 if n.Pos() < rng.Start || n.End() > rng.End {
211 return n.Pos() <= rng.End
213 ret, ok := n.(*ast.ReturnStmt)
217 if findParent(outer, n) == startParent {
218 hasNonNestedReturn = true
221 retStmts = append(retStmts, ret)
224 if hasNonNestedReturn {
225 return nil, fmt.Errorf("extractFunction: selected block contains non-nested return")
227 containsReturnStatement := len(retStmts) > 0
229 // Now that we have determined the correct range for the selection block,
230 // we must determine the signature of the extracted function. We will then replace
231 // the block with an assignment statement that calls the extracted function with
232 // the appropriate parameters and return values.
233 variables, err := collectFreeVars(info, file, fileScope, pkgScope, rng, path[0])
239 params, returns []ast.Expr // used when calling the extracted function
240 paramTypes, returnTypes []*ast.Field // used in the signature of the extracted function
241 uninitialized []types.Object // vars we will need to initialize before the call
244 // Avoid duplicates while traversing vars and uninitialzed.
245 seenVars := make(map[types.Object]ast.Expr)
246 seenUninitialized := make(map[types.Object]struct{})
248 // Some variables on the left-hand side of our assignment statement may be free. If our
249 // selection begins in the same scope in which the free variable is defined, we can
250 // redefine it in our assignment statement. See the following example, where 'b' and
251 // 'err' (both free variables) can be redefined in the second funcCall() while maintaing
257 // a, err := funcCall()
259 // b, err = funcCall()
263 // a, err := funcCall()
264 // b, err := funcCall()
266 // We track the number of free variables that can be redefined to maintain our preference
267 // of using "x, y, z := fn()" style assignment statements.
268 var canRedefineCount int
270 // Each identifier in the selected block must become (1) a parameter to the
271 // extracted function, (2) a return value of the extracted function, or (3) a local
272 // variable in the extracted function. Determine the outcome(s) for each variable
273 // based on whether it is free, altered within the selected block, and used outside
274 // of the selected block.
275 for _, v := range variables {
276 if _, ok := seenVars[v.obj]; ok {
279 typ := analysisinternal.TypeExpr(fset, file, pkg, v.obj.Type())
281 return nil, fmt.Errorf("nil AST expression for type: %v", v.obj.Name())
283 seenVars[v.obj] = typ
284 identifier := ast.NewIdent(v.obj.Name())
285 // An identifier must meet three conditions to become a return value of the
286 // extracted function. (1) its value must be defined or reassigned within
287 // the selection (isAssigned), (2) it must be used at least once after the
288 // selection (isUsed), and (3) its first use after the selection
289 // cannot be its own reassignment or redefinition (objOverriden).
290 if v.obj.Parent() == nil {
291 return nil, fmt.Errorf("parent nil")
293 isUsed, firstUseAfter := objUsed(info, span.NewRange(fset, rng.End, v.obj.Parent().End()), v.obj)
294 if v.assigned && isUsed && !varOverridden(info, firstUseAfter, v.obj, v.free, outer) {
295 returnTypes = append(returnTypes, &ast.Field{Type: typ})
296 returns = append(returns, identifier)
298 uninitialized = append(uninitialized, v.obj)
299 } else if v.obj.Parent().Pos() == startParent.Pos() {
303 // An identifier must meet two conditions to become a parameter of the
304 // extracted function. (1) it must be free (isFree), and (2) its first
305 // use within the selection cannot be its own definition (isDefined).
306 if v.free && !v.defined {
307 params = append(params, identifier)
308 paramTypes = append(paramTypes, &ast.Field{
309 Names: []*ast.Ident{identifier},
315 // Find the function literal that encloses the selection. The enclosing function literal
316 // may not be the enclosing function declaration (i.e. 'outer'). For example, in the
320 // ast.Inspect(node, func(n ast.Node) bool {
321 // v := 1 // this line extracted
326 // 'outer' is main(). However, the extracted selection most directly belongs to
327 // the anonymous function literal, the second argument of ast.Inspect(). We use the
328 // enclosing function literal to determine the proper return types for return statements
329 // within the selection. We still need the enclosing function declaration because this is
330 // the top-level declaration. We inspect the top-level declaration to look for variables
331 // as well as for code replacement.
332 enclosing := outer.Type
333 for _, p := range path {
337 if fl, ok := p.(*ast.FuncLit); ok {
343 // We put the selection in a constructed file. We can then traverse and edit
344 // the extracted selection without modifying the original AST.
345 startOffset := tok.Offset(rng.Start)
346 endOffset := tok.Offset(rng.End)
347 selection := src[startOffset:endOffset]
348 extractedBlock, err := parseBlockStmt(fset, selection)
353 // We need to account for return statements in the selected block, as they will complicate
354 // the logical flow of the extracted function. See the following example, where ** denotes
355 // the range to be extracted.
373 // cond0, ret0 := x0(a, b)
380 // func x0(a int, b int) (bool, int) {
387 // We handle returns by adding an additional boolean return value to the extracted function.
388 // This bool reports whether the original function would have returned. Because the
389 // extracted selection contains a return statement, we must also add the types in the
390 // return signature of the enclosing function to the return signature of the
391 // extracted function. We then add an extra if statement checking this boolean value
392 // in the original function. If the condition is met, the original function should
393 // return a value, mimicking the functionality of the original return statement(s)
396 var retVars []*returnVariable
397 var ifReturn *ast.IfStmt
398 if containsReturnStatement {
399 // The selected block contained return statements, so we have to modify the
400 // signature of the extracted function as described above. Adjust all of
401 // the return statements in the extracted function to reflect this change in
403 if err := adjustReturnStatements(returnTypes, seenVars, fset, file,
404 pkg, extractedBlock); err != nil {
407 // Collect the additional return values and types needed to accomodate return
408 // statements in the selection. Update the type signature of the extracted
409 // function and construct the if statement that will be inserted in the enclosing
411 retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, fset, rng.Start)
417 // Add a return statement to the end of the new function. This return statement must include
418 // the values for the types of the original extracted function signature and (if a return
419 // statement is present in the selection) enclosing function signature.
420 hasReturnValues := len(returns)+len(retVars) > 0
422 extractedBlock.List = append(extractedBlock.List, &ast.ReturnStmt{
423 Results: append(returns, getZeroVals(retVars)...),
427 // Construct the appropriate call to the extracted function.
428 // We must meet two conditions to use ":=" instead of '='. (1) there must be at least
429 // one variable on the lhs that is uninitailized (non-free) prior to the assignment.
430 // (2) all of the initialized (free) variables on the lhs must be able to be redefined.
432 canDefineCount := len(uninitialized) + canRedefineCount
433 canDefine := len(uninitialized)+len(retVars) > 0 && canDefineCount == len(returns)
437 funName := generateAvailableIdentifier(rng.Start, file, path, info, "fn", 0)
438 extractedFunCall := generateFuncCall(hasReturnValues, params,
439 append(returns, getNames(retVars)...), funName, sym)
441 // Build the extracted function.
442 newFunc := &ast.FuncDecl{
443 Name: ast.NewIdent(funName),
445 Params: &ast.FieldList{List: paramTypes},
446 Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)},
448 Body: extractedBlock,
451 // Create variable declarations for any identifiers that need to be initialized prior to
452 // calling the extracted function. We do not manually initialize variables if every return
453 // value is unitialized. We can use := to initialize the variables in this situation.
454 var declarations []ast.Stmt
455 if canDefineCount != len(returns) {
456 declarations = initializeVars(uninitialized, retVars, seenUninitialized, seenVars)
459 var declBuf, replaceBuf, newFuncBuf, ifBuf bytes.Buffer
460 if err := format.Node(&declBuf, fset, declarations); err != nil {
463 if err := format.Node(&replaceBuf, fset, extractedFunCall); err != nil {
467 if err := format.Node(&ifBuf, fset, ifReturn); err != nil {
471 if err := format.Node(&newFuncBuf, fset, newFunc); err != nil {
475 // We're going to replace the whole enclosing function,
476 // so preserve the text before and after the selected block.
477 outerStart := tok.Offset(outer.Pos())
478 outerEnd := tok.Offset(outer.End())
479 before := src[outerStart:startOffset]
480 after := src[endOffset:outerEnd]
481 newLineIndent := "\n" + calculateIndentation(src, tok, start)
483 var fullReplacement strings.Builder
484 fullReplacement.Write(before)
485 if declBuf.Len() > 0 { // add any initializations, if needed
486 initializations := strings.ReplaceAll(declBuf.String(), "\n", newLineIndent) +
488 fullReplacement.WriteString(initializations)
490 fullReplacement.Write(replaceBuf.Bytes()) // call the extracted function
491 if ifBuf.Len() > 0 { // add the if statement below the function call, if needed
492 ifstatement := newLineIndent +
493 strings.ReplaceAll(ifBuf.String(), "\n", newLineIndent)
494 fullReplacement.WriteString(ifstatement)
496 fullReplacement.Write(after)
497 fullReplacement.WriteString("\n\n") // add newlines after the enclosing function
498 fullReplacement.Write(newFuncBuf.Bytes()) // insert the extracted function
500 return &analysis.SuggestedFix{
501 TextEdits: []analysis.TextEdit{{
504 NewText: []byte(fullReplacement.String()),
509 // adjustRangeForWhitespace adjusts the given range to exclude unnecessary leading or
510 // trailing whitespace characters from selection. In the following example, each line
511 // of the if statement is indented once. There are also two extra spaces after the
512 // closing bracket before the line break.
518 // By default, a valid range begins at 'if' and ends at the first whitespace character
519 // after the '}'. But, users are likely to highlight full lines rather than adjusting
520 // their cursors for whitespace. To support this use case, we must manually adjust the
521 // ranges to match the correct AST node. In this particular example, we would adjust
522 // rng.Start forward by one byte, and rng.End backwards by two bytes.
523 func adjustRangeForWhitespace(rng span.Range, tok *token.File, content []byte) span.Range {
524 offset := tok.Offset(rng.Start)
525 for offset < len(content) {
526 if !unicode.IsSpace(rune(content[offset])) {
529 // Move forwards one byte to find a non-whitespace character.
532 rng.Start = tok.Pos(offset)
534 // Move backwards to find a non-whitespace character.
535 offset = tok.Offset(rng.End)
536 for o := offset - 1; 0 <= o && o < len(content); o-- {
537 if !unicode.IsSpace(rune(content[o])) {
542 rng.End = tok.Pos(offset)
546 // findParent finds the parent AST node of the given target node, if the target is a
547 // descendant of the starting node.
548 func findParent(start ast.Node, target ast.Node) ast.Node {
550 analysisinternal.WalkASTWithParent(start, func(n, p ast.Node) bool {
560 // variable describes the status of a variable within a selection.
561 type variable struct {
564 // free reports whether the variable is a free variable, meaning it should
565 // be a parameter to the extracted function.
568 // assigned reports whether the variable is assigned to in the selection.
571 // defined reports whether the variable is defined in the selection.
575 // collectFreeVars maps each identifier in the given range to whether it is "free."
576 // Given a range, a variable in that range is defined as "free" if it is declared
577 // outside of the range and neither at the file scope nor package scope. These free
578 // variables will be used as arguments in the extracted function. It also returns a
579 // list of identifiers that may need to be returned by the extracted function.
580 // Some of the code in this function has been adapted from tools/cmd/guru/freevars.go.
581 func collectFreeVars(info *types.Info, file *ast.File, fileScope, pkgScope *types.Scope, rng span.Range, node ast.Node) ([]*variable, error) {
582 // id returns non-nil if n denotes an object that is referenced by the span
583 // and defined either within the span or in the lexical environment. The bool
584 // return value acts as an indicator for where it was defined.
585 id := func(n *ast.Ident) (types.Object, bool) {
588 return info.Defs[n], false
590 if obj.Name() == "_" {
591 return nil, false // exclude objects denoting '_'
593 if _, ok := obj.(*types.PkgName); ok {
594 return nil, false // imported package
596 if !(file.Pos() <= obj.Pos() && obj.Pos() <= file.End()) {
597 return nil, false // not defined in this file
599 scope := obj.Parent()
601 return nil, false // e.g. interface method, struct field
603 if scope == fileScope || scope == pkgScope {
604 return nil, false // defined at file or package scope
606 if rng.Start <= obj.Pos() && obj.Pos() <= rng.End {
607 return obj, false // defined within selection => not free
611 // sel returns non-nil if n denotes a selection o.x.y that is referenced by the
612 // span and defined either within the span or in the lexical environment. The bool
613 // return value acts as an indicator for where it was defined.
614 var sel func(n *ast.SelectorExpr) (types.Object, bool)
615 sel = func(n *ast.SelectorExpr) (types.Object, bool) {
616 switch x := astutil.Unparen(n.X).(type) {
617 case *ast.SelectorExpr:
624 seen := make(map[types.Object]*variable)
625 firstUseIn := make(map[types.Object]token.Pos)
626 var vars []types.Object
627 ast.Inspect(node, func(n ast.Node) bool {
631 if rng.Start <= n.Pos() && n.End() <= rng.End {
633 var isFree, prune bool
634 switch n := n.(type) {
637 case *ast.SelectorExpr:
642 seen[obj] = &variable{
646 vars = append(vars, obj)
647 // Find the first time that the object is used in the selection.
648 first, ok := firstUseIn[obj]
649 if !ok || n.Pos() < first {
650 firstUseIn[obj] = n.Pos()
657 return n.Pos() <= rng.End
660 // Find identifiers that are initialized or whose values are altered at some
661 // point in the selected block. For example, in a selected block from lines 2-4,
662 // variables x, y, and z are included in assigned. However, in a selected block
663 // from lines 3-4, only variables y and z are included in assigned.
670 ast.Inspect(node, func(n ast.Node) bool {
674 if n.Pos() < rng.Start || n.End() > rng.End {
675 return n.Pos() <= rng.End
677 switch n := n.(type) {
678 case *ast.AssignStmt:
679 for _, assignment := range n.Lhs {
680 lhs, ok := assignment.(*ast.Ident)
688 if _, ok := seen[obj]; !ok {
691 seen[obj].assigned = true
692 if n.Tok != token.DEFINE {
695 // Find identifiers that are defined prior to being used
696 // elsewhere in the selection.
697 // TODO: Include identifiers that are assigned prior to being
698 // used elsewhere in the selection. Then, change the assignment
699 // to a definition in the extracted function.
700 if firstUseIn[obj] != lhs.Pos() {
703 // Ensure that the object is not used in its own re-definition.
706 // f, e := math.Frexp(f)
707 for _, expr := range n.Rhs {
708 if referencesObj(info, expr, obj) {
711 if _, ok := seen[obj]; !ok {
714 seen[obj].defined = true
720 gen, ok := n.Decl.(*ast.GenDecl)
724 for _, spec := range gen.Specs {
725 vSpecs, ok := spec.(*ast.ValueSpec)
729 for _, vSpec := range vSpecs.Names {
734 if _, ok := seen[obj]; !ok {
737 seen[obj].assigned = true
741 case *ast.IncDecStmt:
742 if ident, ok := n.X.(*ast.Ident); !ok {
744 } else if obj, _ := id(ident); obj == nil {
747 if _, ok := seen[obj]; !ok {
750 seen[obj].assigned = true
755 var variables []*variable
756 for _, obj := range vars {
759 return nil, fmt.Errorf("no seen types.Object for %v", obj)
761 variables = append(variables, v)
763 return variables, nil
766 // referencesObj checks whether the given object appears in the given expression.
767 func referencesObj(info *types.Info, expr ast.Expr, obj types.Object) bool {
769 ast.Inspect(expr, func(n ast.Node) bool {
773 ident, ok := n.(*ast.Ident)
777 objUse := info.Uses[ident]
787 type fnExtractParams struct {
795 // canExtractFunction reports whether the code in the given range can be
796 // extracted to a function.
797 func canExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, _ *types.Info) (*fnExtractParams, bool, error) {
798 if rng.Start == rng.End {
799 return nil, false, fmt.Errorf("start and end are equal")
801 tok := fset.File(file.Pos())
803 return nil, false, fmt.Errorf("no file for pos %v", fset.Position(file.Pos()))
805 rng = adjustRangeForWhitespace(rng, tok, src)
806 path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
808 return nil, false, fmt.Errorf("no path enclosing interval")
810 // Node that encloses the selection must be a statement.
811 // TODO: Support function extraction for an expression.
812 _, ok := path[0].(ast.Stmt)
814 return nil, false, fmt.Errorf("node is not a statement")
817 // Find the function declaration that encloses the selection.
818 var outer *ast.FuncDecl
819 for _, p := range path {
820 if p, ok := p.(*ast.FuncDecl); ok {
826 return nil, false, fmt.Errorf("no enclosing function")
829 // Find the nodes at the start and end of the selection.
830 var start, end ast.Node
831 ast.Inspect(outer, func(n ast.Node) bool {
835 // Do not override 'start' with a node that begins at the same location
836 // but is nested further from 'outer'.
837 if start == nil && n.Pos() == rng.Start && n.End() <= rng.End {
840 if end == nil && n.End() == rng.End && n.Pos() >= rng.Start {
843 return n.Pos() <= rng.End
845 if start == nil || end == nil {
846 return nil, false, fmt.Errorf("range does not map to AST nodes")
848 return &fnExtractParams{
857 // objUsed checks if the object is used within the range. It returns the first occurence of
858 // the object in the range, if it exists.
859 func objUsed(info *types.Info, rng span.Range, obj types.Object) (bool, *ast.Ident) {
860 var firstUse *ast.Ident
861 for id, objUse := range info.Uses {
865 if id.Pos() < rng.Start || id.End() > rng.End {
868 if firstUse == nil || id.Pos() < firstUse.Pos() {
872 return firstUse != nil, firstUse
875 // varOverridden traverses the given AST node until we find the given identifier. Then, we
876 // examine the occurrence of the given identifier and check for (1) whether the identifier
877 // is being redefined. If the identifier is free, we also check for (2) whether the identifier
878 // is being reassigned. We will not include an identifier in the return statement of the
879 // extracted function if it meets one of the above conditions.
880 func varOverridden(info *types.Info, firstUse *ast.Ident, obj types.Object, isFree bool, node ast.Node) bool {
882 ast.Inspect(node, func(n ast.Node) bool {
886 assignment, ok := n.(*ast.AssignStmt)
890 // A free variable is initialized prior to the selection. We can always reassign
891 // this variable after the selection because it has already been defined.
892 // Conversely, a non-free variable is initialized within the selection. Thus, we
893 // cannot reassign this variable after the selection unless it is initialized and
894 // returned by the extracted function.
895 if !isFree && assignment.Tok == token.ASSIGN {
898 for _, assigned := range assignment.Lhs {
899 ident, ok := assigned.(*ast.Ident)
900 // Check if we found the first use of the identifier.
901 if !ok || ident != firstUse {
904 objUse := info.Uses[ident]
905 if objUse == nil || objUse != obj {
908 // Ensure that the object is not used in its own definition.
911 // f, e := math.Frexp(f)
912 for _, expr := range assignment.Rhs {
913 if referencesObj(info, expr, obj) {
925 // parseExtraction generates an AST file from the given text. We then return the portion of the
926 // file that represents the text.
927 func parseBlockStmt(fset *token.FileSet, src []byte) (*ast.BlockStmt, error) {
928 text := "package main\nfunc _() { " + string(src) + " }"
929 extract, err := parser.ParseFile(fset, "", text, 0)
933 if len(extract.Decls) == 0 {
934 return nil, fmt.Errorf("parsed file does not contain any declarations")
936 decl, ok := extract.Decls[0].(*ast.FuncDecl)
938 return nil, fmt.Errorf("parsed file does not contain expected function declaration")
940 if decl.Body == nil {
941 return nil, fmt.Errorf("extracted function has no body")
943 return decl.Body, nil
946 // generateReturnInfo generates the information we need to adjust the return statements and
947 // signature of the extracted function. We prepare names, signatures, and "zero values" that
948 // represent the new variables. We also use this information to construct the if statement that
949 // is inserted below the call to the extracted function.
950 func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, fset *token.FileSet, pos token.Pos) ([]*returnVariable, *ast.IfStmt, error) {
951 // Generate information for the added bool value.
952 cond := &ast.Ident{Name: generateAvailableIdentifier(pos, file, path, info, "cond", 0)}
953 retVars := []*returnVariable{
956 decl: &ast.Field{Type: ast.NewIdent("bool")},
957 zeroVal: ast.NewIdent("false"),
960 // Generate information for the values in the return signature of the enclosing function.
961 if enclosing.Results != nil {
962 for i, field := range enclosing.Results.List {
963 typ := info.TypeOf(field.Type)
965 return nil, nil, fmt.Errorf(
966 "failed type conversion, AST expression: %T", field.Type)
968 expr := analysisinternal.TypeExpr(fset, file, pkg, typ)
970 return nil, nil, fmt.Errorf("nil AST expression")
972 retVars = append(retVars, &returnVariable{
973 name: ast.NewIdent(generateAvailableIdentifier(pos, file,
974 path, info, "ret", i)),
975 decl: &ast.Field{Type: expr},
976 zeroVal: analysisinternal.ZeroValue(
977 fset, file, pkg, typ),
981 // Create the return statement for the enclosing function. We must exclude the variable
982 // for the condition of the if statement (cond) from the return statement.
983 ifReturn := &ast.IfStmt{
985 Body: &ast.BlockStmt{
986 List: []ast.Stmt{&ast.ReturnStmt{Results: getNames(retVars)[1:]}},
989 return retVars, ifReturn, nil
992 // adjustReturnStatements adds "zero values" of the given types to each return statement
993 // in the given AST node.
994 func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object]ast.Expr, fset *token.FileSet, file *ast.File, pkg *types.Package, extractedBlock *ast.BlockStmt) error {
995 var zeroVals []ast.Expr
996 // Create "zero values" for each type.
997 for _, returnType := range returnTypes {
999 for obj, typ := range seenVars {
1000 if typ != returnType.Type {
1003 val = analysisinternal.ZeroValue(fset, file, pkg, obj.Type())
1008 "could not find matching AST expression for %T", returnType.Type)
1010 zeroVals = append(zeroVals, val)
1012 // Add "zero values" to each return statement.
1013 // The bool reports whether the enclosing function should return after calling the
1014 // extracted function. We set the bool to 'true' because, if these return statements
1015 // execute, the extracted function terminates early, and the enclosing function must
1017 zeroVals = append(zeroVals, ast.NewIdent("true"))
1018 ast.Inspect(extractedBlock, func(n ast.Node) bool {
1022 if n, ok := n.(*ast.ReturnStmt); ok {
1023 n.Results = append(zeroVals, n.Results...)
1031 // generateFuncCall constructs a call expression for the extracted function, described by the
1032 // given parameters and return variables.
1033 func generateFuncCall(hasReturnVals bool, params, returns []ast.Expr, name string, token token.Token) ast.Node {
1034 var replace ast.Node
1036 callExpr := &ast.CallExpr{
1037 Fun: ast.NewIdent(name),
1040 replace = &ast.AssignStmt{
1043 Rhs: []ast.Expr{callExpr},
1046 replace = &ast.CallExpr{
1047 Fun: ast.NewIdent(name),
1054 // initializeVars creates variable declarations, if needed.
1055 // Our preference is to replace the selected block with an "x, y, z := fn()" style
1056 // assignment statement. We can use this style when all of the variables in the
1057 // extracted function's return statement are either not defined prior to the extracted block
1058 // or can be safely redefined. However, for example, if z is already defined
1059 // in a different scope, we replace the selected block with:
1064 func initializeVars(uninitialized []types.Object, retVars []*returnVariable, seenUninitialized map[types.Object]struct{}, seenVars map[types.Object]ast.Expr) []ast.Stmt {
1065 var declarations []ast.Stmt
1066 for _, obj := range uninitialized {
1067 if _, ok := seenUninitialized[obj]; ok {
1070 seenUninitialized[obj] = struct{}{}
1071 valSpec := &ast.ValueSpec{
1072 Names: []*ast.Ident{ast.NewIdent(obj.Name())},
1073 Type: seenVars[obj],
1075 genDecl := &ast.GenDecl{
1077 Specs: []ast.Spec{valSpec},
1079 declarations = append(declarations, &ast.DeclStmt{Decl: genDecl})
1081 // Each variable added from a return statement in the selection
1082 // must be initialized.
1083 for i, retVar := range retVars {
1084 n := retVar.name.(*ast.Ident)
1085 valSpec := &ast.ValueSpec{
1086 Names: []*ast.Ident{n},
1087 Type: retVars[i].decl.Type,
1089 genDecl := &ast.GenDecl{
1091 Specs: []ast.Spec{valSpec},
1093 declarations = append(declarations, &ast.DeclStmt{Decl: genDecl})
1098 // getNames returns the names from the given list of returnVariable.
1099 func getNames(retVars []*returnVariable) []ast.Expr {
1100 var names []ast.Expr
1101 for _, retVar := range retVars {
1102 names = append(names, retVar.name)
1107 // getZeroVals returns the "zero values" from the given list of returnVariable.
1108 func getZeroVals(retVars []*returnVariable) []ast.Expr {
1110 for _, retVar := range retVars {
1111 zvs = append(zvs, retVar.zeroVal)
1116 // getDecls returns the declarations from the given list of returnVariable.
1117 func getDecls(retVars []*returnVariable) []*ast.Field {
1118 var decls []*ast.Field
1119 for _, retVar := range retVars {
1120 decls = append(decls, retVar.decl)