go extractor: avoid long string concatenations

When we see "a" + "b" + "c" + "d", do not add a
row to the constvalues table for the intermiediate
strings "ab" and "abc". We still have entries for
the string literals ("a", "b", "c", and "d") and
the whole string concatenation ("abcd").
This commit is contained in:
Owen Mansel-Chan
2024-03-11 09:57:02 +00:00
parent 820c14577a
commit da8cc13506

View File

@@ -794,7 +794,7 @@ func extractLocalScope(tw *trap.Writer, scope *types.Scope, parentScopeLabel tra
func extractFileNode(tw *trap.Writer, nd *ast.File) {
lbl := tw.Labeler.FileLabel()
extractExpr(tw, nd.Name, lbl, 0)
extractExpr(tw, nd.Name, lbl, 0, false)
for i, decl := range nd.Decls {
extractDecl(tw, decl, lbl, i)
@@ -851,7 +851,7 @@ func emitScopeNodeInfo(tw *trap.Writer, nd ast.Node, lbl trap.Label) {
}
// extractExpr extracts AST information for the given expression and all its subexpressions
func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) {
func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int, skipExtractingValue bool) {
if expr == nil {
return
}
@@ -900,7 +900,7 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) {
return
}
kind = dbscheme.EllipsisExpr.Index()
extractExpr(tw, expr.Elt, lbl, 0)
extractExpr(tw, expr.Elt, lbl, 0, false)
case *ast.BasicLit:
if expr == nil {
return
@@ -932,28 +932,28 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) {
return
}
kind = dbscheme.FuncLitExpr.Index()
extractExpr(tw, expr.Type, lbl, 0)
extractExpr(tw, expr.Type, lbl, 0, false)
extractStmt(tw, expr.Body, lbl, 1)
case *ast.CompositeLit:
if expr == nil {
return
}
kind = dbscheme.CompositeLitExpr.Index()
extractExpr(tw, expr.Type, lbl, 0)
extractExpr(tw, expr.Type, lbl, 0, false)
extractExprs(tw, expr.Elts, lbl, 1, 1)
case *ast.ParenExpr:
if expr == nil {
return
}
kind = dbscheme.ParenExpr.Index()
extractExpr(tw, expr.X, lbl, 0)
extractExpr(tw, expr.X, lbl, 0, false)
case *ast.SelectorExpr:
if expr == nil {
return
}
kind = dbscheme.SelectorExpr.Index()
extractExpr(tw, expr.X, lbl, 0)
extractExpr(tw, expr.Sel, lbl, 1)
extractExpr(tw, expr.X, lbl, 0, false)
extractExpr(tw, expr.Sel, lbl, 1, false)
case *ast.IndexExpr:
if expr == nil {
return
@@ -974,8 +974,8 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) {
kind = dbscheme.IndexExpr.Index()
}
}
extractExpr(tw, expr.X, lbl, 0)
extractExpr(tw, expr.Index, lbl, 1)
extractExpr(tw, expr.X, lbl, 0, false)
extractExpr(tw, expr.Index, lbl, 1, false)
case *ast.IndexListExpr:
if expr == nil {
return
@@ -993,30 +993,30 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) {
kind = dbscheme.GenericTypeInstantiationExpr.Index()
}
}
extractExpr(tw, expr.X, lbl, 0)
extractExpr(tw, expr.X, lbl, 0, false)
extractExprs(tw, expr.Indices, lbl, 1, 1)
case *ast.SliceExpr:
if expr == nil {
return
}
kind = dbscheme.SliceExpr.Index()
extractExpr(tw, expr.X, lbl, 0)
extractExpr(tw, expr.Low, lbl, 1)
extractExpr(tw, expr.High, lbl, 2)
extractExpr(tw, expr.Max, lbl, 3)
extractExpr(tw, expr.X, lbl, 0, false)
extractExpr(tw, expr.Low, lbl, 1, false)
extractExpr(tw, expr.High, lbl, 2, false)
extractExpr(tw, expr.Max, lbl, 3, false)
case *ast.TypeAssertExpr:
if expr == nil {
return
}
kind = dbscheme.TypeAssertExpr.Index()
extractExpr(tw, expr.X, lbl, 0)
extractExpr(tw, expr.Type, lbl, 1)
extractExpr(tw, expr.X, lbl, 0, false)
extractExpr(tw, expr.Type, lbl, 1, false)
case *ast.CallExpr:
if expr == nil {
return
}
kind = dbscheme.CallOrConversionExpr.Index()
extractExpr(tw, expr.Fun, lbl, 0)
extractExpr(tw, expr.Fun, lbl, 0, false)
extractExprs(tw, expr.Args, lbl, 1, 1)
if expr.Ellipsis.IsValid() {
dbscheme.HasEllipsisTable.Emit(tw, lbl)
@@ -1026,14 +1026,14 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) {
return
}
kind = dbscheme.StarExpr.Index()
extractExpr(tw, expr.X, lbl, 0)
extractExpr(tw, expr.X, lbl, 0, false)
case *ast.KeyValueExpr:
if expr == nil {
return
}
kind = dbscheme.KeyValueExpr.Index()
extractExpr(tw, expr.Key, lbl, 0)
extractExpr(tw, expr.Value, lbl, 1)
extractExpr(tw, expr.Key, lbl, 0, false)
extractExpr(tw, expr.Value, lbl, 1, false)
case *ast.UnaryExpr:
if expr == nil {
return
@@ -1047,7 +1047,7 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) {
}
kind = tp.Index()
}
extractExpr(tw, expr.X, lbl, 0)
extractExpr(tw, expr.X, lbl, 0, false)
case *ast.BinaryExpr:
if expr == nil {
return
@@ -1062,16 +1062,17 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) {
log.Fatalf("unsupported binary operator %s", expr.Op)
}
kind = tp.Index()
extractExpr(tw, expr.X, lbl, 0)
extractExpr(tw, expr.Y, lbl, 1)
skipLeft := skipExtractingValueForLeftOperand(tw, expr)
extractExpr(tw, expr.X, lbl, 0, skipLeft)
extractExpr(tw, expr.Y, lbl, 1, false)
}
case *ast.ArrayType:
if expr == nil {
return
}
kind = dbscheme.ArrayTypeExpr.Index()
extractExpr(tw, expr.Len, lbl, 0)
extractExpr(tw, expr.Elt, lbl, 1)
extractExpr(tw, expr.Len, lbl, 0, false)
extractExpr(tw, expr.Elt, lbl, 1, false)
case *ast.StructType:
if expr == nil {
return
@@ -1100,8 +1101,8 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) {
return
}
kind = dbscheme.MapTypeExpr.Index()
extractExpr(tw, expr.Key, lbl, 0)
extractExpr(tw, expr.Value, lbl, 1)
extractExpr(tw, expr.Key, lbl, 0, false)
extractExpr(tw, expr.Value, lbl, 1, false)
case *ast.ChanType:
if expr == nil {
return
@@ -1111,13 +1112,15 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) {
log.Fatalf("unsupported channel direction %v", expr.Dir)
}
kind = tp.Index()
extractExpr(tw, expr.Value, lbl, 0)
extractExpr(tw, expr.Value, lbl, 0, false)
default:
log.Fatalf("unknown expression of type %T", expr)
}
dbscheme.ExprsTable.Emit(tw, lbl, kind, parent, idx)
extractNodeLocation(tw, expr, lbl)
extractValueOf(tw, expr, lbl)
if !skipExtractingValue {
extractValueOf(tw, expr, lbl)
}
}
// extractExprs extracts AST information for a list of expressions, which are children of
@@ -1128,7 +1131,7 @@ func extractExpr(tw *trap.Writer, expr ast.Expr, parent trap.Label, idx int) {
func extractExprs(tw *trap.Writer, exprs []ast.Expr, parent trap.Label, idx int, dir int) {
if exprs != nil {
for _, expr := range exprs {
extractExpr(tw, expr, parent, idx)
extractExpr(tw, expr, parent, idx, false)
idx += dir
}
}
@@ -1194,11 +1197,11 @@ func extractFields(tw *trap.Writer, fields *ast.FieldList, parent trap.Label, id
extractNodeLocation(tw, field, lbl)
if field.Names != nil {
for i, name := range field.Names {
extractExpr(tw, name, lbl, i+1)
extractExpr(tw, name, lbl, i+1, false)
}
}
extractExpr(tw, field.Type, lbl, 0)
extractExpr(tw, field.Tag, lbl, -1)
extractExpr(tw, field.Type, lbl, 0, false)
extractExpr(tw, field.Tag, lbl, -1, false)
extractDoc(tw, field.Doc, lbl)
idx += dir
}
@@ -1229,21 +1232,21 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
return
}
kind = dbscheme.LabeledStmtType.Index()
extractExpr(tw, stmt.Label, lbl, 0)
extractExpr(tw, stmt.Label, lbl, 0, false)
extractStmt(tw, stmt.Stmt, lbl, 1)
case *ast.ExprStmt:
if stmt == nil {
return
}
kind = dbscheme.ExprStmtType.Index()
extractExpr(tw, stmt.X, lbl, 0)
extractExpr(tw, stmt.X, lbl, 0, false)
case *ast.SendStmt:
if stmt == nil {
return
}
kind = dbscheme.SendStmtType.Index()
extractExpr(tw, stmt.Chan, lbl, 0)
extractExpr(tw, stmt.Value, lbl, 1)
extractExpr(tw, stmt.Chan, lbl, 0, false)
extractExpr(tw, stmt.Value, lbl, 1, false)
case *ast.IncDecStmt:
if stmt == nil {
return
@@ -1255,7 +1258,7 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
} else {
log.Fatalf("unsupported increment/decrement operator %v", stmt.Tok)
}
extractExpr(tw, stmt.X, lbl, 0)
extractExpr(tw, stmt.X, lbl, 0, false)
case *ast.AssignStmt:
if stmt == nil {
return
@@ -1272,13 +1275,13 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
return
}
kind = dbscheme.GoStmtType.Index()
extractExpr(tw, stmt.Call, lbl, 0)
extractExpr(tw, stmt.Call, lbl, 0, false)
case *ast.DeferStmt:
if stmt == nil {
return
}
kind = dbscheme.DeferStmtType.Index()
extractExpr(tw, stmt.Call, lbl, 0)
extractExpr(tw, stmt.Call, lbl, 0, false)
case *ast.ReturnStmt:
kind = dbscheme.ReturnStmtType.Index()
extractExprs(tw, stmt.Results, lbl, 0, 1)
@@ -1298,7 +1301,7 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
default:
log.Fatalf("unsupported branch statement type %v", stmt.Tok)
}
extractExpr(tw, stmt.Label, lbl, 0)
extractExpr(tw, stmt.Label, lbl, 0, false)
case *ast.BlockStmt:
if stmt == nil {
return
@@ -1312,7 +1315,7 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
}
kind = dbscheme.IfStmtType.Index()
extractStmt(tw, stmt.Init, lbl, 0)
extractExpr(tw, stmt.Cond, lbl, 1)
extractExpr(tw, stmt.Cond, lbl, 1, false)
extractStmt(tw, stmt.Body, lbl, 2)
extractStmt(tw, stmt.Else, lbl, 3)
emitScopeNodeInfo(tw, stmt, lbl)
@@ -1330,7 +1333,7 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
}
kind = dbscheme.ExprSwitchStmtType.Index()
extractStmt(tw, stmt.Init, lbl, 0)
extractExpr(tw, stmt.Tag, lbl, 1)
extractExpr(tw, stmt.Tag, lbl, 1, false)
extractStmt(tw, stmt.Body, lbl, 2)
emitScopeNodeInfo(tw, stmt, lbl)
case *ast.TypeSwitchStmt:
@@ -1359,7 +1362,7 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
}
kind = dbscheme.ForStmtType.Index()
extractStmt(tw, stmt.Init, lbl, 0)
extractExpr(tw, stmt.Cond, lbl, 1)
extractExpr(tw, stmt.Cond, lbl, 1, false)
extractStmt(tw, stmt.Post, lbl, 2)
extractStmt(tw, stmt.Body, lbl, 3)
emitScopeNodeInfo(tw, stmt, lbl)
@@ -1368,9 +1371,9 @@ func extractStmt(tw *trap.Writer, stmt ast.Stmt, parent trap.Label, idx int) {
return
}
kind = dbscheme.RangeStmtType.Index()
extractExpr(tw, stmt.Key, lbl, 0)
extractExpr(tw, stmt.Value, lbl, 1)
extractExpr(tw, stmt.X, lbl, 2)
extractExpr(tw, stmt.Key, lbl, 0, false)
extractExpr(tw, stmt.Value, lbl, 1, false)
extractExpr(tw, stmt.X, lbl, 2, false)
extractStmt(tw, stmt.Body, lbl, 3)
emitScopeNodeInfo(tw, stmt, lbl)
default:
@@ -1428,8 +1431,8 @@ func extractDecl(tw *trap.Writer, decl ast.Decl, parent trap.Label, idx int) {
}
kind = dbscheme.FuncDeclType.Index()
extractFields(tw, decl.Recv, lbl, -1, -1)
extractExpr(tw, decl.Name, lbl, 0)
extractExpr(tw, decl.Type, lbl, 1)
extractExpr(tw, decl.Name, lbl, 0, false)
extractExpr(tw, decl.Type, lbl, 1, false)
extractStmt(tw, decl.Body, lbl, 2)
extractDoc(tw, decl.Doc, lbl)
extractTypeParamDecls(tw, decl.Type.TypeParams, lbl)
@@ -1455,8 +1458,8 @@ func extractSpec(tw *trap.Writer, spec ast.Spec, parent trap.Label, idx int) {
return
}
kind = dbscheme.ImportSpecType.Index()
extractExpr(tw, spec.Name, lbl, 0)
extractExpr(tw, spec.Path, lbl, 1)
extractExpr(tw, spec.Name, lbl, 0, false)
extractExpr(tw, spec.Path, lbl, 1, false)
extractDoc(tw, spec.Doc, lbl)
case *ast.ValueSpec:
if spec == nil {
@@ -1464,9 +1467,9 @@ func extractSpec(tw *trap.Writer, spec ast.Spec, parent trap.Label, idx int) {
}
kind = dbscheme.ValueSpecType.Index()
for i, name := range spec.Names {
extractExpr(tw, name, lbl, -(1 + i))
extractExpr(tw, name, lbl, -(1 + i), false)
}
extractExpr(tw, spec.Type, lbl, 0)
extractExpr(tw, spec.Type, lbl, 0, false)
extractExprs(tw, spec.Values, lbl, 1, 1)
extractDoc(tw, spec.Doc, lbl)
case *ast.TypeSpec:
@@ -1478,9 +1481,9 @@ func extractSpec(tw *trap.Writer, spec ast.Spec, parent trap.Label, idx int) {
} else {
kind = dbscheme.TypeDefSpecType.Index()
}
extractExpr(tw, spec.Name, lbl, 0)
extractExpr(tw, spec.Name, lbl, 0, false)
extractTypeParamDecls(tw, spec.TypeParams, lbl)
extractExpr(tw, spec.Type, lbl, 1)
extractExpr(tw, spec.Type, lbl, 1, false)
extractDoc(tw, spec.Doc, lbl)
}
dbscheme.SpecsTable.Emit(tw, lbl, kind, parent, idx)
@@ -1909,7 +1912,7 @@ func flattenBinaryExprTree(tw *trap.Writer, e ast.Expr, parent trap.Label, idx i
idx = flattenBinaryExprTree(tw, binaryexpr.X, parent, idx)
idx = flattenBinaryExprTree(tw, binaryexpr.Y, parent, idx)
} else {
extractExpr(tw, e, parent, idx)
extractExpr(tw, e, parent, idx, false)
idx = idx + 1
}
return idx
@@ -1931,10 +1934,10 @@ func extractTypeParamDecls(tw *trap.Writer, fields *ast.FieldList, parent trap.L
extractNodeLocation(tw, field, lbl)
if field.Names != nil {
for i, name := range field.Names {
extractExpr(tw, name, lbl, i+1)
extractExpr(tw, name, lbl, i+1, false)
}
}
extractExpr(tw, field.Type, lbl, 0)
extractExpr(tw, field.Type, lbl, 0, false)
extractDoc(tw, field.Doc, lbl)
idx += 1
}
@@ -2023,3 +2026,24 @@ func setTypeParamParent(tp *types.TypeParam, newobj types.Object) {
log.Fatalf("Parent of type parameter '%s %s' being set to a different value: '%s' vs '%s'", tp.String(), tp.Constraint().String(), obj, newobj)
}
}
// skipExtractingValueForLeftOperand returns true if the left operand of `be`
// should not have its value extracted because it is an intermediate value in a
// string concatenation - specifically that the right operand is a string
// literal
func skipExtractingValueForLeftOperand(tw *trap.Writer, be *ast.BinaryExpr) bool {
// check `be` has string type
tpVal := tw.Package.TypesInfo.Types[be]
if tpVal.Value == nil || tpVal.Value.Kind() != constant.String {
return false
}
// check that the right operand of `be` is a basic literal
if _, isBasicLit := be.Y.(*ast.BasicLit); !isBasicLit {
return false
}
// check that the left operand of `be` is not a basic literal
if _, isBasicLit := be.X.(*ast.BasicLit); isBasicLit {
return false
}
return true
}