--- /dev/null
+package pattern
+
+import (
+ "fmt"
+ "go/ast"
+ "go/token"
+ "reflect"
+)
+
+type Pattern struct {
+ Root Node
+ // Relevant contains instances of ast.Node that could potentially
+ // initiate a successful match of the pattern.
+ Relevant []reflect.Type
+}
+
+func MustParse(s string) Pattern {
+ p := &Parser{AllowTypeInfo: true}
+ pat, err := p.Parse(s)
+ if err != nil {
+ panic(err)
+ }
+ return pat
+}
+
+func roots(node Node) []reflect.Type {
+ switch node := node.(type) {
+ case Or:
+ var out []reflect.Type
+ for _, el := range node.Nodes {
+ out = append(out, roots(el)...)
+ }
+ return out
+ case Not:
+ return roots(node.Node)
+ case Binding:
+ return roots(node.Node)
+ case Nil, nil:
+ // this branch is reached via bindings
+ return allTypes
+ default:
+ Ts, ok := nodeToASTTypes[reflect.TypeOf(node)]
+ if !ok {
+ panic(fmt.Sprintf("internal error: unhandled type %T", node))
+ }
+ return Ts
+ }
+}
+
+var allTypes = []reflect.Type{
+ reflect.TypeOf((*ast.RangeStmt)(nil)),
+ reflect.TypeOf((*ast.AssignStmt)(nil)),
+ reflect.TypeOf((*ast.IndexExpr)(nil)),
+ reflect.TypeOf((*ast.Ident)(nil)),
+ reflect.TypeOf((*ast.ValueSpec)(nil)),
+ reflect.TypeOf((*ast.GenDecl)(nil)),
+ reflect.TypeOf((*ast.BinaryExpr)(nil)),
+ reflect.TypeOf((*ast.ForStmt)(nil)),
+ reflect.TypeOf((*ast.ArrayType)(nil)),
+ reflect.TypeOf((*ast.DeferStmt)(nil)),
+ reflect.TypeOf((*ast.MapType)(nil)),
+ reflect.TypeOf((*ast.ReturnStmt)(nil)),
+ reflect.TypeOf((*ast.SliceExpr)(nil)),
+ reflect.TypeOf((*ast.StarExpr)(nil)),
+ reflect.TypeOf((*ast.UnaryExpr)(nil)),
+ reflect.TypeOf((*ast.SendStmt)(nil)),
+ reflect.TypeOf((*ast.SelectStmt)(nil)),
+ reflect.TypeOf((*ast.ImportSpec)(nil)),
+ reflect.TypeOf((*ast.IfStmt)(nil)),
+ reflect.TypeOf((*ast.GoStmt)(nil)),
+ reflect.TypeOf((*ast.Field)(nil)),
+ reflect.TypeOf((*ast.SelectorExpr)(nil)),
+ reflect.TypeOf((*ast.StructType)(nil)),
+ reflect.TypeOf((*ast.KeyValueExpr)(nil)),
+ reflect.TypeOf((*ast.FuncType)(nil)),
+ reflect.TypeOf((*ast.FuncLit)(nil)),
+ reflect.TypeOf((*ast.FuncDecl)(nil)),
+ reflect.TypeOf((*ast.ChanType)(nil)),
+ reflect.TypeOf((*ast.CallExpr)(nil)),
+ reflect.TypeOf((*ast.CaseClause)(nil)),
+ reflect.TypeOf((*ast.CommClause)(nil)),
+ reflect.TypeOf((*ast.CompositeLit)(nil)),
+ reflect.TypeOf((*ast.EmptyStmt)(nil)),
+ reflect.TypeOf((*ast.SwitchStmt)(nil)),
+ reflect.TypeOf((*ast.TypeSwitchStmt)(nil)),
+ reflect.TypeOf((*ast.TypeAssertExpr)(nil)),
+ reflect.TypeOf((*ast.TypeSpec)(nil)),
+ reflect.TypeOf((*ast.InterfaceType)(nil)),
+ reflect.TypeOf((*ast.BranchStmt)(nil)),
+ reflect.TypeOf((*ast.IncDecStmt)(nil)),
+ reflect.TypeOf((*ast.BasicLit)(nil)),
+}
+
+var nodeToASTTypes = map[reflect.Type][]reflect.Type{
+ reflect.TypeOf(String("")): nil,
+ reflect.TypeOf(Token(0)): nil,
+ reflect.TypeOf(List{}): {reflect.TypeOf((*ast.BlockStmt)(nil)), reflect.TypeOf((*ast.FieldList)(nil))},
+ reflect.TypeOf(Builtin{}): {reflect.TypeOf((*ast.Ident)(nil))},
+ reflect.TypeOf(Object{}): {reflect.TypeOf((*ast.Ident)(nil))},
+ reflect.TypeOf(Function{}): {reflect.TypeOf((*ast.Ident)(nil)), reflect.TypeOf((*ast.SelectorExpr)(nil))},
+ reflect.TypeOf(Any{}): allTypes,
+ reflect.TypeOf(RangeStmt{}): {reflect.TypeOf((*ast.RangeStmt)(nil))},
+ reflect.TypeOf(AssignStmt{}): {reflect.TypeOf((*ast.AssignStmt)(nil))},
+ reflect.TypeOf(IndexExpr{}): {reflect.TypeOf((*ast.IndexExpr)(nil))},
+ reflect.TypeOf(Ident{}): {reflect.TypeOf((*ast.Ident)(nil))},
+ reflect.TypeOf(ValueSpec{}): {reflect.TypeOf((*ast.ValueSpec)(nil))},
+ reflect.TypeOf(GenDecl{}): {reflect.TypeOf((*ast.GenDecl)(nil))},
+ reflect.TypeOf(BinaryExpr{}): {reflect.TypeOf((*ast.BinaryExpr)(nil))},
+ reflect.TypeOf(ForStmt{}): {reflect.TypeOf((*ast.ForStmt)(nil))},
+ reflect.TypeOf(ArrayType{}): {reflect.TypeOf((*ast.ArrayType)(nil))},
+ reflect.TypeOf(DeferStmt{}): {reflect.TypeOf((*ast.DeferStmt)(nil))},
+ reflect.TypeOf(MapType{}): {reflect.TypeOf((*ast.MapType)(nil))},
+ reflect.TypeOf(ReturnStmt{}): {reflect.TypeOf((*ast.ReturnStmt)(nil))},
+ reflect.TypeOf(SliceExpr{}): {reflect.TypeOf((*ast.SliceExpr)(nil))},
+ reflect.TypeOf(StarExpr{}): {reflect.TypeOf((*ast.StarExpr)(nil))},
+ reflect.TypeOf(UnaryExpr{}): {reflect.TypeOf((*ast.UnaryExpr)(nil))},
+ reflect.TypeOf(SendStmt{}): {reflect.TypeOf((*ast.SendStmt)(nil))},
+ reflect.TypeOf(SelectStmt{}): {reflect.TypeOf((*ast.SelectStmt)(nil))},
+ reflect.TypeOf(ImportSpec{}): {reflect.TypeOf((*ast.ImportSpec)(nil))},
+ reflect.TypeOf(IfStmt{}): {reflect.TypeOf((*ast.IfStmt)(nil))},
+ reflect.TypeOf(GoStmt{}): {reflect.TypeOf((*ast.GoStmt)(nil))},
+ reflect.TypeOf(Field{}): {reflect.TypeOf((*ast.Field)(nil))},
+ reflect.TypeOf(SelectorExpr{}): {reflect.TypeOf((*ast.SelectorExpr)(nil))},
+ reflect.TypeOf(StructType{}): {reflect.TypeOf((*ast.StructType)(nil))},
+ reflect.TypeOf(KeyValueExpr{}): {reflect.TypeOf((*ast.KeyValueExpr)(nil))},
+ reflect.TypeOf(FuncType{}): {reflect.TypeOf((*ast.FuncType)(nil))},
+ reflect.TypeOf(FuncLit{}): {reflect.TypeOf((*ast.FuncLit)(nil))},
+ reflect.TypeOf(FuncDecl{}): {reflect.TypeOf((*ast.FuncDecl)(nil))},
+ reflect.TypeOf(ChanType{}): {reflect.TypeOf((*ast.ChanType)(nil))},
+ reflect.TypeOf(CallExpr{}): {reflect.TypeOf((*ast.CallExpr)(nil))},
+ reflect.TypeOf(CaseClause{}): {reflect.TypeOf((*ast.CaseClause)(nil))},
+ reflect.TypeOf(CommClause{}): {reflect.TypeOf((*ast.CommClause)(nil))},
+ reflect.TypeOf(CompositeLit{}): {reflect.TypeOf((*ast.CompositeLit)(nil))},
+ reflect.TypeOf(EmptyStmt{}): {reflect.TypeOf((*ast.EmptyStmt)(nil))},
+ reflect.TypeOf(SwitchStmt{}): {reflect.TypeOf((*ast.SwitchStmt)(nil))},
+ reflect.TypeOf(TypeSwitchStmt{}): {reflect.TypeOf((*ast.TypeSwitchStmt)(nil))},
+ reflect.TypeOf(TypeAssertExpr{}): {reflect.TypeOf((*ast.TypeAssertExpr)(nil))},
+ reflect.TypeOf(TypeSpec{}): {reflect.TypeOf((*ast.TypeSpec)(nil))},
+ reflect.TypeOf(InterfaceType{}): {reflect.TypeOf((*ast.InterfaceType)(nil))},
+ reflect.TypeOf(BranchStmt{}): {reflect.TypeOf((*ast.BranchStmt)(nil))},
+ reflect.TypeOf(IncDecStmt{}): {reflect.TypeOf((*ast.IncDecStmt)(nil))},
+ reflect.TypeOf(BasicLit{}): {reflect.TypeOf((*ast.BasicLit)(nil))},
+}
+
+var requiresTypeInfo = map[string]bool{
+ "Function": true,
+ "Builtin": true,
+ "Object": true,
+}
+
+type Parser struct {
+ // Allow nodes that rely on type information
+ AllowTypeInfo bool
+
+ lex *lexer
+ cur item
+ last *item
+ items chan item
+}
+
+func (p *Parser) Parse(s string) (Pattern, error) {
+ p.cur = item{}
+ p.last = nil
+ p.items = nil
+
+ fset := token.NewFileSet()
+ p.lex = &lexer{
+ f: fset.AddFile("<input>", -1, len(s)),
+ input: s,
+ items: make(chan item),
+ }
+ go p.lex.run()
+ p.items = p.lex.items
+ root, err := p.node()
+ if err != nil {
+ // drain lexer if parsing failed
+ for range p.lex.items {
+ }
+ return Pattern{}, err
+ }
+ if item := <-p.lex.items; item.typ != itemEOF {
+ return Pattern{}, fmt.Errorf("unexpected token %s after end of pattern", item.typ)
+ }
+ return Pattern{
+ Root: root,
+ Relevant: roots(root),
+ }, nil
+}
+
+func (p *Parser) next() item {
+ if p.last != nil {
+ n := *p.last
+ p.last = nil
+ return n
+ }
+ var ok bool
+ p.cur, ok = <-p.items
+ if !ok {
+ p.cur = item{typ: eof}
+ }
+ return p.cur
+}
+
+func (p *Parser) rewind() {
+ p.last = &p.cur
+}
+
+func (p *Parser) peek() item {
+ n := p.next()
+ p.rewind()
+ return n
+}
+
+func (p *Parser) accept(typ itemType) (item, bool) {
+ n := p.next()
+ if n.typ == typ {
+ return n, true
+ }
+ p.rewind()
+ return item{}, false
+}
+
+func (p *Parser) unexpectedToken(valid string) error {
+ if p.cur.typ == itemError {
+ return fmt.Errorf("error lexing input: %s", p.cur.val)
+ }
+ var got string
+ switch p.cur.typ {
+ case itemTypeName, itemVariable, itemString:
+ got = p.cur.val
+ default:
+ got = "'" + p.cur.typ.String() + "'"
+ }
+
+ pos := p.lex.f.Position(token.Pos(p.cur.pos))
+ return fmt.Errorf("%s: expected %s, found %s", pos, valid, got)
+}
+
+func (p *Parser) node() (Node, error) {
+ if _, ok := p.accept(itemLeftParen); !ok {
+ return nil, p.unexpectedToken("'('")
+ }
+ typ, ok := p.accept(itemTypeName)
+ if !ok {
+ return nil, p.unexpectedToken("Node type")
+ }
+
+ var objs []Node
+ for {
+ if _, ok := p.accept(itemRightParen); ok {
+ break
+ } else {
+ p.rewind()
+ obj, err := p.object()
+ if err != nil {
+ return nil, err
+ }
+ objs = append(objs, obj)
+ }
+ }
+
+ return p.populateNode(typ.val, objs)
+}
+
+func populateNode(typ string, objs []Node, allowTypeInfo bool) (Node, error) {
+ T, ok := structNodes[typ]
+ if !ok {
+ return nil, fmt.Errorf("unknown node %s", typ)
+ }
+
+ if !allowTypeInfo && requiresTypeInfo[typ] {
+ return nil, fmt.Errorf("Node %s requires type information", typ)
+ }
+
+ pv := reflect.New(T)
+ v := pv.Elem()
+
+ if v.NumField() == 1 {
+ f := v.Field(0)
+ if f.Type().Kind() == reflect.Slice {
+ // Variadic node
+ f.Set(reflect.AppendSlice(f, reflect.ValueOf(objs)))
+ return v.Interface().(Node), nil
+ }
+ }
+ if len(objs) != v.NumField() {
+ return nil, fmt.Errorf("tried to initialize node %s with %d values, expected %d", typ, len(objs), v.NumField())
+ }
+ for i := 0; i < v.NumField(); i++ {
+ f := v.Field(i)
+ if f.Kind() == reflect.String {
+ if obj, ok := objs[i].(String); ok {
+ f.Set(reflect.ValueOf(string(obj)))
+ } else {
+ return nil, fmt.Errorf("first argument of (Binding name node) must be string, but got %s", objs[i])
+ }
+ } else {
+ f.Set(reflect.ValueOf(objs[i]))
+ }
+ }
+ return v.Interface().(Node), nil
+}
+
+func (p *Parser) populateNode(typ string, objs []Node) (Node, error) {
+ return populateNode(typ, objs, p.AllowTypeInfo)
+}
+
+var structNodes = map[string]reflect.Type{
+ "Any": reflect.TypeOf(Any{}),
+ "Ellipsis": reflect.TypeOf(Ellipsis{}),
+ "List": reflect.TypeOf(List{}),
+ "Binding": reflect.TypeOf(Binding{}),
+ "RangeStmt": reflect.TypeOf(RangeStmt{}),
+ "AssignStmt": reflect.TypeOf(AssignStmt{}),
+ "IndexExpr": reflect.TypeOf(IndexExpr{}),
+ "Ident": reflect.TypeOf(Ident{}),
+ "Builtin": reflect.TypeOf(Builtin{}),
+ "ValueSpec": reflect.TypeOf(ValueSpec{}),
+ "GenDecl": reflect.TypeOf(GenDecl{}),
+ "BinaryExpr": reflect.TypeOf(BinaryExpr{}),
+ "ForStmt": reflect.TypeOf(ForStmt{}),
+ "ArrayType": reflect.TypeOf(ArrayType{}),
+ "DeferStmt": reflect.TypeOf(DeferStmt{}),
+ "MapType": reflect.TypeOf(MapType{}),
+ "ReturnStmt": reflect.TypeOf(ReturnStmt{}),
+ "SliceExpr": reflect.TypeOf(SliceExpr{}),
+ "StarExpr": reflect.TypeOf(StarExpr{}),
+ "UnaryExpr": reflect.TypeOf(UnaryExpr{}),
+ "SendStmt": reflect.TypeOf(SendStmt{}),
+ "SelectStmt": reflect.TypeOf(SelectStmt{}),
+ "ImportSpec": reflect.TypeOf(ImportSpec{}),
+ "IfStmt": reflect.TypeOf(IfStmt{}),
+ "GoStmt": reflect.TypeOf(GoStmt{}),
+ "Field": reflect.TypeOf(Field{}),
+ "SelectorExpr": reflect.TypeOf(SelectorExpr{}),
+ "StructType": reflect.TypeOf(StructType{}),
+ "KeyValueExpr": reflect.TypeOf(KeyValueExpr{}),
+ "FuncType": reflect.TypeOf(FuncType{}),
+ "FuncLit": reflect.TypeOf(FuncLit{}),
+ "FuncDecl": reflect.TypeOf(FuncDecl{}),
+ "ChanType": reflect.TypeOf(ChanType{}),
+ "CallExpr": reflect.TypeOf(CallExpr{}),
+ "CaseClause": reflect.TypeOf(CaseClause{}),
+ "CommClause": reflect.TypeOf(CommClause{}),
+ "CompositeLit": reflect.TypeOf(CompositeLit{}),
+ "EmptyStmt": reflect.TypeOf(EmptyStmt{}),
+ "SwitchStmt": reflect.TypeOf(SwitchStmt{}),
+ "TypeSwitchStmt": reflect.TypeOf(TypeSwitchStmt{}),
+ "TypeAssertExpr": reflect.TypeOf(TypeAssertExpr{}),
+ "TypeSpec": reflect.TypeOf(TypeSpec{}),
+ "InterfaceType": reflect.TypeOf(InterfaceType{}),
+ "BranchStmt": reflect.TypeOf(BranchStmt{}),
+ "IncDecStmt": reflect.TypeOf(IncDecStmt{}),
+ "BasicLit": reflect.TypeOf(BasicLit{}),
+ "Object": reflect.TypeOf(Object{}),
+ "Function": reflect.TypeOf(Function{}),
+ "Or": reflect.TypeOf(Or{}),
+ "Not": reflect.TypeOf(Not{}),
+}
+
+func (p *Parser) object() (Node, error) {
+ n := p.next()
+ switch n.typ {
+ case itemLeftParen:
+ p.rewind()
+ node, err := p.node()
+ if err != nil {
+ return node, err
+ }
+ if p.peek().typ == itemColon {
+ p.next()
+ tail, err := p.object()
+ if err != nil {
+ return node, err
+ }
+ return List{Head: node, Tail: tail}, nil
+ }
+ return node, nil
+ case itemLeftBracket:
+ p.rewind()
+ return p.array()
+ case itemVariable:
+ v := n
+ if v.val == "nil" {
+ return Nil{}, nil
+ }
+ var b Binding
+ if _, ok := p.accept(itemAt); ok {
+ o, err := p.node()
+ if err != nil {
+ return nil, err
+ }
+ b = Binding{
+ Name: v.val,
+ Node: o,
+ }
+ } else {
+ p.rewind()
+ b = Binding{Name: v.val}
+ }
+ if p.peek().typ == itemColon {
+ p.next()
+ tail, err := p.object()
+ if err != nil {
+ return b, err
+ }
+ return List{Head: b, Tail: tail}, nil
+ }
+ return b, nil
+ case itemBlank:
+ return Any{}, nil
+ case itemString:
+ return String(n.val), nil
+ default:
+ return nil, p.unexpectedToken("object")
+ }
+}
+
+func (p *Parser) array() (Node, error) {
+ if _, ok := p.accept(itemLeftBracket); !ok {
+ return nil, p.unexpectedToken("'['")
+ }
+
+ var objs []Node
+ for {
+ if _, ok := p.accept(itemRightBracket); ok {
+ break
+ } else {
+ p.rewind()
+ obj, err := p.object()
+ if err != nil {
+ return nil, err
+ }
+ objs = append(objs, obj)
+ }
+ }
+
+ tail := List{}
+ for i := len(objs) - 1; i >= 0; i-- {
+ l := List{
+ Head: objs[i],
+ Tail: tail,
+ }
+ tail = l
+ }
+ return tail, nil
+}
+
+/*
+Node ::= itemLeftParen itemTypeName Object* itemRightParen
+Object ::= Node | Array | Binding | itemVariable | itemBlank | itemString
+Array := itemLeftBracket Object* itemRightBracket
+Array := Object itemColon Object
+Binding ::= itemVariable itemAt Node
+*/