--- /dev/null
+// Invoke with //go:generate helper/helper -t Server -d protocol/tsserver.go -u lsp -o server_gen.go
+// invoke in internal/lsp
+package main
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/format"
+ "go/parser"
+ "go/token"
+ "log"
+ "os"
+ "sort"
+ "strings"
+ "text/template"
+)
+
+var (
+ typ = flag.String("t", "Server", "generate code for this type")
+ def = flag.String("d", "", "the file the type is defined in") // this relies on punning
+ use = flag.String("u", "", "look for uses in this package")
+ out = flag.String("o", "", "where to write the generated file")
+)
+
+func main() {
+ log.SetFlags(log.Lshortfile)
+ flag.Parse()
+ if *typ == "" || *def == "" || *use == "" || *out == "" {
+ flag.PrintDefaults()
+ return
+ }
+ // read the type definition and see what methods we're looking for
+ doTypes()
+
+ // parse the package and see which methods are defined
+ doUses()
+
+ output()
+}
+
+// replace "\\\n" with nothing before using
+var tmpl = `
+package lsp
+
+// code generated by helper. DO NOT EDIT.
+
+import (
+ "context"
+
+ "golang.org/x/tools/internal/lsp/protocol"
+)
+
+{{range $key, $v := .Stuff}}
+func (s *{{$.Type}}) {{$v.Name}}({{.Param}}) {{.Result}} {
+ {{if ne .Found ""}} return s.{{.Internal}}({{.Invoke}})\
+ {{else}}return {{if lt 1 (len .Results)}}nil, {{end}}notImplemented("{{.Name}}"){{end}}
+}
+{{end}}
+`
+
+func output() {
+ // put in empty param names as needed
+ for _, t := range types {
+ if t.paramnames == nil {
+ t.paramnames = make([]string, len(t.paramtypes))
+ }
+ for i, p := range t.paramtypes {
+ cm := ""
+ if i > 0 {
+ cm = ", "
+ }
+ t.Param += fmt.Sprintf("%s%s %s", cm, t.paramnames[i], p)
+ t.Invoke += fmt.Sprintf("%s%s", cm, t.paramnames[i])
+ }
+ if len(t.Results) > 1 {
+ t.Result = "("
+ }
+ for i, r := range t.Results {
+ cm := ""
+ if i > 0 {
+ cm = ", "
+ }
+ t.Result += fmt.Sprintf("%s%s", cm, r)
+ }
+ if len(t.Results) > 1 {
+ t.Result += ")"
+ }
+ }
+
+ fd, err := os.Create(*out)
+ if err != nil {
+ log.Fatal(err)
+ }
+ t, err := template.New("foo").Parse(tmpl)
+ if err != nil {
+ log.Fatal(err)
+ }
+ type par struct {
+ Type string
+ Stuff []*Function
+ }
+ p := par{*typ, types}
+ if false { // debugging the template
+ t.Execute(os.Stderr, &p)
+ }
+ buf := bytes.NewBuffer(nil)
+ err = t.Execute(buf, &p)
+ if err != nil {
+ log.Fatal(err)
+ }
+ ans, err := format.Source(bytes.Replace(buf.Bytes(), []byte("\\\n"), []byte{}, -1))
+ if err != nil {
+ log.Fatal(err)
+ }
+ fd.Write(ans)
+}
+
+func doUses() {
+ fset := token.NewFileSet()
+ pkgs, err := parser.ParseDir(fset, *use, nil, 0)
+ if err != nil {
+ log.Fatalf("%q:%v", *use, err)
+ }
+ pkg := pkgs["lsp"] // CHECK
+ files := pkg.Files
+ for fname, f := range files {
+ for _, d := range f.Decls {
+ fd, ok := d.(*ast.FuncDecl)
+ if !ok {
+ continue
+ }
+ nm := fd.Name.String()
+ if ast.IsExported(nm) {
+ // we're looking for things like didChange
+ continue
+ }
+ if fx, ok := byname[nm]; ok {
+ if fx.Found != "" {
+ log.Fatalf("found %s in %s and %s", fx.Internal, fx.Found, fname)
+ }
+ fx.Found = fname
+ // and the Paramnames
+ ft := fd.Type
+ for _, f := range ft.Params.List {
+ nm := ""
+ if len(f.Names) > 0 {
+ nm = f.Names[0].String()
+ }
+ fx.paramnames = append(fx.paramnames, nm)
+ }
+ }
+ }
+ }
+ if false {
+ for i, f := range types {
+ log.Printf("%d %s %s", i, f.Internal, f.Found)
+ }
+ }
+}
+
+type Function struct {
+ Name string
+ Internal string // first letter lower case
+ paramtypes []string
+ paramnames []string
+ Results []string
+ Param string
+ Result string // do it in code, easier than in a template
+ Invoke string
+ Found string // file it was found in
+}
+
+var types []*Function
+var byname = map[string]*Function{} // internal names
+
+func doTypes() {
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, *def, nil, 0)
+ if err != nil {
+ log.Fatal(err)
+ }
+ fd, err := os.Create("/tmp/ast")
+ if err != nil {
+ log.Fatal(err)
+ }
+ ast.Fprint(fd, fset, f, ast.NotNilFilter)
+ ast.Inspect(f, inter)
+ sort.Slice(types, func(i, j int) bool { return types[i].Name < types[j].Name })
+ if false {
+ for i, f := range types {
+ log.Printf("%d %s(%v) %v", i, f.Name, f.paramtypes, f.Results)
+ }
+ }
+}
+
+func inter(n ast.Node) bool {
+ x, ok := n.(*ast.TypeSpec)
+ if !ok || x.Name.Name != *typ {
+ return true
+ }
+ m := x.Type.(*ast.InterfaceType).Methods.List
+ for _, fld := range m {
+ fn := fld.Type.(*ast.FuncType)
+ p := fn.Params.List
+ r := fn.Results.List
+ fx := &Function{
+ Name: fld.Names[0].String(),
+ }
+ fx.Internal = strings.ToLower(fx.Name[:1]) + fx.Name[1:]
+ for _, f := range p {
+ fx.paramtypes = append(fx.paramtypes, whatis(f.Type))
+ }
+ for _, f := range r {
+ fx.Results = append(fx.Results, whatis(f.Type))
+ }
+ types = append(types, fx)
+ byname[fx.Internal] = fx
+ }
+ return false
+}
+
+func whatis(x ast.Expr) string {
+ switch n := x.(type) {
+ case *ast.SelectorExpr:
+ return whatis(n.X) + "." + n.Sel.String()
+ case *ast.StarExpr:
+ return "*" + whatis(n.X)
+ case *ast.Ident:
+ if ast.IsExported(n.Name) {
+ // these are from package protocol
+ return "protocol." + n.Name
+ }
+ return n.Name
+ case *ast.ArrayType:
+ return "[]" + whatis(n.Elt)
+ case *ast.InterfaceType:
+ return "interface{}"
+ default:
+ log.Fatalf("Fatal %T", x)
+ return fmt.Sprintf("%T", x)
+ }
+}