Fix nil checks to stop creating unused labels

In go, an interface with value nil does not compare equal to nil. This
is known as "typed nils". So our existing nil checks weren't working,
which shows why we needed more nil checks inside the type switches. The
solution is to explicitly check for each type we care about.
This commit is contained in:
Owen Mansel-Chan
2025-05-15 13:40:57 +01:00
parent d39e7c2066
commit 401c60654e

View File

@@ -936,7 +936,16 @@ func emitScopeNodeInfo(tw *trap.Writer, nd ast.Node, lbl trap.Label) {
// extractExpr extracts AST information for the given expression and all its subexpressions
func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int, skipExtractingValue bool) {
if expr == nil {
if expr == nil || expr == (*ast.Ident)(nil) || expr == (*ast.BasicLit)(nil) ||
expr == (*ast.Ellipsis)(nil) || expr == (*ast.FuncLit)(nil) ||
expr == (*ast.CompositeLit)(nil) || expr == (*ast.SelectorExpr)(nil) ||
expr == (*ast.IndexListExpr)(nil) || expr == (*ast.SliceExpr)(nil) ||
expr == (*ast.TypeAssertExpr)(nil) || expr == (*ast.CallExpr)(nil) ||
expr == (*ast.StarExpr)(nil) || expr == (*ast.KeyValueExpr)(nil) ||
expr == (*ast.UnaryExpr)(nil) || expr == (*ast.BinaryExpr)(nil) ||
expr == (*ast.ArrayType)(nil) || expr == (*ast.StructType)(nil) ||
expr == (*ast.FuncType)(nil) || expr == (*ast.InterfaceType)(nil) ||
expr == (*ast.MapType)(nil) || expr == (*ast.ChanType)(nil) {
return
}
@@ -948,9 +957,6 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int, ski
case *ast.BadExpr:
kind = dbscheme.BadExpr.Index()
case *ast.Ident:
if expr == nil {
return
}
kind = dbscheme.IdentExpr.Index()
dbscheme.LiteralsTable.Emit(tw, lbl, expr.Name, expr.Name)
def := tw.Package.TypesInfo.Defs[expr]
@@ -984,15 +990,9 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int, ski
}
}
case *ast.Ellipsis:
if expr == nil {
return
}
kind = dbscheme.EllipsisExpr.Index()
extractExpr(tw, expr.Elt, lbl, 0, false)
case *ast.BasicLit:
if expr == nil {
return
}
value := ""
switch expr.Kind {
case token.INT:
@@ -1016,36 +1016,21 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int, ski
}
dbscheme.LiteralsTable.Emit(tw, lbl, value, expr.Value)
case *ast.FuncLit:
if expr == nil {
return
}
kind = dbscheme.FuncLitExpr.Index()
extractExpr(tw, expr.Type, lbl, 0, false)
extractStmt(tw, expr.Body, lbl, 1)
case *ast.CompositeLit:
if expr == nil {
return
}
kind = dbscheme.CompositeLitExpr.Index()
extractExpr(tw, expr.Type, lbl, 0, false)
extractExprs(tw, expr.Elts, lbl, 1, 1)
case *ast.ParenExpr:
if expr == nil {
return
}
kind = dbscheme.ParenExpr.Index()
extractExpr(tw, expr.X, lbl, 0, false)
case *ast.SelectorExpr:
if expr == nil {
return
}
kind = dbscheme.SelectorExpr.Index()
extractExpr(tw, expr.X, lbl, 0, false)
extractExpr(tw, expr.Sel, lbl, 1, false)
case *ast.IndexExpr:
if expr == nil {
return
}
typeofx := typeOf(tw, expr.X)
if typeofx == nil {
// We are missing type information for `expr.X`, so we cannot
@@ -1065,9 +1050,6 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int, ski
extractExpr(tw, expr.X, lbl, 0, false)
extractExpr(tw, expr.Index, lbl, 1, false)
case *ast.IndexListExpr:
if expr == nil {
return
}
typeofx := typeOf(tw, expr.X)
if typeofx == nil {
// We are missing type information for `expr.X`, so we cannot
@@ -1084,18 +1066,12 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int, ski
extractExpr(tw, expr.X, lbl, 0, false)
extractExprs(tw, expr.Indices, lbl, 1, 1)
case *ast.SliceExpr:
if expr == nil {
return
}
kind = dbscheme.SliceExpr.Index()
extractExpr(tw, expr.X, lbl, 0, false)
extractExpr(tw, expr.Low, lbl, 1, false)
extractExpr(tw, expr.High, lbl, 2, false)
extractExpr(tw, expr.Max, lbl, 3, false)
case *ast.TypeAssertExpr:
if expr == nil {
return
}
kind = dbscheme.TypeAssertExpr.Index()
extractExpr(tw, expr.X, lbl, 0, false)
// expr.Type can be `nil` if this is the `x.(type)` in a type switch.
@@ -1103,9 +1079,6 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int, ski
extractExpr(tw, expr.Type, lbl, 1, false)
}
case *ast.CallExpr:
if expr == nil {
return
}
kind = dbscheme.CallOrConversionExpr.Index()
extractExpr(tw, expr.Fun, lbl, 0, false)
extractExprs(tw, expr.Args, lbl, 1, 1)
@@ -1113,22 +1086,13 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int, ski
dbscheme.HasEllipsisTable.Emit(tw, lbl)
}
case *ast.StarExpr:
if expr == nil {
return
}
kind = dbscheme.StarExpr.Index()
extractExpr(tw, expr.X, lbl, 0, false)
case *ast.KeyValueExpr:
if expr == nil {
return
}
kind = dbscheme.KeyValueExpr.Index()
extractExpr(tw, expr.Key, lbl, 0, false)
extractExpr(tw, expr.Value, lbl, 1, false)
case *ast.UnaryExpr:
if expr == nil {
return
}
if expr.Op == token.TILDE {
kind = dbscheme.TypeSetLiteralExpr.Index()
} else {
@@ -1140,9 +1104,6 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int, ski
}
extractExpr(tw, expr.X, lbl, 0, false)
case *ast.BinaryExpr:
if expr == nil {
return
}
_, isUnionType := typeOf(tw, expr).(*types.Union)
if expr.Op == token.OR && isUnionType {
kind = dbscheme.TypeSetLiteralExpr.Index()
@@ -1158,46 +1119,28 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int, ski
extractExpr(tw, expr.Y, lbl, 1, false)
}
case *ast.ArrayType:
if expr == nil {
return
}
kind = dbscheme.ArrayTypeExpr.Index()
extractExpr(tw, expr.Len, lbl, 0, false)
extractExpr(tw, expr.Elt, lbl, 1, false)
case *ast.StructType:
if expr == nil {
return
}
kind = dbscheme.StructTypeExpr.Index()
extractFields(tw, expr.Fields, lbl, 0, 1)
case *ast.FuncType:
if expr == nil {
return
}
kind = dbscheme.FuncTypeExpr.Index()
extractFields(tw, expr.Params, lbl, 0, 1)
extractFields(tw, expr.Results, lbl, -1, -1)
emitScopeNodeInfo(tw, expr, lbl)
case *ast.InterfaceType:
if expr == nil {
return
}
kind = dbscheme.InterfaceTypeExpr.Index()
// expr.Methods contains methods, embedded interfaces and type set
// literals.
makeTypeSetLiteralsUnionTyped(tw, expr.Methods)
extractFields(tw, expr.Methods, lbl, 0, 1)
case *ast.MapType:
if expr == nil {
return
}
kind = dbscheme.MapTypeExpr.Index()
extractExpr(tw, expr.Key, lbl, 0, false)
extractExpr(tw, expr.Value, lbl, 1, false)
case *ast.ChanType:
if expr == nil {
return
}
tp := dbscheme.ChanTypeExprs[expr.Dir]
if tp == nil {
log.Fatalf("unsupported channel direction %v", expr.Dir)
@@ -1299,7 +1242,15 @@ func extractFields(tw *trap.Writer, fields *ast.FieldList, parent trap.Label, id
// extractStmt extracts AST information for a given statement and all other statements or expressions
// nested inside it
func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
if stmt == nil {
if stmt == nil || stmt == (*ast.DeclStmt)(nil) ||
stmt == (*ast.LabeledStmt)(nil) || stmt == (*ast.ExprStmt)(nil) ||
stmt == (*ast.SendStmt)(nil) || stmt == (*ast.IncDecStmt)(nil) ||
stmt == (*ast.AssignStmt)(nil) || stmt == (*ast.GoStmt)(nil) ||
stmt == (*ast.DeferStmt)(nil) || stmt == (*ast.BranchStmt)(nil) ||
stmt == (*ast.BlockStmt)(nil) || stmt == (*ast.IfStmt)(nil) ||
stmt == (*ast.CaseClause)(nil) || stmt == (*ast.SwitchStmt)(nil) ||
stmt == (*ast.TypeSwitchStmt)(nil) || stmt == (*ast.CommClause)(nil) ||
stmt == (*ast.ForStmt)(nil) || stmt == (*ast.RangeStmt)(nil) {
return
}
@@ -1309,37 +1260,22 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
case *ast.BadStmt:
kind = dbscheme.BadStmtType.Index()
case *ast.DeclStmt:
if stmt == nil {
return
}
kind = dbscheme.DeclStmtType.Index()
extractDecl(tw, stmt.Decl, lbl, 0)
case *ast.EmptyStmt:
kind = dbscheme.EmptyStmtType.Index()
case *ast.LabeledStmt:
if stmt == nil {
return
}
kind = dbscheme.LabeledStmtType.Index()
extractExpr(tw, stmt.Label, lbl, 0, false)
extractStmt(tw, stmt.Stmt, lbl, 1)
case *ast.ExprStmt:
if stmt == nil {
return
}
kind = dbscheme.ExprStmtType.Index()
extractExpr(tw, stmt.X, lbl, 0, false)
case *ast.SendStmt:
if stmt == nil {
return
}
kind = dbscheme.SendStmtType.Index()
extractExpr(tw, stmt.Chan, lbl, 0, false)
extractExpr(tw, stmt.Value, lbl, 1, false)
case *ast.IncDecStmt:
if stmt == nil {
return
}
if stmt.Tok == token.INC {
kind = dbscheme.IncStmtType.Index()
} else if stmt.Tok == token.DEC {
@@ -1349,9 +1285,6 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
}
extractExpr(tw, stmt.X, lbl, 0, false)
case *ast.AssignStmt:
if stmt == nil {
return
}
tp := dbscheme.AssignStmtTypes[stmt.Tok]
if tp == nil {
log.Fatalf("unsupported assignment statement with operator %v", stmt.Tok)
@@ -1360,24 +1293,15 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
extractExprs(tw, stmt.Lhs, lbl, -1, -1)
extractExprs(tw, stmt.Rhs, lbl, 1, 1)
case *ast.GoStmt:
if stmt == nil {
return
}
kind = dbscheme.GoStmtType.Index()
extractExpr(tw, stmt.Call, lbl, 0, false)
case *ast.DeferStmt:
if stmt == nil {
return
}
kind = dbscheme.DeferStmtType.Index()
extractExpr(tw, stmt.Call, lbl, 0, false)
case *ast.ReturnStmt:
kind = dbscheme.ReturnStmtType.Index()
extractExprs(tw, stmt.Results, lbl, 0, 1)
case *ast.BranchStmt:
if stmt == nil {
return
}
switch stmt.Tok {
case token.BREAK:
kind = dbscheme.BreakStmtType.Index()
@@ -1392,16 +1316,10 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
}
extractExpr(tw, stmt.Label, lbl, 0, false)
case *ast.BlockStmt:
if stmt == nil {
return
}
kind = dbscheme.BlockStmtType.Index()
extractStmts(tw, stmt.List, lbl, 0, 1)
emitScopeNodeInfo(tw, stmt, lbl)
case *ast.IfStmt:
if stmt == nil {
return
}
kind = dbscheme.IfStmtType.Index()
extractStmt(tw, stmt.Init, lbl, 0)
extractExpr(tw, stmt.Cond, lbl, 1, false)
@@ -1409,35 +1327,23 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
extractStmt(tw, stmt.Else, lbl, 3)
emitScopeNodeInfo(tw, stmt, lbl)
case *ast.CaseClause:
if stmt == nil {
return
}
kind = dbscheme.CaseClauseType.Index()
extractExprs(tw, stmt.List, lbl, -1, -1)
extractStmts(tw, stmt.Body, lbl, 0, 1)
emitScopeNodeInfo(tw, stmt, lbl)
case *ast.SwitchStmt:
if stmt == nil {
return
}
kind = dbscheme.ExprSwitchStmtType.Index()
extractStmt(tw, stmt.Init, lbl, 0)
extractExpr(tw, stmt.Tag, lbl, 1, false)
extractStmt(tw, stmt.Body, lbl, 2)
emitScopeNodeInfo(tw, stmt, lbl)
case *ast.TypeSwitchStmt:
if stmt == nil {
return
}
kind = dbscheme.TypeSwitchStmtType.Index()
extractStmt(tw, stmt.Init, lbl, 0)
extractStmt(tw, stmt.Assign, lbl, 1)
extractStmt(tw, stmt.Body, lbl, 2)
emitScopeNodeInfo(tw, stmt, lbl)
case *ast.CommClause:
if stmt == nil {
return
}
kind = dbscheme.CommClauseType.Index()
extractStmt(tw, stmt.Comm, lbl, 0)
extractStmts(tw, stmt.Body, lbl, 1, 1)
@@ -1446,9 +1352,6 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
kind = dbscheme.SelectStmtType.Index()
extractStmt(tw, stmt.Body, lbl, 0)
case *ast.ForStmt:
if stmt == nil {
return
}
kind = dbscheme.ForStmtType.Index()
extractStmt(tw, stmt.Init, lbl, 0)
extractExpr(tw, stmt.Cond, lbl, 1, false)
@@ -1456,9 +1359,6 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
extractStmt(tw, stmt.Body, lbl, 3)
emitScopeNodeInfo(tw, stmt, lbl)
case *ast.RangeStmt:
if stmt == nil {
return
}
kind = dbscheme.RangeStmtType.Index()
extractExpr(tw, stmt.Key, lbl, 0, false)
extractExpr(tw, stmt.Value, lbl, 1, false)
@@ -1486,15 +1386,15 @@ func extractStmts(tw *trap.Writer, stmts []ast.Stmt, parent trap.Label, idx int,
// extractDecl extracts AST information for the given declaration
func extractDecl(tw *trap.Writer, decl ast.Decl, parent trap.Label, idx int) {
if decl == (*ast.FuncDecl)(nil) || decl == (*ast.GenDecl)(nil) {
return
}
lbl := tw.Labeler.LocalID(decl)
var kind int
switch decl := decl.(type) {
case *ast.BadDecl:
kind = dbscheme.BadDeclType.Index()
case *ast.GenDecl:
if decl == nil {
return
}
switch decl.Tok {
case token.IMPORT:
kind = dbscheme.ImportDeclType.Index()
@@ -1512,9 +1412,6 @@ func extractDecl(tw *trap.Writer, decl ast.Decl, parent trap.Label, idx int) {
}
extractDoc(tw, decl.Doc, lbl)
case *ast.FuncDecl:
if decl == nil {
return
}
kind = dbscheme.FuncDeclType.Index()
extractFields(tw, decl.Recv, lbl, -1, -1)
extractExpr(tw, decl.Name, lbl, 0, false)