+++ /dev/null
-// Copyright 2014 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package eg
-
-// This file defines the AST rewriting pass.
-// Most of it was plundered directly from
-// $GOROOT/src/cmd/gofmt/rewrite.go (after convergent evolution).
-
-import (
- "fmt"
- "go/ast"
- "go/token"
- "go/types"
- "os"
- "reflect"
- "sort"
- "strconv"
- "strings"
-
- "golang.org/x/tools/go/ast/astutil"
-)
-
-// transformItem takes a reflect.Value representing a variable of type ast.Node
-// transforms its child elements recursively with apply, and then transforms the
-// actual element if it contains an expression.
-func (tr *Transformer) transformItem(rv reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
- // don't bother if val is invalid to start with
- if !rv.IsValid() {
- return reflect.Value{}, false, nil
- }
-
- rv, changed, newEnv := tr.apply(tr.transformItem, rv)
-
- e := rvToExpr(rv)
- if e == nil {
- return rv, changed, newEnv
- }
-
- savedEnv := tr.env
- tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs
-
- if tr.matchExpr(tr.before, e) {
- if tr.verbose {
- fmt.Fprintf(os.Stderr, "%s matches %s",
- astString(tr.fset, tr.before), astString(tr.fset, e))
- if len(tr.env) > 0 {
- fmt.Fprintf(os.Stderr, " with:")
- for name, ast := range tr.env {
- fmt.Fprintf(os.Stderr, " %s->%s",
- name, astString(tr.fset, ast))
- }
- }
- fmt.Fprintf(os.Stderr, "\n")
- }
- tr.nsubsts++
-
- // Clone the replacement tree, performing parameter substitution.
- // We update all positions to n.Pos() to aid comment placement.
- rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
- reflect.ValueOf(e.Pos()))
- changed = true
- newEnv = tr.env
- }
- tr.env = savedEnv
-
- return rv, changed, newEnv
-}
-
-// Transform applies the transformation to the specified parsed file,
-// whose type information is supplied in info, and returns the number
-// of replacements that were made.
-//
-// It mutates the AST in place (the identity of the root node is
-// unchanged), and may add nodes for which no type information is
-// available in info.
-//
-// Derived from rewriteFile in $GOROOT/src/cmd/gofmt/rewrite.go.
-//
-func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int {
- if !tr.seenInfos[info] {
- tr.seenInfos[info] = true
- mergeTypeInfo(tr.info, info)
- }
- tr.currentPkg = pkg
- tr.nsubsts = 0
-
- if tr.verbose {
- fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before))
- fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after))
- fmt.Fprintf(os.Stderr, "afterStmts: %s\n", tr.afterStmts)
- }
-
- o, changed, _ := tr.apply(tr.transformItem, reflect.ValueOf(file))
- if changed {
- panic("BUG")
- }
- file2 := o.Interface().(*ast.File)
-
- // By construction, the root node is unchanged.
- if file != file2 {
- panic("BUG")
- }
-
- // Add any necessary imports.
- // TODO(adonovan): remove no-longer needed imports too.
- if tr.nsubsts > 0 {
- pkgs := make(map[string]*types.Package)
- for obj := range tr.importedObjs {
- pkgs[obj.Pkg().Path()] = obj.Pkg()
- }
-
- for _, imp := range file.Imports {
- path, _ := strconv.Unquote(imp.Path.Value)
- delete(pkgs, path)
- }
- delete(pkgs, pkg.Path()) // don't import self
-
- // NB: AddImport may completely replace the AST!
- // It thus renders info and tr.info no longer relevant to file.
- var paths []string
- for path := range pkgs {
- paths = append(paths, path)
- }
- sort.Strings(paths)
- for _, path := range paths {
- astutil.AddImport(tr.fset, file, path)
- }
- }
-
- tr.currentPkg = nil
-
- return tr.nsubsts
-}
-
-// setValue is a wrapper for x.SetValue(y); it protects
-// the caller from panics if x cannot be changed to y.
-func setValue(x, y reflect.Value) {
- // don't bother if y is invalid to start with
- if !y.IsValid() {
- return
- }
- defer func() {
- if x := recover(); x != nil {
- if s, ok := x.(string); ok &&
- (strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
- // x cannot be set to y - ignore this rewrite
- return
- }
- panic(x)
- }
- }()
- x.Set(y)
-}
-
-// Values/types for special cases.
-var (
- objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
- scopePtrNil = reflect.ValueOf((*ast.Scope)(nil))
-
- identType = reflect.TypeOf((*ast.Ident)(nil))
- selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil))
- objectPtrType = reflect.TypeOf((*ast.Object)(nil))
- statementType = reflect.TypeOf((*ast.Stmt)(nil)).Elem()
- positionType = reflect.TypeOf(token.NoPos)
- scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
-)
-
-// apply replaces each AST field x in val with f(x), returning val.
-// To avoid extra conversions, f operates on the reflect.Value form.
-// f takes a reflect.Value representing the variable to modify of type ast.Node.
-// It returns a reflect.Value containing the transformed value of type ast.Node,
-// whether any change was made, and a map of identifiers to ast.Expr (so we can
-// do contextually correct substitutions in the parent statements).
-func (tr *Transformer) apply(f func(reflect.Value) (reflect.Value, bool, map[string]ast.Expr), val reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
- if !val.IsValid() {
- return reflect.Value{}, false, nil
- }
-
- // *ast.Objects introduce cycles and are likely incorrect after
- // rewrite; don't follow them but replace with nil instead
- if val.Type() == objectPtrType {
- return objectPtrNil, false, nil
- }
-
- // similarly for scopes: they are likely incorrect after a rewrite;
- // replace them with nil
- if val.Type() == scopePtrType {
- return scopePtrNil, false, nil
- }
-
- switch v := reflect.Indirect(val); v.Kind() {
- case reflect.Slice:
- // no possible rewriting of statements.
- if v.Type().Elem() != statementType {
- changed := false
- var envp map[string]ast.Expr
- for i := 0; i < v.Len(); i++ {
- e := v.Index(i)
- o, localchanged, env := f(e)
- if localchanged {
- changed = true
- // we clobber envp here,
- // which means if we have two successive
- // replacements inside the same statement
- // we will only generate the setup for one of them.
- envp = env
- }
- setValue(e, o)
- }
- return val, changed, envp
- }
-
- // statements are rewritten.
- var out []ast.Stmt
- for i := 0; i < v.Len(); i++ {
- e := v.Index(i)
- o, changed, env := f(e)
- if changed {
- for _, s := range tr.afterStmts {
- t := tr.subst(env, reflect.ValueOf(s), reflect.Value{}).Interface()
- out = append(out, t.(ast.Stmt))
- }
- }
- setValue(e, o)
- out = append(out, e.Interface().(ast.Stmt))
- }
- return reflect.ValueOf(out), false, nil
- case reflect.Struct:
- changed := false
- var envp map[string]ast.Expr
- for i := 0; i < v.NumField(); i++ {
- e := v.Field(i)
- o, localchanged, env := f(e)
- if localchanged {
- changed = true
- envp = env
- }
- setValue(e, o)
- }
- return val, changed, envp
- case reflect.Interface:
- e := v.Elem()
- o, changed, env := f(e)
- setValue(v, o)
- return val, changed, env
- }
- return val, false, nil
-}
-
-// subst returns a copy of (replacement) pattern with values from env
-// substituted in place of wildcards and pos used as the position of
-// tokens from the pattern. if env == nil, subst returns a copy of
-// pattern and doesn't change the line number information.
-func (tr *Transformer) subst(env map[string]ast.Expr, pattern, pos reflect.Value) reflect.Value {
- if !pattern.IsValid() {
- return reflect.Value{}
- }
-
- // *ast.Objects introduce cycles and are likely incorrect after
- // rewrite; don't follow them but replace with nil instead
- if pattern.Type() == objectPtrType {
- return objectPtrNil
- }
-
- // similarly for scopes: they are likely incorrect after a rewrite;
- // replace them with nil
- if pattern.Type() == scopePtrType {
- return scopePtrNil
- }
-
- // Wildcard gets replaced with map value.
- if env != nil && pattern.Type() == identType {
- id := pattern.Interface().(*ast.Ident)
- if old, ok := env[id.Name]; ok {
- return tr.subst(nil, reflect.ValueOf(old), reflect.Value{})
- }
- }
-
- // Emit qualified identifiers in the pattern by appropriate
- // (possibly qualified) identifier in the input.
- //
- // The template cannot contain dot imports, so all identifiers
- // for imported objects are explicitly qualified.
- //
- // We assume (unsoundly) that there are no dot or named
- // imports in the input code, nor are any imported package
- // names shadowed, so the usual normal qualified identifier
- // syntax may be used.
- // TODO(adonovan): fix: avoid this assumption.
- //
- // A refactoring may be applied to a package referenced by the
- // template. Objects belonging to the current package are
- // denoted by unqualified identifiers.
- //
- if tr.importedObjs != nil && pattern.Type() == selectorExprType {
- obj := isRef(pattern.Interface().(*ast.SelectorExpr), tr.info)
- if obj != nil {
- if sel, ok := tr.importedObjs[obj]; ok {
- var id ast.Expr
- if obj.Pkg() == tr.currentPkg {
- id = sel.Sel // unqualified
- } else {
- id = sel // pkg-qualified
- }
-
- // Return a clone of id.
- saved := tr.importedObjs
- tr.importedObjs = nil // break cycle
- r := tr.subst(nil, reflect.ValueOf(id), pos)
- tr.importedObjs = saved
- return r
- }
- }
- }
-
- if pos.IsValid() && pattern.Type() == positionType {
- // use new position only if old position was valid in the first place
- if old := pattern.Interface().(token.Pos); !old.IsValid() {
- return pattern
- }
- return pos
- }
-
- // Otherwise copy.
- switch p := pattern; p.Kind() {
- case reflect.Slice:
- v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
- for i := 0; i < p.Len(); i++ {
- v.Index(i).Set(tr.subst(env, p.Index(i), pos))
- }
- return v
-
- case reflect.Struct:
- v := reflect.New(p.Type()).Elem()
- for i := 0; i < p.NumField(); i++ {
- v.Field(i).Set(tr.subst(env, p.Field(i), pos))
- }
- return v
-
- case reflect.Ptr:
- v := reflect.New(p.Type()).Elem()
- if elem := p.Elem(); elem.IsValid() {
- v.Set(tr.subst(env, elem, pos).Addr())
- }
-
- // Duplicate type information for duplicated ast.Expr.
- // All ast.Node implementations are *structs,
- // so this case catches them all.
- if e := rvToExpr(v); e != nil {
- updateTypeInfo(tr.info, e, p.Interface().(ast.Expr))
- }
- return v
-
- case reflect.Interface:
- v := reflect.New(p.Type()).Elem()
- if elem := p.Elem(); elem.IsValid() {
- v.Set(tr.subst(env, elem, pos))
- }
- return v
- }
-
- return pattern
-}
-
-// -- utilities -------------------------------------------------------
-
-func rvToExpr(rv reflect.Value) ast.Expr {
- if rv.CanInterface() {
- if e, ok := rv.Interface().(ast.Expr); ok {
- return e
- }
- }
- return nil
-}
-
-// updateTypeInfo duplicates type information for the existing AST old
-// so that it also applies to duplicated AST new.
-func updateTypeInfo(info *types.Info, new, old ast.Expr) {
- switch new := new.(type) {
- case *ast.Ident:
- orig := old.(*ast.Ident)
- if obj, ok := info.Defs[orig]; ok {
- info.Defs[new] = obj
- }
- if obj, ok := info.Uses[orig]; ok {
- info.Uses[new] = obj
- }
-
- case *ast.SelectorExpr:
- orig := old.(*ast.SelectorExpr)
- if sel, ok := info.Selections[orig]; ok {
- info.Selections[new] = sel
- }
- }
-
- if tv, ok := info.Types[old]; ok {
- info.Types[new] = tv
- }
-}