// Copyright 2014 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package eg // This file defines the AST rewriting pass. // Most of it was plundered directly from // $GOROOT/src/cmd/gofmt/rewrite.go (after convergent evolution). import ( "fmt" "go/ast" "go/token" "go/types" "os" "reflect" "sort" "strconv" "strings" "golang.org/x/tools/go/ast/astutil" ) // transformItem takes a reflect.Value representing a variable of type ast.Node // transforms its child elements recursively with apply, and then transforms the // actual element if it contains an expression. func (tr *Transformer) transformItem(rv reflect.Value) (reflect.Value, bool, map[string]ast.Expr) { // don't bother if val is invalid to start with if !rv.IsValid() { return reflect.Value{}, false, nil } rv, changed, newEnv := tr.apply(tr.transformItem, rv) e := rvToExpr(rv) if e == nil { return rv, changed, newEnv } savedEnv := tr.env tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs if tr.matchExpr(tr.before, e) { if tr.verbose { fmt.Fprintf(os.Stderr, "%s matches %s", astString(tr.fset, tr.before), astString(tr.fset, e)) if len(tr.env) > 0 { fmt.Fprintf(os.Stderr, " with:") for name, ast := range tr.env { fmt.Fprintf(os.Stderr, " %s->%s", name, astString(tr.fset, ast)) } } fmt.Fprintf(os.Stderr, "\n") } tr.nsubsts++ // Clone the replacement tree, performing parameter substitution. // We update all positions to n.Pos() to aid comment placement. rv = tr.subst(tr.env, reflect.ValueOf(tr.after), reflect.ValueOf(e.Pos())) changed = true newEnv = tr.env } tr.env = savedEnv return rv, changed, newEnv } // Transform applies the transformation to the specified parsed file, // whose type information is supplied in info, and returns the number // of replacements that were made. // // It mutates the AST in place (the identity of the root node is // unchanged), and may add nodes for which no type information is // available in info. // // Derived from rewriteFile in $GOROOT/src/cmd/gofmt/rewrite.go. // func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int { if !tr.seenInfos[info] { tr.seenInfos[info] = true mergeTypeInfo(tr.info, info) } tr.currentPkg = pkg tr.nsubsts = 0 if tr.verbose { fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before)) fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after)) fmt.Fprintf(os.Stderr, "afterStmts: %s\n", tr.afterStmts) } o, changed, _ := tr.apply(tr.transformItem, reflect.ValueOf(file)) if changed { panic("BUG") } file2 := o.Interface().(*ast.File) // By construction, the root node is unchanged. if file != file2 { panic("BUG") } // Add any necessary imports. // TODO(adonovan): remove no-longer needed imports too. if tr.nsubsts > 0 { pkgs := make(map[string]*types.Package) for obj := range tr.importedObjs { pkgs[obj.Pkg().Path()] = obj.Pkg() } for _, imp := range file.Imports { path, _ := strconv.Unquote(imp.Path.Value) delete(pkgs, path) } delete(pkgs, pkg.Path()) // don't import self // NB: AddImport may completely replace the AST! // It thus renders info and tr.info no longer relevant to file. var paths []string for path := range pkgs { paths = append(paths, path) } sort.Strings(paths) for _, path := range paths { astutil.AddImport(tr.fset, file, path) } } tr.currentPkg = nil return tr.nsubsts } // setValue is a wrapper for x.SetValue(y); it protects // the caller from panics if x cannot be changed to y. func setValue(x, y reflect.Value) { // don't bother if y is invalid to start with if !y.IsValid() { return } defer func() { if x := recover(); x != nil { if s, ok := x.(string); ok && (strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) { // x cannot be set to y - ignore this rewrite return } panic(x) } }() x.Set(y) } // Values/types for special cases. var ( objectPtrNil = reflect.ValueOf((*ast.Object)(nil)) scopePtrNil = reflect.ValueOf((*ast.Scope)(nil)) identType = reflect.TypeOf((*ast.Ident)(nil)) selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil)) objectPtrType = reflect.TypeOf((*ast.Object)(nil)) statementType = reflect.TypeOf((*ast.Stmt)(nil)).Elem() positionType = reflect.TypeOf(token.NoPos) scopePtrType = reflect.TypeOf((*ast.Scope)(nil)) ) // apply replaces each AST field x in val with f(x), returning val. // To avoid extra conversions, f operates on the reflect.Value form. // f takes a reflect.Value representing the variable to modify of type ast.Node. // It returns a reflect.Value containing the transformed value of type ast.Node, // whether any change was made, and a map of identifiers to ast.Expr (so we can // do contextually correct substitutions in the parent statements). func (tr *Transformer) apply(f func(reflect.Value) (reflect.Value, bool, map[string]ast.Expr), val reflect.Value) (reflect.Value, bool, map[string]ast.Expr) { if !val.IsValid() { return reflect.Value{}, false, nil } // *ast.Objects introduce cycles and are likely incorrect after // rewrite; don't follow them but replace with nil instead if val.Type() == objectPtrType { return objectPtrNil, false, nil } // similarly for scopes: they are likely incorrect after a rewrite; // replace them with nil if val.Type() == scopePtrType { return scopePtrNil, false, nil } switch v := reflect.Indirect(val); v.Kind() { case reflect.Slice: // no possible rewriting of statements. if v.Type().Elem() != statementType { changed := false var envp map[string]ast.Expr for i := 0; i < v.Len(); i++ { e := v.Index(i) o, localchanged, env := f(e) if localchanged { changed = true // we clobber envp here, // which means if we have two successive // replacements inside the same statement // we will only generate the setup for one of them. envp = env } setValue(e, o) } return val, changed, envp } // statements are rewritten. var out []ast.Stmt for i := 0; i < v.Len(); i++ { e := v.Index(i) o, changed, env := f(e) if changed { for _, s := range tr.afterStmts { t := tr.subst(env, reflect.ValueOf(s), reflect.Value{}).Interface() out = append(out, t.(ast.Stmt)) } } setValue(e, o) out = append(out, e.Interface().(ast.Stmt)) } return reflect.ValueOf(out), false, nil case reflect.Struct: changed := false var envp map[string]ast.Expr for i := 0; i < v.NumField(); i++ { e := v.Field(i) o, localchanged, env := f(e) if localchanged { changed = true envp = env } setValue(e, o) } return val, changed, envp case reflect.Interface: e := v.Elem() o, changed, env := f(e) setValue(v, o) return val, changed, env } return val, false, nil } // subst returns a copy of (replacement) pattern with values from env // substituted in place of wildcards and pos used as the position of // tokens from the pattern. if env == nil, subst returns a copy of // pattern and doesn't change the line number information. func (tr *Transformer) subst(env map[string]ast.Expr, pattern, pos reflect.Value) reflect.Value { if !pattern.IsValid() { return reflect.Value{} } // *ast.Objects introduce cycles and are likely incorrect after // rewrite; don't follow them but replace with nil instead if pattern.Type() == objectPtrType { return objectPtrNil } // similarly for scopes: they are likely incorrect after a rewrite; // replace them with nil if pattern.Type() == scopePtrType { return scopePtrNil } // Wildcard gets replaced with map value. if env != nil && pattern.Type() == identType { id := pattern.Interface().(*ast.Ident) if old, ok := env[id.Name]; ok { return tr.subst(nil, reflect.ValueOf(old), reflect.Value{}) } } // Emit qualified identifiers in the pattern by appropriate // (possibly qualified) identifier in the input. // // The template cannot contain dot imports, so all identifiers // for imported objects are explicitly qualified. // // We assume (unsoundly) that there are no dot or named // imports in the input code, nor are any imported package // names shadowed, so the usual normal qualified identifier // syntax may be used. // TODO(adonovan): fix: avoid this assumption. // // A refactoring may be applied to a package referenced by the // template. Objects belonging to the current package are // denoted by unqualified identifiers. // if tr.importedObjs != nil && pattern.Type() == selectorExprType { obj := isRef(pattern.Interface().(*ast.SelectorExpr), tr.info) if obj != nil { if sel, ok := tr.importedObjs[obj]; ok { var id ast.Expr if obj.Pkg() == tr.currentPkg { id = sel.Sel // unqualified } else { id = sel // pkg-qualified } // Return a clone of id. saved := tr.importedObjs tr.importedObjs = nil // break cycle r := tr.subst(nil, reflect.ValueOf(id), pos) tr.importedObjs = saved return r } } } if pos.IsValid() && pattern.Type() == positionType { // use new position only if old position was valid in the first place if old := pattern.Interface().(token.Pos); !old.IsValid() { return pattern } return pos } // Otherwise copy. switch p := pattern; p.Kind() { case reflect.Slice: v := reflect.MakeSlice(p.Type(), p.Len(), p.Len()) for i := 0; i < p.Len(); i++ { v.Index(i).Set(tr.subst(env, p.Index(i), pos)) } return v case reflect.Struct: v := reflect.New(p.Type()).Elem() for i := 0; i < p.NumField(); i++ { v.Field(i).Set(tr.subst(env, p.Field(i), pos)) } return v case reflect.Ptr: v := reflect.New(p.Type()).Elem() if elem := p.Elem(); elem.IsValid() { v.Set(tr.subst(env, elem, pos).Addr()) } // Duplicate type information for duplicated ast.Expr. // All ast.Node implementations are *structs, // so this case catches them all. if e := rvToExpr(v); e != nil { updateTypeInfo(tr.info, e, p.Interface().(ast.Expr)) } return v case reflect.Interface: v := reflect.New(p.Type()).Elem() if elem := p.Elem(); elem.IsValid() { v.Set(tr.subst(env, elem, pos)) } return v } return pattern } // -- utilities ------------------------------------------------------- func rvToExpr(rv reflect.Value) ast.Expr { if rv.CanInterface() { if e, ok := rv.Interface().(ast.Expr); ok { return e } } return nil } // updateTypeInfo duplicates type information for the existing AST old // so that it also applies to duplicated AST new. func updateTypeInfo(info *types.Info, new, old ast.Expr) { switch new := new.(type) { case *ast.Ident: orig := old.(*ast.Ident) if obj, ok := info.Defs[orig]; ok { info.Defs[new] = obj } if obj, ok := info.Uses[orig]; ok { info.Uses[new] = obj } case *ast.SelectorExpr: orig := old.(*ast.SelectorExpr) if sel, ok := info.Selections[orig]; ok { info.Selections[new] = sel } } if tv, ok := info.Types[old]; ok { info.Types[new] = tv } }