--- /dev/null
+// Copyright (c) 2019, Daniel Martà <mvdan@mvdan.cc>
+// See LICENSE for licensing information
+
+// Package format exposes gofumpt's formatting in an API similar to go/format.
+// In general, the APIs are only guaranteed to work well when the input source
+// is in canonical gofmt format.
+package format
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/format"
+ "go/parser"
+ "go/token"
+ "reflect"
+ "regexp"
+ "sort"
+ "strconv"
+ "strings"
+ "unicode"
+ "unicode/utf8"
+
+ "github.com/google/go-cmp/cmp"
+ "golang.org/x/mod/semver"
+ "golang.org/x/tools/go/ast/astutil"
+)
+
+type Options struct {
+ // LangVersion corresponds to the Go language version a piece of code is
+ // written in. The version is used to decide whether to apply formatting
+ // rules which require new language features. When inside a Go module,
+ // LangVersion should generally be specified as the result of:
+ //
+ // go list -m -f {{.GoVersion}}
+ //
+ // LangVersion is treated as a semantic version, which might start with
+ // a "v" prefix. Like Go versions, it might also be incomplete; "1.14"
+ // is equivalent to "1.14.0". When empty, it is equivalent to "v1", to
+ // not use language features which could break programs.
+ LangVersion string
+
+ ExtraRules bool
+}
+
+// Source formats src in gofumpt's format, assuming that src holds a valid Go
+// source file.
+func Source(src []byte, opts Options) ([]byte, error) {
+ fset := token.NewFileSet()
+ file, err := parser.ParseFile(fset, "", src, parser.ParseComments)
+ if err != nil {
+ return nil, err
+ }
+
+ File(fset, file, opts)
+
+ var buf bytes.Buffer
+ if err := format.Node(&buf, fset, file); err != nil {
+ return nil, err
+ }
+ return buf.Bytes(), nil
+}
+
+// File modifies a file and fset in place to follow gofumpt's format. The
+// changes might include manipulating adding or removing newlines in fset,
+// modifying the position of nodes, or modifying literal values.
+func File(fset *token.FileSet, file *ast.File, opts Options) {
+ if opts.LangVersion == "" {
+ opts.LangVersion = "v1"
+ } else if opts.LangVersion[0] != 'v' {
+ opts.LangVersion = "v" + opts.LangVersion
+ }
+ if !semver.IsValid(opts.LangVersion) {
+ panic(fmt.Sprintf("invalid semver string: %q", opts.LangVersion))
+ }
+ f := &fumpter{
+ File: fset.File(file.Pos()),
+ fset: fset,
+ astFile: file,
+ Options: opts,
+ }
+ pre := func(c *astutil.Cursor) bool {
+ f.applyPre(c)
+ if _, ok := c.Node().(*ast.BlockStmt); ok {
+ f.blockLevel++
+ }
+ return true
+ }
+ post := func(c *astutil.Cursor) bool {
+ if _, ok := c.Node().(*ast.BlockStmt); ok {
+ f.blockLevel--
+ }
+ return true
+ }
+ astutil.Apply(file, pre, post)
+}
+
+// Multiline nodes which could fit on a single line under this many
+// bytes may be collapsed onto a single line.
+const shortLineLimit = 60
+
+var rxOctalInteger = regexp.MustCompile(`\A0[0-7_]+\z`)
+
+type fumpter struct {
+ Options
+
+ *token.File
+ fset *token.FileSet
+
+ astFile *ast.File
+
+ blockLevel int
+}
+
+func (f *fumpter) commentsBetween(p1, p2 token.Pos) []*ast.CommentGroup {
+ comments := f.astFile.Comments
+ i1 := sort.Search(len(comments), func(i int) bool {
+ return comments[i].Pos() >= p1
+ })
+ comments = comments[i1:]
+ i2 := sort.Search(len(comments), func(i int) bool {
+ return comments[i].Pos() >= p2
+ })
+ comments = comments[:i2]
+ return comments
+}
+
+func (f *fumpter) inlineComment(pos token.Pos) *ast.Comment {
+ comments := f.astFile.Comments
+ i := sort.Search(len(comments), func(i int) bool {
+ return comments[i].Pos() >= pos
+ })
+ if i >= len(comments) {
+ return nil
+ }
+ line := f.Line(pos)
+ for _, comment := range comments[i].List {
+ if f.Line(comment.Pos()) == line {
+ return comment
+ }
+ }
+ return nil
+}
+
+// addNewline is a hack to let us force a newline at a certain position.
+func (f *fumpter) addNewline(at token.Pos) {
+ offset := f.Offset(at)
+
+ field := reflect.ValueOf(f.File).Elem().FieldByName("lines")
+ n := field.Len()
+ lines := make([]int, 0, n+1)
+ for i := 0; i < n; i++ {
+ cur := int(field.Index(i).Int())
+ if offset == cur {
+ // This newline already exists; do nothing. Duplicate
+ // newlines can't exist.
+ return
+ }
+ if offset >= 0 && offset < cur {
+ lines = append(lines, offset)
+ offset = -1
+ }
+ lines = append(lines, cur)
+ }
+ if offset >= 0 {
+ lines = append(lines, offset)
+ }
+ if !f.SetLines(lines) {
+ panic(fmt.Sprintf("could not set lines to %v", lines))
+ }
+}
+
+// removeNewlines removes all newlines between two positions, so that they end
+// up on the same line.
+func (f *fumpter) removeLines(fromLine, toLine int) {
+ for fromLine < toLine {
+ f.MergeLine(fromLine)
+ toLine--
+ }
+}
+
+// removeLinesBetween is like removeLines, but it leaves one newline between the
+// two positions.
+func (f *fumpter) removeLinesBetween(from, to token.Pos) {
+ f.removeLines(f.Line(from)+1, f.Line(to))
+}
+
+type byteCounter int
+
+func (b *byteCounter) Write(p []byte) (n int, err error) {
+ *b += byteCounter(len(p))
+ return len(p), nil
+}
+
+func (f *fumpter) printLength(node ast.Node) int {
+ var count byteCounter
+ if err := format.Node(&count, f.fset, node); err != nil {
+ panic(fmt.Sprintf("unexpected print error: %v", err))
+ }
+
+ // Add the space taken by an inline comment.
+ if c := f.inlineComment(node.End()); c != nil {
+ fmt.Fprintf(&count, " %s", c.Text)
+ }
+
+ // Add an approximation of the indentation level. We can't know the
+ // number of tabs go/printer will add ahead of time. Trying to print the
+ // entire top-level declaration would tell us that, but then it's near
+ // impossible to reliably find our node again.
+ return int(count) + (f.blockLevel * 8)
+}
+
+// rxCommentDirective covers all common Go comment directives:
+//
+// //go: | standard Go directives, like go:noinline
+// //someword: | similar to the syntax above, like lint:ignore
+// //line | inserted line information for cmd/compile
+// //export | to mark cgo funcs for exporting
+// //extern | C function declarations for gccgo
+// //sys(nb)? | syscall function wrapper prototypes
+// //nolint | nolint directive for golangci
+var rxCommentDirective = regexp.MustCompile(`^([a-z]+:|line\b|export\b|extern\b|sys(nb)?\b|nolint\b)`)
+
+// visit takes either an ast.Node or a []ast.Stmt.
+func (f *fumpter) applyPre(c *astutil.Cursor) {
+ switch node := c.Node().(type) {
+ case *ast.File:
+ var lastMulti bool
+ var lastEnd token.Pos
+ for _, decl := range node.Decls {
+ pos := decl.Pos()
+ comments := f.commentsBetween(lastEnd, pos)
+ if len(comments) > 0 {
+ pos = comments[0].Pos()
+ }
+
+ // multiline top-level declarations should be separated
+ multi := f.Line(pos) < f.Line(decl.End())
+ if multi && lastMulti && f.Line(lastEnd)+1 == f.Line(pos) {
+ f.addNewline(lastEnd)
+ }
+
+ lastMulti = multi
+ lastEnd = decl.End()
+ }
+
+ // Join contiguous lone var/const/import lines; abort if there
+ // are empty lines or comments in between.
+ newDecls := make([]ast.Decl, 0, len(node.Decls))
+ for i := 0; i < len(node.Decls); {
+ newDecls = append(newDecls, node.Decls[i])
+ start, ok := node.Decls[i].(*ast.GenDecl)
+ if !ok {
+ i++
+ continue
+ }
+ lastPos := start.Pos()
+ for i++; i < len(node.Decls); {
+ cont, ok := node.Decls[i].(*ast.GenDecl)
+ if !ok || cont.Tok != start.Tok || cont.Lparen != token.NoPos ||
+ f.Line(lastPos) < f.Line(cont.Pos())-1 {
+ break
+ }
+ start.Specs = append(start.Specs, cont.Specs...)
+ if c := f.inlineComment(cont.End()); c != nil {
+ // don't move an inline comment outside
+ start.Rparen = c.End()
+ }
+ lastPos = cont.Pos()
+ i++
+ }
+ }
+ node.Decls = newDecls
+
+ // Comments aren't nodes, so they're not walked by default.
+ groupLoop:
+ for _, group := range node.Comments {
+ for _, comment := range group.List {
+ body := strings.TrimPrefix(comment.Text, "//")
+ if body == comment.Text {
+ // /*-style comment
+ continue groupLoop
+ }
+ if rxCommentDirective.MatchString(body) {
+ // this line is a directive
+ continue groupLoop
+ }
+ r, _ := utf8.DecodeRuneInString(body)
+ if !unicode.IsLetter(r) && !unicode.IsNumber(r) && !unicode.IsSpace(r) {
+ // this line could be code like "//{"
+ continue groupLoop
+ }
+ }
+ // If none of the comment group's lines look like a
+ // directive or code, add spaces, if needed.
+ for _, comment := range group.List {
+ body := strings.TrimPrefix(comment.Text, "//")
+ r, _ := utf8.DecodeRuneInString(body)
+ if !unicode.IsSpace(r) {
+ comment.Text = "// " + strings.TrimPrefix(comment.Text, "//")
+ }
+ }
+ }
+
+ case *ast.DeclStmt:
+ decl, ok := node.Decl.(*ast.GenDecl)
+ if !ok || decl.Tok != token.VAR || len(decl.Specs) != 1 {
+ break // e.g. const name = "value"
+ }
+ spec := decl.Specs[0].(*ast.ValueSpec)
+ if spec.Type != nil {
+ break // e.g. var name Type
+ }
+ tok := token.ASSIGN
+ names := make([]ast.Expr, len(spec.Names))
+ for i, name := range spec.Names {
+ names[i] = name
+ if name.Name != "_" {
+ tok = token.DEFINE
+ }
+ }
+ c.Replace(&ast.AssignStmt{
+ Lhs: names,
+ Tok: tok,
+ Rhs: spec.Values,
+ })
+
+ case *ast.GenDecl:
+ if node.Tok == token.IMPORT && node.Lparen.IsValid() {
+ f.joinStdImports(node)
+ }
+
+ // Single var declarations shouldn't use parentheses, unless
+ // there's a comment on the grouped declaration.
+ if node.Tok == token.VAR && len(node.Specs) == 1 &&
+ node.Lparen.IsValid() && node.Doc == nil {
+ specPos := node.Specs[0].Pos()
+ specEnd := node.Specs[0].End()
+
+ if len(f.commentsBetween(node.TokPos, specPos)) > 0 {
+ // If the single spec has any comment, it must
+ // go before the entire declaration now.
+ node.TokPos = specPos
+ } else {
+ f.removeLines(f.Line(node.TokPos), f.Line(specPos))
+ }
+ f.removeLines(f.Line(specEnd), f.Line(node.Rparen))
+
+ // Remove the parentheses. go/printer will automatically
+ // get rid of the newlines.
+ node.Lparen = token.NoPos
+ node.Rparen = token.NoPos
+ }
+
+ case *ast.BlockStmt:
+ f.stmts(node.List)
+ comments := f.commentsBetween(node.Lbrace, node.Rbrace)
+ if len(node.List) == 0 && len(comments) == 0 {
+ f.removeLinesBetween(node.Lbrace, node.Rbrace)
+ break
+ }
+
+ isFuncBody := false
+ switch c.Parent().(type) {
+ case *ast.FuncDecl:
+ isFuncBody = true
+ case *ast.FuncLit:
+ isFuncBody = true
+ }
+
+ if len(node.List) > 1 && !isFuncBody {
+ // only if we have a single statement, or if
+ // it's a func body.
+ break
+ }
+ var bodyPos, bodyEnd token.Pos
+
+ if len(node.List) > 0 {
+ bodyPos = node.List[0].Pos()
+ bodyEnd = node.List[len(node.List)-1].End()
+ }
+ if len(comments) > 0 {
+ if pos := comments[0].Pos(); !bodyPos.IsValid() || pos < bodyPos {
+ bodyPos = pos
+ }
+ if pos := comments[len(comments)-1].End(); !bodyPos.IsValid() || pos > bodyEnd {
+ bodyEnd = pos
+ }
+ }
+
+ f.removeLinesBetween(node.Lbrace, bodyPos)
+ f.removeLinesBetween(bodyEnd, node.Rbrace)
+
+ case *ast.CompositeLit:
+ if len(node.Elts) == 0 {
+ // doesn't have elements
+ break
+ }
+ openLine := f.Line(node.Lbrace)
+ closeLine := f.Line(node.Rbrace)
+ if openLine == closeLine {
+ // all in a single line
+ break
+ }
+
+ newlineAroundElems := false
+ newlineBetweenElems := false
+ lastLine := openLine
+ for i, elem := range node.Elts {
+ if f.Line(elem.Pos()) > lastLine {
+ if i == 0 {
+ newlineAroundElems = true
+ } else {
+ newlineBetweenElems = true
+ }
+ }
+ lastLine = f.Line(elem.End())
+ }
+ if closeLine > lastLine {
+ newlineAroundElems = true
+ }
+
+ if newlineBetweenElems || newlineAroundElems {
+ first := node.Elts[0]
+ if openLine == f.Line(first.Pos()) {
+ // We want the newline right after the brace.
+ f.addNewline(node.Lbrace + 1)
+ closeLine = f.Line(node.Rbrace)
+ }
+ last := node.Elts[len(node.Elts)-1]
+ if closeLine == f.Line(last.End()) {
+ // We want the newline right before the brace.
+ f.addNewline(node.Rbrace)
+ }
+ }
+
+ // If there's a newline between any consecutive elements, there
+ // must be a newline between all composite literal elements.
+ if !newlineBetweenElems {
+ break
+ }
+ for i1, elem1 := range node.Elts {
+ i2 := i1 + 1
+ if i2 >= len(node.Elts) {
+ break
+ }
+ elem2 := node.Elts[i2]
+ // TODO: do we care about &{}?
+ _, ok1 := elem1.(*ast.CompositeLit)
+ _, ok2 := elem2.(*ast.CompositeLit)
+ if !ok1 && !ok2 {
+ continue
+ }
+ if f.Line(elem1.End()) == f.Line(elem2.Pos()) {
+ f.addNewline(elem1.End())
+ }
+ }
+
+ case *ast.CaseClause:
+ f.stmts(node.Body)
+ openLine := f.Line(node.Case)
+ closeLine := f.Line(node.Colon)
+ if openLine == closeLine {
+ // nothing to do
+ break
+ }
+ if len(f.commentsBetween(node.Case, node.Colon)) > 0 {
+ // don't move comments
+ break
+ }
+ if f.printLength(node) > shortLineLimit {
+ // too long to collapse
+ break
+ }
+ f.removeLines(openLine, closeLine)
+
+ case *ast.CommClause:
+ f.stmts(node.Body)
+
+ case *ast.FieldList:
+ // Merging adjacent fields (e.g. parameters) is disabled by default.
+ if !f.ExtraRules {
+ break
+ }
+ switch c.Parent().(type) {
+ case *ast.FuncDecl, *ast.FuncType, *ast.InterfaceType:
+ node.List = f.mergeAdjacentFields(node.List)
+ c.Replace(node)
+ case *ast.StructType:
+ // Do not merge adjacent fields in structs.
+ }
+
+ case *ast.BasicLit:
+ // Octal number literals were introduced in 1.13.
+ if semver.Compare(f.LangVersion, "v1.13") >= 0 {
+ if node.Kind == token.INT && rxOctalInteger.MatchString(node.Value) {
+ node.Value = "0o" + node.Value[1:]
+ c.Replace(node)
+ }
+ }
+ }
+}
+
+func (f *fumpter) stmts(list []ast.Stmt) {
+ for i, stmt := range list {
+ ifs, ok := stmt.(*ast.IfStmt)
+ if !ok || i < 1 {
+ continue // not an if following another statement
+ }
+ as, ok := list[i-1].(*ast.AssignStmt)
+ if !ok || as.Tok != token.DEFINE ||
+ !identEqual(as.Lhs[len(as.Lhs)-1], "err") {
+ continue // not "..., err := ..."
+ }
+ be, ok := ifs.Cond.(*ast.BinaryExpr)
+ if !ok || ifs.Init != nil || ifs.Else != nil {
+ continue // complex if
+ }
+ if be.Op != token.NEQ || !identEqual(be.X, "err") ||
+ !identEqual(be.Y, "nil") {
+ continue // not "err != nil"
+ }
+ f.removeLinesBetween(as.End(), ifs.Pos())
+ }
+}
+
+func identEqual(expr ast.Expr, name string) bool {
+ id, ok := expr.(*ast.Ident)
+ return ok && id.Name == name
+}
+
+// joinStdImports ensures that all standard library imports are together and at
+// the top of the imports list.
+func (f *fumpter) joinStdImports(d *ast.GenDecl) {
+ var std, other []ast.Spec
+ firstGroup := true
+ lastEnd := d.Pos()
+ needsSort := false
+ for i, spec := range d.Specs {
+ spec := spec.(*ast.ImportSpec)
+ if coms := f.commentsBetween(lastEnd, spec.Pos()); len(coms) > 0 {
+ lastEnd = coms[len(coms)-1].End()
+ }
+ if i > 0 && firstGroup && f.Line(spec.Pos()) > f.Line(lastEnd)+1 {
+ firstGroup = false
+ } else {
+ // We're still in the first group, update lastEnd.
+ lastEnd = spec.End()
+ }
+
+ path, _ := strconv.Unquote(spec.Path.Value)
+ switch {
+ // Imports with a period are definitely third party.
+ case strings.Contains(path, "."):
+ fallthrough
+ // "test" and "example" are reserved as per golang.org/issue/37641.
+ // "internal" is unreachable.
+ case strings.HasPrefix(path, "test/") ||
+ strings.HasPrefix(path, "example/") ||
+ strings.HasPrefix(path, "internal/"):
+ fallthrough
+ // To be conservative, if an import has a name or an inline
+ // comment, and isn't part of the top group, treat it as non-std.
+ case !firstGroup && (spec.Name != nil || spec.Comment != nil):
+ other = append(other, spec)
+ continue
+ }
+
+ // If we're moving this std import further up, reset its
+ // position, to avoid breaking comments.
+ if !firstGroup || len(other) > 0 {
+ setPos(reflect.ValueOf(spec), d.Pos())
+ needsSort = true
+ }
+ std = append(std, spec)
+ }
+ // Ensure there is an empty line between std imports and other imports.
+ if len(std) > 0 && len(other) > 0 && f.Line(std[len(std)-1].End())+1 >= f.Line(other[0].Pos()) {
+ // We add two newlines, as that's necessary in some edge cases.
+ // For example, if the std and non-std imports were together and
+ // without indentation, adding one newline isn't enough. Two
+ // empty lines will be printed as one by go/printer, anyway.
+ f.addNewline(other[0].Pos() - 1)
+ f.addNewline(other[0].Pos())
+ }
+ // Finally, join the imports, keeping std at the top.
+ d.Specs = append(std, other...)
+
+ // If we moved any std imports to the first group, we need to sort them
+ // again.
+ if needsSort {
+ ast.SortImports(f.fset, f.astFile)
+ }
+}
+
+// mergeAdjacentFields returns fields with adjacent fields merged if possible.
+func (f *fumpter) mergeAdjacentFields(fields []*ast.Field) []*ast.Field {
+ // If there are less than two fields then there is nothing to merge.
+ if len(fields) < 2 {
+ return fields
+ }
+
+ // Otherwise, iterate over adjacent pairs of fields, merging if possible,
+ // and mutating fields. Elements of fields may be mutated (if merged with
+ // following fields), discarded (if merged with a preceeding field), or left
+ // unchanged.
+ i := 0
+ for j := 1; j < len(fields); j++ {
+ if f.shouldMergeAdjacentFields(fields[i], fields[j]) {
+ fields[i].Names = append(fields[i].Names, fields[j].Names...)
+ } else {
+ i++
+ fields[i] = fields[j]
+ }
+ }
+ return fields[:i+1]
+}
+
+func (f *fumpter) shouldMergeAdjacentFields(f1, f2 *ast.Field) bool {
+ if len(f1.Names) == 0 || len(f2.Names) == 0 {
+ // Both must have names for the merge to work.
+ return false
+ }
+ if f.Line(f1.Pos()) != f.Line(f2.Pos()) {
+ // Trust the user if they used separate lines.
+ return false
+ }
+
+ // Only merge if the types are equal.
+ opt := cmp.Comparer(func(x, y token.Pos) bool { return true })
+ return cmp.Equal(f1.Type, f2.Type, opt)
+}
+
+var posType = reflect.TypeOf(token.NoPos)
+
+// setPos recursively sets all position fields in the node v to pos.
+func setPos(v reflect.Value, pos token.Pos) {
+ if v.Kind() == reflect.Ptr {
+ v = v.Elem()
+ }
+ if !v.IsValid() {
+ return
+ }
+ if v.Type() == posType {
+ v.Set(reflect.ValueOf(pos))
+ }
+ if v.Kind() == reflect.Struct {
+ for i := 0; i < v.NumField(); i++ {
+ setPos(v.Field(i), pos)
+ }
+ }
+}