1 // Copyright 2014 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
7 // This file defines the AST rewriting pass.
8 // Most of it was plundered directly from
9 // $GOROOT/src/cmd/gofmt/rewrite.go (after convergent evolution).
22 "golang.org/x/tools/go/ast/astutil"
25 // transformItem takes a reflect.Value representing a variable of type ast.Node
26 // transforms its child elements recursively with apply, and then transforms the
27 // actual element if it contains an expression.
28 func (tr *Transformer) transformItem(rv reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
29 // don't bother if val is invalid to start with
31 return reflect.Value{}, false, nil
34 rv, changed, newEnv := tr.apply(tr.transformItem, rv)
38 return rv, changed, newEnv
42 tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs
44 if tr.matchExpr(tr.before, e) {
46 fmt.Fprintf(os.Stderr, "%s matches %s",
47 astString(tr.fset, tr.before), astString(tr.fset, e))
49 fmt.Fprintf(os.Stderr, " with:")
50 for name, ast := range tr.env {
51 fmt.Fprintf(os.Stderr, " %s->%s",
52 name, astString(tr.fset, ast))
55 fmt.Fprintf(os.Stderr, "\n")
59 // Clone the replacement tree, performing parameter substitution.
60 // We update all positions to n.Pos() to aid comment placement.
61 rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
62 reflect.ValueOf(e.Pos()))
68 return rv, changed, newEnv
71 // Transform applies the transformation to the specified parsed file,
72 // whose type information is supplied in info, and returns the number
73 // of replacements that were made.
75 // It mutates the AST in place (the identity of the root node is
76 // unchanged), and may add nodes for which no type information is
79 // Derived from rewriteFile in $GOROOT/src/cmd/gofmt/rewrite.go.
81 func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int {
82 if !tr.seenInfos[info] {
83 tr.seenInfos[info] = true
84 mergeTypeInfo(tr.info, info)
90 fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before))
91 fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after))
92 fmt.Fprintf(os.Stderr, "afterStmts: %s\n", tr.afterStmts)
95 o, changed, _ := tr.apply(tr.transformItem, reflect.ValueOf(file))
99 file2 := o.Interface().(*ast.File)
101 // By construction, the root node is unchanged.
106 // Add any necessary imports.
107 // TODO(adonovan): remove no-longer needed imports too.
109 pkgs := make(map[string]*types.Package)
110 for obj := range tr.importedObjs {
111 pkgs[obj.Pkg().Path()] = obj.Pkg()
114 for _, imp := range file.Imports {
115 path, _ := strconv.Unquote(imp.Path.Value)
118 delete(pkgs, pkg.Path()) // don't import self
120 // NB: AddImport may completely replace the AST!
121 // It thus renders info and tr.info no longer relevant to file.
123 for path := range pkgs {
124 paths = append(paths, path)
127 for _, path := range paths {
128 astutil.AddImport(tr.fset, file, path)
137 // setValue is a wrapper for x.SetValue(y); it protects
138 // the caller from panics if x cannot be changed to y.
139 func setValue(x, y reflect.Value) {
140 // don't bother if y is invalid to start with
145 if x := recover(); x != nil {
146 if s, ok := x.(string); ok &&
147 (strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
148 // x cannot be set to y - ignore this rewrite
157 // Values/types for special cases.
159 objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
160 scopePtrNil = reflect.ValueOf((*ast.Scope)(nil))
162 identType = reflect.TypeOf((*ast.Ident)(nil))
163 selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil))
164 objectPtrType = reflect.TypeOf((*ast.Object)(nil))
165 statementType = reflect.TypeOf((*ast.Stmt)(nil)).Elem()
166 positionType = reflect.TypeOf(token.NoPos)
167 scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
170 // apply replaces each AST field x in val with f(x), returning val.
171 // To avoid extra conversions, f operates on the reflect.Value form.
172 // f takes a reflect.Value representing the variable to modify of type ast.Node.
173 // It returns a reflect.Value containing the transformed value of type ast.Node,
174 // whether any change was made, and a map of identifiers to ast.Expr (so we can
175 // do contextually correct substitutions in the parent statements).
176 func (tr *Transformer) apply(f func(reflect.Value) (reflect.Value, bool, map[string]ast.Expr), val reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
178 return reflect.Value{}, false, nil
181 // *ast.Objects introduce cycles and are likely incorrect after
182 // rewrite; don't follow them but replace with nil instead
183 if val.Type() == objectPtrType {
184 return objectPtrNil, false, nil
187 // similarly for scopes: they are likely incorrect after a rewrite;
188 // replace them with nil
189 if val.Type() == scopePtrType {
190 return scopePtrNil, false, nil
193 switch v := reflect.Indirect(val); v.Kind() {
195 // no possible rewriting of statements.
196 if v.Type().Elem() != statementType {
198 var envp map[string]ast.Expr
199 for i := 0; i < v.Len(); i++ {
201 o, localchanged, env := f(e)
204 // we clobber envp here,
205 // which means if we have two successive
206 // replacements inside the same statement
207 // we will only generate the setup for one of them.
212 return val, changed, envp
215 // statements are rewritten.
217 for i := 0; i < v.Len(); i++ {
219 o, changed, env := f(e)
221 for _, s := range tr.afterStmts {
222 t := tr.subst(env, reflect.ValueOf(s), reflect.Value{}).Interface()
223 out = append(out, t.(ast.Stmt))
227 out = append(out, e.Interface().(ast.Stmt))
229 return reflect.ValueOf(out), false, nil
232 var envp map[string]ast.Expr
233 for i := 0; i < v.NumField(); i++ {
235 o, localchanged, env := f(e)
242 return val, changed, envp
243 case reflect.Interface:
245 o, changed, env := f(e)
247 return val, changed, env
249 return val, false, nil
252 // subst returns a copy of (replacement) pattern with values from env
253 // substituted in place of wildcards and pos used as the position of
254 // tokens from the pattern. if env == nil, subst returns a copy of
255 // pattern and doesn't change the line number information.
256 func (tr *Transformer) subst(env map[string]ast.Expr, pattern, pos reflect.Value) reflect.Value {
257 if !pattern.IsValid() {
258 return reflect.Value{}
261 // *ast.Objects introduce cycles and are likely incorrect after
262 // rewrite; don't follow them but replace with nil instead
263 if pattern.Type() == objectPtrType {
267 // similarly for scopes: they are likely incorrect after a rewrite;
268 // replace them with nil
269 if pattern.Type() == scopePtrType {
273 // Wildcard gets replaced with map value.
274 if env != nil && pattern.Type() == identType {
275 id := pattern.Interface().(*ast.Ident)
276 if old, ok := env[id.Name]; ok {
277 return tr.subst(nil, reflect.ValueOf(old), reflect.Value{})
281 // Emit qualified identifiers in the pattern by appropriate
282 // (possibly qualified) identifier in the input.
284 // The template cannot contain dot imports, so all identifiers
285 // for imported objects are explicitly qualified.
287 // We assume (unsoundly) that there are no dot or named
288 // imports in the input code, nor are any imported package
289 // names shadowed, so the usual normal qualified identifier
290 // syntax may be used.
291 // TODO(adonovan): fix: avoid this assumption.
293 // A refactoring may be applied to a package referenced by the
294 // template. Objects belonging to the current package are
295 // denoted by unqualified identifiers.
297 if tr.importedObjs != nil && pattern.Type() == selectorExprType {
298 obj := isRef(pattern.Interface().(*ast.SelectorExpr), tr.info)
300 if sel, ok := tr.importedObjs[obj]; ok {
302 if obj.Pkg() == tr.currentPkg {
303 id = sel.Sel // unqualified
305 id = sel // pkg-qualified
308 // Return a clone of id.
309 saved := tr.importedObjs
310 tr.importedObjs = nil // break cycle
311 r := tr.subst(nil, reflect.ValueOf(id), pos)
312 tr.importedObjs = saved
318 if pos.IsValid() && pattern.Type() == positionType {
319 // use new position only if old position was valid in the first place
320 if old := pattern.Interface().(token.Pos); !old.IsValid() {
327 switch p := pattern; p.Kind() {
329 v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
330 for i := 0; i < p.Len(); i++ {
331 v.Index(i).Set(tr.subst(env, p.Index(i), pos))
336 v := reflect.New(p.Type()).Elem()
337 for i := 0; i < p.NumField(); i++ {
338 v.Field(i).Set(tr.subst(env, p.Field(i), pos))
343 v := reflect.New(p.Type()).Elem()
344 if elem := p.Elem(); elem.IsValid() {
345 v.Set(tr.subst(env, elem, pos).Addr())
348 // Duplicate type information for duplicated ast.Expr.
349 // All ast.Node implementations are *structs,
350 // so this case catches them all.
351 if e := rvToExpr(v); e != nil {
352 updateTypeInfo(tr.info, e, p.Interface().(ast.Expr))
356 case reflect.Interface:
357 v := reflect.New(p.Type()).Elem()
358 if elem := p.Elem(); elem.IsValid() {
359 v.Set(tr.subst(env, elem, pos))
367 // -- utilities -------------------------------------------------------
369 func rvToExpr(rv reflect.Value) ast.Expr {
370 if rv.CanInterface() {
371 if e, ok := rv.Interface().(ast.Expr); ok {
378 // updateTypeInfo duplicates type information for the existing AST old
379 // so that it also applies to duplicated AST new.
380 func updateTypeInfo(info *types.Info, new, old ast.Expr) {
381 switch new := new.(type) {
383 orig := old.(*ast.Ident)
384 if obj, ok := info.Defs[orig]; ok {
387 if obj, ok := info.Uses[orig]; ok {
391 case *ast.SelectorExpr:
392 orig := old.(*ast.SelectorExpr)
393 if sel, ok := info.Selections[orig]; ok {
394 info.Selections[new] = sel
398 if tv, ok := info.Types[old]; ok {