1 // Copyright (c) 2019, Daniel Martà <mvdan@mvdan.cc>
2 // See LICENSE for licensing information
4 // Package format exposes gofumpt's formatting in an API similar to go/format.
5 // In general, the APIs are only guaranteed to work well when the input source
6 // is in canonical gofmt format.
24 "github.com/google/go-cmp/cmp"
25 "golang.org/x/mod/semver"
26 "golang.org/x/tools/go/ast/astutil"
30 // LangVersion corresponds to the Go language version a piece of code is
31 // written in. The version is used to decide whether to apply formatting
32 // rules which require new language features. When inside a Go module,
33 // LangVersion should generally be specified as the result of:
35 // go list -m -f {{.GoVersion}}
37 // LangVersion is treated as a semantic version, which might start with
38 // a "v" prefix. Like Go versions, it might also be incomplete; "1.14"
39 // is equivalent to "1.14.0". When empty, it is equivalent to "v1", to
40 // not use language features which could break programs.
46 // Source formats src in gofumpt's format, assuming that src holds a valid Go
48 func Source(src []byte, opts Options) ([]byte, error) {
49 fset := token.NewFileSet()
50 file, err := parser.ParseFile(fset, "", src, parser.ParseComments)
55 File(fset, file, opts)
58 if err := format.Node(&buf, fset, file); err != nil {
61 return buf.Bytes(), nil
64 // File modifies a file and fset in place to follow gofumpt's format. The
65 // changes might include manipulating adding or removing newlines in fset,
66 // modifying the position of nodes, or modifying literal values.
67 func File(fset *token.FileSet, file *ast.File, opts Options) {
68 if opts.LangVersion == "" {
69 opts.LangVersion = "v1"
70 } else if opts.LangVersion[0] != 'v' {
71 opts.LangVersion = "v" + opts.LangVersion
73 if !semver.IsValid(opts.LangVersion) {
74 panic(fmt.Sprintf("invalid semver string: %q", opts.LangVersion))
77 File: fset.File(file.Pos()),
82 pre := func(c *astutil.Cursor) bool {
84 if _, ok := c.Node().(*ast.BlockStmt); ok {
89 post := func(c *astutil.Cursor) bool {
90 if _, ok := c.Node().(*ast.BlockStmt); ok {
95 astutil.Apply(file, pre, post)
98 // Multiline nodes which could fit on a single line under this many
99 // bytes may be collapsed onto a single line.
100 const shortLineLimit = 60
102 var rxOctalInteger = regexp.MustCompile(`\A0[0-7_]+\z`)
104 type fumpter struct {
115 func (f *fumpter) commentsBetween(p1, p2 token.Pos) []*ast.CommentGroup {
116 comments := f.astFile.Comments
117 i1 := sort.Search(len(comments), func(i int) bool {
118 return comments[i].Pos() >= p1
120 comments = comments[i1:]
121 i2 := sort.Search(len(comments), func(i int) bool {
122 return comments[i].Pos() >= p2
124 comments = comments[:i2]
128 func (f *fumpter) inlineComment(pos token.Pos) *ast.Comment {
129 comments := f.astFile.Comments
130 i := sort.Search(len(comments), func(i int) bool {
131 return comments[i].Pos() >= pos
133 if i >= len(comments) {
137 for _, comment := range comments[i].List {
138 if f.Line(comment.Pos()) == line {
145 // addNewline is a hack to let us force a newline at a certain position.
146 func (f *fumpter) addNewline(at token.Pos) {
147 offset := f.Offset(at)
149 field := reflect.ValueOf(f.File).Elem().FieldByName("lines")
151 lines := make([]int, 0, n+1)
152 for i := 0; i < n; i++ {
153 cur := int(field.Index(i).Int())
155 // This newline already exists; do nothing. Duplicate
156 // newlines can't exist.
159 if offset >= 0 && offset < cur {
160 lines = append(lines, offset)
163 lines = append(lines, cur)
166 lines = append(lines, offset)
168 if !f.SetLines(lines) {
169 panic(fmt.Sprintf("could not set lines to %v", lines))
173 // removeLines removes all newlines between two positions, so that they end
174 // up on the same line.
175 func (f *fumpter) removeLines(fromLine, toLine int) {
176 for fromLine < toLine {
177 f.MergeLine(fromLine)
182 // removeLinesBetween is like removeLines, but it leaves one newline between the
184 func (f *fumpter) removeLinesBetween(from, to token.Pos) {
185 f.removeLines(f.Line(from)+1, f.Line(to))
190 func (b *byteCounter) Write(p []byte) (n int, err error) {
191 *b += byteCounter(len(p))
195 func (f *fumpter) printLength(node ast.Node) int {
196 var count byteCounter
197 if err := format.Node(&count, f.fset, node); err != nil {
198 panic(fmt.Sprintf("unexpected print error: %v", err))
201 // Add the space taken by an inline comment.
202 if c := f.inlineComment(node.End()); c != nil {
203 fmt.Fprintf(&count, " %s", c.Text)
206 // Add an approximation of the indentation level. We can't know the
207 // number of tabs go/printer will add ahead of time. Trying to print the
208 // entire top-level declaration would tell us that, but then it's near
209 // impossible to reliably find our node again.
210 return int(count) + (f.blockLevel * 8)
213 // rxCommentDirective covers all common Go comment directives:
215 // //go: | standard Go directives, like go:noinline
216 // //someword: | similar to the syntax above, like lint:ignore
217 // //line | inserted line information for cmd/compile
218 // //export | to mark cgo funcs for exporting
219 // //extern | C function declarations for gccgo
220 // //sys(nb)? | syscall function wrapper prototypes
221 // //nolint | nolint directive for golangci
223 // Note that the "someword:" matching expects a letter afterward, such as
224 // "go:generate", to prevent matching false positives like "https://site".
225 var rxCommentDirective = regexp.MustCompile(`^([a-z]+:[a-z]+|line\b|export\b|extern\b|sys(nb)?\b|nolint\b)`)
227 // visit takes either an ast.Node or a []ast.Stmt.
228 func (f *fumpter) applyPre(c *astutil.Cursor) {
229 switch node := c.Node().(type) {
232 var lastEnd token.Pos
233 for _, decl := range node.Decls {
235 comments := f.commentsBetween(lastEnd, pos)
236 if len(comments) > 0 {
237 pos = comments[0].Pos()
240 // multiline top-level declarations should be separated
241 multi := f.Line(pos) < f.Line(decl.End())
242 if multi && lastMulti && f.Line(lastEnd)+1 == f.Line(pos) {
243 f.addNewline(lastEnd)
250 // Join contiguous lone var/const/import lines; abort if there
251 // are empty lines or comments in between.
252 newDecls := make([]ast.Decl, 0, len(node.Decls))
253 for i := 0; i < len(node.Decls); {
254 newDecls = append(newDecls, node.Decls[i])
255 start, ok := node.Decls[i].(*ast.GenDecl)
256 if !ok || isCgoImport(start) {
260 lastPos := start.Pos()
261 for i++; i < len(node.Decls); {
262 cont, ok := node.Decls[i].(*ast.GenDecl)
263 if !ok || cont.Tok != start.Tok || cont.Lparen != token.NoPos ||
264 f.Line(lastPos) < f.Line(cont.Pos())-1 || isCgoImport(cont) {
267 start.Specs = append(start.Specs, cont.Specs...)
268 if c := f.inlineComment(cont.End()); c != nil {
269 // don't move an inline comment outside
270 start.Rparen = c.End()
276 node.Decls = newDecls
278 // Comments aren't nodes, so they're not walked by default.
280 for _, group := range node.Comments {
281 for _, comment := range group.List {
282 body := strings.TrimPrefix(comment.Text, "//")
283 if body == comment.Text {
287 if rxCommentDirective.MatchString(body) {
288 // this line is a directive
291 r, _ := utf8.DecodeRuneInString(body)
292 if !unicode.IsLetter(r) && !unicode.IsNumber(r) && !unicode.IsSpace(r) {
293 // this line could be code like "//{"
297 // If none of the comment group's lines look like a
298 // directive or code, add spaces, if needed.
299 for _, comment := range group.List {
300 body := strings.TrimPrefix(comment.Text, "//")
301 r, _ := utf8.DecodeRuneInString(body)
302 if !unicode.IsSpace(r) {
303 comment.Text = "// " + strings.TrimPrefix(comment.Text, "//")
309 decl, ok := node.Decl.(*ast.GenDecl)
310 if !ok || decl.Tok != token.VAR || len(decl.Specs) != 1 {
311 break // e.g. const name = "value"
313 spec := decl.Specs[0].(*ast.ValueSpec)
314 if spec.Type != nil {
315 break // e.g. var name Type
318 names := make([]ast.Expr, len(spec.Names))
319 for i, name := range spec.Names {
321 if name.Name != "_" {
325 c.Replace(&ast.AssignStmt{
332 if node.Tok == token.IMPORT && node.Lparen.IsValid() {
333 f.joinStdImports(node)
336 // Single var declarations shouldn't use parentheses, unless
337 // there's a comment on the grouped declaration.
338 if node.Tok == token.VAR && len(node.Specs) == 1 &&
339 node.Lparen.IsValid() && node.Doc == nil {
340 specPos := node.Specs[0].Pos()
341 specEnd := node.Specs[0].End()
343 if len(f.commentsBetween(node.TokPos, specPos)) > 0 {
344 // If the single spec has any comment, it must
345 // go before the entire declaration now.
346 node.TokPos = specPos
348 f.removeLines(f.Line(node.TokPos), f.Line(specPos))
350 f.removeLines(f.Line(specEnd), f.Line(node.Rparen))
352 // Remove the parentheses. go/printer will automatically
353 // get rid of the newlines.
354 node.Lparen = token.NoPos
355 node.Rparen = token.NoPos
360 comments := f.commentsBetween(node.Lbrace, node.Rbrace)
361 if len(node.List) == 0 && len(comments) == 0 {
362 f.removeLinesBetween(node.Lbrace, node.Rbrace)
368 switch parent := c.Parent().(type) {
379 if len(node.List) > 1 && !isFuncBody {
380 // only if we have a single statement, or if
384 var bodyPos, bodyEnd token.Pos
386 if len(node.List) > 0 {
387 bodyPos = node.List[0].Pos()
388 bodyEnd = node.List[len(node.List)-1].End()
390 if len(comments) > 0 {
391 if pos := comments[0].Pos(); !bodyPos.IsValid() || pos < bodyPos {
394 if pos := comments[len(comments)-1].End(); !bodyPos.IsValid() || pos > bodyEnd {
399 if cond != nil && f.Line(cond.Pos()) != f.Line(cond.End()) {
400 // The body is preceded by a multi-line condition, so an
401 // empty line can help readability.
403 f.removeLinesBetween(node.Lbrace, bodyPos)
405 f.removeLinesBetween(bodyEnd, node.Rbrace)
407 case *ast.CompositeLit:
408 if len(node.Elts) == 0 {
409 // doesn't have elements
412 openLine := f.Line(node.Lbrace)
413 closeLine := f.Line(node.Rbrace)
414 if openLine == closeLine {
415 // all in a single line
419 newlineAroundElems := false
420 newlineBetweenElems := false
422 for i, elem := range node.Elts {
423 if f.Line(elem.Pos()) > lastLine {
425 newlineAroundElems = true
427 newlineBetweenElems = true
430 lastLine = f.Line(elem.End())
432 if closeLine > lastLine {
433 newlineAroundElems = true
436 if newlineBetweenElems || newlineAroundElems {
437 first := node.Elts[0]
438 if openLine == f.Line(first.Pos()) {
439 // We want the newline right after the brace.
440 f.addNewline(node.Lbrace + 1)
441 closeLine = f.Line(node.Rbrace)
443 last := node.Elts[len(node.Elts)-1]
444 if closeLine == f.Line(last.End()) {
445 // We want the newline right before the brace.
446 f.addNewline(node.Rbrace)
450 // If there's a newline between any consecutive elements, there
451 // must be a newline between all composite literal elements.
452 if !newlineBetweenElems {
455 for i1, elem1 := range node.Elts {
457 if i2 >= len(node.Elts) {
460 elem2 := node.Elts[i2]
461 // TODO: do we care about &{}?
462 _, ok1 := elem1.(*ast.CompositeLit)
463 _, ok2 := elem2.(*ast.CompositeLit)
467 if f.Line(elem1.End()) == f.Line(elem2.Pos()) {
468 f.addNewline(elem1.End())
472 case *ast.CaseClause:
474 openLine := f.Line(node.Case)
475 closeLine := f.Line(node.Colon)
476 if openLine == closeLine {
480 if len(f.commentsBetween(node.Case, node.Colon)) > 0 {
481 // don't move comments
484 if f.printLength(node) > shortLineLimit {
485 // too long to collapse
488 f.removeLines(openLine, closeLine)
490 case *ast.CommClause:
494 if node.NumFields() == 0 {
495 // Empty field lists should not contain a newline.
496 openLine := f.Line(node.Pos())
497 closeLine := f.Line(node.End())
498 f.removeLines(openLine, closeLine)
501 // Merging adjacent fields (e.g. parameters) is disabled by default.
505 switch c.Parent().(type) {
506 case *ast.FuncDecl, *ast.FuncType, *ast.InterfaceType:
507 node.List = f.mergeAdjacentFields(node.List)
509 case *ast.StructType:
510 // Do not merge adjacent fields in structs.
514 // Octal number literals were introduced in 1.13.
515 if semver.Compare(f.LangVersion, "v1.13") >= 0 {
516 if node.Kind == token.INT && rxOctalInteger.MatchString(node.Value) {
517 node.Value = "0o" + node.Value[1:]
524 func (f *fumpter) stmts(list []ast.Stmt) {
525 for i, stmt := range list {
526 ifs, ok := stmt.(*ast.IfStmt)
528 continue // not an if following another statement
530 as, ok := list[i-1].(*ast.AssignStmt)
531 if !ok || as.Tok != token.DEFINE ||
532 !identEqual(as.Lhs[len(as.Lhs)-1], "err") {
533 continue // not "..., err := ..."
535 be, ok := ifs.Cond.(*ast.BinaryExpr)
536 if !ok || ifs.Init != nil || ifs.Else != nil {
537 continue // complex if
539 if be.Op != token.NEQ || !identEqual(be.X, "err") ||
540 !identEqual(be.Y, "nil") {
541 continue // not "err != nil"
543 f.removeLinesBetween(as.End(), ifs.Pos())
547 func identEqual(expr ast.Expr, name string) bool {
548 id, ok := expr.(*ast.Ident)
549 return ok && id.Name == name
552 // isCgoImport returns true if the declaration is simply:
556 // Note that parentheses do not affect the result.
557 func isCgoImport(decl *ast.GenDecl) bool {
558 if decl.Tok != token.IMPORT || len(decl.Specs) != 1 {
561 spec := decl.Specs[0].(*ast.ImportSpec)
562 return spec.Path.Value == `"C"`
565 // joinStdImports ensures that all standard library imports are together and at
566 // the top of the imports list.
567 func (f *fumpter) joinStdImports(d *ast.GenDecl) {
568 var std, other []ast.Spec
572 for i, spec := range d.Specs {
573 spec := spec.(*ast.ImportSpec)
574 if coms := f.commentsBetween(lastEnd, spec.Pos()); len(coms) > 0 {
575 lastEnd = coms[len(coms)-1].End()
577 if i > 0 && firstGroup && f.Line(spec.Pos()) > f.Line(lastEnd)+1 {
580 // We're still in the first group, update lastEnd.
584 path, _ := strconv.Unquote(spec.Path.Value)
586 // Imports with a period are definitely third party.
587 case strings.Contains(path, "."):
589 // "test" and "example" are reserved as per golang.org/issue/37641.
590 // "internal" is unreachable.
591 case strings.HasPrefix(path, "test/") ||
592 strings.HasPrefix(path, "example/") ||
593 strings.HasPrefix(path, "internal/"):
595 // To be conservative, if an import has a name or an inline
596 // comment, and isn't part of the top group, treat it as non-std.
597 case !firstGroup && (spec.Name != nil || spec.Comment != nil):
598 other = append(other, spec)
602 // If we're moving this std import further up, reset its
603 // position, to avoid breaking comments.
604 if !firstGroup || len(other) > 0 {
605 setPos(reflect.ValueOf(spec), d.Pos())
608 std = append(std, spec)
610 // Ensure there is an empty line between std imports and other imports.
611 if len(std) > 0 && len(other) > 0 && f.Line(std[len(std)-1].End())+1 >= f.Line(other[0].Pos()) {
612 // We add two newlines, as that's necessary in some edge cases.
613 // For example, if the std and non-std imports were together and
614 // without indentation, adding one newline isn't enough. Two
615 // empty lines will be printed as one by go/printer, anyway.
616 f.addNewline(other[0].Pos() - 1)
617 f.addNewline(other[0].Pos())
619 // Finally, join the imports, keeping std at the top.
620 d.Specs = append(std, other...)
622 // If we moved any std imports to the first group, we need to sort them
625 ast.SortImports(f.fset, f.astFile)
629 // mergeAdjacentFields returns fields with adjacent fields merged if possible.
630 func (f *fumpter) mergeAdjacentFields(fields []*ast.Field) []*ast.Field {
631 // If there are less than two fields then there is nothing to merge.
636 // Otherwise, iterate over adjacent pairs of fields, merging if possible,
637 // and mutating fields. Elements of fields may be mutated (if merged with
638 // following fields), discarded (if merged with a preceeding field), or left
641 for j := 1; j < len(fields); j++ {
642 if f.shouldMergeAdjacentFields(fields[i], fields[j]) {
643 fields[i].Names = append(fields[i].Names, fields[j].Names...)
646 fields[i] = fields[j]
652 func (f *fumpter) shouldMergeAdjacentFields(f1, f2 *ast.Field) bool {
653 if len(f1.Names) == 0 || len(f2.Names) == 0 {
654 // Both must have names for the merge to work.
657 if f.Line(f1.Pos()) != f.Line(f2.Pos()) {
658 // Trust the user if they used separate lines.
662 // Only merge if the types are equal.
663 opt := cmp.Comparer(func(x, y token.Pos) bool { return true })
664 return cmp.Equal(f1.Type, f2.Type, opt)
667 var posType = reflect.TypeOf(token.NoPos)
669 // setPos recursively sets all position fields in the node v to pos.
670 func setPos(v reflect.Value, pos token.Pos) {
671 if v.Kind() == reflect.Ptr {
677 if v.Type() == posType {
678 v.Set(reflect.ValueOf(pos))
680 if v.Kind() == reflect.Struct {
681 for i := 0; i < v.NumField(); i++ {
682 setPos(v.Field(i), pos)