Files
codeql/go/extractor/extractor.go
Owen Mansel-Chan 401c60654e Fix nil checks to stop creating unused labels
In go, an interface with value nil does not compare equal to nil. This
is known as "typed nils". So our existing nil checks weren't working,
which shows why we needed more nil checks inside the type switches. The
solution is to explicitly check for each type we care about.
2025-05-20 13:13:22 +01:00

2070 lines
68 KiB
Go

package extractor
import (
"crypto/md5"
"encoding/hex"
"fmt"
"go/ast"
"go/constant"
"go/scanner"
"go/token"
"go/types"
"io"
"log"
"os"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/github/codeql-go/extractor/dbscheme"
"github.com/github/codeql-go/extractor/diagnostics"
"github.com/github/codeql-go/extractor/srcarchive"
"github.com/github/codeql-go/extractor/toolchain"
"github.com/github/codeql-go/extractor/trap"
"github.com/github/codeql-go/extractor/util"
"golang.org/x/tools/go/packages"
)
var MaxGoRoutines int
var typeParamParent map[*types.TypeParam]types.Object = make(map[*types.TypeParam]types.Object)
func init() {
// this sets the number of threads that the Go runtime will spawn; this is separate
// from the number of goroutines that the program spawns, which are scheduled into
// the system threads by the Go runtime scheduler
threads := os.Getenv("LGTM_THREADS")
if maxprocs, err := strconv.Atoi(threads); err == nil && maxprocs > 0 {
log.Printf("Max threads set to %d", maxprocs)
runtime.GOMAXPROCS(maxprocs)
} else if threads != "" {
log.Printf("Warning: LGTM_THREADS value %s is not valid, defaulting to using all available threads.", threads)
}
// if the value is empty or not set, use the Go default, which is the number of cores
// available since Go 1.5, but is subject to change
var err error
if MaxGoRoutines, err = strconv.Atoi(util.Getenv(
"CODEQL_EXTRACTOR_GO_MAX_GOROUTINES",
"SEMMLE_MAX_GOROUTINES",
)); err != nil {
MaxGoRoutines = 32
} else {
log.Printf("Max goroutines set to %d", MaxGoRoutines)
}
}
// Extract extracts the packages specified by the given patterns
func Extract(patterns []string) error {
return ExtractWithFlags(nil, patterns, false)
}
// ExtractWithFlags extracts the packages specified by the given patterns and build flags
func ExtractWithFlags(buildFlags []string, patterns []string, extractTests bool) error {
startTime := time.Now()
extraction := NewExtraction(buildFlags, patterns)
defer extraction.StatWriter.Close()
modEnabled := os.Getenv("GO111MODULE") != "off"
if !modEnabled {
log.Println("Go module mode disabled.")
}
modFlags := make([]string, 0, 1)
for _, flag := range buildFlags {
if strings.HasPrefix(flag, "-mod=") {
modFlags = append(modFlags, flag)
}
}
// If CODEQL_EXTRACTOR_GO_[OPTION_]EXTRACT_VENDOR_DIRS is "true", we extract `vendor` directories;
// otherwise (the default) is to exclude them from extraction
includeVendor, oldOptionUsed := util.IsVendorDirExtractionEnabled()
if oldOptionUsed {
log.Println("Warning: obsolete option \"CODEQL_EXTRACTOR_GO_EXTRACT_VENDOR_DIRS\" was set. Use \"CODEQL_EXTRACTOR_GO_OPTION_EXTRACT_VENDOR_DIRS\" or pass `--extractor-option extract_vendor_dirs=true` instead.")
}
modeNotifications := make([]string, 0, 2)
if extractTests {
modeNotifications = append(modeNotifications, "test extraction enabled")
}
if includeVendor {
modeNotifications = append(modeNotifications, "extracting vendor directories")
}
modeMessage := strings.Join(modeNotifications, ", ")
if modeMessage != "" {
modeMessage = " (" + modeMessage + ")"
}
log.Printf("Running packages.Load%s.", modeMessage)
// This includes test packages if either we're tracing a `go test` command,
// or if CODEQL_EXTRACTOR_GO_OPTION_EXTRACT_TESTS is set to "true".
cfg := &packages.Config{
Mode: packages.NeedName | packages.NeedFiles |
packages.NeedCompiledGoFiles |
packages.NeedImports | packages.NeedDeps |
packages.NeedTypes | packages.NeedTypesSizes |
packages.NeedTypesInfo | packages.NeedSyntax,
BuildFlags: buildFlags,
Tests: extractTests,
}
pkgs, err := packages.Load(cfg, patterns...)
if err != nil {
// the toolchain directive is only supported in Go 1.21 and above
if strings.Contains(err.Error(), "unknown directive: toolchain") {
diagnostics.EmitNewerSystemGoRequired("1.21.0")
}
return err
}
log.Println("Done running packages.Load.")
if len(pkgs) == 0 {
log.Println("No packages found.")
wd, err := os.Getwd()
if err != nil {
log.Printf("Warning: failed to get working directory: %s\n", err.Error())
} else if util.FindGoFiles(wd) {
diagnostics.EmitGoFilesFoundButNotProcessed()
}
}
log.Println("Extracting universe scope.")
extractUniverseScope()
log.Println("Done extracting universe scope.")
// a map of package path to source directory and module root directory
pkgInfos := make(map[string]toolchain.PkgInfo)
// root directories of packages that we want to extract
wantedRoots := make(map[string]bool)
if os.Getenv("CODEQL_EXTRACTOR_GO_FAST_PACKAGE_INFO") != "false" {
log.Printf("Running go list to resolve package and module directories.")
// get all packages information
pkgInfos, err = toolchain.GetPkgsInfo(patterns, true, extractTests, modFlags...)
if err != nil {
log.Fatalf("Error getting dependency package or module directories: %v.", err)
}
log.Printf("Done running go list deps: resolved %d packages.", len(pkgInfos))
}
pkgsNotFound := make([]string, 0, len(pkgs))
// Build a map from package paths to their longest IDs--
// in the context of a `go test -c` compilation, we will see the same package more than
// once, with IDs like "abc.com/pkgname [abc.com/pkgname.test]" to distinguish the version
// that contains and is used by test code.
// For our purposes it is simplest to just ignore the non-test version, since the test
// version seems to be a superset of it.
longestPackageIds := make(map[string]string)
packages.Visit(pkgs, nil, func(pkg *packages.Package) {
if longestIDSoFar, present := longestPackageIds[pkg.PkgPath]; present {
if len(pkg.ID) > len(longestIDSoFar) {
longestPackageIds[pkg.PkgPath] = pkg.ID
}
} else {
longestPackageIds[pkg.PkgPath] = pkg.ID
}
})
// Do a post-order traversal and extract the package scope of each package
packages.Visit(pkgs, nil, func(pkg *packages.Package) {
// Note that if test extraction is enabled, we will encounter a package twice here:
// once as the main package, and once as the test package (with a package ID like
// "abc.com/pkgname [abc.com/pkgname.test]").
//
// We will extract it both times however, because we need to visit the packages
// in the right order in order to visit used types before their users, and the
// ordering determined by packages.Visit for the main and the test package may differ.
//
// This should only cause some wasted time and not inconsistency because the names for
// objects seen in this process should be the same each time.
log.Printf("Processing package %s.", pkg.PkgPath)
if _, ok := pkgInfos[pkg.PkgPath]; !ok {
pkgInfos[pkg.PkgPath] = toolchain.GetPkgInfo(pkg.PkgPath, modFlags...)
}
log.Printf("Extracting types for package %s.", pkg.PkgPath)
tw, err := trap.NewWriter(pkg.PkgPath, pkg)
if err != nil {
log.Fatal(err)
}
defer tw.Close()
scope := extractPackageScope(tw, pkg)
extractObjectTypes(tw)
lbl := tw.Labeler.GlobalID(util.EscapeTrapSpecialChars(pkg.PkgPath) + ";pkg")
dbscheme.PackagesTable.Emit(tw, lbl, pkg.Name, pkg.PkgPath, scope)
if len(pkg.Errors) != 0 {
log.Printf("Warning: encountered errors extracting package `%s`:", pkg.PkgPath)
for i, err := range pkg.Errors {
errString := err.Error()
log.Printf(" %s", errString)
if strings.Contains(errString, "build constraints exclude all Go files in ") {
// `err` is a NoGoError from the package cmd/go/internal/load, which we cannot access as it is internal
diagnostics.EmitPackageDifferentOSArchitecture(pkg.PkgPath)
} else if strings.Contains(errString, "cannot find package") ||
strings.Contains(errString, "no required module provides package") {
pkgsNotFound = append(pkgsNotFound, pkg.PkgPath)
}
extraction.extractError(tw, err, lbl, i)
}
}
log.Printf("Done extracting types for package %s.", pkg.PkgPath)
})
if len(pkgsNotFound) > 0 {
diagnostics.EmitCannotFindPackages(pkgsNotFound)
}
for _, pkg := range pkgs {
pkgInfo, ok := pkgInfos[pkg.PkgPath]
if !ok || pkgInfo.PkgDir == "" {
log.Fatalf("Unable to get a source directory for input package %s.", pkg.PkgPath)
}
wantedRoots[pkgInfo.PkgDir] = true
if pkgInfo.ModDir != "" {
wantedRoots[pkgInfo.ModDir] = true
}
}
log.Println("Done processing dependencies.")
log.Println("Starting to extract packages.")
sep := regexp.QuoteMeta(string(filepath.Separator))
// Construct a list of directory segments to exclude from extraction, starting with ".."
excludedDirs := []string{`\.\.`}
if !includeVendor {
excludedDirs = append(excludedDirs, "vendor")
}
// If a path matches this regexp, we don't extract this package. It checks whether the path
// contains one of the `excludedDirs`.
noExtractRe := regexp.MustCompile(`.*(^|` + sep + `)(` + strings.Join(excludedDirs, "|") + `)($|` + sep + `).*`)
// extract AST information for all packages
packages.Visit(pkgs, nil, func(pkg *packages.Package) {
// If this is a variant of a package that also occurs with a longer ID, skip it;
// otherwise we would extract the same file more than once including extracting the
// body of methods twice, causing database inconsistencies.
//
// We prefer the version with the longest ID because that is (so far as I know) always
// the version that defines more entities -- the only case I'm aware of being a test
// variant of a package, which includes test-only functions in addition to the complete
// contents of the main variant.
if pkg.ID != longestPackageIds[pkg.PkgPath] {
return
}
for root := range wantedRoots {
pkgInfo := pkgInfos[pkg.PkgPath]
relDir, err := filepath.Rel(root, pkgInfo.PkgDir)
if err != nil || noExtractRe.MatchString(relDir) {
// if the path can't be made relative or matches the noExtract regexp skip it
continue
}
extraction.extractPackage(pkg)
modDir := pkgInfo.ModDir
if modDir == "" {
modDir = pkgInfo.PkgDir
}
if modDir != "" {
modPath := filepath.Join(modDir, "go.mod")
if util.FileExists(modPath) {
log.Printf("Extracting %s", modPath)
start := time.Now()
err := extraction.extractGoMod(modPath)
if err != nil {
log.Printf("Failed to extract go.mod: %s", err.Error())
}
end := time.Since(start)
log.Printf("Done extracting %s (%dms)", modPath, end.Nanoseconds()/1000000)
}
}
return
}
log.Printf("Skipping dependency package %s.", pkg.PkgPath)
})
extraction.WaitGroup.Wait()
log.Println("Done extracting packages.")
t := time.Now()
elapsed := t.Sub(startTime)
dbscheme.CompilationFinishedTable.Emit(extraction.StatWriter, extraction.Label, 0.0, elapsed.Seconds())
return nil
}
type Extraction struct {
// A lock for preventing concurrent writes to maps and the stat trap writer, as they are not
// thread-safe
Lock sync.Mutex
LabelKey string
Label trap.Label
StatWriter *trap.Writer
WaitGroup sync.WaitGroup
GoroutineSem *semaphore
FdSem *semaphore
NextFileId int
FileInfo map[string]*FileInfo
SeenGoMods map[string]bool
}
type FileInfo struct {
Idx int
NextErr int
}
func (extraction *Extraction) SeenFile(path string) bool {
_, ok := extraction.FileInfo[path]
return ok
}
func (extraction *Extraction) GetFileInfo(path string) *FileInfo {
if fileInfo, ok := extraction.FileInfo[path]; ok {
return fileInfo
}
extraction.FileInfo[path] = &FileInfo{extraction.NextFileId, 0}
extraction.NextFileId += 1
return extraction.FileInfo[path]
}
func (extraction *Extraction) GetFileIdx(path string) int {
return extraction.GetFileInfo(path).Idx
}
func (extraction *Extraction) GetNextErr(path string) int {
finfo := extraction.GetFileInfo(path)
res := finfo.NextErr
finfo.NextErr += 1
return res
}
func NewExtraction(buildFlags []string, patterns []string) *Extraction {
hash := md5.New()
io.WriteString(hash, "go")
for _, buildFlag := range buildFlags {
io.WriteString(hash, " "+buildFlag)
}
io.WriteString(hash, " --")
for _, pattern := range patterns {
io.WriteString(hash, " "+pattern)
}
sum := hash.Sum(nil)
i := 0
var path string
// split compilation files into directories to avoid filling a single directory with too many files
pathFmt := fmt.Sprintf("compilations/%s/%s_%%d", hex.EncodeToString(sum[:1]), hex.EncodeToString(sum[1:]))
for {
path = fmt.Sprintf(pathFmt, i)
file, err := trap.FileFor(path)
if err != nil {
log.Fatalf("Error creating trap file: %s\n", err.Error())
}
i++
if !util.FileExists(file) {
break
}
}
statWriter, err := trap.NewWriter(path, nil)
if err != nil {
log.Fatal(err)
}
lblKey := fmt.Sprintf("%s_%d;compilation", hex.EncodeToString(sum), i)
lbl := statWriter.Labeler.GlobalID(lblKey)
wd, err := os.Getwd()
if err != nil {
log.Fatalf("Unable to determine current directory: %s\n", err.Error())
}
dbscheme.CompilationsTable.Emit(statWriter, lbl, wd)
i = 0
extractorPath, err := util.GetExtractorPath()
if err != nil {
log.Fatalf("Unable to get extractor path: %s\n", err.Error())
}
dbscheme.CompilationArgsTable.Emit(statWriter, lbl, 0, extractorPath)
i++
for _, flag := range buildFlags {
dbscheme.CompilationArgsTable.Emit(statWriter, lbl, i, flag)
i++
}
// emit a fake "--" argument to make it clear that what comes after it are patterns
dbscheme.CompilationArgsTable.Emit(statWriter, lbl, i, "--")
i++
for _, pattern := range patterns {
dbscheme.CompilationArgsTable.Emit(statWriter, lbl, i, pattern)
i++
}
return &Extraction{
LabelKey: lblKey,
Label: lbl,
StatWriter: statWriter,
// this semaphore is used to limit the number of files that are open at once;
// this is to prevent the extractor from running into issues with caps on the
// number of open files that can be held by one process
FdSem: newSemaphore(100),
// this semaphore is used to limit the number of goroutines spawned, so we
// don't run into memory issues
GoroutineSem: newSemaphore(MaxGoRoutines),
NextFileId: 0,
FileInfo: make(map[string]*FileInfo),
SeenGoMods: make(map[string]bool),
}
}
// extractUniverseScope extracts symbol table information for the universe scope
func extractUniverseScope() {
tw, err := trap.NewWriter("universe", nil)
if err != nil {
log.Fatal(err)
}
defer tw.Close()
lbl := tw.Labeler.ScopeID(types.Universe, nil)
dbscheme.ScopesTable.Emit(tw, lbl, dbscheme.UniverseScopeType.Index())
extractObjects(tw, types.Universe, lbl)
// Always extract an empty interface type
extractType(tw, types.NewInterfaceType([]*types.Func{}, []types.Type{}))
}
// extractObjects extracts all objects declared in the given scope
// For more information on objects, see:
// https://github.com/golang/example/blob/master/gotypes/README.md#objects
func extractObjects(tw *trap.Writer, scope *types.Scope, scopeLabel trap.Label) {
for _, name := range scope.Names() {
obj := scope.Lookup(name)
lbl, exists := tw.Labeler.ScopedObjectID(obj, func() trap.Label { return extractType(tw, obj.Type()) })
if !exists {
// Populate type parameter parents for functions. Note that methods
// do not appear as objects in any scope, so they have to be dealt
// with separately in extractMethods.
if funcObj, ok := obj.(*types.Func); ok {
populateTypeParamParents(funcObj.Type().(*types.Signature).TypeParams(), obj)
populateTypeParamParents(funcObj.Type().(*types.Signature).RecvTypeParams(), obj)
}
// Populate type parameter parents for defined types and alias types.
if typeNameObj, ok := obj.(*types.TypeName); ok {
// `types.TypeName` represents a type with a name: a defined
// type, an alias type, a type parameter, or a predeclared
// type such as `int` or `error`. We can distinguish these
// using `typeNameObj.Type()`, except that we need to be
// careful with alias types because before Go 1.24 they would
// return the underlying type.
if tp, ok := typeNameObj.Type().(*types.Named); ok && !typeNameObj.IsAlias() {
populateTypeParamParents(tp.TypeParams(), obj)
} else if tp, ok := typeNameObj.Type().(*types.Alias); ok {
populateTypeParamParents(tp.TypeParams(), obj)
}
}
extractObject(tw, obj, lbl)
}
if obj.Parent() != scope {
// this can happen if a scope is embedded into another with a `.` import.
continue
}
dbscheme.ObjectScopesTable.Emit(tw, lbl, scopeLabel)
}
}
// extractMethod extracts a method `meth` and emits it to the objects table, then returns its label
func extractMethod(tw *trap.Writer, meth *types.Func) trap.Label {
// get the receiver type of the method
recvtyp := meth.Type().(*types.Signature).Recv().Type()
// ensure receiver type has been extracted
recvtyplbl := extractType(tw, recvtyp)
// if the method label does not exist, extract it
methlbl, exists := tw.Labeler.MethodID(meth, recvtyplbl)
if !exists {
// Populate type parameter parents for methods. They do not appear as
// objects in any scope, so they have to be dealt with separately here.
populateTypeParamParents(meth.Type().(*types.Signature).TypeParams(), meth)
populateTypeParamParents(meth.Type().(*types.Signature).RecvTypeParams(), meth)
extractObject(tw, meth, methlbl)
}
return methlbl
}
// extractObject extracts a single object and emits it to the objects table.
// For more information on objects, see:
// https://github.com/golang/example/blob/master/gotypes/README.md#objects
func extractObject(tw *trap.Writer, obj types.Object, lbl trap.Label) {
checkObjectNotSpecialized(obj)
name := obj.Name()
isBuiltin := obj.Parent() == types.Universe
var kind int
switch obj.(type) {
case *types.PkgName:
kind = dbscheme.PkgObjectType.Index()
case *types.TypeName:
if isBuiltin {
kind = dbscheme.BuiltinTypeObjectType.Index()
} else {
kind = dbscheme.DeclTypeObjectType.Index()
}
case *types.Const:
if isBuiltin {
kind = dbscheme.BuiltinConstObjectType.Index()
} else {
kind = dbscheme.DeclConstObjectType.Index()
}
case *types.Nil:
kind = dbscheme.BuiltinConstObjectType.Index()
case *types.Var:
kind = dbscheme.DeclVarObjectType.Index()
case *types.Builtin:
kind = dbscheme.BuiltinFuncObjectType.Index()
case *types.Func:
kind = dbscheme.DeclFuncObjectType.Index()
case *types.Label:
kind = dbscheme.LabelObjectType.Index()
default:
log.Fatalf("unknown object of type %T", obj)
}
dbscheme.ObjectsTable.Emit(tw, lbl, kind, name)
// for methods, additionally extract information about the receiver
if sig, ok := obj.Type().(*types.Signature); ok {
if recv := sig.Recv(); recv != nil {
recvlbl, exists := tw.Labeler.ReceiverObjectID(recv, lbl)
if !exists {
extractObject(tw, recv, recvlbl)
}
dbscheme.MethodReceiversTable.Emit(tw, lbl, recvlbl)
}
}
}
// extractObjectTypes extracts type and receiver information for all objects
// For more information on objects, see:
// https://github.com/golang/example/blob/master/gotypes/README.md#objects
func extractObjectTypes(tw *trap.Writer) {
// calling `extractType` on a defined type will extract all methods defined
// on it, which will add new objects. Therefore we need to do this first
// before we loop over all objects and emit them.
changed := true
for changed {
changed = tw.ForEachObject(extractObjectType)
}
changed = tw.ForEachObject(emitObjectType)
if changed {
log.Printf("Warning: more objects were labeled while emitting object types")
}
}
// extractObjectType extracts type and receiver information for a given object
// For more information on objects, see:
// https://github.com/golang/example/blob/master/gotypes/README.md#objects
func extractObjectType(tw *trap.Writer, obj types.Object, lbl trap.Label) {
if tp := obj.Type(); tp != nil {
extractType(tw, tp)
}
}
// emitObjectType emits the type information for a given object
func emitObjectType(tw *trap.Writer, obj types.Object, lbl trap.Label) {
if tp := obj.Type(); tp != nil {
dbscheme.ObjectTypesTable.Emit(tw, lbl, extractType(tw, tp))
}
}
var (
// file:line:col
threePartPos = regexp.MustCompile(`^(.+):(\d+):(\d+)$`)
// file:line
twoPartPos = regexp.MustCompile(`^(.+):(\d+)$`)
)
// extractError extracts the message and location of a frontend error
func (extraction *Extraction) extractError(tw *trap.Writer, err packages.Error, pkglbl trap.Label, idx int) {
var (
lbl = tw.Labeler.FreshID()
tag = dbscheme.ErrorTags[err.Kind]
kind = dbscheme.ErrorTypes[err.Kind].Index()
pos = err.Pos
file = ""
line, col int
e error
)
if pos == "" || pos == "-" {
// extract a dummy file
wd, e := os.Getwd()
if e != nil {
wd = "."
log.Printf("Warning: failed to get working directory")
}
ewd, e := filepath.EvalSymlinks(wd)
if e != nil {
ewd = wd
log.Printf("Warning: failed to evaluate symlinks for %s", wd)
}
file = filepath.Join(ewd, "-")
extraction.extractFileInfo(tw, file, true)
} else {
var rawfile string
if parts := threePartPos.FindStringSubmatch(pos); parts != nil {
// "file:line:col"
col, e = strconv.Atoi(parts[3])
if e != nil {
log.Printf("Warning: malformed column number `%s`: %v", parts[3], e)
}
line, e = strconv.Atoi(parts[2])
if e != nil {
log.Printf("Warning: malformed line number `%s`: %v", parts[2], e)
}
rawfile = parts[1]
} else if parts := twoPartPos.FindStringSubmatch(pos); parts != nil {
// "file:line"
line, e = strconv.Atoi(parts[2])
if e != nil {
log.Printf("Warning: malformed line number `%s`: %v", parts[2], e)
}
rawfile = parts[1]
} else if pos != "" && pos != "-" {
log.Printf("Warning: malformed error position `%s`", pos)
}
afile, e := filepath.Abs(rawfile)
if e != nil {
log.Printf("Warning: failed to get absolute path for for %s", file)
afile = file
}
file, e = filepath.EvalSymlinks(afile)
if e != nil {
log.Printf("Warning: failed to evaluate symlinks for %s", afile)
file = afile
}
extraction.extractFileInfo(tw, file, false)
}
extraction.Lock.Lock()
flbl := extraction.StatWriter.Labeler.FileLabelFor(file)
diagLbl := extraction.StatWriter.Labeler.FreshID()
dbscheme.DiagnosticsTable.Emit(
extraction.StatWriter, diagLbl, 1, tag, err.Msg, err.Msg,
emitLocation(extraction.StatWriter, flbl, line, col, line, col))
dbscheme.DiagnosticForTable.Emit(extraction.StatWriter, diagLbl, extraction.Label, extraction.GetFileIdx(file), extraction.GetNextErr(file))
extraction.Lock.Unlock()
transformed := filepath.ToSlash(srcarchive.TransformPath(file))
dbscheme.ErrorsTable.Emit(tw, lbl, kind, err.Msg, pos, transformed, line, col, pkglbl, idx)
}
// extractPackage extracts AST information for all files in the given package
func (extraction *Extraction) extractPackage(pkg *packages.Package) {
for _, astFile := range pkg.Syntax {
extraction.WaitGroup.Add(1)
extraction.GoroutineSem.acquire(1)
go func(astFile *ast.File) {
err := extraction.extractFile(astFile, pkg)
if err != nil {
log.Fatal(err)
}
extraction.GoroutineSem.release(1)
extraction.WaitGroup.Done()
}(astFile)
}
}
// normalizedPath computes the normalized path (with symlinks resolved) for the given file
func normalizedPath(ast *ast.File, fset *token.FileSet) string {
file := fset.File(ast.Package).Name()
path, err := filepath.EvalSymlinks(file)
if err != nil {
return file
}
return path
}
// extractFile extracts AST information for the given file
func (extraction *Extraction) extractFile(ast *ast.File, pkg *packages.Package) error {
fset := pkg.Fset
if ast.Package == token.NoPos {
log.Printf("Skipping extracting a file without a 'package' declaration")
return nil
}
path := normalizedPath(ast, fset)
extraction.FdSem.acquire(3)
log.Printf("Extracting %s", path)
start := time.Now()
defer extraction.FdSem.release(1)
tw, err := trap.NewWriter(path, pkg)
if err != nil {
extraction.FdSem.release(2)
return err
}
defer tw.Close()
err = srcarchive.Add(path)
extraction.FdSem.release(2)
if err != nil {
return err
}
extraction.extractFileInfo(tw, path, false)
extractScopes(tw, ast, pkg)
extractFileNode(tw, ast)
extractObjectTypes(tw)
extractNumLines(tw, path, ast)
end := time.Since(start)
log.Printf("Done extracting %s (%dms)", path, end.Nanoseconds()/1000000)
return nil
}
// extractFileInfo extracts file-system level information for the given file, populating
// the `files` and `containerparent` tables
func (extraction *Extraction) extractFileInfo(tw *trap.Writer, file string, isDummy bool) {
// We may visit the same file twice because `extractError` calls this function to describe files containing
// compilation errors. It is also called for user source files being extracted.
extraction.Lock.Lock()
if extraction.SeenFile(file) {
extraction.Lock.Unlock()
return
}
extraction.Lock.Unlock()
path := filepath.ToSlash(srcarchive.TransformPath(file))
components := strings.Split(path, "/")
parentPath := ""
var parentLbl trap.Label
for i, component := range components {
if i == 0 {
if component == "" {
path = "/"
} else {
path = component
}
} else {
path = parentPath + "/" + component
}
if i == len(components)-1 {
lbl := tw.Labeler.FileLabelFor(file)
dbscheme.FilesTable.Emit(tw, lbl, path)
dbscheme.ContainerParentTable.Emit(tw, parentLbl, lbl)
dbscheme.HasLocationTable.Emit(tw, lbl, emitLocation(tw, lbl, 0, 0, 0, 0))
extraction.Lock.Lock()
slbl := extraction.StatWriter.Labeler.FileLabelFor(file)
if !isDummy {
dbscheme.CompilationCompilingFilesTable.Emit(extraction.StatWriter, extraction.Label, extraction.GetFileIdx(file), slbl)
}
extraction.Lock.Unlock()
break
}
lbl := tw.Labeler.GlobalID(util.EscapeTrapSpecialChars(path) + ";folder")
dbscheme.FoldersTable.Emit(tw, lbl, path)
if i > 0 {
dbscheme.ContainerParentTable.Emit(tw, parentLbl, lbl)
}
if path != "/" {
parentPath = path
}
parentLbl = lbl
}
}
// extractLocation emits a location entity for the given entity
func extractLocation(tw *trap.Writer, entity trap.Label, sl int, sc int, el int, ec int) {
filelbl := tw.Labeler.FileLabel()
dbscheme.HasLocationTable.Emit(tw, entity, emitLocation(tw, filelbl, sl, sc, el, ec))
}
// emitLocation emits a location entity
func emitLocation(tw *trap.Writer, filelbl trap.Label, sl int, sc int, el int, ec int) trap.Label {
locLbl := tw.Labeler.GlobalID(fmt.Sprintf("loc,{%s},%d,%d,%d,%d", filelbl, sl, sc, el, ec))
dbscheme.LocationsDefaultTable.Emit(tw, locLbl, filelbl, sl, sc, el, ec)
return locLbl
}
// extractNodeLocation extracts location information for the given node
func extractNodeLocation(tw *trap.Writer, nd ast.Node, lbl trap.Label) {
if nd == nil {
return
}
fset := tw.Package.Fset
start, end := fset.Position(nd.Pos()), fset.Position(nd.End())
extractLocation(tw, lbl, start.Line, start.Column, end.Line, end.Column-1)
}
// extractPackageScope extracts symbol table information for the given package
func extractPackageScope(tw *trap.Writer, pkg *packages.Package) trap.Label {
pkgScope := pkg.Types.Scope()
pkgScopeLabel := tw.Labeler.ScopeID(pkgScope, pkg.Types)
dbscheme.ScopesTable.Emit(tw, pkgScopeLabel, dbscheme.PackageScopeType.Index())
dbscheme.ScopeNestingTable.Emit(tw, pkgScopeLabel, tw.Labeler.ScopeID(types.Universe, nil))
extractObjects(tw, pkgScope, pkgScopeLabel)
return pkgScopeLabel
}
// extractScopeLocation extracts location information for the given scope
func extractScopeLocation(tw *trap.Writer, scope *types.Scope, lbl trap.Label) {
fset := tw.Package.Fset
start, end := fset.Position(scope.Pos()), fset.Position(scope.End())
extractLocation(tw, lbl, start.Line, start.Column, end.Line, end.Column-1)
}
// extractScopes extracts symbol table information for the package scope and all local scopes
// of the given package. Note that this will not encounter methods or struct fields as
// they do not have a parent scope.
func extractScopes(tw *trap.Writer, nd *ast.File, pkg *packages.Package) {
pkgScopeLabel := extractPackageScope(tw, pkg)
fileScope := pkg.TypesInfo.Scopes[nd]
if fileScope != nil {
extractLocalScope(tw, fileScope, pkgScopeLabel)
}
}
// extractLocalScope extracts symbol table information for the given scope and all its nested scopes
func extractLocalScope(tw *trap.Writer, scope *types.Scope, parentScopeLabel trap.Label) {
scopeLabel := tw.Labeler.ScopeID(scope, nil)
dbscheme.ScopesTable.Emit(tw, scopeLabel, dbscheme.LocalScopeType.Index())
extractScopeLocation(tw, scope, scopeLabel)
dbscheme.ScopeNestingTable.Emit(tw, scopeLabel, parentScopeLabel)
for i := 0; i < scope.NumChildren(); i++ {
childScope := scope.Child(i)
extractLocalScope(tw, childScope, scopeLabel)
}
extractObjects(tw, scope, scopeLabel)
}
// extractFileNode extracts AST information for the given file and all nodes contained in it
func extractFileNode(tw *trap.Writer, nd *ast.File) {
lbl := tw.Labeler.FileLabel()
extractExpr(tw, nd.Name, lbl, 0, false)
for i, decl := range nd.Decls {
extractDecl(tw, decl, lbl, i)
}
for i, cg := range nd.Comments {
extractCommentGroup(tw, cg, lbl, i)
}
extractDoc(tw, nd.Doc, lbl)
emitScopeNodeInfo(tw, nd, lbl)
}
// extractDoc extracts information about a doc comment group associated with a given element
func extractDoc(tw *trap.Writer, doc *ast.CommentGroup, elt trap.Label) {
if doc != nil {
dbscheme.DocCommentsTable.Emit(tw, elt, tw.Labeler.LocalID(doc))
}
}
// extractCommentGroup extracts information about a doc comment group
func extractCommentGroup(tw *trap.Writer, cg *ast.CommentGroup, parent trap.Label, idx int) {
lbl := tw.Labeler.LocalID(cg)
dbscheme.CommentGroupsTable.Emit(tw, lbl, parent, idx)
extractNodeLocation(tw, cg, lbl)
for i, c := range cg.List {
extractComment(tw, c, lbl, i)
}
}
// extractComment extracts information about a given comment
func extractComment(tw *trap.Writer, c *ast.Comment, parent trap.Label, idx int) {
lbl := tw.Labeler.LocalID(c)
rawText := c.Text
var kind int
var text string
if rawText[:2] == "//" {
kind = dbscheme.SlashSlashComment.Index()
text = rawText[2:]
} else {
kind = dbscheme.SlashStarComment.Index()
text = rawText[2 : len(rawText)-2]
}
dbscheme.CommentsTable.Emit(tw, lbl, kind, parent, idx, text)
extractNodeLocation(tw, c, lbl)
}
// emitScopeNodeInfo associates an AST node with its induced scope, if any
func emitScopeNodeInfo(tw *trap.Writer, nd ast.Node, lbl trap.Label) {
scope, exists := tw.Package.TypesInfo.Scopes[nd]
if exists {
dbscheme.ScopeNodesTable.Emit(tw, lbl, tw.Labeler.ScopeID(scope, tw.Package.Types))
}
}
// extractExpr extracts AST information for the given expression and all its subexpressions
func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int, skipExtractingValue bool) {
if expr == nil || expr == (*ast.Ident)(nil) || expr == (*ast.BasicLit)(nil) ||
expr == (*ast.Ellipsis)(nil) || expr == (*ast.FuncLit)(nil) ||
expr == (*ast.CompositeLit)(nil) || expr == (*ast.SelectorExpr)(nil) ||
expr == (*ast.IndexListExpr)(nil) || expr == (*ast.SliceExpr)(nil) ||
expr == (*ast.TypeAssertExpr)(nil) || expr == (*ast.CallExpr)(nil) ||
expr == (*ast.StarExpr)(nil) || expr == (*ast.KeyValueExpr)(nil) ||
expr == (*ast.UnaryExpr)(nil) || expr == (*ast.BinaryExpr)(nil) ||
expr == (*ast.ArrayType)(nil) || expr == (*ast.StructType)(nil) ||
expr == (*ast.FuncType)(nil) || expr == (*ast.InterfaceType)(nil) ||
expr == (*ast.MapType)(nil) || expr == (*ast.ChanType)(nil) {
return
}
lbl := tw.Labeler.LocalID(expr)
extractTypeOf(tw, expr, lbl)
var kind int
switch expr := expr.(type) {
case *ast.BadExpr:
kind = dbscheme.BadExpr.Index()
case *ast.Ident:
kind = dbscheme.IdentExpr.Index()
dbscheme.LiteralsTable.Emit(tw, lbl, expr.Name, expr.Name)
def := tw.Package.TypesInfo.Defs[expr]
// Note that there are some cases where `expr` is in the map but `def`
// is nil. The docs for `tw.Package.TypesInfo.Defs` give the following
// examples: the package name in package clauses, or symbolic variables
// `t` in `t := x.(type)` of type switch headers.
if def != nil {
defTyp := extractType(tw, def.Type())
objlbl, exists := tw.Labeler.LookupObjectID(def, defTyp)
if objlbl == trap.InvalidLabel {
log.Printf("Omitting def binding to unknown object %v", def)
} else {
if !exists {
extractObject(tw, def, objlbl)
}
dbscheme.DefsTable.Emit(tw, lbl, objlbl)
}
}
use := getObjectBeingUsed(tw, expr)
if use != nil {
useTyp := extractType(tw, use.Type())
objlbl, exists := tw.Labeler.LookupObjectID(use, useTyp)
if objlbl == trap.InvalidLabel {
log.Printf("Omitting use binding to unknown object %v", use)
} else {
if !exists {
extractObject(tw, use, objlbl)
}
dbscheme.UsesTable.Emit(tw, lbl, objlbl)
}
}
case *ast.Ellipsis:
kind = dbscheme.EllipsisExpr.Index()
extractExpr(tw, expr.Elt, lbl, 0, false)
case *ast.BasicLit:
value := ""
switch expr.Kind {
case token.INT:
ival, _ := strconv.ParseInt(expr.Value, 0, 64)
value = strconv.FormatInt(ival, 10)
kind = dbscheme.IntLitExpr.Index()
case token.FLOAT:
value = expr.Value
kind = dbscheme.FloatLitExpr.Index()
case token.IMAG:
value = expr.Value
kind = dbscheme.ImagLitExpr.Index()
case token.CHAR:
value, _ = strconv.Unquote(expr.Value)
kind = dbscheme.CharLitExpr.Index()
case token.STRING:
value, _ = strconv.Unquote(expr.Value)
kind = dbscheme.StringLitExpr.Index()
default:
log.Fatalf("unknown literal kind %v", expr.Kind)
}
dbscheme.LiteralsTable.Emit(tw, lbl, value, expr.Value)
case *ast.FuncLit:
kind = dbscheme.FuncLitExpr.Index()
extractExpr(tw, expr.Type, lbl, 0, false)
extractStmt(tw, expr.Body, lbl, 1)
case *ast.CompositeLit:
kind = dbscheme.CompositeLitExpr.Index()
extractExpr(tw, expr.Type, lbl, 0, false)
extractExprs(tw, expr.Elts, lbl, 1, 1)
case *ast.ParenExpr:
kind = dbscheme.ParenExpr.Index()
extractExpr(tw, expr.X, lbl, 0, false)
case *ast.SelectorExpr:
kind = dbscheme.SelectorExpr.Index()
extractExpr(tw, expr.X, lbl, 0, false)
extractExpr(tw, expr.Sel, lbl, 1, false)
case *ast.IndexExpr:
typeofx := typeOf(tw, expr.X)
if typeofx == nil {
// We are missing type information for `expr.X`, so we cannot
// determine whether this is a generic function instantiation
// or not.
kind = dbscheme.IndexExpr.Index()
} else {
if _, ok := typeofx.Underlying().(*types.Signature); ok {
kind = dbscheme.GenericFunctionInstantiationExpr.Index()
} else {
// Can't distinguish between actual index expressions (into a
// map, array, slice, string or pointer to array) and generic
// type specialization expression, so we do it later in QL.
kind = dbscheme.IndexExpr.Index()
}
}
extractExpr(tw, expr.X, lbl, 0, false)
extractExpr(tw, expr.Index, lbl, 1, false)
case *ast.IndexListExpr:
typeofx := typeOf(tw, expr.X)
if typeofx == nil {
// We are missing type information for `expr.X`, so we cannot
// determine whether this is a generic function instantiation
// or not.
kind = dbscheme.GenericTypeInstantiationExpr.Index()
} else {
if _, ok := typeofx.Underlying().(*types.Signature); ok {
kind = dbscheme.GenericFunctionInstantiationExpr.Index()
} else {
kind = dbscheme.GenericTypeInstantiationExpr.Index()
}
}
extractExpr(tw, expr.X, lbl, 0, false)
extractExprs(tw, expr.Indices, lbl, 1, 1)
case *ast.SliceExpr:
kind = dbscheme.SliceExpr.Index()
extractExpr(tw, expr.X, lbl, 0, false)
extractExpr(tw, expr.Low, lbl, 1, false)
extractExpr(tw, expr.High, lbl, 2, false)
extractExpr(tw, expr.Max, lbl, 3, false)
case *ast.TypeAssertExpr:
kind = dbscheme.TypeAssertExpr.Index()
extractExpr(tw, expr.X, lbl, 0, false)
// expr.Type can be `nil` if this is the `x.(type)` in a type switch.
if expr.Type != nil {
extractExpr(tw, expr.Type, lbl, 1, false)
}
case *ast.CallExpr:
kind = dbscheme.CallOrConversionExpr.Index()
extractExpr(tw, expr.Fun, lbl, 0, false)
extractExprs(tw, expr.Args, lbl, 1, 1)
if expr.Ellipsis.IsValid() {
dbscheme.HasEllipsisTable.Emit(tw, lbl)
}
case *ast.StarExpr:
kind = dbscheme.StarExpr.Index()
extractExpr(tw, expr.X, lbl, 0, false)
case *ast.KeyValueExpr:
kind = dbscheme.KeyValueExpr.Index()
extractExpr(tw, expr.Key, lbl, 0, false)
extractExpr(tw, expr.Value, lbl, 1, false)
case *ast.UnaryExpr:
if expr.Op == token.TILDE {
kind = dbscheme.TypeSetLiteralExpr.Index()
} else {
tp := dbscheme.UnaryExprs[expr.Op]
if tp == nil {
log.Fatalf("unsupported unary operator %s", expr.Op)
}
kind = tp.Index()
}
extractExpr(tw, expr.X, lbl, 0, false)
case *ast.BinaryExpr:
_, isUnionType := typeOf(tw, expr).(*types.Union)
if expr.Op == token.OR && isUnionType {
kind = dbscheme.TypeSetLiteralExpr.Index()
flattenBinaryExprTree(tw, expr, lbl, 0)
} else {
tp := dbscheme.BinaryExprs[expr.Op]
if tp == nil {
log.Fatalf("unsupported binary operator %s", expr.Op)
}
kind = tp.Index()
skipLeft := skipExtractingValueForLeftOperand(tw, expr)
extractExpr(tw, expr.X, lbl, 0, skipLeft)
extractExpr(tw, expr.Y, lbl, 1, false)
}
case *ast.ArrayType:
kind = dbscheme.ArrayTypeExpr.Index()
extractExpr(tw, expr.Len, lbl, 0, false)
extractExpr(tw, expr.Elt, lbl, 1, false)
case *ast.StructType:
kind = dbscheme.StructTypeExpr.Index()
extractFields(tw, expr.Fields, lbl, 0, 1)
case *ast.FuncType:
kind = dbscheme.FuncTypeExpr.Index()
extractFields(tw, expr.Params, lbl, 0, 1)
extractFields(tw, expr.Results, lbl, -1, -1)
emitScopeNodeInfo(tw, expr, lbl)
case *ast.InterfaceType:
kind = dbscheme.InterfaceTypeExpr.Index()
// expr.Methods contains methods, embedded interfaces and type set
// literals.
makeTypeSetLiteralsUnionTyped(tw, expr.Methods)
extractFields(tw, expr.Methods, lbl, 0, 1)
case *ast.MapType:
kind = dbscheme.MapTypeExpr.Index()
extractExpr(tw, expr.Key, lbl, 0, false)
extractExpr(tw, expr.Value, lbl, 1, false)
case *ast.ChanType:
tp := dbscheme.ChanTypeExprs[expr.Dir]
if tp == nil {
log.Fatalf("unsupported channel direction %v", expr.Dir)
}
kind = tp.Index()
extractExpr(tw, expr.Value, lbl, 0, false)
default:
log.Fatalf("unknown expression of type %T", expr)
}
dbscheme.ExprsTable.Emit(tw, lbl, kind, parent, idx)
extractNodeLocation(tw, expr, lbl)
if !skipExtractingValue {
extractValueOf(tw, expr, lbl)
}
}
// extractExprs extracts AST information for a list of expressions, which are children of
// the given parent
// `idx` is the index of the first child in the list, and `dir` is the index increment of
// each child over its preceding child (usually either 1 for assigning increasing indices, or
// -1 for decreasing indices)
func extractExprs(tw *trap.Writer, exprs []ast.Expr, parent trap.Label, idx int, dir int) {
for _, expr := range exprs {
extractExpr(tw, expr, parent, idx, false)
idx += dir
}
}
// extractTypeOf looks up the type of `expr`, extracts it if it hasn't previously been
// extracted, and associates it with `expr` in the `type_of` table
func extractTypeOf(tw *trap.Writer, expr ast.Expr, lbl trap.Label) {
tp := typeOf(tw, expr)
if tp != nil {
tplbl := extractType(tw, tp)
dbscheme.TypeOfTable.Emit(tw, lbl, tplbl)
}
}
// extractValueOf looks up the value of `expr`, and associates it with `expr` in
// the `consts` table
func extractValueOf(tw *trap.Writer, expr ast.Expr, lbl trap.Label) {
tpVal := tw.Package.TypesInfo.Types[expr]
if tpVal.Value != nil {
// if Value is non-nil, the expression has a constant value
// note that string literals in import statements do not have an associated
// Value and so do not get added to the table
var value string
exact := tpVal.Value.ExactString()
switch tpVal.Value.Kind() {
case constant.String:
// we need to unquote strings
value = constant.StringVal(tpVal.Value)
exact = constant.StringVal(tpVal.Value)
case constant.Float:
flval, _ := constant.Float64Val(tpVal.Value)
value = fmt.Sprintf("%.20g", flval)
case constant.Complex:
real, _ := constant.Float64Val(constant.Real(tpVal.Value))
imag, _ := constant.Float64Val(constant.Imag(tpVal.Value))
value = fmt.Sprintf("(%.20g + %.20gi)", real, imag)
default:
value = tpVal.Value.ExactString()
}
dbscheme.ConstValuesTable.Emit(tw, lbl, value, exact)
} else if tpVal.IsNil() {
dbscheme.ConstValuesTable.Emit(tw, lbl, "nil", "nil")
}
}
// extractFields extracts AST information for a list of fields, which are children of
// the given parent
// `idx` is the index of the first child in the list, and `dir` is the index increment of
// each child over its preceding child (usually either 1 for assigning increasing indices, or
// -1 for decreasing indices)
func extractFields(tw *trap.Writer, fields *ast.FieldList, parent trap.Label, idx int, dir int) {
if fields == nil || fields.List == nil {
return
}
for _, field := range fields.List {
lbl := tw.Labeler.LocalID(field)
dbscheme.FieldsTable.Emit(tw, lbl, parent, idx)
extractNodeLocation(tw, field, lbl)
if field.Names != nil {
for i, name := range field.Names {
extractExpr(tw, name, lbl, i+1, false)
}
}
extractExpr(tw, field.Type, lbl, 0, false)
extractExpr(tw, field.Tag, lbl, -1, false)
extractDoc(tw, field.Doc, lbl)
idx += dir
}
}
// extractStmt extracts AST information for a given statement and all other statements or expressions
// nested inside it
func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
if stmt == nil || stmt == (*ast.DeclStmt)(nil) ||
stmt == (*ast.LabeledStmt)(nil) || stmt == (*ast.ExprStmt)(nil) ||
stmt == (*ast.SendStmt)(nil) || stmt == (*ast.IncDecStmt)(nil) ||
stmt == (*ast.AssignStmt)(nil) || stmt == (*ast.GoStmt)(nil) ||
stmt == (*ast.DeferStmt)(nil) || stmt == (*ast.BranchStmt)(nil) ||
stmt == (*ast.BlockStmt)(nil) || stmt == (*ast.IfStmt)(nil) ||
stmt == (*ast.CaseClause)(nil) || stmt == (*ast.SwitchStmt)(nil) ||
stmt == (*ast.TypeSwitchStmt)(nil) || stmt == (*ast.CommClause)(nil) ||
stmt == (*ast.ForStmt)(nil) || stmt == (*ast.RangeStmt)(nil) {
return
}
lbl := tw.Labeler.LocalID(stmt)
var kind int
switch stmt := stmt.(type) {
case *ast.BadStmt:
kind = dbscheme.BadStmtType.Index()
case *ast.DeclStmt:
kind = dbscheme.DeclStmtType.Index()
extractDecl(tw, stmt.Decl, lbl, 0)
case *ast.EmptyStmt:
kind = dbscheme.EmptyStmtType.Index()
case *ast.LabeledStmt:
kind = dbscheme.LabeledStmtType.Index()
extractExpr(tw, stmt.Label, lbl, 0, false)
extractStmt(tw, stmt.Stmt, lbl, 1)
case *ast.ExprStmt:
kind = dbscheme.ExprStmtType.Index()
extractExpr(tw, stmt.X, lbl, 0, false)
case *ast.SendStmt:
kind = dbscheme.SendStmtType.Index()
extractExpr(tw, stmt.Chan, lbl, 0, false)
extractExpr(tw, stmt.Value, lbl, 1, false)
case *ast.IncDecStmt:
if stmt.Tok == token.INC {
kind = dbscheme.IncStmtType.Index()
} else if stmt.Tok == token.DEC {
kind = dbscheme.DecStmtType.Index()
} else {
log.Fatalf("unsupported increment/decrement operator %v", stmt.Tok)
}
extractExpr(tw, stmt.X, lbl, 0, false)
case *ast.AssignStmt:
tp := dbscheme.AssignStmtTypes[stmt.Tok]
if tp == nil {
log.Fatalf("unsupported assignment statement with operator %v", stmt.Tok)
}
kind = tp.Index()
extractExprs(tw, stmt.Lhs, lbl, -1, -1)
extractExprs(tw, stmt.Rhs, lbl, 1, 1)
case *ast.GoStmt:
kind = dbscheme.GoStmtType.Index()
extractExpr(tw, stmt.Call, lbl, 0, false)
case *ast.DeferStmt:
kind = dbscheme.DeferStmtType.Index()
extractExpr(tw, stmt.Call, lbl, 0, false)
case *ast.ReturnStmt:
kind = dbscheme.ReturnStmtType.Index()
extractExprs(tw, stmt.Results, lbl, 0, 1)
case *ast.BranchStmt:
switch stmt.Tok {
case token.BREAK:
kind = dbscheme.BreakStmtType.Index()
case token.CONTINUE:
kind = dbscheme.ContinueStmtType.Index()
case token.GOTO:
kind = dbscheme.GotoStmtType.Index()
case token.FALLTHROUGH:
kind = dbscheme.FallthroughStmtType.Index()
default:
log.Fatalf("unsupported branch statement type %v", stmt.Tok)
}
extractExpr(tw, stmt.Label, lbl, 0, false)
case *ast.BlockStmt:
kind = dbscheme.BlockStmtType.Index()
extractStmts(tw, stmt.List, lbl, 0, 1)
emitScopeNodeInfo(tw, stmt, lbl)
case *ast.IfStmt:
kind = dbscheme.IfStmtType.Index()
extractStmt(tw, stmt.Init, lbl, 0)
extractExpr(tw, stmt.Cond, lbl, 1, false)
extractStmt(tw, stmt.Body, lbl, 2)
extractStmt(tw, stmt.Else, lbl, 3)
emitScopeNodeInfo(tw, stmt, lbl)
case *ast.CaseClause:
kind = dbscheme.CaseClauseType.Index()
extractExprs(tw, stmt.List, lbl, -1, -1)
extractStmts(tw, stmt.Body, lbl, 0, 1)
emitScopeNodeInfo(tw, stmt, lbl)
case *ast.SwitchStmt:
kind = dbscheme.ExprSwitchStmtType.Index()
extractStmt(tw, stmt.Init, lbl, 0)
extractExpr(tw, stmt.Tag, lbl, 1, false)
extractStmt(tw, stmt.Body, lbl, 2)
emitScopeNodeInfo(tw, stmt, lbl)
case *ast.TypeSwitchStmt:
kind = dbscheme.TypeSwitchStmtType.Index()
extractStmt(tw, stmt.Init, lbl, 0)
extractStmt(tw, stmt.Assign, lbl, 1)
extractStmt(tw, stmt.Body, lbl, 2)
emitScopeNodeInfo(tw, stmt, lbl)
case *ast.CommClause:
kind = dbscheme.CommClauseType.Index()
extractStmt(tw, stmt.Comm, lbl, 0)
extractStmts(tw, stmt.Body, lbl, 1, 1)
emitScopeNodeInfo(tw, stmt, lbl)
case *ast.SelectStmt:
kind = dbscheme.SelectStmtType.Index()
extractStmt(tw, stmt.Body, lbl, 0)
case *ast.ForStmt:
kind = dbscheme.ForStmtType.Index()
extractStmt(tw, stmt.Init, lbl, 0)
extractExpr(tw, stmt.Cond, lbl, 1, false)
extractStmt(tw, stmt.Post, lbl, 2)
extractStmt(tw, stmt.Body, lbl, 3)
emitScopeNodeInfo(tw, stmt, lbl)
case *ast.RangeStmt:
kind = dbscheme.RangeStmtType.Index()
extractExpr(tw, stmt.Key, lbl, 0, false)
extractExpr(tw, stmt.Value, lbl, 1, false)
extractExpr(tw, stmt.X, lbl, 2, false)
extractStmt(tw, stmt.Body, lbl, 3)
emitScopeNodeInfo(tw, stmt, lbl)
default:
log.Fatalf("unknown statement of type %T", stmt)
}
dbscheme.StmtsTable.Emit(tw, lbl, kind, parent, idx)
extractNodeLocation(tw, stmt, lbl)
}
// extractStmts extracts AST information for a list of statements, which are children of
// the given parent
// `idx` is the index of the first child in the list, and `dir` is the index increment of
// each child over its preceding child (usually either 1 for assigning increasing indices, or
// -1 for decreasing indices)
func extractStmts(tw *trap.Writer, stmts []ast.Stmt, parent trap.Label, idx int, dir int) {
for _, stmt := range stmts {
extractStmt(tw, stmt, parent, idx)
idx += dir
}
}
// extractDecl extracts AST information for the given declaration
func extractDecl(tw *trap.Writer, decl ast.Decl, parent trap.Label, idx int) {
if decl == (*ast.FuncDecl)(nil) || decl == (*ast.GenDecl)(nil) {
return
}
lbl := tw.Labeler.LocalID(decl)
var kind int
switch decl := decl.(type) {
case *ast.BadDecl:
kind = dbscheme.BadDeclType.Index()
case *ast.GenDecl:
switch decl.Tok {
case token.IMPORT:
kind = dbscheme.ImportDeclType.Index()
case token.CONST:
kind = dbscheme.ConstDeclType.Index()
case token.TYPE:
kind = dbscheme.TypeDeclType.Index()
case token.VAR:
kind = dbscheme.VarDeclType.Index()
default:
log.Fatalf("unknown declaration of kind %v", decl.Tok)
}
for i, spec := range decl.Specs {
extractSpec(tw, spec, lbl, i)
}
extractDoc(tw, decl.Doc, lbl)
case *ast.FuncDecl:
kind = dbscheme.FuncDeclType.Index()
extractFields(tw, decl.Recv, lbl, -1, -1)
extractExpr(tw, decl.Name, lbl, 0, false)
extractExpr(tw, decl.Type, lbl, 1, false)
extractStmt(tw, decl.Body, lbl, 2)
extractDoc(tw, decl.Doc, lbl)
extractTypeParamDecls(tw, decl.Type.TypeParams, lbl)
// Note that we currently don't extract any kind of declaration for
// receiver type parameters. There isn't an explicit declaration, but
// we could consider the index/indices of an IndexExpr/IndexListExpr
// receiver as declarations.
default:
log.Fatalf("unknown declaration of type %T", decl)
}
dbscheme.DeclsTable.Emit(tw, lbl, kind, parent, idx)
extractNodeLocation(tw, decl, lbl)
}
// extractSpec extracts AST information for the given declaration specifier
func extractSpec(tw *trap.Writer, spec ast.Spec, parent trap.Label, idx int) {
lbl := tw.Labeler.LocalID(spec)
var kind int
switch spec := spec.(type) {
case *ast.ImportSpec:
if spec == nil {
return
}
kind = dbscheme.ImportSpecType.Index()
extractExpr(tw, spec.Name, lbl, 0, false)
extractExpr(tw, spec.Path, lbl, 1, false)
extractDoc(tw, spec.Doc, lbl)
case *ast.ValueSpec:
if spec == nil {
return
}
kind = dbscheme.ValueSpecType.Index()
for i, name := range spec.Names {
extractExpr(tw, name, lbl, -(1 + i), false)
}
extractExpr(tw, spec.Type, lbl, 0, false)
extractExprs(tw, spec.Values, lbl, 1, 1)
extractDoc(tw, spec.Doc, lbl)
case *ast.TypeSpec:
if spec == nil {
return
}
if spec.Assign.IsValid() {
kind = dbscheme.AliasSpecType.Index()
} else {
kind = dbscheme.TypeDefSpecType.Index()
}
extractExpr(tw, spec.Name, lbl, 0, false)
extractTypeParamDecls(tw, spec.TypeParams, lbl)
extractExpr(tw, spec.Type, lbl, 1, false)
extractDoc(tw, spec.Doc, lbl)
}
dbscheme.SpecsTable.Emit(tw, lbl, kind, parent, idx)
extractNodeLocation(tw, spec, lbl)
}
// Determines whether the given type is an alias.
func isAlias(tp types.Type) bool {
_, ok := tp.(*types.Alias)
return ok
}
// extractType extracts type information for `tp` and returns its associated label;
// types are only extracted once, so the second time `extractType` is invoked it simply returns the label
func extractType(tw *trap.Writer, tp types.Type) trap.Label {
tp = types.Unalias(tp)
lbl, exists := getTypeLabel(tw, tp)
if !exists {
var kind int
switch tp := tp.(type) {
case *types.Basic:
branch := dbscheme.BasicTypes[tp.Kind()]
if branch == nil {
log.Fatalf("unknown basic type %v", tp.Kind())
}
kind = branch.Index()
case *types.Array:
kind = dbscheme.ArrayType.Index()
dbscheme.ArrayLengthTable.Emit(tw, lbl, fmt.Sprintf("%d", tp.Len()))
extractElementType(tw, lbl, tp.Elem())
case *types.Slice:
kind = dbscheme.SliceType.Index()
extractElementType(tw, lbl, tp.Elem())
case *types.Struct:
kind = dbscheme.StructType.Index()
for i := 0; i < tp.NumFields(); i++ {
field := tp.Field(i).Origin()
// ensure the field is associated with a label - note that
// struct fields do not have a parent scope, so they are not
// dealt with by `extractScopes`
fieldlbl, exists := tw.Labeler.FieldID(field, i, lbl)
if !exists {
extractObject(tw, field, fieldlbl)
}
dbscheme.FieldStructsTable.Emit(tw, fieldlbl, lbl)
name := field.Name()
if field.Embedded() {
name = ""
}
extractComponentType(tw, lbl, i, name, field.Type())
if tp.Tag(i) != "" {
dbscheme.StructTagsTable.Emit(tw, lbl, i, tp.Tag(i))
}
}
case *types.Pointer:
kind = dbscheme.PointerType.Index()
extractBaseType(tw, lbl, tp.Elem())
case *types.Interface:
kind = dbscheme.InterfaceType.Index()
for i := 0; i < tp.NumMethods(); i++ {
// Note that methods coming from embedded interfaces can be
// accessed through `Method(i)`, so there is no need to
// deal with them separately.
meth := tp.Method(i).Origin()
// Note that methods do not have a parent scope, so they are
// not dealt with by `extractScopes`
extractMethod(tw, meth)
extractComponentType(tw, lbl, i, meth.Name(), meth.Type())
if !meth.Exported() {
dbscheme.InterfacePrivateMethodIdsTable.Emit(tw, lbl, i, meth.Id())
}
}
for i := 0; i < tp.NumEmbeddeds(); i++ {
component := tp.EmbeddedType(i)
if isNonUnionTypeSetLiteral(component) {
component = createUnionFromType(component)
}
extractComponentType(tw, lbl, -(i + 1), "", component)
}
case *types.Tuple:
kind = dbscheme.TupleType.Index()
for i := 0; i < tp.Len(); i++ {
extractComponentType(tw, lbl, i, "", tp.At(i).Type())
}
case *types.Signature:
kind = dbscheme.SignatureType.Index()
params, results := tp.Params(), tp.Results()
if params != nil {
for i := 0; i < params.Len(); i++ {
param := params.At(i)
extractComponentType(tw, lbl, i+1, "", param.Type())
}
}
if results != nil {
for i := 0; i < results.Len(); i++ {
result := results.At(i)
extractComponentType(tw, lbl, -(i + 1), "", result.Type())
}
}
if tp.Variadic() {
dbscheme.VariadicTable.Emit(tw, lbl)
}
case *types.Map:
kind = dbscheme.MapType.Index()
extractKeyType(tw, lbl, tp.Key())
extractElementType(tw, lbl, tp.Elem())
case *types.Chan:
kind = dbscheme.ChanTypes[tp.Dir()].Index()
extractElementType(tw, lbl, tp.Elem())
case *types.Named:
origintp := tp.Origin()
kind = dbscheme.DefinedType.Index()
dbscheme.TypeNameTable.Emit(tw, lbl, origintp.Obj().Name())
underlying := origintp.Underlying()
extractUnderlyingType(tw, lbl, underlying)
trackInstantiatedStructFields(tw, tp, origintp)
entitylbl, exists := tw.Labeler.LookupObjectID(origintp.Obj(), lbl)
if entitylbl == trap.InvalidLabel {
log.Printf("Omitting type-object binding for unknown object %v.\n", origintp.Obj())
} else {
if !exists {
extractObject(tw, origintp.Obj(), entitylbl)
}
dbscheme.TypeObjectTable.Emit(tw, lbl, entitylbl)
}
// ensure all methods have labels - note that methods do not have a
// parent scope, so they are not dealt with by `extractScopes`
for i := 0; i < origintp.NumMethods(); i++ {
meth := origintp.Method(i).Origin()
extractMethod(tw, meth)
}
underlyingInterface, underlyingIsInterface := underlying.(*types.Interface)
_, underlyingIsPointer := underlying.(*types.Pointer)
// associate all methods of underlying interface with this type
if underlyingIsInterface {
for i := 0; i < underlyingInterface.NumMethods(); i++ {
methlbl := extractMethod(tw, underlyingInterface.Method(i).Origin())
dbscheme.MethodHostsTable.Emit(tw, methlbl, lbl)
}
}
// If `underlying` is not a pointer or interface then methods can
// be defined on `origintp`. In this case we must ensure that
// `*origintp` is in the database, so that Method.hasQualifiedName
// correctly includes methods with receiver type `*origintp`.
if !underlyingIsInterface && !underlyingIsPointer {
extractType(tw, types.NewPointer(origintp))
}
case *types.TypeParam:
kind = dbscheme.TypeParamType.Index()
parentlbl := getTypeParamParentLabel(tw, tp)
constraintLabel := extractType(tw, tp.Constraint())
dbscheme.TypeParamTable.Emit(tw, lbl, tp.Obj().Name(), constraintLabel, parentlbl, tp.Index())
case *types.Union:
kind = dbscheme.TypeSetLiteral.Index()
for i := 0; i < tp.Len(); i++ {
term := tp.Term(i)
tildeStr := ""
if term.Tilde() {
tildeStr = "~"
}
extractComponentType(tw, lbl, i, tildeStr, term.Type())
}
default:
log.Fatalf("unexpected type %T", tp)
}
dbscheme.TypesTable.Emit(tw, lbl, kind)
}
return lbl
}
// getTypeLabel looks up the label associated with `tp`, creating a new label if
// it does not have one yet; the second result indicates whether the label
// already existed
//
// Type labels refer to global keys to ensure that if the same type is
// encountered during the extraction of different files it is still ultimately
// mapped to the same entity. In particular, this means that keys for compound
// types refer to the labels of their component types. For defined types, the key
// is constructed from their globally unique ID. This prevents cyclic type keys
// since type recursion in Go always goes through defined types.
func getTypeLabel(tw *trap.Writer, tp types.Type) (trap.Label, bool) {
tp = types.Unalias(tp)
lbl, exists := tw.Labeler.TypeLabels[tp]
if !exists {
switch tp := tp.(type) {
case *types.Basic:
lbl = tw.Labeler.GlobalID(fmt.Sprintf("%d;basictype", tp.Kind()))
case *types.Array:
len := tp.Len()
elem := extractType(tw, tp.Elem())
lbl = tw.Labeler.GlobalID(fmt.Sprintf("%d,{%s};arraytype", len, elem))
case *types.Slice:
elem := extractType(tw, tp.Elem())
lbl = tw.Labeler.GlobalID(fmt.Sprintf("{%s};slicetype", elem))
case *types.Struct:
var b strings.Builder
for i := 0; i < tp.NumFields(); i++ {
field := tp.Field(i)
fieldTypeLbl := extractType(tw, field.Type())
if i > 0 {
b.WriteString(",")
}
name := field.Name()
if field.Embedded() {
name = ""
}
fmt.Fprintf(&b, "%s,{%s},%s", name, fieldTypeLbl, util.EscapeTrapSpecialChars(tp.Tag(i)))
}
lbl = tw.Labeler.GlobalID(fmt.Sprintf("%s;structtype", b.String()))
case *types.Pointer:
base := extractType(tw, tp.Elem())
lbl = tw.Labeler.GlobalID(fmt.Sprintf("{%s};pointertype", base))
case *types.Interface:
var b strings.Builder
for i := 0; i < tp.NumMethods(); i++ {
meth := tp.Method(i).Origin()
methLbl := extractType(tw, meth.Type())
if i > 0 {
b.WriteString(",")
}
fmt.Fprintf(&b, "%s,{%s}", meth.Id(), methLbl)
}
b.WriteString(";")
for i := 0; i < tp.NumEmbeddeds(); i++ {
if i > 0 {
b.WriteString(",")
}
fmt.Fprintf(&b, "{%s}", extractType(tw, tp.EmbeddedType(i)))
}
// We note whether the interface is comparable so that we can
// distinguish the underlying type of `comparable` from an
// empty interface.
if tp.IsComparable() {
b.WriteString(";comparable")
}
lbl = tw.Labeler.GlobalID(fmt.Sprintf("%s;interfacetype", b.String()))
case *types.Tuple:
var b strings.Builder
for i := 0; i < tp.Len(); i++ {
compLbl := extractType(tw, tp.At(i).Type())
if i > 0 {
b.WriteString(",")
}
fmt.Fprintf(&b, "{%s}", compLbl)
}
lbl = tw.Labeler.GlobalID(fmt.Sprintf("%s;tupletype", b.String()))
case *types.Signature:
var b strings.Builder
params, results := tp.Params(), tp.Results()
if params != nil {
for i := 0; i < params.Len(); i++ {
paramLbl := extractType(tw, params.At(i).Type())
if i > 0 {
b.WriteString(",")
}
fmt.Fprintf(&b, "{%s}", paramLbl)
}
}
b.WriteString(";")
if results != nil {
for i := 0; i < results.Len(); i++ {
resultLbl := extractType(tw, results.At(i).Type())
if i > 0 {
b.WriteString(",")
}
fmt.Fprintf(&b, "{%s}", resultLbl)
}
}
if tp.Variadic() {
b.WriteString(";variadic")
}
lbl = tw.Labeler.GlobalID(fmt.Sprintf("%s;signaturetype", b.String()))
case *types.Map:
key := extractType(tw, tp.Key())
value := extractType(tw, tp.Elem())
lbl = tw.Labeler.GlobalID(fmt.Sprintf("{%s},{%s};maptype", key, value))
case *types.Chan:
dir := tp.Dir()
elem := extractType(tw, tp.Elem())
lbl = tw.Labeler.GlobalID(fmt.Sprintf("%v,{%s};chantype", dir, elem))
case *types.Named:
origintp := tp.Origin()
entitylbl, exists := tw.Labeler.LookupObjectID(origintp.Obj(), lbl)
if entitylbl == trap.InvalidLabel {
panic(fmt.Sprintf("Cannot construct label for defined type %v (underlying object is %v).\n", origintp, origintp.Obj()))
}
if !exists {
extractObject(tw, origintp.Obj(), entitylbl)
}
lbl = tw.Labeler.GlobalID(fmt.Sprintf("{%s};definedtype", entitylbl))
case *types.TypeParam:
parentlbl := getTypeParamParentLabel(tw, tp)
idx := tp.Index()
lbl = tw.Labeler.GlobalID(fmt.Sprintf("{%v},%d,%s;typeparamtype", parentlbl, idx, tp.Obj().Name()))
case *types.Union:
var b strings.Builder
for i := 0; i < tp.Len(); i++ {
compLbl := extractType(tw, tp.Term(i).Type())
if i > 0 {
b.WriteString("|")
}
if tp.Term(i).Tilde() {
b.WriteString("~")
}
fmt.Fprintf(&b, "{%s}", compLbl)
}
lbl = tw.Labeler.GlobalID(fmt.Sprintf("%s;typesetliteraltype", b.String()))
default:
log.Fatalf("(getTypeLabel) unexpected type %T", tp)
}
tw.Labeler.TypeLabels[tp] = lbl
}
return lbl, exists
}
// extractKeyType extracts `key` as the key type of the map type `mp`
func extractKeyType(tw *trap.Writer, mp trap.Label, key types.Type) {
dbscheme.KeyTypeTable.Emit(tw, mp, extractType(tw, key))
}
// extractElementType extracts `element` as the element type of the container type `container`
func extractElementType(tw *trap.Writer, container trap.Label, element types.Type) {
dbscheme.ElementTypeTable.Emit(tw, container, extractType(tw, element))
}
// extractBaseType extracts `base` as the base type of the pointer type `ptr`
func extractBaseType(tw *trap.Writer, ptr trap.Label, base types.Type) {
dbscheme.BaseTypeTable.Emit(tw, ptr, extractType(tw, base))
}
// extractUnderlyingType extracts `underlying` as the underlying type of the
// defined type `defined`
func extractUnderlyingType(tw *trap.Writer, defined trap.Label, underlying types.Type) {
dbscheme.UnderlyingTypeTable.Emit(tw, defined, extractType(tw, underlying))
}
// extractComponentType extracts `component` as the `idx`th component type of `parent` with name `name`
func extractComponentType(tw *trap.Writer, parent trap.Label, idx int, name string, component types.Type) {
dbscheme.ComponentTypesTable.Emit(tw, parent, idx, name, extractType(tw, component))
}
// extractNumLines extracts lines-of-code and lines-of-comments information for the
// given file
func extractNumLines(tw *trap.Writer, fileName string, ast *ast.File) {
f := tw.Package.Fset.File(ast.Pos())
lineCount := f.LineCount()
// count lines of code by tokenizing
linesOfCode := 0
src, err := os.ReadFile(fileName)
if err != nil {
log.Fatalf("Unable to read file %s.", fileName)
}
var s scanner.Scanner
lastCodeLine := -1
s.Init(f, src, nil, 0)
for {
pos, tok, lit := s.Scan()
if tok == token.EOF {
break
} else if tok != token.ILLEGAL && !(tok == token.SEMICOLON && lit == "\n") {
// specifically exclude newlines that are treated as semicolons
tkStartLine := f.Position(pos).Line
tkEndLine := tkStartLine + strings.Count(lit, "\n")
if tkEndLine > lastCodeLine {
if tkStartLine <= lastCodeLine {
// if the start line is the same as the last code line we've seen we don't want to double
// count it
// note tkStartLine < lastCodeLine should not be possible
linesOfCode += tkEndLine - lastCodeLine
} else {
linesOfCode += tkEndLine - tkStartLine + 1
}
lastCodeLine = tkEndLine
}
}
}
// count lines of comments by iterating over ast.Comments
linesOfComments := 0
for _, cg := range ast.Comments {
for _, g := range cg.List {
fset := tw.Package.Fset
startPos, endPos := fset.Position(g.Pos()), fset.Position(g.End())
linesOfComments += endPos.Line - startPos.Line + 1
}
}
dbscheme.NumlinesTable.Emit(tw, tw.Labeler.FileLabel(), lineCount, linesOfCode, linesOfComments)
}
// For a type `t` which is the type of a field of an interface type, return
// whether `t` a type set literal which is not a union type. Note that a field
// of an interface must be a method signature, an embedded interface type or a
// type set literal.
func isNonUnionTypeSetLiteral(t types.Type) bool {
if t == nil {
return false
}
switch t.Underlying().(type) {
case *types.Interface, *types.Union, *types.Signature:
return false
default:
return true
}
}
// Given a type `t`, return a union with a single term that is `t` without a
// tilde.
func createUnionFromType(t types.Type) *types.Union {
return types.NewUnion([]*types.Term{types.NewTerm(false, t)})
}
// Go through a `FieldList` and update the types of all type set literals which
// are not already union types to be union types. We do this by changing the
// types stored in `tw.Package.TypesInfo.Types`. Type set literals can only
// occur in two places: a type parameter declaration or a type in an interface.
func makeTypeSetLiteralsUnionTyped(tw *trap.Writer, fields *ast.FieldList) {
if fields == nil || fields.List == nil {
return
}
for i := 0; i < len(fields.List); i++ {
x := fields.List[i].Type
if _, alreadyOverridden := tw.TypesOverride[x]; !alreadyOverridden {
xtp := typeOf(tw, x)
if isNonUnionTypeSetLiteral(xtp) {
tw.TypesOverride[x] = createUnionFromType(xtp)
}
}
}
}
func typeOf(tw *trap.Writer, e ast.Expr) types.Type {
if val, ok := tw.TypesOverride[e]; ok {
return val
}
return tw.Package.TypesInfo.TypeOf(e)
}
func flattenBinaryExprTree(tw *trap.Writer, e ast.Expr, parent trap.Label, idx int) int {
binaryexpr, ok := e.(*ast.BinaryExpr)
if ok {
idx = flattenBinaryExprTree(tw, binaryexpr.X, parent, idx)
idx = flattenBinaryExprTree(tw, binaryexpr.Y, parent, idx)
} else {
extractExpr(tw, e, parent, idx, false)
idx = idx + 1
}
return idx
}
func extractTypeParamDecls(tw *trap.Writer, fields *ast.FieldList, parent trap.Label) {
if fields == nil || fields.List == nil {
return
}
// Type set literals can occur as the type in a type parameter declaration,
// so we ensure that they are union typed.
makeTypeSetLiteralsUnionTyped(tw, fields)
idx := 0
for _, field := range fields.List {
lbl := tw.Labeler.LocalID(field)
dbscheme.TypeParamDeclsTable.Emit(tw, lbl, parent, idx)
extractNodeLocation(tw, field, lbl)
if field.Names != nil {
for i, name := range field.Names {
extractExpr(tw, name, lbl, i+1, false)
}
}
extractExpr(tw, field.Type, lbl, 0, false)
extractDoc(tw, field.Doc, lbl)
idx += 1
}
}
// populateTypeParamParents sets `parent` as the parent of the elements of `typeparams`
func populateTypeParamParents(typeparams *types.TypeParamList, parent types.Object) {
if typeparams != nil {
for idx := 0; idx < typeparams.Len(); idx++ {
setTypeParamParent(typeparams.At(idx), parent)
}
}
}
// getobjectBeingUsed looks up `ident` in `tw.Package.TypesInfo.Uses` and makes
// some changes to the object to avoid returning objects relating to instantiated
// types.
func getObjectBeingUsed(tw *trap.Writer, ident *ast.Ident) types.Object {
switch obj := tw.Package.TypesInfo.Uses[ident].(type) {
case *types.Var:
return obj.Origin()
case *types.Func:
return obj.Origin()
default:
return obj
}
}
// trackInstantiatedStructFields tries to give the fields of an instantiated
// struct type underlying `tp` the same labels as the corresponding fields of
// the generic struct type. This is so that when we come across the
// instantiated field in `tw.Package.TypesInfo.Uses` we will get the label for
// the generic field instead.
func trackInstantiatedStructFields(tw *trap.Writer, tp, origintp *types.Named) {
if tp == origintp {
return
}
if instantiatedStruct, ok := tp.Underlying().(*types.Struct); ok {
genericStruct, ok2 := origintp.Underlying().(*types.Struct)
if !ok2 {
log.Fatalf(
"Error: underlying type of instantiated type is a struct but underlying type of generic type is %s",
origintp.Underlying())
}
if instantiatedStruct.NumFields() != genericStruct.NumFields() {
log.Fatalf(
"Error: instantiated struct %s has different number of fields than the generic version %s (%d != %d)",
instantiatedStruct, genericStruct, instantiatedStruct.NumFields(), genericStruct.NumFields())
}
for i := 0; i < instantiatedStruct.NumFields(); i++ {
tw.ObjectsOverride[instantiatedStruct.Field(i)] = genericStruct.Field(i)
}
}
}
func getTypeParamParentLabel(tw *trap.Writer, tp *types.TypeParam) trap.Label {
parent, exists := typeParamParent[tp]
if !exists {
log.Fatalf("Parent of type parameter does not exist: %s %s", tp.String(), tp.Constraint().String())
}
parentlbl, _ := tw.Labeler.ScopedObjectID(parent, func() trap.Label {
log.Fatalf("getTypeLabel() called for parent of type parameter %s", tp.String())
return trap.InvalidLabel
})
return parentlbl
}
func setTypeParamParent(tp *types.TypeParam, newobj types.Object) {
obj, exists := typeParamParent[tp]
if !exists {
typeParamParent[tp] = newobj
} else if newobj != obj {
log.Fatalf("Parent of type parameter '%s %s' being set to a different value: '%s' vs '%s'", tp.String(), tp.Constraint().String(), obj, newobj)
}
}
// skipExtractingValueForLeftOperand returns true if the left operand of `be`
// should not have its value extracted because it is an intermediate value in a
// string concatenation - specifically that the right operand is a string
// literal
func skipExtractingValueForLeftOperand(tw *trap.Writer, be *ast.BinaryExpr) bool {
// check `be` has string type
tpVal := tw.Package.TypesInfo.Types[be]
if tpVal.Value == nil || tpVal.Value.Kind() != constant.String {
return false
}
// check that the right operand of `be` is a basic literal
if _, isBasicLit := be.Y.(*ast.BasicLit); !isBasicLit {
return false
}
// check that the left operand of `be` is not a basic literal
if _, isBasicLit := be.X.(*ast.BasicLit); isBasicLit {
return false
}
return true
}
// checkObjectNotSpecialized exits the program if `obj` is specialized. Note
// that specialization is only possible for function objects and variable
// objects.
func checkObjectNotSpecialized(obj types.Object) {
if funcObj, ok := obj.(*types.Func); ok && funcObj != funcObj.Origin() {
log.Fatalf("Encountered unexpected specialization %s of generic function object %s", funcObj.FullName(), funcObj.Origin().FullName())
}
if varObj, ok := obj.(*types.Var); ok && varObj != varObj.Origin() {
log.Fatalf("Encountered unexpected specialization %s of generic variable object %s", varObj.String(), varObj.Origin().String())
}
if typeNameObj, ok := obj.(*types.TypeName); ok {
if definedType, ok := typeNameObj.Type().(*types.Named); ok && definedType != definedType.Origin() {
log.Fatalf("Encountered type object for specialization %s of defined type %s", definedType.String(), definedType.Origin().String())
}
}
}