Giant blob of minor changes
[dotfiles/.git] / .config / coc / extensions / coc-go-data / tools / pkg / mod / golang.org / x / tools@v0.0.0-20201028153306-37f0764111ff / internal / lsp / source / implementation.go
1 // Copyright 2019 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package source
6
7 import (
8         "context"
9         "errors"
10         "fmt"
11         "go/ast"
12         "go/token"
13         "go/types"
14         "sort"
15
16         "golang.org/x/tools/internal/event"
17         "golang.org/x/tools/internal/lsp/protocol"
18         "golang.org/x/xerrors"
19 )
20
21 func Implementation(ctx context.Context, snapshot Snapshot, f FileHandle, pp protocol.Position) ([]protocol.Location, error) {
22         ctx, done := event.Start(ctx, "source.Implementation")
23         defer done()
24
25         impls, err := implementations(ctx, snapshot, f, pp)
26         if err != nil {
27                 return nil, err
28         }
29         var locations []protocol.Location
30         for _, impl := range impls {
31                 if impl.pkg == nil || len(impl.pkg.CompiledGoFiles()) == 0 {
32                         continue
33                 }
34                 rng, err := objToMappedRange(snapshot, impl.pkg, impl.obj)
35                 if err != nil {
36                         return nil, err
37                 }
38                 pr, err := rng.Range()
39                 if err != nil {
40                         return nil, err
41                 }
42                 locations = append(locations, protocol.Location{
43                         URI:   protocol.URIFromSpanURI(rng.URI()),
44                         Range: pr,
45                 })
46         }
47         sort.Slice(locations, func(i, j int) bool {
48                 li, lj := locations[i], locations[j]
49                 if li.URI == lj.URI {
50                         return protocol.CompareRange(li.Range, lj.Range) < 0
51                 }
52                 return li.URI < lj.URI
53         })
54         return locations, nil
55 }
56
57 var ErrNotAType = errors.New("not a type name or method")
58
59 // implementations returns the concrete implementations of the specified
60 // interface, or the interfaces implemented by the specified concrete type.
61 func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]qualifiedObject, error) {
62         var (
63                 impls []qualifiedObject
64                 seen  = make(map[token.Position]bool)
65                 fset  = s.FileSet()
66         )
67
68         qos, err := qualifiedObjsAtProtocolPos(ctx, s, f, pp)
69         if err != nil {
70                 return nil, err
71         }
72         for _, qo := range qos {
73                 var (
74                         queryType   types.Type
75                         queryMethod *types.Func
76                 )
77
78                 switch obj := qo.obj.(type) {
79                 case *types.Func:
80                         queryMethod = obj
81                         if recv := obj.Type().(*types.Signature).Recv(); recv != nil {
82                                 queryType = ensurePointer(recv.Type())
83                         }
84                 case *types.TypeName:
85                         queryType = ensurePointer(obj.Type())
86                 }
87
88                 if queryType == nil {
89                         return nil, ErrNotAType
90                 }
91
92                 if types.NewMethodSet(queryType).Len() == 0 {
93                         return nil, nil
94                 }
95
96                 // Find all named types, even local types (which can have methods
97                 // due to promotion).
98                 var (
99                         allNamed []*types.Named
100                         pkgs     = make(map[*types.Package]Package)
101                 )
102                 knownPkgs, err := s.KnownPackages(ctx)
103                 if err != nil {
104                         return nil, err
105                 }
106                 for _, pkg := range knownPkgs {
107                         pkgs[pkg.GetTypes()] = pkg
108                         info := pkg.GetTypesInfo()
109                         for _, obj := range info.Defs {
110                                 obj, ok := obj.(*types.TypeName)
111                                 // We ignore aliases 'type M = N' to avoid duplicate reporting
112                                 // of the Named type N.
113                                 if !ok || obj.IsAlias() {
114                                         continue
115                                 }
116                                 if named, ok := obj.Type().(*types.Named); ok {
117                                         allNamed = append(allNamed, named)
118                                 }
119                         }
120                 }
121
122                 // Find all the named types that match our query.
123                 for _, named := range allNamed {
124                         var (
125                                 candObj  types.Object = named.Obj()
126                                 candType              = ensurePointer(named)
127                         )
128
129                         if !concreteImplementsIntf(candType, queryType) {
130                                 continue
131                         }
132
133                         ms := types.NewMethodSet(candType)
134                         if ms.Len() == 0 {
135                                 // Skip empty interfaces.
136                                 continue
137                         }
138
139                         // If client queried a method, look up corresponding candType method.
140                         if queryMethod != nil {
141                                 sel := ms.Lookup(queryMethod.Pkg(), queryMethod.Name())
142                                 if sel == nil {
143                                         continue
144                                 }
145                                 candObj = sel.Obj()
146                         }
147
148                         pos := fset.Position(candObj.Pos())
149                         if candObj == queryMethod || seen[pos] {
150                                 continue
151                         }
152
153                         seen[pos] = true
154
155                         impls = append(impls, qualifiedObject{
156                                 obj: candObj,
157                                 pkg: pkgs[candObj.Pkg()],
158                         })
159                 }
160         }
161
162         return impls, nil
163 }
164
165 // concreteImplementsIntf returns true if a is an interface type implemented by
166 // concrete type b, or vice versa.
167 func concreteImplementsIntf(a, b types.Type) bool {
168         aIsIntf, bIsIntf := IsInterface(a), IsInterface(b)
169
170         // Make sure exactly one is an interface type.
171         if aIsIntf == bIsIntf {
172                 return false
173         }
174
175         // Rearrange if needed so "a" is the concrete type.
176         if aIsIntf {
177                 a, b = b, a
178         }
179
180         return types.AssignableTo(a, b)
181 }
182
183 // ensurePointer wraps T in a *types.Pointer if T is a named, non-interface
184 // type. This is useful to make sure you consider a named type's full method
185 // set.
186 func ensurePointer(T types.Type) types.Type {
187         if _, ok := T.(*types.Named); ok && !IsInterface(T) {
188                 return types.NewPointer(T)
189         }
190
191         return T
192 }
193
194 type qualifiedObject struct {
195         obj types.Object
196
197         // pkg is the Package that contains obj's definition.
198         pkg Package
199
200         // node is the *ast.Ident or *ast.ImportSpec we followed to find obj, if any.
201         node ast.Node
202
203         // sourcePkg is the Package that contains node, if any.
204         sourcePkg Package
205 }
206
207 var (
208         errBuiltin       = errors.New("builtin object")
209         errNoObjectFound = errors.New("no object found")
210 )
211
212 // qualifiedObjsAtProtocolPos returns info for all the type.Objects
213 // referenced at the given position. An object will be returned for
214 // every package that the file belongs to, in every typechecking mode
215 // applicable.
216 func qualifiedObjsAtProtocolPos(ctx context.Context, s Snapshot, fh FileHandle, pp protocol.Position) ([]qualifiedObject, error) {
217         pkgs, err := s.PackagesForFile(ctx, fh.URI(), TypecheckAll)
218         if err != nil {
219                 return nil, err
220         }
221         // Check all the packages that the file belongs to.
222         var qualifiedObjs []qualifiedObject
223         for _, searchpkg := range pkgs {
224                 astFile, pos, err := getASTFile(searchpkg, fh, pp)
225                 if err != nil {
226                         return nil, err
227                 }
228                 path := pathEnclosingObjNode(astFile, pos)
229                 if path == nil {
230                         continue
231                 }
232                 var objs []types.Object
233                 switch leaf := path[0].(type) {
234                 case *ast.Ident:
235                         // If leaf represents an implicit type switch object or the type
236                         // switch "assign" variable, expand to all of the type switch's
237                         // implicit objects.
238                         if implicits, _ := typeSwitchImplicits(searchpkg, path); len(implicits) > 0 {
239                                 objs = append(objs, implicits...)
240                         } else {
241                                 obj := searchpkg.GetTypesInfo().ObjectOf(leaf)
242                                 if obj == nil {
243                                         return nil, xerrors.Errorf("%w for %q", errNoObjectFound, leaf.Name)
244                                 }
245                                 objs = append(objs, obj)
246                         }
247                 case *ast.ImportSpec:
248                         // Look up the implicit *types.PkgName.
249                         obj := searchpkg.GetTypesInfo().Implicits[leaf]
250                         if obj == nil {
251                                 return nil, xerrors.Errorf("%w for import %q", errNoObjectFound, ImportPath(leaf))
252                         }
253                         objs = append(objs, obj)
254                 }
255                 // Get all of the transitive dependencies of the search package.
256                 pkgs := make(map[*types.Package]Package)
257                 var addPkg func(pkg Package)
258                 addPkg = func(pkg Package) {
259                         pkgs[pkg.GetTypes()] = pkg
260                         for _, imp := range pkg.Imports() {
261                                 if _, ok := pkgs[imp.GetTypes()]; !ok {
262                                         addPkg(imp)
263                                 }
264                         }
265                 }
266                 addPkg(searchpkg)
267                 for _, obj := range objs {
268                         if obj.Parent() == types.Universe {
269                                 return nil, xerrors.Errorf("%q: %w", obj.Name(), errBuiltin)
270                         }
271                         pkg, ok := pkgs[obj.Pkg()]
272                         if !ok {
273                                 event.Error(ctx, fmt.Sprintf("no package for obj %s: %v", obj, obj.Pkg()), err)
274                                 continue
275                         }
276                         qualifiedObjs = append(qualifiedObjs, qualifiedObject{
277                                 obj:       obj,
278                                 pkg:       pkg,
279                                 sourcePkg: searchpkg,
280                                 node:      path[0],
281                         })
282                 }
283         }
284         // Return an error if no objects were found since callers will assume that
285         // the slice has at least 1 element.
286         if len(qualifiedObjs) == 0 {
287                 return nil, errNoObjectFound
288         }
289         return qualifiedObjs, nil
290 }
291
292 func getASTFile(pkg Package, f FileHandle, pos protocol.Position) (*ast.File, token.Pos, error) {
293         pgf, err := pkg.File(f.URI())
294         if err != nil {
295                 return nil, 0, err
296         }
297         spn, err := pgf.Mapper.PointSpan(pos)
298         if err != nil {
299                 return nil, 0, err
300         }
301         rng, err := spn.Range(pgf.Mapper.Converter)
302         if err != nil {
303                 return nil, 0, err
304         }
305         return pgf.File, rng.Start, nil
306 }
307
308 // pathEnclosingObjNode returns the AST path to the object-defining
309 // node associated with pos. "Object-defining" means either an
310 // *ast.Ident mapped directly to a types.Object or an ast.Node mapped
311 // implicitly to a types.Object.
312 func pathEnclosingObjNode(f *ast.File, pos token.Pos) []ast.Node {
313         var (
314                 path  []ast.Node
315                 found bool
316         )
317
318         ast.Inspect(f, func(n ast.Node) bool {
319                 if found {
320                         return false
321                 }
322
323                 if n == nil {
324                         path = path[:len(path)-1]
325                         return false
326                 }
327
328                 path = append(path, n)
329
330                 switch n := n.(type) {
331                 case *ast.Ident:
332                         // Include the position directly after identifier. This handles
333                         // the common case where the cursor is right after the
334                         // identifier the user is currently typing. Previously we
335                         // handled this by calling astutil.PathEnclosingInterval twice,
336                         // once for "pos" and once for "pos-1".
337                         found = n.Pos() <= pos && pos <= n.End()
338                 case *ast.ImportSpec:
339                         if n.Path.Pos() <= pos && pos < n.Path.End() {
340                                 found = true
341                                 // If import spec has a name, add name to path even though
342                                 // position isn't in the name.
343                                 if n.Name != nil {
344                                         path = append(path, n.Name)
345                                 }
346                         }
347                 case *ast.StarExpr:
348                         // Follow star expressions to the inner identifier.
349                         if pos == n.Star {
350                                 pos = n.X.Pos()
351                         }
352                 case *ast.SelectorExpr:
353                         // If pos is on the ".", move it into the selector.
354                         if pos == n.X.End() {
355                                 pos = n.Sel.Pos()
356                         }
357                 }
358
359                 return !found
360         })
361
362         if len(path) == 0 {
363                 return nil
364         }
365
366         // Reverse path so leaf is first element.
367         for i := 0; i < len(path)/2; i++ {
368                 path[i], path[len(path)-1-i] = path[len(path)-1-i], path[i]
369         }
370
371         return path
372 }