From 8276ca04b44936db5683541e5ac053ad3d784d8f Mon Sep 17 00:00:00 2001 From: Owen Mansel-Chan Date: Thu, 7 Apr 2022 17:01:52 +0100 Subject: [PATCH] Use generic method not instantiated one in Uses --- extractor/extractor.go | 57 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/extractor/extractor.go b/extractor/extractor.go index b99b54907f0..7a14e8769a1 100644 --- a/extractor/extractor.go +++ b/extractor/extractor.go @@ -844,7 +844,7 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) { dbscheme.DefsTable.Emit(tw, lbl, objlbl) } } - use := tw.Package.TypesInfo.Uses[expr] + use := getObjectBeingUsed(tw, expr) if use != nil { useTyp := extractType(tw, use.Type()) objlbl, exists := tw.Labeler.LookupObjectID(use, useTyp) @@ -1893,3 +1893,58 @@ func populateTypeParamParents(tw *trap.Writer, typeparams *types.TypeParamList, } } } + +// getobjectBeingUsed looks up `ident` in `tw.Package.TypesInfo.Uses` and makes +// some changes to the object to avoid returning objects relating to instantiated +// types. +func getObjectBeingUsed(tw *trap.Writer, ident *ast.Ident) types.Object { + obj := tw.Package.TypesInfo.Uses[ident] + if obj == nil { + return nil + } + if funcObj, ok := obj.(*types.Func); ok { + sig := funcObj.Type().(*types.Signature) + if recv := sig.Recv(); recv != nil { + recvType := recv.Type() + originType, isSame := tryGetGenericType(recvType) + + if originType == nil { + if pointerType, ok := recvType.(*types.Pointer); ok { + originType, isSame = tryGetGenericType(pointerType.Elem()) + } + } + + if originType == nil || isSame { + return obj + } + + for i := 0; i < originType.NumMethods(); i++ { + meth := originType.Method(i) + if meth.Name() == funcObj.Name() { + return meth + } + } + if interfaceType, ok := originType.Underlying().(*types.Interface); ok { + for i := 0; i < interfaceType.NumMethods(); i++ { + meth := interfaceType.Method(i) + if meth.Name() == funcObj.Name() { + return meth + } + } + } + log.Fatalf("Could not find method %s on type %s", funcObj.Name(), originType) + } + } + + return obj +} + +// tryGetGenericType returns the generic type of `tp`, and a boolean indicating +// whether it is the same as `tp`. +func tryGetGenericType(tp types.Type) (*types.Named, bool) { + if namedType, ok := tp.(*types.Named); ok { + originType := namedType.Origin() + return originType, namedType == originType + } + return nil, false +}