// Copyright 2017 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package astutil import ( "fmt" "go/ast" "reflect" "sort" ) // An ApplyFunc is invoked by Apply for each node n, even if n is nil, // before and/or after the node's children, using a Cursor describing // the current node and providing operations on it. // // The return value of ApplyFunc controls the syntax tree traversal. // See Apply for details. type ApplyFunc func(*Cursor) bool // Apply traverses a syntax tree recursively, starting with root, // and calling pre and post for each node as described below. // Apply returns the syntax tree, possibly modified. // // If pre is not nil, it is called for each node before the node's // children are traversed (pre-order). If pre returns false, no // children are traversed, and post is not called for that node. // // If post is not nil, and a prior call of pre didn't return false, // post is called for each node after its children are traversed // (post-order). If post returns false, traversal is terminated and // Apply returns immediately. // // Only fields that refer to AST nodes are considered children; // i.e., token.Pos, Scopes, Objects, and fields of basic types // (strings, etc.) are ignored. // // Children are traversed in the order in which they appear in the // respective node's struct definition. A package's files are // traversed in the filenames' alphabetical order. // func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) { parent := &struct{ ast.Node }{root} defer func() { if r := recover(); r != nil && r != abort { panic(r) } result = parent.Node }() a := &application{pre: pre, post: post} a.apply(parent, "Node", nil, root) return } var abort = new(int) // singleton, to signal termination of Apply // A Cursor describes a node encountered during Apply. // Information about the node and its parent is available // from the Node, Parent, Name, and Index methods. // // If p is a variable of type and value of the current parent node // c.Parent(), and f is the field identifier with name c.Name(), // the following invariants hold: // // p.f == c.Node() if c.Index() < 0 // p.f[c.Index()] == c.Node() if c.Index() >= 0 // // The methods Replace, Delete, InsertBefore, and InsertAfter // can be used to change the AST without disrupting Apply. type Cursor struct { parent ast.Node name string iter *iterator // valid if non-nil node ast.Node } // Node returns the current Node. func (c *Cursor) Node() ast.Node { return c.node } // Parent returns the parent of the current Node. func (c *Cursor) Parent() ast.Node { return c.parent } // Name returns the name of the parent Node field that contains the current Node. // If the parent is a *ast.Package and the current Node is a *ast.File, Name returns // the filename for the current Node. func (c *Cursor) Name() string { return c.name } // Index reports the index >= 0 of the current Node in the slice of Nodes that // contains it, or a value < 0 if the current Node is not part of a slice. // The index of the current node changes if InsertBefore is called while // processing the current node. func (c *Cursor) Index() int { if c.iter != nil { return c.iter.index } return -1 } // field returns the current node's parent field value. func (c *Cursor) field() reflect.Value { return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name) } // Replace replaces the current Node with n. // The replacement node is not walked by Apply. func (c *Cursor) Replace(n ast.Node) { if _, ok := c.node.(*ast.File); ok { file, ok := n.(*ast.File) if !ok { panic("attempt to replace *ast.File with non-*ast.File") } c.parent.(*ast.Package).Files[c.name] = file return } v := c.field() if i := c.Index(); i >= 0 { v = v.Index(i) } v.Set(reflect.ValueOf(n)) } // Delete deletes the current Node from its containing slice. // If the current Node is not part of a slice, Delete panics. // As a special case, if the current node is a package file, // Delete removes it from the package's Files map. func (c *Cursor) Delete() { if _, ok := c.node.(*ast.File); ok { delete(c.parent.(*ast.Package).Files, c.name) return } i := c.Index() if i < 0 { panic("Delete node not contained in slice") } v := c.field() l := v.Len() reflect.Copy(v.Slice(i, l), v.Slice(i+1, l)) v.Index(l - 1).Set(reflect.Zero(v.Type().Elem())) v.SetLen(l - 1) c.iter.step-- } // InsertAfter inserts n after the current Node in its containing slice. // If the current Node is not part of a slice, InsertAfter panics. // Apply does not walk n. func (c *Cursor) InsertAfter(n ast.Node) { i := c.Index() if i < 0 { panic("InsertAfter node not contained in slice") } v := c.field() v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem()))) l := v.Len() reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l)) v.Index(i + 1).Set(reflect.ValueOf(n)) c.iter.step++ } // InsertBefore inserts n before the current Node in its containing slice. // If the current Node is not part of a slice, InsertBefore panics. // Apply will not walk n. func (c *Cursor) InsertBefore(n ast.Node) { i := c.Index() if i < 0 { panic("InsertBefore node not contained in slice") } v := c.field() v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem()))) l := v.Len() reflect.Copy(v.Slice(i+1, l), v.Slice(i, l)) v.Index(i).Set(reflect.ValueOf(n)) c.iter.index++ } // application carries all the shared data so we can pass it around cheaply. type application struct { pre, post ApplyFunc cursor Cursor iter iterator } func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) { // convert typed nil into untyped nil if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() { n = nil } // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead saved := a.cursor a.cursor.parent = parent a.cursor.name = name a.cursor.iter = iter a.cursor.node = n if a.pre != nil && !a.pre(&a.cursor) { a.cursor = saved return } // walk children // (the order of the cases matches the order of the corresponding node types in go/ast) switch n := n.(type) { case nil: // nothing to do // Comments and fields case *ast.Comment: // nothing to do case *ast.CommentGroup: if n != nil { a.applyList(n, "List") } case *ast.Field: a.apply(n, "Doc", nil, n.Doc) a.applyList(n, "Names") a.apply(n, "Type", nil, n.Type) a.apply(n, "Tag", nil, n.Tag) a.apply(n, "Comment", nil, n.Comment) case *ast.FieldList: a.applyList(n, "List") // Expressions case *ast.BadExpr, *ast.Ident, *ast.BasicLit: // nothing to do case *ast.Ellipsis: a.apply(n, "Elt", nil, n.Elt) case *ast.FuncLit: a.apply(n, "Type", nil, n.Type) a.apply(n, "Body", nil, n.Body) case *ast.CompositeLit: a.apply(n, "Type", nil, n.Type) a.applyList(n, "Elts") case *ast.ParenExpr: a.apply(n, "X", nil, n.X) case *ast.SelectorExpr: a.apply(n, "X", nil, n.X) a.apply(n, "Sel", nil, n.Sel) case *ast.IndexExpr: a.apply(n, "X", nil, n.X) a.apply(n, "Index", nil, n.Index) case *ast.SliceExpr: a.apply(n, "X", nil, n.X) a.apply(n, "Low", nil, n.Low) a.apply(n, "High", nil, n.High) a.apply(n, "Max", nil, n.Max) case *ast.TypeAssertExpr: a.apply(n, "X", nil, n.X) a.apply(n, "Type", nil, n.Type) case *ast.CallExpr: a.apply(n, "Fun", nil, n.Fun) a.applyList(n, "Args") case *ast.StarExpr: a.apply(n, "X", nil, n.X) case *ast.UnaryExpr: a.apply(n, "X", nil, n.X) case *ast.BinaryExpr: a.apply(n, "X", nil, n.X) a.apply(n, "Y", nil, n.Y) case *ast.KeyValueExpr: a.apply(n, "Key", nil, n.Key) a.apply(n, "Value", nil, n.Value) // Types case *ast.ArrayType: a.apply(n, "Len", nil, n.Len) a.apply(n, "Elt", nil, n.Elt) case *ast.StructType: a.apply(n, "Fields", nil, n.Fields) case *ast.FuncType: a.apply(n, "Params", nil, n.Params) a.apply(n, "Results", nil, n.Results) case *ast.InterfaceType: a.apply(n, "Methods", nil, n.Methods) case *ast.MapType: a.apply(n, "Key", nil, n.Key) a.apply(n, "Value", nil, n.Value) case *ast.ChanType: a.apply(n, "Value", nil, n.Value) // Statements case *ast.BadStmt: // nothing to do case *ast.DeclStmt: a.apply(n, "Decl", nil, n.Decl) case *ast.EmptyStmt: // nothing to do case *ast.LabeledStmt: a.apply(n, "Label", nil, n.Label) a.apply(n, "Stmt", nil, n.Stmt) case *ast.ExprStmt: a.apply(n, "X", nil, n.X) case *ast.SendStmt: a.apply(n, "Chan", nil, n.Chan) a.apply(n, "Value", nil, n.Value) case *ast.IncDecStmt: a.apply(n, "X", nil, n.X) case *ast.AssignStmt: a.applyList(n, "Lhs") a.applyList(n, "Rhs") case *ast.GoStmt: a.apply(n, "Call", nil, n.Call) case *ast.DeferStmt: a.apply(n, "Call", nil, n.Call) case *ast.ReturnStmt: a.applyList(n, "Results") case *ast.BranchStmt: a.apply(n, "Label", nil, n.Label) case *ast.BlockStmt: a.applyList(n, "List") case *ast.IfStmt: a.apply(n, "Init", nil, n.Init) a.apply(n, "Cond", nil, n.Cond) a.apply(n, "Body", nil, n.Body) a.apply(n, "Else", nil, n.Else) case *ast.CaseClause: a.applyList(n, "List") a.applyList(n, "Body") case *ast.SwitchStmt: a.apply(n, "Init", nil, n.Init) a.apply(n, "Tag", nil, n.Tag) a.apply(n, "Body", nil, n.Body) case *ast.TypeSwitchStmt: a.apply(n, "Init", nil, n.Init) a.apply(n, "Assign", nil, n.Assign) a.apply(n, "Body", nil, n.Body) case *ast.CommClause: a.apply(n, "Comm", nil, n.Comm) a.applyList(n, "Body") case *ast.SelectStmt: a.apply(n, "Body", nil, n.Body) case *ast.ForStmt: a.apply(n, "Init", nil, n.Init) a.apply(n, "Cond", nil, n.Cond) a.apply(n, "Post", nil, n.Post) a.apply(n, "Body", nil, n.Body) case *ast.RangeStmt: a.apply(n, "Key", nil, n.Key) a.apply(n, "Value", nil, n.Value) a.apply(n, "X", nil, n.X) a.apply(n, "Body", nil, n.Body) // Declarations case *ast.ImportSpec: a.apply(n, "Doc", nil, n.Doc) a.apply(n, "Name", nil, n.Name) a.apply(n, "Path", nil, n.Path) a.apply(n, "Comment", nil, n.Comment) case *ast.ValueSpec: a.apply(n, "Doc", nil, n.Doc) a.applyList(n, "Names") a.apply(n, "Type", nil, n.Type) a.applyList(n, "Values") a.apply(n, "Comment", nil, n.Comment) case *ast.TypeSpec: a.apply(n, "Doc", nil, n.Doc) a.apply(n, "Name", nil, n.Name) a.apply(n, "Type", nil, n.Type) a.apply(n, "Comment", nil, n.Comment) case *ast.BadDecl: // nothing to do case *ast.GenDecl: a.apply(n, "Doc", nil, n.Doc) a.applyList(n, "Specs") case *ast.FuncDecl: a.apply(n, "Doc", nil, n.Doc) a.apply(n, "Recv", nil, n.Recv) a.apply(n, "Name", nil, n.Name) a.apply(n, "Type", nil, n.Type) a.apply(n, "Body", nil, n.Body) // Files and packages case *ast.File: a.apply(n, "Doc", nil, n.Doc) a.apply(n, "Name", nil, n.Name) a.applyList(n, "Decls") // Don't walk n.Comments; they have either been walked already if // they are Doc comments, or they can be easily walked explicitly. case *ast.Package: // collect and sort names for reproducible behavior var names []string for name := range n.Files { names = append(names, name) } sort.Strings(names) for _, name := range names { a.apply(n, name, nil, n.Files[name]) } default: panic(fmt.Sprintf("Apply: unexpected node type %T", n)) } if a.post != nil && !a.post(&a.cursor) { panic(abort) } a.cursor = saved } // An iterator controls iteration over a slice of nodes. type iterator struct { index, step int } func (a *application) applyList(parent ast.Node, name string) { // avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead saved := a.iter a.iter.index = 0 for { // must reload parent.name each time, since cursor modifications might change it v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name) if a.iter.index >= v.Len() { break } // element x may be nil in a bad AST - be cautious var x ast.Node if e := v.Index(a.iter.index); e.IsValid() { x = e.Interface().(ast.Node) } a.iter.step = 1 a.apply(parent, name, &a.iter, x) a.iter.index += a.iter.step } a.iter = saved }