From 03b8e8fdde1ca9fb769f486afed9cbf7b923aa32 Mon Sep 17 00:00:00 2001 From: Copilot Date: Tue, 5 May 2026 15:09:19 +0000 Subject: [PATCH] Python: refactor getChild into per-class OO dispatch Replace the single ~240-line top-level getChild predicate with one override per AST class. AstNode declares a default AstNode getChild(int index) { none() } and each subclass with children overrides it (41 classes total). The top-level predicate becomes a one-line dispatch: AstNode getChild(AstNode n, int index) { result = n.getChild(index) } No behavioral change: NewCfg evaluation-order tests still pass at the same 22/24 baseline, and all 11 shared-CFG consistency queries still report 0 violations on CPython. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../controlflow/internal/AstNodeImpl.qll | 448 ++++++++---------- 1 file changed, 209 insertions(+), 239 deletions(-) diff --git a/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll b/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll index da960060edf..86fdf45e0ba 100644 --- a/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll +++ b/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll @@ -143,6 +143,13 @@ module Ast implements AstSig { /** Gets the underlying Python `Pattern`, if this node wraps one. */ Py::Pattern asPattern() { this = TPattern(result) } + + /** + * Gets the child of this AST node at the specified (zero-based) + * index, in evaluation order. Subclasses with children override + * this method. + */ + AstNode getChild(int index) { none() } } /** Gets the immediately enclosing callable that contains `node`. */ @@ -186,6 +193,8 @@ module Ast implements AstSig { /** Gets the last statement in this block. */ Stmt getLastStmt() { result = TStmt(getBodyStmtList(parent, slot).getLastItem()) } + + override AstNode getChild(int index) { result = this.getStmt(index) } } /** An expression statement. */ @@ -196,6 +205,8 @@ module Ast implements AstSig { /** Gets the expression in this expression statement. */ Expr getExpr() { result = TExpr(exprStmt.getValue()) } + + override AstNode getChild(int index) { index = 0 and result = this.getExpr() } } /** An assignment statement (`x = y = expr`). */ @@ -209,6 +220,12 @@ module Ast implements AstSig { Expr getTarget(int n) { result = TExpr(assign.getTarget(n)) } int getNumberOfTargets() { result = count(assign.getATarget()) } + + override AstNode getChild(int index) { + index = 0 and result = this.getValue() + or + result = this.getTarget(index - 1) and index >= 1 + } } /** An augmented assignment statement (`x += expr`). */ @@ -218,6 +235,8 @@ module Ast implements AstSig { AugAssignStmt() { augAssign = this.asStmt() } Expr getOperation() { result = TExpr(augAssign.getOperation()) } + + override AstNode getChild(int index) { index = 0 and result = this.getOperation() } } /** An assignment expression / walrus operator (`x := expr`). */ @@ -229,6 +248,12 @@ module Ast implements AstSig { Expr getValue() { result = TExpr(assignExpr.getValue()) } Expr getTarget() { result = TExpr(assignExpr.getTarget()) } + + override AstNode getChild(int index) { + index = 0 and result = this.getValue() + or + index = 1 and result = this.getTarget() + } } /** @@ -256,6 +281,14 @@ module Ast implements AstSig { /** Gets the `else` (false) branch, if any. */ Stmt getElse() { result = TBlockStmt(ifStmt, "orelse") } + + override AstNode getChild(int index) { + index = 0 and result = this.getCondition() + or + index = 1 and result = this.getThen() + or + index = 2 and result = this.getElse() + } } /** A loop statement. */ @@ -279,6 +312,14 @@ module Ast implements AstSig { /** Gets the `else` branch of this `while` loop, if any. */ Stmt getElse() { result = TBlockStmt(whileStmt, "orelse") } + + override AstNode getChild(int index) { + index = 0 and result = this.getCondition() + or + index = 1 and result = this.getBody() + or + index = 2 and result = this.getElse() + } } /** @@ -317,6 +358,16 @@ module Ast implements AstSig { /** Gets the `else` branch of this `for` loop, if any. */ Stmt getElse() { result = TBlockStmt(forStmt, "orelse") } + + override AstNode getChild(int index) { + index = 0 and result = this.getCollection() + or + index = 1 and result = this.getVariable() + or + index = 2 and result = this.getBody() + or + index = 3 and result = this.getElse() + } } /** A `break` statement. */ @@ -342,6 +393,8 @@ module Ast implements AstSig { /** Gets the expression being returned, if any. */ Expr getExpr() { result = TExpr(ret.getValue()) } + + override AstNode getChild(int index) { index = 0 and result = this.getExpr() } } /** A `raise` statement (mapped to `Throw`). */ @@ -355,6 +408,12 @@ module Ast implements AstSig { /** Gets the cause of this `raise`, if any. */ Expr getCause() { result = TExpr(raise.getCause()) } + + override AstNode getChild(int index) { + index = 0 and result = this.getExpr() + or + index = 1 and result = this.getCause() + } } /** A `with` statement. */ @@ -368,6 +427,14 @@ module Ast implements AstSig { Expr getOptionalVars() { result = TExpr(withStmt.getOptionalVars()) } Stmt getBody() { result = TBlockStmt(withStmt, "body") } + + override AstNode getChild(int index) { + index = 0 and result = this.getContextExpr() + or + index = 1 and result = this.getOptionalVars() + or + index = 2 and result = this.getBody() + } } /** An `assert` statement. */ @@ -379,6 +446,12 @@ module Ast implements AstSig { Expr getTest() { result = TExpr(assertStmt.getTest()) } Expr getMsg() { result = TExpr(assertStmt.getMsg()) } + + override AstNode getChild(int index) { + index = 0 and result = this.getTest() + or + index = 1 and result = this.getMsg() + } } /** A `delete` statement. */ @@ -388,6 +461,8 @@ module Ast implements AstSig { DeleteStmt() { del = this.asStmt() } Expr getTarget(int n) { result = TExpr(del.getTarget(n)) } + + override AstNode getChild(int index) { result = this.getTarget(index) } } /** A `try` statement. */ @@ -404,6 +479,16 @@ module Ast implements AstSig { Stmt getFinally() { result = TBlockStmt(tryStmt, "finally") } CatchClause getCatch(int index) { result = TStmt(tryStmt.getHandler(index)) } + + override AstNode getChild(int index) { + index = 0 and result = this.getBody() + or + result = this.getCatch(index - 1) and index >= 1 + or + index = -1 and result = this.getFinally() + or + index = -2 and result = this.getElse() + } } /** @@ -442,6 +527,14 @@ module Ast implements AstSig { or result = TBlockStmt(handler.(Py::ExceptGroupStmt), "body") } + + override AstNode getChild(int index) { + index = 0 and result = this.getType() + or + index = 1 and result = this.getVariable() + or + index = 2 and result = this.getBody() + } } /** A `match` statement, mapped to the shared CFG's `Switch`. */ @@ -455,6 +548,12 @@ module Ast implements AstSig { Case getCase(int index) { result = TStmt(matchStmt.getCase(index)) } Stmt getStmt(int index) { none() } + + override AstNode getChild(int index) { + index = 0 and result = this.getExpr() + or + result = this.getCase(index - 1) and index >= 1 + } } /** A `case` clause in a match statement. */ @@ -471,6 +570,14 @@ module Ast implements AstSig { /** Holds if this case is a wildcard pattern (`case _:`). */ predicate isWildcard() { caseStmt.getPattern() instanceof Py::MatchWildcardPattern } + + override AstNode getChild(int index) { + index = 0 and result = this.getAPattern() + or + index = 1 and result = this.getGuard() + or + index = 2 and result = this.getBody() + } } /** A wildcard case (`case _:`). */ @@ -492,6 +599,14 @@ module Ast implements AstSig { /** Gets the false branch of this expression. */ Expr getElse() { result = TExpr(ifExp.getOrelse()) } + + override AstNode getChild(int index) { + index = 0 and result = this.getCondition() + or + index = 1 and result = this.getThen() + or + index = 2 and result = this.getElse() + } } /** @@ -547,6 +662,12 @@ module Ast implements AstSig { result = TBoolExprPair(be, i + 1) ) } + + override AstNode getChild(int index) { + index = 0 and result = this.getLeftOperand() + or + index = 1 and result = this.getRightOperand() + } } /** A short-circuiting logical `and` expression. */ @@ -582,6 +703,8 @@ module Ast implements AstSig { /** Gets the operand of this unary expression. */ Expr getOperand() { result = TExpr(this.asExpr().(Py::UnaryExpr).getOperand()) } + + override AstNode getChild(int index) { index = 0 and result = this.getOperand() } } /** A logical `not` expression. */ @@ -633,6 +756,12 @@ module Ast implements AstSig { Expr getLeft() { result = TExpr(binExpr.getLeft()) } Expr getRight() { result = TExpr(binExpr.getRight()) } + + override AstNode getChild(int index) { + index = 0 and result = this.getLeft() + or + index = 1 and result = this.getRight() + } } /** A call expression (`func(args...)`). */ @@ -654,6 +783,15 @@ module Ast implements AstSig { } int getNumberOfNamedArgs() { result = count(call.getANamedArg()) } + + override AstNode getChild(int index) { + index = 0 and result = this.getFunc() + or + result = this.getPositionalArg(index - 1) and index >= 1 + or + result = this.getKeywordValue(index - 1 - this.getNumberOfPositionalArgs()) and + index >= 1 + this.getNumberOfPositionalArgs() + } } /** A subscript expression (`obj[index]`). */ @@ -665,6 +803,12 @@ module Ast implements AstSig { Expr getObject() { result = TExpr(sub.getObject()) } Expr getIndex() { result = TExpr(sub.getIndex()) } + + override AstNode getChild(int index) { + index = 0 and result = this.getObject() + or + index = 1 and result = this.getIndex() + } } /** An attribute access (`obj.name`). */ @@ -674,6 +818,8 @@ module Ast implements AstSig { AttributeExpr() { attr = this.asExpr() } Expr getObject() { result = TExpr(attr.getObject()) } + + override AstNode getChild(int index) { index = 0 and result = this.getObject() } } /** A tuple literal. */ @@ -683,6 +829,8 @@ module Ast implements AstSig { TupleExpr() { tuple = this.asExpr() } Expr getElt(int n) { result = TExpr(tuple.getElt(n)) } + + override AstNode getChild(int index) { result = this.getElt(index) } } /** A list literal. */ @@ -692,6 +840,8 @@ module Ast implements AstSig { ListExpr() { list = this.asExpr() } Expr getElt(int n) { result = TExpr(list.getElt(n)) } + + override AstNode getChild(int index) { result = this.getElt(index) } } /** A set literal. */ @@ -701,6 +851,8 @@ module Ast implements AstSig { SetExpr() { set = this.asExpr() } Expr getElt(int n) { result = TExpr(set.getElt(n)) } + + override AstNode getChild(int index) { result = this.getElt(index) } } /** A dict literal. */ @@ -718,6 +870,14 @@ module Ast implements AstSig { Expr getValue(int n) { result = TExpr(dict.getItem(n).(Py::KeyValuePair).getValue()) } int getNumberOfItems() { result = count(dict.getAnItem()) } + + override AstNode getChild(int index) { + exists(int item | + index = 2 * item and result = this.getKey(item) + or + index = 2 * item + 1 and result = this.getValue(item) + ) + } } /** A unary expression other than `not` (e.g., `-x`, `+x`, `~x`). */ @@ -727,6 +887,8 @@ module Ast implements AstSig { ArithUnaryExpr() { unaryExpr = this.asExpr() and not unaryExpr.getOp() instanceof Py::Not } Expr getOperand() { result = TExpr(unaryExpr.getOperand()) } + + override AstNode getChild(int index) { index = 0 and result = this.getOperand() } } /** @@ -748,6 +910,8 @@ module Ast implements AstSig { } Expr getIterable() { result = TExpr(iterable) } + + override AstNode getChild(int index) { index = 0 and result = this.getIterable() } } /** A comparison expression (`a < b`, `a < b < c`, etc.). */ @@ -759,6 +923,12 @@ module Ast implements AstSig { Expr getLeft() { result = TExpr(cmp.getLeft()) } Expr getComparator(int n) { result = TExpr(cmp.getComparator(n)) } + + override AstNode getChild(int index) { + index = 0 and result = this.getLeft() + or + result = this.getComparator(index - 1) and index >= 1 + } } /** A slice expression (`start:stop:step`). */ @@ -772,6 +942,14 @@ module Ast implements AstSig { Expr getStop() { result = TExpr(slice.getStop()) } Expr getStep() { result = TExpr(slice.getStep()) } + + override AstNode getChild(int index) { + index = 0 and result = this.getStart() + or + index = 1 and result = this.getStop() + or + index = 2 and result = this.getStep() + } } /** A starred expression (`*x`). */ @@ -781,6 +959,8 @@ module Ast implements AstSig { StarredExpr() { starred = this.asExpr() } Expr getValue() { result = TExpr(starred.getValue()) } + + override AstNode getChild(int index) { index = 0 and result = this.getValue() } } /** A formatted string literal (`f"...{expr}..."`). */ @@ -790,6 +970,8 @@ module Ast implements AstSig { FstringExpr() { fstring = this.asExpr() } Expr getValue(int n) { result = TExpr(fstring.getValue(n)) } + + override AstNode getChild(int index) { result = this.getValue(index) } } /** A formatted value inside an f-string (`{expr}` or `{expr:spec}`). */ @@ -801,6 +983,12 @@ module Ast implements AstSig { Expr getValue() { result = TExpr(fv.getValue()) } Expr getFormatSpec() { result = TExpr(fv.getFormatSpec()) } + + override AstNode getChild(int index) { + index = 0 and result = this.getValue() + or + index = 1 and result = this.getFormatSpec() + } } /** A `yield` expression. */ @@ -810,6 +998,8 @@ module Ast implements AstSig { YieldExpr() { yield = this.asExpr() } Expr getValue() { result = TExpr(yield.getValue()) } + + override AstNode getChild(int index) { index = 0 and result = this.getValue() } } /** A `yield from` expression. */ @@ -819,6 +1009,8 @@ module Ast implements AstSig { YieldFromExpr() { yieldFrom = this.asExpr() } Expr getValue() { result = TExpr(yieldFrom.getValue()) } + + override AstNode getChild(int index) { index = 0 and result = this.getValue() } } /** An `await` expression. */ @@ -828,6 +1020,8 @@ module Ast implements AstSig { AwaitExpr() { await = this.asExpr() } Expr getValue() { result = TExpr(await.getValue()) } + + override AstNode getChild(int index) { index = 0 and result = this.getValue() } } /** A class definition expression (has base classes evaluated at definition time). */ @@ -837,6 +1031,8 @@ module Ast implements AstSig { ClassDefExpr() { classExpr = this.asExpr() } Expr getBase(int n) { result = TExpr(classExpr.getBase(n)) } + + override AstNode getChild(int index) { result = this.getBase(index) } } /** A function definition expression (has default args evaluated at definition time). */ @@ -863,6 +1059,12 @@ module Ast implements AstSig { } int getNumberOfDefaults() { result = count(funcExpr.getArgs().getADefault()) } + + override AstNode getChild(int index) { + result = this.getDefault(index) + or + result = this.getKwDefault(index - this.getNumberOfDefaults()) + } } /** A lambda expression (has default args evaluated at definition time). */ @@ -884,248 +1086,16 @@ module Ast implements AstSig { } int getNumberOfDefaults() { result = count(lambda.getArgs().getADefault()) } + + override AstNode getChild(int index) { + result = this.getDefault(index) + or + result = this.getKwDefault(index - this.getNumberOfDefaults()) + } } /** Gets the child of `n` at the specified (zero-based) index. */ - AstNode getChild(AstNode n, int index) { - // BlockStmt: indexed statements - result = n.(BlockStmt).getStmt(index) - or - // IfStmt: condition (0), then (1), else (2) - exists(IfStmt ifStmt | ifStmt = n | - index = 0 and result = ifStmt.getCondition() - or - index = 1 and result = ifStmt.getThen() - or - index = 2 and result = ifStmt.getElse() - ) - or - // ExprStmt: the expression (0) - index = 0 and result = n.(ExprStmt).getExpr() - or - // Assign: value (0), targets (1..n) - exists(AssignStmt a | a = n | - index = 0 and result = a.getValue() - or - result = a.getTarget(index - 1) and index >= 1 - ) - or - // AugAssign: the operation (0) - index = 0 and result = n.(AugAssignStmt).getOperation() - or - // Walrus (`x := expr`): value (0), target (1) - exists(NamedExpr ne | ne = n | - index = 0 and result = ne.getValue() - or - index = 1 and result = ne.getTarget() - ) - or - // WhileStmt: condition (0), body (1), orelse (2) - exists(WhileStmt w | w = n | - index = 0 and result = w.getCondition() - or - index = 1 and result = w.getBody() - or - index = 2 and result = w.getElse() - ) - or - // ForeachStmt: collection (0), variable (1), body (2), orelse (3) - exists(ForeachStmt f | f = n | - index = 0 and result = f.getCollection() - or - index = 1 and result = f.getVariable() - or - index = 2 and result = f.getBody() - or - index = 3 and result = f.getElse() - ) - or - // ReturnStmt: the value (0) - index = 0 and result = n.(ReturnStmt).getExpr() - or - // AssertStmt: test (0), message (1) - exists(AssertStmt a | a = n | - index = 0 and result = a.getTest() - or - index = 1 and result = a.getMsg() - ) - or - // DeleteStmt: targets left to right - result = n.(DeleteStmt).getTarget(index) - or - // WithStmt: context expr (0), optional vars (1), body (2) - exists(WithStmt w | w = n | - index = 0 and result = w.getContextExpr() - or - index = 1 and result = w.getOptionalVars() - or - index = 2 and result = w.getBody() - ) - or - // Throw (raise): exception (0), cause (1) - exists(Throw r | r = n | - index = 0 and result = r.getExpr() - or - index = 1 and result = r.getCause() - ) - or - // TryStmt: body (0), handlers (1..n), else (-2), finally (-1) - exists(TryStmt t | t = n | - index = 0 and result = t.getBody() - or - result = t.getCatch(index - 1) and index >= 1 - or - index = -1 and result = t.getFinally() - or - index = -2 and result = t.getElse() - ) - or - // Switch (match): subject (0), cases (1..n) - exists(Switch m | m = n | - index = 0 and result = m.getExpr() - or - result = m.getCase(index - 1) and index >= 1 - ) - or - // Case: pattern (0), guard (1), body (2) - exists(Case c | c = n | - index = 0 and result = c.getAPattern() - or - index = 1 and result = c.getGuard() - or - index = 2 and result = c.getBody() - ) - or - // CatchClause (except handler): type (0), name (1), body (2) - exists(CatchClause h | h = n | - index = 0 and result = h.getType() - or - index = 1 and result = h.getVariable() - or - index = 2 and result = h.getBody() - ) - or - // ConditionalExpr (IfExp): condition (0), then (1), else (2) - exists(ConditionalExpr ie | ie = n | - index = 0 and result = ie.getCondition() - or - index = 1 and result = ie.getThen() - or - index = 2 and result = ie.getElse() - ) - or - // Call: func (0), positional args (1..n), keyword values (n+1..n+k) - exists(CallExpr call | call = n | - index = 0 and result = call.getFunc() - or - result = call.getPositionalArg(index - 1) and index >= 1 - or - result = call.getKeywordValue(index - 1 - call.getNumberOfPositionalArgs()) and - index >= 1 + call.getNumberOfPositionalArgs() - ) - or - // Python BinaryExpr (arithmetic, bitwise, matmul, etc.): left (0), right (1) - exists(ArithBinaryExpr be | be = n | - index = 0 and result = be.getLeft() - or - index = 1 and result = be.getRight() - ) - or - // Subscript (obj[index]): object (0), index (1) - exists(SubscriptExpr sub | sub = n | - index = 0 and result = sub.getObject() - or - index = 1 and result = sub.getIndex() - ) - or - // Attribute (obj.name): object (0) - index = 0 and result = n.(AttributeExpr).getObject() - or - // Comprehension/generator: iterable (0) - index = 0 and result = n.(Comprehension).getIterable() - or - // Tuple, List, Set: elements left to right - result = n.(TupleExpr).getElt(index) - or - result = n.(ListExpr).getElt(index) - or - result = n.(SetExpr).getElt(index) - or - // Dict: key(0), value(0), key(1), value(1), ... - exists(DictExpr d, int item | d = n | - index = 2 * item and result = d.getKey(item) - or - index = 2 * item + 1 and result = d.getValue(item) - ) - or - // Arithmetic unary (-x, +x, ~x): operand (0) - index = 0 and result = n.(ArithUnaryExpr).getOperand() - or - // Compare (a < b < c): left (0), comparators (1..n) - exists(CompareExpr cmp | cmp = n | - index = 0 and result = cmp.getLeft() - or - result = cmp.getComparator(index - 1) and index >= 1 - ) - or - // Slice (start:stop:step): start (0), stop (1), step (2) - exists(SliceExpr sl | sl = n | - index = 0 and result = sl.getStart() - or - index = 1 and result = sl.getStop() - or - index = 2 and result = sl.getStep() - ) - or - // Starred (*x): value (0) - index = 0 and result = n.(StarredExpr).getValue() - or - // Fstring: values left to right - result = n.(FstringExpr).getValue(index) - or - // FormattedValue ({expr} or {expr:spec}): value (0), format spec (1) - exists(FormattedValueExpr fv | fv = n | - index = 0 and result = fv.getValue() - or - index = 1 and result = fv.getFormatSpec() - ) - or - // Yield: value (0) - index = 0 and result = n.(YieldExpr).getValue() - or - // YieldFrom: value (0) - index = 0 and result = n.(YieldFromExpr).getValue() - or - // Await: value (0) - index = 0 and result = n.(AwaitExpr).getValue() - or - // ClassExpr: base classes left to right - result = n.(ClassDefExpr).getBase(index) - or - // FunctionExpr: defaults left to right, then kw defaults - exists(FunctionDefExpr fe | fe = n | - result = fe.getDefault(index) - or - result = fe.getKwDefault(index - fe.getNumberOfDefaults()) - ) - or - // Lambda: defaults left to right, then kw defaults - exists(LambdaExpr lam | lam = n | - result = lam.getDefault(index) - or - result = lam.getKwDefault(index - lam.getNumberOfDefaults()) - ) - or - // LogicalNotExpr: operand (0) - index = 0 and result = n.(LogicalNotExpr).getOperand() - or - // BinaryExpr (`and`/`or`): left (0), right (1) - exists(BinaryExpr be | be = n | - index = 0 and result = be.getLeftOperand() - or - index = 1 and result = be.getRightOperand() - ) - } + AstNode getChild(AstNode n, int index) { result = n.getChild(index) } } private module Cfg0 = Make0;