+++ /dev/null
-package pattern
-
-import (
- "fmt"
- "go/ast"
- "go/token"
- "go/types"
- "reflect"
-
- "honnef.co/go/tools/lint"
-)
-
-var tokensByString = map[string]Token{
- "INT": Token(token.INT),
- "FLOAT": Token(token.FLOAT),
- "IMAG": Token(token.IMAG),
- "CHAR": Token(token.CHAR),
- "STRING": Token(token.STRING),
- "+": Token(token.ADD),
- "-": Token(token.SUB),
- "*": Token(token.MUL),
- "/": Token(token.QUO),
- "%": Token(token.REM),
- "&": Token(token.AND),
- "|": Token(token.OR),
- "^": Token(token.XOR),
- "<<": Token(token.SHL),
- ">>": Token(token.SHR),
- "&^": Token(token.AND_NOT),
- "+=": Token(token.ADD_ASSIGN),
- "-=": Token(token.SUB_ASSIGN),
- "*=": Token(token.MUL_ASSIGN),
- "/=": Token(token.QUO_ASSIGN),
- "%=": Token(token.REM_ASSIGN),
- "&=": Token(token.AND_ASSIGN),
- "|=": Token(token.OR_ASSIGN),
- "^=": Token(token.XOR_ASSIGN),
- "<<=": Token(token.SHL_ASSIGN),
- ">>=": Token(token.SHR_ASSIGN),
- "&^=": Token(token.AND_NOT_ASSIGN),
- "&&": Token(token.LAND),
- "||": Token(token.LOR),
- "<-": Token(token.ARROW),
- "++": Token(token.INC),
- "--": Token(token.DEC),
- "==": Token(token.EQL),
- "<": Token(token.LSS),
- ">": Token(token.GTR),
- "=": Token(token.ASSIGN),
- "!": Token(token.NOT),
- "!=": Token(token.NEQ),
- "<=": Token(token.LEQ),
- ">=": Token(token.GEQ),
- ":=": Token(token.DEFINE),
- "...": Token(token.ELLIPSIS),
- "IMPORT": Token(token.IMPORT),
- "VAR": Token(token.VAR),
- "TYPE": Token(token.TYPE),
- "CONST": Token(token.CONST),
-}
-
-func maybeToken(node Node) (Node, bool) {
- if node, ok := node.(String); ok {
- if tok, ok := tokensByString[string(node)]; ok {
- return tok, true
- }
- return node, false
- }
- return node, false
-}
-
-func isNil(v interface{}) bool {
- if v == nil {
- return true
- }
- if _, ok := v.(Nil); ok {
- return true
- }
- return false
-}
-
-type matcher interface {
- Match(*Matcher, interface{}) (interface{}, bool)
-}
-
-type State = map[string]interface{}
-
-type Matcher struct {
- TypesInfo *types.Info
- State State
-}
-
-func (m *Matcher) fork() *Matcher {
- state := make(State, len(m.State))
- for k, v := range m.State {
- state[k] = v
- }
- return &Matcher{
- TypesInfo: m.TypesInfo,
- State: state,
- }
-}
-
-func (m *Matcher) merge(mc *Matcher) {
- m.State = mc.State
-}
-
-func (m *Matcher) Match(a Node, b ast.Node) bool {
- m.State = State{}
- _, ok := match(m, a, b)
- return ok
-}
-
-func Match(a Node, b ast.Node) (*Matcher, bool) {
- m := &Matcher{}
- ret := m.Match(a, b)
- return m, ret
-}
-
-// Match two items, which may be (Node, AST) or (AST, AST)
-func match(m *Matcher, l, r interface{}) (interface{}, bool) {
- if _, ok := r.(Node); ok {
- panic("Node mustn't be on right side of match")
- }
-
- switch l := l.(type) {
- case *ast.ParenExpr:
- return match(m, l.X, r)
- case *ast.ExprStmt:
- return match(m, l.X, r)
- case *ast.DeclStmt:
- return match(m, l.Decl, r)
- case *ast.LabeledStmt:
- return match(m, l.Stmt, r)
- case *ast.BlockStmt:
- return match(m, l.List, r)
- case *ast.FieldList:
- return match(m, l.List, r)
- }
-
- switch r := r.(type) {
- case *ast.ParenExpr:
- return match(m, l, r.X)
- case *ast.ExprStmt:
- return match(m, l, r.X)
- case *ast.DeclStmt:
- return match(m, l, r.Decl)
- case *ast.LabeledStmt:
- return match(m, l, r.Stmt)
- case *ast.BlockStmt:
- if r == nil {
- return match(m, l, nil)
- }
- return match(m, l, r.List)
- case *ast.FieldList:
- if r == nil {
- return match(m, l, nil)
- }
- return match(m, l, r.List)
- case *ast.BasicLit:
- if r == nil {
- return match(m, l, nil)
- }
- }
-
- if l, ok := l.(matcher); ok {
- return l.Match(m, r)
- }
-
- if l, ok := l.(Node); ok {
- // Matching of pattern with concrete value
- return matchNodeAST(m, l, r)
- }
-
- if l == nil || r == nil {
- return nil, l == r
- }
-
- {
- ln, ok1 := l.(ast.Node)
- rn, ok2 := r.(ast.Node)
- if ok1 && ok2 {
- return matchAST(m, ln, rn)
- }
- }
-
- {
- obj, ok := l.(types.Object)
- if ok {
- switch r := r.(type) {
- case *ast.Ident:
- return obj, obj == m.TypesInfo.ObjectOf(r)
- case *ast.SelectorExpr:
- return obj, obj == m.TypesInfo.ObjectOf(r.Sel)
- default:
- return obj, false
- }
- }
- }
-
- {
- ln, ok1 := l.([]ast.Expr)
- rn, ok2 := r.([]ast.Expr)
- if ok1 || ok2 {
- if ok1 && !ok2 {
- rn = []ast.Expr{r.(ast.Expr)}
- } else if !ok1 && ok2 {
- ln = []ast.Expr{l.(ast.Expr)}
- }
-
- if len(ln) != len(rn) {
- return nil, false
- }
- for i, ll := range ln {
- if _, ok := match(m, ll, rn[i]); !ok {
- return nil, false
- }
- }
- return r, true
- }
- }
-
- {
- ln, ok1 := l.([]ast.Stmt)
- rn, ok2 := r.([]ast.Stmt)
- if ok1 || ok2 {
- if ok1 && !ok2 {
- rn = []ast.Stmt{r.(ast.Stmt)}
- } else if !ok1 && ok2 {
- ln = []ast.Stmt{l.(ast.Stmt)}
- }
-
- if len(ln) != len(rn) {
- return nil, false
- }
- for i, ll := range ln {
- if _, ok := match(m, ll, rn[i]); !ok {
- return nil, false
- }
- }
- return r, true
- }
- }
-
- {
- ln, ok1 := l.([]*ast.Field)
- rn, ok2 := r.([]*ast.Field)
- if ok1 || ok2 {
- if ok1 && !ok2 {
- rn = []*ast.Field{r.(*ast.Field)}
- } else if !ok1 && ok2 {
- ln = []*ast.Field{l.(*ast.Field)}
- }
-
- if len(ln) != len(rn) {
- return nil, false
- }
- for i, ll := range ln {
- if _, ok := match(m, ll, rn[i]); !ok {
- return nil, false
- }
- }
- return r, true
- }
- }
-
- panic(fmt.Sprintf("unsupported comparison: %T and %T", l, r))
-}
-
-// Match a Node with an AST node
-func matchNodeAST(m *Matcher, a Node, b interface{}) (interface{}, bool) {
- switch b := b.(type) {
- case []ast.Stmt:
- // 'a' is not a List or we'd be using its Match
- // implementation.
-
- if len(b) != 1 {
- return nil, false
- }
- return match(m, a, b[0])
- case []ast.Expr:
- // 'a' is not a List or we'd be using its Match
- // implementation.
-
- if len(b) != 1 {
- return nil, false
- }
- return match(m, a, b[0])
- case ast.Node:
- ra := reflect.ValueOf(a)
- rb := reflect.ValueOf(b).Elem()
-
- if ra.Type().Name() != rb.Type().Name() {
- return nil, false
- }
-
- for i := 0; i < ra.NumField(); i++ {
- af := ra.Field(i)
- fieldName := ra.Type().Field(i).Name
- bf := rb.FieldByName(fieldName)
- if (bf == reflect.Value{}) {
- panic(fmt.Sprintf("internal error: could not find field %s in type %t when comparing with %T", fieldName, b, a))
- }
- ai := af.Interface()
- bi := bf.Interface()
- if ai == nil {
- return b, bi == nil
- }
- if _, ok := match(m, ai.(Node), bi); !ok {
- return b, false
- }
- }
- return b, true
- case nil:
- return nil, a == Nil{}
- default:
- panic(fmt.Sprintf("unhandled type %T", b))
- }
-}
-
-// Match two AST nodes
-func matchAST(m *Matcher, a, b ast.Node) (interface{}, bool) {
- ra := reflect.ValueOf(a)
- rb := reflect.ValueOf(b)
-
- if ra.Type() != rb.Type() {
- return nil, false
- }
- if ra.IsNil() || rb.IsNil() {
- return rb, ra.IsNil() == rb.IsNil()
- }
-
- ra = ra.Elem()
- rb = rb.Elem()
- for i := 0; i < ra.NumField(); i++ {
- af := ra.Field(i)
- bf := rb.Field(i)
- if af.Type() == rtTokPos || af.Type() == rtObject || af.Type() == rtCommentGroup {
- continue
- }
-
- switch af.Kind() {
- case reflect.Slice:
- if af.Len() != bf.Len() {
- return nil, false
- }
- for j := 0; j < af.Len(); j++ {
- if _, ok := match(m, af.Index(j).Interface().(ast.Node), bf.Index(j).Interface().(ast.Node)); !ok {
- return nil, false
- }
- }
- case reflect.String:
- if af.String() != bf.String() {
- return nil, false
- }
- case reflect.Int:
- if af.Int() != bf.Int() {
- return nil, false
- }
- case reflect.Bool:
- if af.Bool() != bf.Bool() {
- return nil, false
- }
- case reflect.Ptr, reflect.Interface:
- if _, ok := match(m, af.Interface(), bf.Interface()); !ok {
- return nil, false
- }
- default:
- panic(fmt.Sprintf("internal error: unhandled kind %s (%T)", af.Kind(), af.Interface()))
- }
- }
- return b, true
-}
-
-func (b Binding) Match(m *Matcher, node interface{}) (interface{}, bool) {
- if isNil(b.Node) {
- v, ok := m.State[b.Name]
- if ok {
- // Recall value
- return match(m, v, node)
- }
- // Matching anything
- b.Node = Any{}
- }
-
- // Store value
- if _, ok := m.State[b.Name]; ok {
- panic(fmt.Sprintf("binding already created: %s", b.Name))
- }
- new, ret := match(m, b.Node, node)
- if ret {
- m.State[b.Name] = new
- }
- return new, ret
-}
-
-func (Any) Match(m *Matcher, node interface{}) (interface{}, bool) {
- return node, true
-}
-
-func (l List) Match(m *Matcher, node interface{}) (interface{}, bool) {
- v := reflect.ValueOf(node)
- if v.Kind() == reflect.Slice {
- if isNil(l.Head) {
- return node, v.Len() == 0
- }
- if v.Len() == 0 {
- return nil, false
- }
- // OPT(dh): don't check the entire tail if head didn't match
- _, ok1 := match(m, l.Head, v.Index(0).Interface())
- _, ok2 := match(m, l.Tail, v.Slice(1, v.Len()).Interface())
- return node, ok1 && ok2
- }
- // Our empty list does not equal an untyped Go nil. This way, we can
- // tell apart an if with no else and an if with an empty else.
- return nil, false
-}
-
-func (s String) Match(m *Matcher, node interface{}) (interface{}, bool) {
- switch o := node.(type) {
- case token.Token:
- if tok, ok := maybeToken(s); ok {
- return match(m, tok, node)
- }
- return nil, false
- case string:
- return o, string(s) == o
- default:
- return nil, false
- }
-}
-
-func (tok Token) Match(m *Matcher, node interface{}) (interface{}, bool) {
- o, ok := node.(token.Token)
- if !ok {
- return nil, false
- }
- return o, token.Token(tok) == o
-}
-
-func (Nil) Match(m *Matcher, node interface{}) (interface{}, bool) {
- return nil, isNil(node)
-}
-
-func (builtin Builtin) Match(m *Matcher, node interface{}) (interface{}, bool) {
- ident, ok := node.(*ast.Ident)
- if !ok {
- return nil, false
- }
- obj := m.TypesInfo.ObjectOf(ident)
- if obj != types.Universe.Lookup(ident.Name) {
- return nil, false
- }
- return match(m, builtin.Name, ident.Name)
-}
-
-func (obj Object) Match(m *Matcher, node interface{}) (interface{}, bool) {
- ident, ok := node.(*ast.Ident)
- if !ok {
- return nil, false
- }
-
- id := m.TypesInfo.ObjectOf(ident)
- _, ok = match(m, obj.Name, ident.Name)
- return id, ok
-}
-
-func (fn Function) Match(m *Matcher, node interface{}) (interface{}, bool) {
- var name string
- var obj types.Object
- switch node := node.(type) {
- case *ast.Ident:
- obj = m.TypesInfo.ObjectOf(node)
- switch obj := obj.(type) {
- case *types.Func:
- name = lint.FuncName(obj)
- case *types.Builtin:
- name = obj.Name()
- default:
- return nil, false
- }
- case *ast.SelectorExpr:
- var ok bool
- obj, ok = m.TypesInfo.ObjectOf(node.Sel).(*types.Func)
- if !ok {
- return nil, false
- }
- name = lint.FuncName(obj.(*types.Func))
- default:
- return nil, false
- }
- _, ok := match(m, fn.Name, name)
- return obj, ok
-}
-
-func (or Or) Match(m *Matcher, node interface{}) (interface{}, bool) {
- for _, opt := range or.Nodes {
- mc := m.fork()
- if ret, ok := match(mc, opt, node); ok {
- m.merge(mc)
- return ret, true
- }
- }
- return nil, false
-}
-
-func (not Not) Match(m *Matcher, node interface{}) (interface{}, bool) {
- _, ok := match(m, not.Node, node)
- if ok {
- return nil, false
- }
- return node, true
-}
-
-var (
- // Types of fields in go/ast structs that we want to skip
- rtTokPos = reflect.TypeOf(token.Pos(0))
- rtObject = reflect.TypeOf((*ast.Object)(nil))
- rtCommentGroup = reflect.TypeOf((*ast.CommentGroup)(nil))
-)
-
-var (
- _ matcher = Binding{}
- _ matcher = Any{}
- _ matcher = List{}
- _ matcher = String("")
- _ matcher = Token(0)
- _ matcher = Nil{}
- _ matcher = Builtin{}
- _ matcher = Object{}
- _ matcher = Function{}
- _ matcher = Or{}
- _ matcher = Not{}
-)