--- /dev/null
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package inspector_test
+
+import (
+ "go/ast"
+ "go/build"
+ "go/parser"
+ "go/token"
+ "log"
+ "path/filepath"
+ "reflect"
+ "strings"
+ "testing"
+
+ "golang.org/x/tools/go/ast/inspector"
+)
+
+var netFiles []*ast.File
+
+func init() {
+ files, err := parseNetFiles()
+ if err != nil {
+ log.Fatal(err)
+ }
+ netFiles = files
+}
+
+func parseNetFiles() ([]*ast.File, error) {
+ pkg, err := build.Default.Import("net", "", 0)
+ if err != nil {
+ return nil, err
+ }
+ fset := token.NewFileSet()
+ var files []*ast.File
+ for _, filename := range pkg.GoFiles {
+ filename = filepath.Join(pkg.Dir, filename)
+ f, err := parser.ParseFile(fset, filename, nil, 0)
+ if err != nil {
+ return nil, err
+ }
+ files = append(files, f)
+ }
+ return files, nil
+}
+
+// TestAllNodes compares Inspector against ast.Inspect.
+func TestInspectAllNodes(t *testing.T) {
+ inspect := inspector.New(netFiles)
+
+ var nodesA []ast.Node
+ inspect.Nodes(nil, func(n ast.Node, push bool) bool {
+ if push {
+ nodesA = append(nodesA, n)
+ }
+ return true
+ })
+ var nodesB []ast.Node
+ for _, f := range netFiles {
+ ast.Inspect(f, func(n ast.Node) bool {
+ if n != nil {
+ nodesB = append(nodesB, n)
+ }
+ return true
+ })
+ }
+ compare(t, nodesA, nodesB)
+}
+
+// TestPruning compares Inspector against ast.Inspect,
+// pruning descent within ast.CallExpr nodes.
+func TestInspectPruning(t *testing.T) {
+ inspect := inspector.New(netFiles)
+
+ var nodesA []ast.Node
+ inspect.Nodes(nil, func(n ast.Node, push bool) bool {
+ if push {
+ nodesA = append(nodesA, n)
+ _, isCall := n.(*ast.CallExpr)
+ return !isCall // don't descend into function calls
+ }
+ return false
+ })
+ var nodesB []ast.Node
+ for _, f := range netFiles {
+ ast.Inspect(f, func(n ast.Node) bool {
+ if n != nil {
+ nodesB = append(nodesB, n)
+ _, isCall := n.(*ast.CallExpr)
+ return !isCall // don't descend into function calls
+ }
+ return false
+ })
+ }
+ compare(t, nodesA, nodesB)
+}
+
+func compare(t *testing.T, nodesA, nodesB []ast.Node) {
+ if len(nodesA) != len(nodesB) {
+ t.Errorf("inconsistent node lists: %d vs %d", len(nodesA), len(nodesB))
+ } else {
+ for i := range nodesA {
+ if a, b := nodesA[i], nodesB[i]; a != b {
+ t.Errorf("node %d is inconsistent: %T, %T", i, a, b)
+ }
+ }
+ }
+}
+
+func TestTypeFiltering(t *testing.T) {
+ const src = `package a
+func f() {
+ print("hi")
+ panic("oops")
+}
+`
+ fset := token.NewFileSet()
+ f, _ := parser.ParseFile(fset, "a.go", src, 0)
+ inspect := inspector.New([]*ast.File{f})
+
+ var got []string
+ fn := func(n ast.Node, push bool) bool {
+ if push {
+ got = append(got, typeOf(n))
+ }
+ return true
+ }
+
+ // no type filtering
+ inspect.Nodes(nil, fn)
+ if want := strings.Fields("File Ident FuncDecl Ident FuncType FieldList BlockStmt ExprStmt CallExpr Ident BasicLit ExprStmt CallExpr Ident BasicLit"); !reflect.DeepEqual(got, want) {
+ t.Errorf("inspect: got %s, want %s", got, want)
+ }
+
+ // type filtering
+ nodeTypes := []ast.Node{
+ (*ast.BasicLit)(nil),
+ (*ast.CallExpr)(nil),
+ }
+ got = nil
+ inspect.Nodes(nodeTypes, fn)
+ if want := strings.Fields("CallExpr BasicLit CallExpr BasicLit"); !reflect.DeepEqual(got, want) {
+ t.Errorf("inspect: got %s, want %s", got, want)
+ }
+
+ // inspect with stack
+ got = nil
+ inspect.WithStack(nodeTypes, func(n ast.Node, push bool, stack []ast.Node) bool {
+ if push {
+ var line []string
+ for _, n := range stack {
+ line = append(line, typeOf(n))
+ }
+ got = append(got, strings.Join(line, " "))
+ }
+ return true
+ })
+ want := []string{
+ "File FuncDecl BlockStmt ExprStmt CallExpr",
+ "File FuncDecl BlockStmt ExprStmt CallExpr BasicLit",
+ "File FuncDecl BlockStmt ExprStmt CallExpr",
+ "File FuncDecl BlockStmt ExprStmt CallExpr BasicLit",
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("inspect: got %s, want %s", got, want)
+ }
+}
+
+func typeOf(n ast.Node) string {
+ return strings.TrimPrefix(reflect.TypeOf(n).String(), "*ast.")
+}
+
+// The numbers show a marginal improvement (ASTInspect/Inspect) of 3.5x,
+// but a break-even point (NewInspector/(ASTInspect-Inspect)) of about 5
+// traversals.
+//
+// BenchmarkNewInspector 4.5 ms
+// BenchmarkNewInspect 0.33ms
+// BenchmarkASTInspect 1.2 ms
+
+func BenchmarkNewInspector(b *testing.B) {
+ // Measure one-time construction overhead.
+ for i := 0; i < b.N; i++ {
+ inspector.New(netFiles)
+ }
+}
+
+func BenchmarkInspect(b *testing.B) {
+ b.StopTimer()
+ inspect := inspector.New(netFiles)
+ b.StartTimer()
+
+ // Measure marginal cost of traversal.
+ var ndecls, nlits int
+ for i := 0; i < b.N; i++ {
+ inspect.Preorder(nil, func(n ast.Node) {
+ switch n.(type) {
+ case *ast.FuncDecl:
+ ndecls++
+ case *ast.FuncLit:
+ nlits++
+ }
+ })
+ }
+}
+
+func BenchmarkASTInspect(b *testing.B) {
+ var ndecls, nlits int
+ for i := 0; i < b.N; i++ {
+ for _, f := range netFiles {
+ ast.Inspect(f, func(n ast.Node) bool {
+ switch n.(type) {
+ case *ast.FuncDecl:
+ ndecls++
+ case *ast.FuncLit:
+ nlits++
+ }
+ return true
+ })
+ }
+ }
+}