1 // Copyright (c) 2019, Daniel Martà <mvdan@mvdan.cc>
2 // See LICENSE for licensing information
4 // Package format exposes gofumpt's formatting in an API similar to go/format.
5 // In general, the APIs are only guaranteed to work well when the input source
6 // is in canonical gofmt format.
24 "github.com/google/go-cmp/cmp"
25 "golang.org/x/mod/semver"
26 "golang.org/x/tools/go/ast/astutil"
30 // LangVersion corresponds to the Go language version a piece of code is
31 // written in. The version is used to decide whether to apply formatting
32 // rules which require new language features. When inside a Go module,
33 // LangVersion should generally be specified as the result of:
35 // go list -m -f {{.GoVersion}}
37 // LangVersion is treated as a semantic version, which might start with
38 // a "v" prefix. Like Go versions, it might also be incomplete; "1.14"
39 // is equivalent to "1.14.0". When empty, it is equivalent to "v1", to
40 // not use language features which could break programs.
46 // Source formats src in gofumpt's format, assuming that src holds a valid Go
48 func Source(src []byte, opts Options) ([]byte, error) {
49 fset := token.NewFileSet()
50 file, err := parser.ParseFile(fset, "", src, parser.ParseComments)
55 File(fset, file, opts)
58 if err := format.Node(&buf, fset, file); err != nil {
61 return buf.Bytes(), nil
64 // File modifies a file and fset in place to follow gofumpt's format. The
65 // changes might include manipulating adding or removing newlines in fset,
66 // modifying the position of nodes, or modifying literal values.
67 func File(fset *token.FileSet, file *ast.File, opts Options) {
68 if opts.LangVersion == "" {
69 opts.LangVersion = "v1"
70 } else if opts.LangVersion[0] != 'v' {
71 opts.LangVersion = "v" + opts.LangVersion
73 if !semver.IsValid(opts.LangVersion) {
74 panic(fmt.Sprintf("invalid semver string: %q", opts.LangVersion))
77 File: fset.File(file.Pos()),
82 pre := func(c *astutil.Cursor) bool {
84 if _, ok := c.Node().(*ast.BlockStmt); ok {
89 post := func(c *astutil.Cursor) bool {
90 if _, ok := c.Node().(*ast.BlockStmt); ok {
95 astutil.Apply(file, pre, post)
98 // Multiline nodes which could fit on a single line under this many
99 // bytes may be collapsed onto a single line.
100 const shortLineLimit = 60
102 var rxOctalInteger = regexp.MustCompile(`\A0[0-7_]+\z`)
104 type fumpter struct {
115 func (f *fumpter) commentsBetween(p1, p2 token.Pos) []*ast.CommentGroup {
116 comments := f.astFile.Comments
117 i1 := sort.Search(len(comments), func(i int) bool {
118 return comments[i].Pos() >= p1
120 comments = comments[i1:]
121 i2 := sort.Search(len(comments), func(i int) bool {
122 return comments[i].Pos() >= p2
124 comments = comments[:i2]
128 func (f *fumpter) inlineComment(pos token.Pos) *ast.Comment {
129 comments := f.astFile.Comments
130 i := sort.Search(len(comments), func(i int) bool {
131 return comments[i].Pos() >= pos
133 if i >= len(comments) {
137 for _, comment := range comments[i].List {
138 if f.Line(comment.Pos()) == line {
145 // addNewline is a hack to let us force a newline at a certain position.
146 func (f *fumpter) addNewline(at token.Pos) {
147 offset := f.Offset(at)
149 field := reflect.ValueOf(f.File).Elem().FieldByName("lines")
151 lines := make([]int, 0, n+1)
152 for i := 0; i < n; i++ {
153 cur := int(field.Index(i).Int())
155 // This newline already exists; do nothing. Duplicate
156 // newlines can't exist.
159 if offset >= 0 && offset < cur {
160 lines = append(lines, offset)
163 lines = append(lines, cur)
166 lines = append(lines, offset)
168 if !f.SetLines(lines) {
169 panic(fmt.Sprintf("could not set lines to %v", lines))
173 // removeNewlines removes all newlines between two positions, so that they end
174 // up on the same line.
175 func (f *fumpter) removeLines(fromLine, toLine int) {
176 for fromLine < toLine {
177 f.MergeLine(fromLine)
182 // removeLinesBetween is like removeLines, but it leaves one newline between the
184 func (f *fumpter) removeLinesBetween(from, to token.Pos) {
185 f.removeLines(f.Line(from)+1, f.Line(to))
190 func (b *byteCounter) Write(p []byte) (n int, err error) {
191 *b += byteCounter(len(p))
195 func (f *fumpter) printLength(node ast.Node) int {
196 var count byteCounter
197 if err := format.Node(&count, f.fset, node); err != nil {
198 panic(fmt.Sprintf("unexpected print error: %v", err))
201 // Add the space taken by an inline comment.
202 if c := f.inlineComment(node.End()); c != nil {
203 fmt.Fprintf(&count, " %s", c.Text)
206 // Add an approximation of the indentation level. We can't know the
207 // number of tabs go/printer will add ahead of time. Trying to print the
208 // entire top-level declaration would tell us that, but then it's near
209 // impossible to reliably find our node again.
210 return int(count) + (f.blockLevel * 8)
213 // rxCommentDirective covers all common Go comment directives:
215 // //go: | standard Go directives, like go:noinline
216 // //someword: | similar to the syntax above, like lint:ignore
217 // //line | inserted line information for cmd/compile
218 // //export | to mark cgo funcs for exporting
219 // //extern | C function declarations for gccgo
220 // //sys(nb)? | syscall function wrapper prototypes
221 // //nolint | nolint directive for golangci
222 var rxCommentDirective = regexp.MustCompile(`^([a-z]+:|line\b|export\b|extern\b|sys(nb)?\b|nolint\b)`)
224 // visit takes either an ast.Node or a []ast.Stmt.
225 func (f *fumpter) applyPre(c *astutil.Cursor) {
226 switch node := c.Node().(type) {
229 var lastEnd token.Pos
230 for _, decl := range node.Decls {
232 comments := f.commentsBetween(lastEnd, pos)
233 if len(comments) > 0 {
234 pos = comments[0].Pos()
237 // multiline top-level declarations should be separated
238 multi := f.Line(pos) < f.Line(decl.End())
239 if multi && lastMulti && f.Line(lastEnd)+1 == f.Line(pos) {
240 f.addNewline(lastEnd)
247 // Join contiguous lone var/const/import lines; abort if there
248 // are empty lines or comments in between.
249 newDecls := make([]ast.Decl, 0, len(node.Decls))
250 for i := 0; i < len(node.Decls); {
251 newDecls = append(newDecls, node.Decls[i])
252 start, ok := node.Decls[i].(*ast.GenDecl)
257 lastPos := start.Pos()
258 for i++; i < len(node.Decls); {
259 cont, ok := node.Decls[i].(*ast.GenDecl)
260 if !ok || cont.Tok != start.Tok || cont.Lparen != token.NoPos ||
261 f.Line(lastPos) < f.Line(cont.Pos())-1 {
264 start.Specs = append(start.Specs, cont.Specs...)
265 if c := f.inlineComment(cont.End()); c != nil {
266 // don't move an inline comment outside
267 start.Rparen = c.End()
273 node.Decls = newDecls
275 // Comments aren't nodes, so they're not walked by default.
277 for _, group := range node.Comments {
278 for _, comment := range group.List {
279 body := strings.TrimPrefix(comment.Text, "//")
280 if body == comment.Text {
284 if rxCommentDirective.MatchString(body) {
285 // this line is a directive
288 r, _ := utf8.DecodeRuneInString(body)
289 if !unicode.IsLetter(r) && !unicode.IsNumber(r) && !unicode.IsSpace(r) {
290 // this line could be code like "//{"
294 // If none of the comment group's lines look like a
295 // directive or code, add spaces, if needed.
296 for _, comment := range group.List {
297 body := strings.TrimPrefix(comment.Text, "//")
298 r, _ := utf8.DecodeRuneInString(body)
299 if !unicode.IsSpace(r) {
300 comment.Text = "// " + strings.TrimPrefix(comment.Text, "//")
306 decl, ok := node.Decl.(*ast.GenDecl)
307 if !ok || decl.Tok != token.VAR || len(decl.Specs) != 1 {
308 break // e.g. const name = "value"
310 spec := decl.Specs[0].(*ast.ValueSpec)
311 if spec.Type != nil {
312 break // e.g. var name Type
315 names := make([]ast.Expr, len(spec.Names))
316 for i, name := range spec.Names {
318 if name.Name != "_" {
322 c.Replace(&ast.AssignStmt{
329 if node.Tok == token.IMPORT && node.Lparen.IsValid() {
330 f.joinStdImports(node)
333 // Single var declarations shouldn't use parentheses, unless
334 // there's a comment on the grouped declaration.
335 if node.Tok == token.VAR && len(node.Specs) == 1 &&
336 node.Lparen.IsValid() && node.Doc == nil {
337 specPos := node.Specs[0].Pos()
338 specEnd := node.Specs[0].End()
340 if len(f.commentsBetween(node.TokPos, specPos)) > 0 {
341 // If the single spec has any comment, it must
342 // go before the entire declaration now.
343 node.TokPos = specPos
345 f.removeLines(f.Line(node.TokPos), f.Line(specPos))
347 f.removeLines(f.Line(specEnd), f.Line(node.Rparen))
349 // Remove the parentheses. go/printer will automatically
350 // get rid of the newlines.
351 node.Lparen = token.NoPos
352 node.Rparen = token.NoPos
357 comments := f.commentsBetween(node.Lbrace, node.Rbrace)
358 if len(node.List) == 0 && len(comments) == 0 {
359 f.removeLinesBetween(node.Lbrace, node.Rbrace)
364 switch c.Parent().(type) {
371 if len(node.List) > 1 && !isFuncBody {
372 // only if we have a single statement, or if
376 var bodyPos, bodyEnd token.Pos
378 if len(node.List) > 0 {
379 bodyPos = node.List[0].Pos()
380 bodyEnd = node.List[len(node.List)-1].End()
382 if len(comments) > 0 {
383 if pos := comments[0].Pos(); !bodyPos.IsValid() || pos < bodyPos {
386 if pos := comments[len(comments)-1].End(); !bodyPos.IsValid() || pos > bodyEnd {
391 f.removeLinesBetween(node.Lbrace, bodyPos)
392 f.removeLinesBetween(bodyEnd, node.Rbrace)
394 case *ast.CompositeLit:
395 if len(node.Elts) == 0 {
396 // doesn't have elements
399 openLine := f.Line(node.Lbrace)
400 closeLine := f.Line(node.Rbrace)
401 if openLine == closeLine {
402 // all in a single line
406 newlineAroundElems := false
407 newlineBetweenElems := false
409 for i, elem := range node.Elts {
410 if f.Line(elem.Pos()) > lastLine {
412 newlineAroundElems = true
414 newlineBetweenElems = true
417 lastLine = f.Line(elem.End())
419 if closeLine > lastLine {
420 newlineAroundElems = true
423 if newlineBetweenElems || newlineAroundElems {
424 first := node.Elts[0]
425 if openLine == f.Line(first.Pos()) {
426 // We want the newline right after the brace.
427 f.addNewline(node.Lbrace + 1)
428 closeLine = f.Line(node.Rbrace)
430 last := node.Elts[len(node.Elts)-1]
431 if closeLine == f.Line(last.End()) {
432 // We want the newline right before the brace.
433 f.addNewline(node.Rbrace)
437 // If there's a newline between any consecutive elements, there
438 // must be a newline between all composite literal elements.
439 if !newlineBetweenElems {
442 for i1, elem1 := range node.Elts {
444 if i2 >= len(node.Elts) {
447 elem2 := node.Elts[i2]
448 // TODO: do we care about &{}?
449 _, ok1 := elem1.(*ast.CompositeLit)
450 _, ok2 := elem2.(*ast.CompositeLit)
454 if f.Line(elem1.End()) == f.Line(elem2.Pos()) {
455 f.addNewline(elem1.End())
459 case *ast.CaseClause:
461 openLine := f.Line(node.Case)
462 closeLine := f.Line(node.Colon)
463 if openLine == closeLine {
467 if len(f.commentsBetween(node.Case, node.Colon)) > 0 {
468 // don't move comments
471 if f.printLength(node) > shortLineLimit {
472 // too long to collapse
475 f.removeLines(openLine, closeLine)
477 case *ast.CommClause:
481 // Merging adjacent fields (e.g. parameters) is disabled by default.
485 switch c.Parent().(type) {
486 case *ast.FuncDecl, *ast.FuncType, *ast.InterfaceType:
487 node.List = f.mergeAdjacentFields(node.List)
489 case *ast.StructType:
490 // Do not merge adjacent fields in structs.
494 // Octal number literals were introduced in 1.13.
495 if semver.Compare(f.LangVersion, "v1.13") >= 0 {
496 if node.Kind == token.INT && rxOctalInteger.MatchString(node.Value) {
497 node.Value = "0o" + node.Value[1:]
504 func (f *fumpter) stmts(list []ast.Stmt) {
505 for i, stmt := range list {
506 ifs, ok := stmt.(*ast.IfStmt)
508 continue // not an if following another statement
510 as, ok := list[i-1].(*ast.AssignStmt)
511 if !ok || as.Tok != token.DEFINE ||
512 !identEqual(as.Lhs[len(as.Lhs)-1], "err") {
513 continue // not "..., err := ..."
515 be, ok := ifs.Cond.(*ast.BinaryExpr)
516 if !ok || ifs.Init != nil || ifs.Else != nil {
517 continue // complex if
519 if be.Op != token.NEQ || !identEqual(be.X, "err") ||
520 !identEqual(be.Y, "nil") {
521 continue // not "err != nil"
523 f.removeLinesBetween(as.End(), ifs.Pos())
527 func identEqual(expr ast.Expr, name string) bool {
528 id, ok := expr.(*ast.Ident)
529 return ok && id.Name == name
532 // joinStdImports ensures that all standard library imports are together and at
533 // the top of the imports list.
534 func (f *fumpter) joinStdImports(d *ast.GenDecl) {
535 var std, other []ast.Spec
539 for i, spec := range d.Specs {
540 spec := spec.(*ast.ImportSpec)
541 if coms := f.commentsBetween(lastEnd, spec.Pos()); len(coms) > 0 {
542 lastEnd = coms[len(coms)-1].End()
544 if i > 0 && firstGroup && f.Line(spec.Pos()) > f.Line(lastEnd)+1 {
547 // We're still in the first group, update lastEnd.
551 path, _ := strconv.Unquote(spec.Path.Value)
553 // Imports with a period are definitely third party.
554 case strings.Contains(path, "."):
556 // "test" and "example" are reserved as per golang.org/issue/37641.
557 // "internal" is unreachable.
558 case strings.HasPrefix(path, "test/") ||
559 strings.HasPrefix(path, "example/") ||
560 strings.HasPrefix(path, "internal/"):
562 // To be conservative, if an import has a name or an inline
563 // comment, and isn't part of the top group, treat it as non-std.
564 case !firstGroup && (spec.Name != nil || spec.Comment != nil):
565 other = append(other, spec)
569 // If we're moving this std import further up, reset its
570 // position, to avoid breaking comments.
571 if !firstGroup || len(other) > 0 {
572 setPos(reflect.ValueOf(spec), d.Pos())
575 std = append(std, spec)
577 // Ensure there is an empty line between std imports and other imports.
578 if len(std) > 0 && len(other) > 0 && f.Line(std[len(std)-1].End())+1 >= f.Line(other[0].Pos()) {
579 // We add two newlines, as that's necessary in some edge cases.
580 // For example, if the std and non-std imports were together and
581 // without indentation, adding one newline isn't enough. Two
582 // empty lines will be printed as one by go/printer, anyway.
583 f.addNewline(other[0].Pos() - 1)
584 f.addNewline(other[0].Pos())
586 // Finally, join the imports, keeping std at the top.
587 d.Specs = append(std, other...)
589 // If we moved any std imports to the first group, we need to sort them
592 ast.SortImports(f.fset, f.astFile)
596 // mergeAdjacentFields returns fields with adjacent fields merged if possible.
597 func (f *fumpter) mergeAdjacentFields(fields []*ast.Field) []*ast.Field {
598 // If there are less than two fields then there is nothing to merge.
603 // Otherwise, iterate over adjacent pairs of fields, merging if possible,
604 // and mutating fields. Elements of fields may be mutated (if merged with
605 // following fields), discarded (if merged with a preceeding field), or left
608 for j := 1; j < len(fields); j++ {
609 if f.shouldMergeAdjacentFields(fields[i], fields[j]) {
610 fields[i].Names = append(fields[i].Names, fields[j].Names...)
613 fields[i] = fields[j]
619 func (f *fumpter) shouldMergeAdjacentFields(f1, f2 *ast.Field) bool {
620 if len(f1.Names) == 0 || len(f2.Names) == 0 {
621 // Both must have names for the merge to work.
624 if f.Line(f1.Pos()) != f.Line(f2.Pos()) {
625 // Trust the user if they used separate lines.
629 // Only merge if the types are equal.
630 opt := cmp.Comparer(func(x, y token.Pos) bool { return true })
631 return cmp.Equal(f1.Type, f2.Type, opt)
634 var posType = reflect.TypeOf(token.NoPos)
636 // setPos recursively sets all position fields in the node v to pos.
637 func setPos(v reflect.Value, pos token.Pos) {
638 if v.Kind() == reflect.Ptr {
644 if v.Type() == posType {
645 v.Set(reflect.ValueOf(pos))
647 if v.Kind() == reflect.Struct {
648 for i := 0; i < v.NumField(); i++ {
649 setPos(v.Field(i), pos)