From a70abdd007535d716036e73064608c529a03cddc Mon Sep 17 00:00:00 2001 From: Copilot Date: Thu, 7 May 2026 20:00:55 +0000 Subject: [PATCH] Python: project via as* helpers outside characteristic predicates Style cleanup: avoid naming newtype branch constructors (TPyStmt, TPyExpr, TBlockStmt, TPattern, TBoolExprPair, TScope) outside the char-preds that classify their wrappers. Method bodies and helper predicates now use the as* projections instead: // Before: result = TBlockStmt(ifStmt.getBody()) // After: result.asStmtList() = ifStmt.getBody() // Before: result = TPyStmt(matchStmt.getCase(index)) // After: result.asStmt() = matchStmt.getCase(index) Adds: - AstNode.asStmtList() - the inverse of TBlockStmt(_). - BinaryExpr.getIndex() - exposes the synthetic-pair index, used internally by getRightOperand to find the next pair without naming TBoolExprPair. No behaviour change: all 24 NewCfg evaluation-order tests pass; all 11 shared-CFG consistency queries report 0 violations on CPython. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../controlflow/internal/AstNodeImpl.qll | 45 +++++++++++-------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll b/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll index 8b4cf1189f4..681b05cdf05 100644 --- a/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll +++ b/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll @@ -101,6 +101,9 @@ module Ast implements AstSig { /** Gets the underlying Python `Pattern`, if this node wraps one. */ Py::Pattern asPattern() { this = TPattern(result) } + /** Gets the underlying Python `StmtList`, if this node is a `BlockStmt`. */ + Py::StmtList asStmtList() { this = TBlockStmt(result) } + /** * Gets the child of this AST node at the specified (zero-based) * index, in evaluation order. Subclasses with children override @@ -133,7 +136,7 @@ module Ast implements AstSig { } /** Gets the body of callable `c`. */ - AstNode callableGetBody(Callable c) { result = TBlockStmt(c.asScope().getBody()) } + AstNode callableGetBody(Callable c) { result.asStmtList() = c.asScope().getBody() } /** * A parameter of a callable. @@ -300,10 +303,10 @@ module Ast implements AstSig { Expr getCondition() { result.asExpr() = ifStmt.getTest() } /** Gets the `then` (true) branch of this `if` statement. */ - Stmt getThen() { result = TBlockStmt(ifStmt.getBody()) } + Stmt getThen() { result.asStmtList() = ifStmt.getBody() } /** Gets the `else` (false) branch, if any. */ - Stmt getElse() { result = TBlockStmt(ifStmt.getOrelse()) } + Stmt getElse() { result.asStmtList() = ifStmt.getOrelse() } override AstNode getChild(int index) { index = 0 and result = this.getCondition() @@ -335,10 +338,10 @@ module Ast implements AstSig { /** Gets the boolean condition of this `while` loop. */ Expr getCondition() { result.asExpr() = whileStmt.getTest() } - override Stmt getBody() { result = TBlockStmt(whileStmt.getBody()) } + override Stmt getBody() { result.asStmtList() = whileStmt.getBody() } /** Gets the `else` branch of this `while` loop, if any. */ - Stmt getElse() { result = TBlockStmt(whileStmt.getOrelse()) } + Stmt getElse() { result.asStmtList() = whileStmt.getOrelse() } override AstNode getChild(int index) { index = 0 and result = this.getCondition() @@ -381,10 +384,10 @@ module Ast implements AstSig { /** Gets the collection being iterated. */ Expr getCollection() { result.asExpr() = forStmt.getIter() } - override Stmt getBody() { result = TBlockStmt(forStmt.getBody()) } + override Stmt getBody() { result.asStmtList() = forStmt.getBody() } /** Gets the `else` branch of this `for` loop, if any. */ - Stmt getElse() { result = TBlockStmt(forStmt.getOrelse()) } + Stmt getElse() { result.asStmtList() = forStmt.getOrelse() } override AstNode getChild(int index) { index = 0 and result = this.getCollection() @@ -453,7 +456,7 @@ module Ast implements AstSig { Expr getOptionalVars() { result.asExpr() = withStmt.getOptionalVars() } - Stmt getBody() { result = TBlockStmt(withStmt.getBody()) } + Stmt getBody() { result.asStmtList() = withStmt.getBody() } override AstNode getChild(int index) { index = 0 and result = this.getContextExpr() @@ -498,14 +501,14 @@ module Ast implements AstSig { TryStmt() { this = TPyStmt(tryStmt) } - Stmt getBody() { result = TBlockStmt(tryStmt.getBody()) } + Stmt getBody() { result.asStmtList() = tryStmt.getBody() } /** Gets the `else` branch of this `try` statement, if any. */ - Stmt getElse() { result = TBlockStmt(tryStmt.getOrelse()) } + Stmt getElse() { result.asStmtList() = tryStmt.getOrelse() } - Stmt getFinally() { result = TBlockStmt(tryStmt.getFinalbody()) } + Stmt getFinally() { result.asStmtList() = tryStmt.getFinalbody() } - CatchClause getCatch(int index) { result = TPyStmt(tryStmt.getHandler(index)) } + CatchClause getCatch(int index) { result.asStmt() = tryStmt.getHandler(index) } override AstNode getChild(int index) { index = 0 and result = this.getBody() @@ -550,9 +553,9 @@ module Ast implements AstSig { /** Gets the body of this exception handler. */ Stmt getBody() { - result = TBlockStmt(handler.(Py::ExceptStmt).getBody()) + result.asStmtList() = handler.(Py::ExceptStmt).getBody() or - result = TBlockStmt(handler.(Py::ExceptGroupStmt).getBody()) + result.asStmtList() = handler.(Py::ExceptGroupStmt).getBody() } override AstNode getChild(int index) { @@ -572,7 +575,7 @@ module Ast implements AstSig { Expr getExpr() { result.asExpr() = matchStmt.getSubject() } - Case getCase(int index) { result = TPyStmt(matchStmt.getCase(index)) } + Case getCase(int index) { result.asStmt() = matchStmt.getCase(index) } Stmt getStmt(int index) { none() } @@ -589,11 +592,11 @@ module Ast implements AstSig { Case() { this = TPyStmt(caseStmt) } - AstNode getAPattern() { result = TPattern(caseStmt.getPattern()) } + AstNode getAPattern() { result.asPattern() = caseStmt.getPattern() } Expr getGuard() { result.asExpr() = caseStmt.getGuard().(Py::Guard).getTest() } - AstNode getBody() { result = TBlockStmt(caseStmt.getBody()) } + AstNode getBody() { result.asStmtList() = caseStmt.getBody() } /** Holds if this case is a wildcard pattern (`case _:`). */ predicate isWildcard() { caseStmt.getPattern() instanceof Py::MatchWildcardPattern } @@ -649,6 +652,9 @@ module Ast implements AstSig { /** Gets the underlying Python `BoolExpr`. */ Py::BoolExpr getBoolExpr() { result = be } + /** Gets the (zero-based) index of this pair within its `BoolExpr`. */ + int getIndex() { result = index } + override string toString() { result = be.getOperator() } override Py::Location getLocation() { result = be.getValue(index).getLocation() } @@ -664,7 +670,10 @@ module Ast implements AstSig { index = count(be.getAValue()) - 2 and result.asExpr() = be.getValue(index + 1) or // Non-last pair: right operand is the next synthetic pair. - index < count(be.getAValue()) - 2 and result = TBoolExprPair(be, index + 1) + index < count(be.getAValue()) - 2 and + exists(BinaryExpr next | + next.getBoolExpr() = be and next.getIndex() = index + 1 and result = next + ) } override AstNode getChild(int childIndex) {