--- /dev/null
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package sigchanyzer defines an Analyzer that detects
+// misuse of unbuffered signal as argument to signal.Notify.
+package sigchanyzer
+
+import (
+ "bytes"
+ "go/ast"
+ "go/format"
+ "go/token"
+ "go/types"
+
+ "golang.org/x/tools/go/analysis"
+ "golang.org/x/tools/go/analysis/passes/inspect"
+ "golang.org/x/tools/go/ast/inspector"
+)
+
+const Doc = `check for unbuffered channel of os.Signal
+
+This checker reports call expression of the form signal.Notify(c <-chan os.Signal, sig ...os.Signal),
+where c is an unbuffered channel, which can be at risk of missing the signal.`
+
+// Analyzer describes sigchanyzer analysis function detector.
+var Analyzer = &analysis.Analyzer{
+ Name: "sigchanyzer",
+ Doc: Doc,
+ Requires: []*analysis.Analyzer{inspect.Analyzer},
+ Run: run,
+}
+
+func run(pass *analysis.Pass) (interface{}, error) {
+ inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
+
+ nodeFilter := []ast.Node{
+ (*ast.CallExpr)(nil),
+ }
+ inspect.Preorder(nodeFilter, func(n ast.Node) {
+ call := n.(*ast.CallExpr)
+ if !isSignalNotify(pass.TypesInfo, call) {
+ return
+ }
+ var chanDecl *ast.CallExpr
+ switch arg := call.Args[0].(type) {
+ case *ast.Ident:
+ if decl, ok := findDecl(arg).(*ast.CallExpr); ok {
+ chanDecl = decl
+ }
+ case *ast.CallExpr:
+ chanDecl = arg
+ }
+ if chanDecl == nil || len(chanDecl.Args) != 1 {
+ return
+ }
+ chanDecl.Args = append(chanDecl.Args, &ast.BasicLit{
+ Kind: token.INT,
+ Value: "1",
+ })
+ var buf bytes.Buffer
+ if err := format.Node(&buf, token.NewFileSet(), chanDecl); err != nil {
+ return
+ }
+ pass.Report(analysis.Diagnostic{
+ Pos: call.Pos(),
+ End: call.End(),
+ Message: "misuse of unbuffered os.Signal channel as argument to signal.Notify",
+ SuggestedFixes: []analysis.SuggestedFix{{
+ Message: "Change to buffer channel",
+ TextEdits: []analysis.TextEdit{{
+ Pos: chanDecl.Pos(),
+ End: chanDecl.End(),
+ NewText: buf.Bytes(),
+ }},
+ }},
+ })
+ })
+ return nil, nil
+}
+
+func isSignalNotify(info *types.Info, call *ast.CallExpr) bool {
+ check := func(id *ast.Ident) bool {
+ obj := info.ObjectOf(id)
+ return obj.Name() == "Notify" && obj.Pkg().Path() == "os/signal"
+ }
+ switch fun := call.Fun.(type) {
+ case *ast.SelectorExpr:
+ return check(fun.Sel)
+ case *ast.Ident:
+ if fun, ok := findDecl(fun).(*ast.SelectorExpr); ok {
+ return check(fun.Sel)
+ }
+ return false
+ default:
+ return false
+ }
+}
+
+func findDecl(arg *ast.Ident) ast.Node {
+ if arg.Obj == nil {
+ return nil
+ }
+ switch as := arg.Obj.Decl.(type) {
+ case *ast.AssignStmt:
+ if len(as.Lhs) != len(as.Rhs) {
+ return nil
+ }
+ for i, lhs := range as.Lhs {
+ lid, ok := lhs.(*ast.Ident)
+ if !ok {
+ continue
+ }
+ if lid.Obj == arg.Obj {
+ return as.Rhs[i]
+ }
+ }
+ case *ast.ValueSpec:
+ if len(as.Names) != len(as.Values) {
+ return nil
+ }
+ for i, name := range as.Names {
+ if name.Obj == arg.Obj {
+ return as.Values[i]
+ }
+ }
+ }
+ return nil
+}