+++ /dev/null
-// Copyright 2019 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package source
-
-import (
- "context"
- "errors"
- "fmt"
- "go/ast"
- "go/token"
- "go/types"
- "sort"
-
- "golang.org/x/tools/internal/event"
- "golang.org/x/tools/internal/lsp/protocol"
- "golang.org/x/xerrors"
-)
-
-func Implementation(ctx context.Context, snapshot Snapshot, f FileHandle, pp protocol.Position) ([]protocol.Location, error) {
- ctx, done := event.Start(ctx, "source.Implementation")
- defer done()
-
- impls, err := implementations(ctx, snapshot, f, pp)
- if err != nil {
- return nil, err
- }
- var locations []protocol.Location
- for _, impl := range impls {
- if impl.pkg == nil || len(impl.pkg.CompiledGoFiles()) == 0 {
- continue
- }
- rng, err := objToMappedRange(snapshot, impl.pkg, impl.obj)
- if err != nil {
- return nil, err
- }
- pr, err := rng.Range()
- if err != nil {
- return nil, err
- }
- locations = append(locations, protocol.Location{
- URI: protocol.URIFromSpanURI(rng.URI()),
- Range: pr,
- })
- }
- sort.Slice(locations, func(i, j int) bool {
- li, lj := locations[i], locations[j]
- if li.URI == lj.URI {
- return protocol.CompareRange(li.Range, lj.Range) < 0
- }
- return li.URI < lj.URI
- })
- return locations, nil
-}
-
-var ErrNotAType = errors.New("not a type name or method")
-
-// implementations returns the concrete implementations of the specified
-// interface, or the interfaces implemented by the specified concrete type.
-func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]qualifiedObject, error) {
- var (
- impls []qualifiedObject
- seen = make(map[token.Position]bool)
- fset = s.FileSet()
- )
-
- qos, err := qualifiedObjsAtProtocolPos(ctx, s, f, pp)
- if err != nil {
- return nil, err
- }
- for _, qo := range qos {
- var (
- queryType types.Type
- queryMethod *types.Func
- )
-
- switch obj := qo.obj.(type) {
- case *types.Func:
- queryMethod = obj
- if recv := obj.Type().(*types.Signature).Recv(); recv != nil {
- queryType = ensurePointer(recv.Type())
- }
- case *types.TypeName:
- queryType = ensurePointer(obj.Type())
- }
-
- if queryType == nil {
- return nil, ErrNotAType
- }
-
- if types.NewMethodSet(queryType).Len() == 0 {
- return nil, nil
- }
-
- // Find all named types, even local types (which can have methods
- // due to promotion).
- var (
- allNamed []*types.Named
- pkgs = make(map[*types.Package]Package)
- )
- knownPkgs, err := s.KnownPackages(ctx)
- if err != nil {
- return nil, err
- }
- for _, pkg := range knownPkgs {
- pkgs[pkg.GetTypes()] = pkg
- info := pkg.GetTypesInfo()
- for _, obj := range info.Defs {
- obj, ok := obj.(*types.TypeName)
- // We ignore aliases 'type M = N' to avoid duplicate reporting
- // of the Named type N.
- if !ok || obj.IsAlias() {
- continue
- }
- if named, ok := obj.Type().(*types.Named); ok {
- allNamed = append(allNamed, named)
- }
- }
- }
-
- // Find all the named types that match our query.
- for _, named := range allNamed {
- var (
- candObj types.Object = named.Obj()
- candType = ensurePointer(named)
- )
-
- if !concreteImplementsIntf(candType, queryType) {
- continue
- }
-
- ms := types.NewMethodSet(candType)
- if ms.Len() == 0 {
- // Skip empty interfaces.
- continue
- }
-
- // If client queried a method, look up corresponding candType method.
- if queryMethod != nil {
- sel := ms.Lookup(queryMethod.Pkg(), queryMethod.Name())
- if sel == nil {
- continue
- }
- candObj = sel.Obj()
- }
-
- pos := fset.Position(candObj.Pos())
- if candObj == queryMethod || seen[pos] {
- continue
- }
-
- seen[pos] = true
-
- impls = append(impls, qualifiedObject{
- obj: candObj,
- pkg: pkgs[candObj.Pkg()],
- })
- }
- }
-
- return impls, nil
-}
-
-// concreteImplementsIntf returns true if a is an interface type implemented by
-// concrete type b, or vice versa.
-func concreteImplementsIntf(a, b types.Type) bool {
- aIsIntf, bIsIntf := IsInterface(a), IsInterface(b)
-
- // Make sure exactly one is an interface type.
- if aIsIntf == bIsIntf {
- return false
- }
-
- // Rearrange if needed so "a" is the concrete type.
- if aIsIntf {
- a, b = b, a
- }
-
- return types.AssignableTo(a, b)
-}
-
-// ensurePointer wraps T in a *types.Pointer if T is a named, non-interface
-// type. This is useful to make sure you consider a named type's full method
-// set.
-func ensurePointer(T types.Type) types.Type {
- if _, ok := T.(*types.Named); ok && !IsInterface(T) {
- return types.NewPointer(T)
- }
-
- return T
-}
-
-type qualifiedObject struct {
- obj types.Object
-
- // pkg is the Package that contains obj's definition.
- pkg Package
-
- // node is the *ast.Ident or *ast.ImportSpec we followed to find obj, if any.
- node ast.Node
-
- // sourcePkg is the Package that contains node, if any.
- sourcePkg Package
-}
-
-var (
- errBuiltin = errors.New("builtin object")
- errNoObjectFound = errors.New("no object found")
-)
-
-// qualifiedObjsAtProtocolPos returns info for all the type.Objects
-// referenced at the given position. An object will be returned for
-// every package that the file belongs to, in every typechecking mode
-// applicable.
-func qualifiedObjsAtProtocolPos(ctx context.Context, s Snapshot, fh FileHandle, pp protocol.Position) ([]qualifiedObject, error) {
- pkgs, err := s.PackagesForFile(ctx, fh.URI(), TypecheckAll)
- if err != nil {
- return nil, err
- }
- // Check all the packages that the file belongs to.
- var qualifiedObjs []qualifiedObject
- for _, searchpkg := range pkgs {
- astFile, pos, err := getASTFile(searchpkg, fh, pp)
- if err != nil {
- return nil, err
- }
- path := pathEnclosingObjNode(astFile, pos)
- if path == nil {
- continue
- }
- var objs []types.Object
- switch leaf := path[0].(type) {
- case *ast.Ident:
- // If leaf represents an implicit type switch object or the type
- // switch "assign" variable, expand to all of the type switch's
- // implicit objects.
- if implicits, _ := typeSwitchImplicits(searchpkg, path); len(implicits) > 0 {
- objs = append(objs, implicits...)
- } else {
- obj := searchpkg.GetTypesInfo().ObjectOf(leaf)
- if obj == nil {
- return nil, xerrors.Errorf("%w for %q", errNoObjectFound, leaf.Name)
- }
- objs = append(objs, obj)
- }
- case *ast.ImportSpec:
- // Look up the implicit *types.PkgName.
- obj := searchpkg.GetTypesInfo().Implicits[leaf]
- if obj == nil {
- return nil, xerrors.Errorf("%w for import %q", errNoObjectFound, ImportPath(leaf))
- }
- objs = append(objs, obj)
- }
- // Get all of the transitive dependencies of the search package.
- pkgs := make(map[*types.Package]Package)
- var addPkg func(pkg Package)
- addPkg = func(pkg Package) {
- pkgs[pkg.GetTypes()] = pkg
- for _, imp := range pkg.Imports() {
- if _, ok := pkgs[imp.GetTypes()]; !ok {
- addPkg(imp)
- }
- }
- }
- addPkg(searchpkg)
- for _, obj := range objs {
- if obj.Parent() == types.Universe {
- return nil, xerrors.Errorf("%q: %w", obj.Name(), errBuiltin)
- }
- pkg, ok := pkgs[obj.Pkg()]
- if !ok {
- event.Error(ctx, fmt.Sprintf("no package for obj %s: %v", obj, obj.Pkg()), err)
- continue
- }
- qualifiedObjs = append(qualifiedObjs, qualifiedObject{
- obj: obj,
- pkg: pkg,
- sourcePkg: searchpkg,
- node: path[0],
- })
- }
- }
- // Return an error if no objects were found since callers will assume that
- // the slice has at least 1 element.
- if len(qualifiedObjs) == 0 {
- return nil, errNoObjectFound
- }
- return qualifiedObjs, nil
-}
-
-func getASTFile(pkg Package, f FileHandle, pos protocol.Position) (*ast.File, token.Pos, error) {
- pgf, err := pkg.File(f.URI())
- if err != nil {
- return nil, 0, err
- }
- spn, err := pgf.Mapper.PointSpan(pos)
- if err != nil {
- return nil, 0, err
- }
- rng, err := spn.Range(pgf.Mapper.Converter)
- if err != nil {
- return nil, 0, err
- }
- return pgf.File, rng.Start, nil
-}
-
-// pathEnclosingObjNode returns the AST path to the object-defining
-// node associated with pos. "Object-defining" means either an
-// *ast.Ident mapped directly to a types.Object or an ast.Node mapped
-// implicitly to a types.Object.
-func pathEnclosingObjNode(f *ast.File, pos token.Pos) []ast.Node {
- var (
- path []ast.Node
- found bool
- )
-
- ast.Inspect(f, func(n ast.Node) bool {
- if found {
- return false
- }
-
- if n == nil {
- path = path[:len(path)-1]
- return false
- }
-
- path = append(path, n)
-
- switch n := n.(type) {
- case *ast.Ident:
- // Include the position directly after identifier. This handles
- // the common case where the cursor is right after the
- // identifier the user is currently typing. Previously we
- // handled this by calling astutil.PathEnclosingInterval twice,
- // once for "pos" and once for "pos-1".
- found = n.Pos() <= pos && pos <= n.End()
- case *ast.ImportSpec:
- if n.Path.Pos() <= pos && pos < n.Path.End() {
- found = true
- // If import spec has a name, add name to path even though
- // position isn't in the name.
- if n.Name != nil {
- path = append(path, n.Name)
- }
- }
- case *ast.StarExpr:
- // Follow star expressions to the inner identifier.
- if pos == n.Star {
- pos = n.X.Pos()
- }
- case *ast.SelectorExpr:
- // If pos is on the ".", move it into the selector.
- if pos == n.X.End() {
- pos = n.Sel.Pos()
- }
- }
-
- return !found
- })
-
- if len(path) == 0 {
- return nil
- }
-
- // Reverse path so leaf is first element.
- for i := 0; i < len(path)/2; i++ {
- path[i], path[len(path)-1-i] = path[len(path)-1-i], path[i]
- }
-
- return path
-}