--- /dev/null
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package expect
+
+import (
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "path/filepath"
+ "regexp"
+ "strconv"
+ "strings"
+ "text/scanner"
+
+ "golang.org/x/mod/modfile"
+)
+
+const commentStart = "@"
+const commentStartLen = len(commentStart)
+
+// Identifier is the type for an identifier in an Note argument list.
+type Identifier string
+
+// Parse collects all the notes present in a file.
+// If content is nil, the filename specified is read and parsed, otherwise the
+// content is used and the filename is used for positions and error messages.
+// Each comment whose text starts with @ is parsed as a comma-separated
+// sequence of notes.
+// See the package documentation for details about the syntax of those
+// notes.
+func Parse(fset *token.FileSet, filename string, content []byte) ([]*Note, error) {
+ var src interface{}
+ if content != nil {
+ src = content
+ }
+ switch filepath.Ext(filename) {
+ case ".go":
+ // TODO: We should write this in terms of the scanner.
+ // there are ways you can break the parser such that it will not add all the
+ // comments to the ast, which may result in files where the tests are silently
+ // not run.
+ file, err := parser.ParseFile(fset, filename, src, parser.ParseComments)
+ if file == nil {
+ return nil, err
+ }
+ return ExtractGo(fset, file)
+ case ".mod":
+ file, err := modfile.Parse(filename, content, nil)
+ if err != nil {
+ return nil, err
+ }
+ f := fset.AddFile(filename, -1, len(content))
+ f.SetLinesForContent(content)
+ notes, err := extractMod(fset, file)
+ if err != nil {
+ return nil, err
+ }
+ // Since modfile.Parse does not return an *ast, we need to add the offset
+ // within the file's contents to the file's base relative to the fileset.
+ for _, note := range notes {
+ note.Pos += token.Pos(f.Base())
+ }
+ return notes, nil
+ }
+ return nil, nil
+}
+
+// extractMod collects all the notes present in a go.mod file.
+// Each comment whose text starts with @ is parsed as a comma-separated
+// sequence of notes.
+// See the package documentation for details about the syntax of those
+// notes.
+// Only allow notes to appear with the following format: "//@mark()" or // @mark()
+func extractMod(fset *token.FileSet, file *modfile.File) ([]*Note, error) {
+ var notes []*Note
+ for _, stmt := range file.Syntax.Stmt {
+ comment := stmt.Comment()
+ if comment == nil {
+ continue
+ }
+ // Handle the case for markers of `// indirect` to be on the line before
+ // the require statement.
+ // TODO(golang/go#36894): have a more intuitive approach for // indirect
+ for _, cmt := range comment.Before {
+ text, adjust := getAdjustedNote(cmt.Token)
+ if text == "" {
+ continue
+ }
+ parsed, err := parse(fset, token.Pos(int(cmt.Start.Byte)+adjust), text)
+ if err != nil {
+ return nil, err
+ }
+ notes = append(notes, parsed...)
+ }
+ // Handle the normal case for markers on the same line.
+ for _, cmt := range comment.Suffix {
+ text, adjust := getAdjustedNote(cmt.Token)
+ if text == "" {
+ continue
+ }
+ parsed, err := parse(fset, token.Pos(int(cmt.Start.Byte)+adjust), text)
+ if err != nil {
+ return nil, err
+ }
+ notes = append(notes, parsed...)
+ }
+ }
+ return notes, nil
+}
+
+// ExtractGo collects all the notes present in an AST.
+// Each comment whose text starts with @ is parsed as a comma-separated
+// sequence of notes.
+// See the package documentation for details about the syntax of those
+// notes.
+func ExtractGo(fset *token.FileSet, file *ast.File) ([]*Note, error) {
+ var notes []*Note
+ for _, g := range file.Comments {
+ for _, c := range g.List {
+ text, adjust := getAdjustedNote(c.Text)
+ if text == "" {
+ continue
+ }
+ parsed, err := parse(fset, token.Pos(int(c.Pos())+adjust), text)
+ if err != nil {
+ return nil, err
+ }
+ notes = append(notes, parsed...)
+ }
+ }
+ return notes, nil
+}
+
+func getAdjustedNote(text string) (string, int) {
+ if strings.HasPrefix(text, "/*") {
+ text = strings.TrimSuffix(text, "*/")
+ }
+ text = text[2:] // remove "//" or "/*" prefix
+
+ // Allow notes to appear within comments.
+ // For example:
+ // "// //@mark()" is valid.
+ // "// @mark()" is not valid.
+ // "// /*@mark()*/" is not valid.
+ var adjust int
+ if i := strings.Index(text, commentStart); i > 2 {
+ // Get the text before the commentStart.
+ pre := text[i-2 : i]
+ if pre != "//" {
+ return "", 0
+ }
+ text = text[i:]
+ adjust = i
+ }
+ if !strings.HasPrefix(text, commentStart) {
+ return "", 0
+ }
+ text = text[commentStartLen:]
+ return text, commentStartLen + adjust + 1
+}
+
+const invalidToken rune = 0
+
+type tokens struct {
+ scanner scanner.Scanner
+ current rune
+ err error
+ base token.Pos
+}
+
+func (t *tokens) Init(base token.Pos, text string) *tokens {
+ t.base = base
+ t.scanner.Init(strings.NewReader(text))
+ t.scanner.Mode = scanner.GoTokens
+ t.scanner.Whitespace ^= 1 << '\n' // don't skip new lines
+ t.scanner.Error = func(s *scanner.Scanner, msg string) {
+ t.Errorf("%v", msg)
+ }
+ return t
+}
+
+func (t *tokens) Consume() string {
+ t.current = invalidToken
+ return t.scanner.TokenText()
+}
+
+func (t *tokens) Token() rune {
+ if t.err != nil {
+ return scanner.EOF
+ }
+ if t.current == invalidToken {
+ t.current = t.scanner.Scan()
+ }
+ return t.current
+}
+
+func (t *tokens) Skip(r rune) int {
+ i := 0
+ for t.Token() == '\n' {
+ t.Consume()
+ i++
+ }
+ return i
+}
+
+func (t *tokens) TokenString() string {
+ return scanner.TokenString(t.Token())
+}
+
+func (t *tokens) Pos() token.Pos {
+ return t.base + token.Pos(t.scanner.Position.Offset)
+}
+
+func (t *tokens) Errorf(msg string, args ...interface{}) {
+ if t.err != nil {
+ return
+ }
+ t.err = fmt.Errorf(msg, args...)
+}
+
+func parse(fset *token.FileSet, base token.Pos, text string) ([]*Note, error) {
+ t := new(tokens).Init(base, text)
+ notes := parseComment(t)
+ if t.err != nil {
+ return nil, fmt.Errorf("%v:%s", fset.Position(t.Pos()), t.err)
+ }
+ return notes, nil
+}
+
+func parseComment(t *tokens) []*Note {
+ var notes []*Note
+ for {
+ t.Skip('\n')
+ switch t.Token() {
+ case scanner.EOF:
+ return notes
+ case scanner.Ident:
+ notes = append(notes, parseNote(t))
+ default:
+ t.Errorf("unexpected %s parsing comment, expect identifier", t.TokenString())
+ return nil
+ }
+ switch t.Token() {
+ case scanner.EOF:
+ return notes
+ case ',', '\n':
+ t.Consume()
+ default:
+ t.Errorf("unexpected %s parsing comment, expect separator", t.TokenString())
+ return nil
+ }
+ }
+}
+
+func parseNote(t *tokens) *Note {
+ n := &Note{
+ Pos: t.Pos(),
+ Name: t.Consume(),
+ }
+
+ switch t.Token() {
+ case ',', '\n', scanner.EOF:
+ // no argument list present
+ return n
+ case '(':
+ n.Args = parseArgumentList(t)
+ return n
+ default:
+ t.Errorf("unexpected %s parsing note", t.TokenString())
+ return nil
+ }
+}
+
+func parseArgumentList(t *tokens) []interface{} {
+ args := []interface{}{} // @name() is represented by a non-nil empty slice.
+ t.Consume() // '('
+ t.Skip('\n')
+ for t.Token() != ')' {
+ args = append(args, parseArgument(t))
+ if t.Token() != ',' {
+ break
+ }
+ t.Consume()
+ t.Skip('\n')
+ }
+ if t.Token() != ')' {
+ t.Errorf("unexpected %s parsing argument list", t.TokenString())
+ return nil
+ }
+ t.Consume() // ')'
+ return args
+}
+
+func parseArgument(t *tokens) interface{} {
+ switch t.Token() {
+ case scanner.Ident:
+ v := t.Consume()
+ switch v {
+ case "true":
+ return true
+ case "false":
+ return false
+ case "nil":
+ return nil
+ case "re":
+ if t.Token() != scanner.String && t.Token() != scanner.RawString {
+ t.Errorf("re must be followed by string, got %s", t.TokenString())
+ return nil
+ }
+ pattern, _ := strconv.Unquote(t.Consume()) // can't fail
+ re, err := regexp.Compile(pattern)
+ if err != nil {
+ t.Errorf("invalid regular expression %s: %v", pattern, err)
+ return nil
+ }
+ return re
+ default:
+ return Identifier(v)
+ }
+
+ case scanner.String, scanner.RawString:
+ v, _ := strconv.Unquote(t.Consume()) // can't fail
+ return v
+
+ case scanner.Int:
+ s := t.Consume()
+ v, err := strconv.ParseInt(s, 0, 0)
+ if err != nil {
+ t.Errorf("cannot convert %v to int: %v", s, err)
+ }
+ return v
+
+ case scanner.Float:
+ s := t.Consume()
+ v, err := strconv.ParseFloat(s, 64)
+ if err != nil {
+ t.Errorf("cannot convert %v to float: %v", s, err)
+ }
+ return v
+
+ case scanner.Char:
+ t.Errorf("unexpected char literal %s", t.Consume())
+ return nil
+
+ default:
+ t.Errorf("unexpected %s parsing argument", t.TokenString())
+ return nil
+ }
+}