From dc0344e2fc199a719aa939081c1faad7be9c8e88 Mon Sep 17 00:00:00 2001 From: Taus Date: Tue, 21 Apr 2026 13:54:26 +0000 Subject: [PATCH] Python: More AstNodeImpl improvements Co-authored-by: yoff --- .../controlflow/internal/AstNodeImpl.qll | 529 +++++++++++++++--- 1 file changed, 443 insertions(+), 86 deletions(-) diff --git a/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll b/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll index 526733a340a..47df5c0f619 100644 --- a/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll +++ b/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll @@ -17,7 +17,17 @@ private module Ast { TStmtNode(Py::Stmt s) or TExprNode(Py::Expr e) or TScopeNode(Py::Scope sc) or - TStmtListNode(Py::StmtList sl) + TStmtListNode(Py::StmtList sl) or + /** + * A synthetic node representing an intermediate pair in a multi-operand + * `and`/`or` expression. For `a and b and c` (values 0,1,2), we + * synthesize a right-nested tree: the pair at index 1 represents + * `b and c`, which becomes the right operand of the outermost pair. + * + * Only created for inner pairs (index >= 1); the outermost pair (index 0) + * is represented by the original `BoolExpr` node via `TExprNode`. + */ + TBoolExprPair(Py::BoolExpr be, int index) { index = [1 .. count(be.getAValue()) - 2] } /** * An AST node for the shared CFG. Each branch of the newtype gets a @@ -135,6 +145,226 @@ private module Ast { /** Gets the expression in this statement. */ ExprNode getValue() { result.asExpr() = exprStmt.getValue() } } + + /** A `while` statement. */ + class WhileNode extends StmtNode { + private Py::While whileStmt; + + WhileNode() { whileStmt = this.asStmt() } + + ExprNode getTest() { result.asExpr() = whileStmt.getTest() } + + StmtListNode getBody() { result.asStmtList() = whileStmt.getBody() } + + StmtListNode getOrelse() { result.asStmtList() = whileStmt.getOrelse() } + } + + /** A `for` statement. */ + class ForNode extends StmtNode { + private Py::For forStmt; + + ForNode() { forStmt = this.asStmt() } + + ExprNode getTarget() { result.asExpr() = forStmt.getTarget() } + + ExprNode getIter() { result.asExpr() = forStmt.getIter() } + + StmtListNode getBody() { result.asStmtList() = forStmt.getBody() } + + StmtListNode getOrelse() { result.asStmtList() = forStmt.getOrelse() } + } + + /** A `return` statement. */ + class ReturnNode extends StmtNode { + private Py::Return ret; + + ReturnNode() { ret = this.asStmt() } + + ExprNode getValue() { result.asExpr() = ret.getValue() } + } + + /** A `raise` statement. */ + class RaiseNode extends StmtNode { + private Py::Raise raise; + + RaiseNode() { raise = this.asStmt() } + + ExprNode getException() { result.asExpr() = raise.getException() } + + ExprNode getCause() { result.asExpr() = raise.getCause() } + } + + /** A `break` statement. */ + class BreakNode extends StmtNode { + BreakNode() { this.asStmt() instanceof Py::Break } + } + + /** A `continue` statement. */ + class ContinueNode extends StmtNode { + ContinueNode() { this.asStmt() instanceof Py::Continue } + } + + /** A `try` statement. */ + class TryNode extends StmtNode { + private Py::Try tryStmt; + + TryNode() { tryStmt = this.asStmt() } + + StmtListNode getBody() { result.asStmtList() = tryStmt.getBody() } + + StmtListNode getOrelse() { result.asStmtList() = tryStmt.getOrelse() } + + StmtListNode getFinalbody() { result.asStmtList() = tryStmt.getFinalbody() } + + ExceptionHandlerNode getHandler(int i) { result.asStmt() = tryStmt.getHandler(i) } + } + + /** An exception handler (`except` or `except*`). */ + class ExceptionHandlerNode extends StmtNode { + private Py::ExceptionHandler handler; + + ExceptionHandlerNode() { handler = this.asStmt() } + + ExprNode getType() { result.asExpr() = handler.getType() } + + ExprNode getName() { result.asExpr() = handler.getName() } + + StmtListNode getBody() { + result.asStmtList() = handler.(Py::ExceptStmt).getBody() or + result.asStmtList() = handler.(Py::ExceptGroupStmt).getBody() + } + } + + /** A conditional expression (`x if cond else y`). */ + class IfExpNode extends ExprNode { + private Py::IfExp ifExp; + + IfExpNode() { ifExp = this.asExpr() } + + ExprNode getTest() { result.asExpr() = ifExp.getTest() } + + ExprNode getBody() { result.asExpr() = ifExp.getBody() } + + ExprNode getOrelse() { result.asExpr() = ifExp.getOrelse() } + } + + /** A Python binary expression (arithmetic, bitwise, matmul, etc.). */ + class BinaryExprNode extends ExprNode { + private Py::BinaryExpr binExpr; + + BinaryExprNode() { binExpr = this.asExpr() } + + ExprNode getLeft() { result.asExpr() = binExpr.getLeft() } + + ExprNode getRight() { result.asExpr() = binExpr.getRight() } + } + + /** A subscript expression (`obj[index]`). */ + class SubscriptNode extends ExprNode { + private Py::Subscript sub; + + SubscriptNode() { sub = this.asExpr() } + + ExprNode getObject() { result.asExpr() = sub.getObject() } + + ExprNode getIndex() { result.asExpr() = sub.getIndex() } + } + + /** + * A `not` expression. This is a `UnaryExpr` whose operator is `Not`. + */ + class NotExprNode extends ExprNode { + private Py::UnaryExpr notExpr; + + NotExprNode() { notExpr = this.asExpr() and notExpr.getOp() instanceof Py::Not } + + ExprNode getOperand() { result.asExpr() = notExpr.getOperand() } + } + + /** + * A boolean expression (`and`/`or`) with exactly 2 operands. + * For 2-operand BoolExprs, the `TExprNode` itself serves as the + * logical and/or expression. + */ + class BoolExpr2Node extends ExprNode { + private Py::BoolExpr boolExpr; + + BoolExpr2Node() { boolExpr = this.asExpr() and count(boolExpr.getAValue()) = 2 } + + predicate isAnd() { boolExpr.getOp() instanceof Py::And } + + predicate isOr() { boolExpr.getOp() instanceof Py::Or } + + ExprNode getLeftOperand() { result.asExpr() = boolExpr.getValue(0) } + + ExprNode getRightOperand() { result.asExpr() = boolExpr.getValue(1) } + } + + /** + * The outermost pair of a multi-operand (3+) boolean expression. + * Represented by the original `BoolExpr` node (`TExprNode`). + * Left operand is `getValue(0)`, right operand is `TBoolExprPair(be, 1)`. + */ + class BoolExprOuterNode extends ExprNode { + private Py::BoolExpr boolExpr; + + BoolExprOuterNode() { boolExpr = this.asExpr() and count(boolExpr.getAValue()) > 2 } + + predicate isAnd() { boolExpr.getOp() instanceof Py::And } + + predicate isOr() { boolExpr.getOp() instanceof Py::Or } + + Node getLeftOperand() { result = TExprNode(boolExpr.getValue(0)) } + + Node getRightOperand() { result = TBoolExprPair(boolExpr, 1) } + } + + /** + * A synthetic intermediate node in a multi-operand boolean expression. + * Pair at index `i` has left=`getValue(i)` and right=pair at `i+1` + * (or `getValue(n-1)` for the last pair). + */ + class BoolExprPairNode extends Node, TBoolExprPair { + private Py::BoolExpr boolExpr; + private int index; + + BoolExprPairNode() { this = TBoolExprPair(boolExpr, index) } + + override string toString() { result = boolExpr.getOperator() } + + override Py::Location getLocation() { result = boolExpr.getValue(index).getLocation() } + + override ScopeNode getEnclosingScope() { + result.asScope() = boolExpr.getValue(index).getScope() + } + + predicate isAnd() { boolExpr.getOp() instanceof Py::And } + + predicate isOr() { boolExpr.getOp() instanceof Py::Or } + + Node getLeftOperand() { result = TExprNode(boolExpr.getValue(index)) } + + Node getRightOperand() { + // Last pair: right operand is the final value + index = count(boolExpr.getAValue()) - 2 and + result = TExprNode(boolExpr.getValue(index + 1)) + or + // Not last pair: right operand is the next synthetic pair + index < count(boolExpr.getAValue()) - 2 and + result = TBoolExprPair(boolExpr, index + 1) + } + } + + /** A `True` or `False` literal. */ + class BoolLiteralNode extends ExprNode { + BoolLiteralNode() { this.asExpr() instanceof Py::True or this.asExpr() instanceof Py::False } + + boolean getBoolValue() { + this.asExpr() instanceof Py::True and result = true + or + this.asExpr() instanceof Py::False and result = false + } + } } /** Provides an implementation of the AST signature for Python. */ @@ -143,6 +373,7 @@ module AstSigImpl implements AstSig { /** Gets the child of `n` at the specified (zero-based) index. */ AstNode getChild(AstNode n, int index) { + // IfStmt: condition (0), then branch (1), else branch (2) exists(Ast::IfNode ifNode | ifNode = n | index = 0 and result = ifNode.getTest() or @@ -151,9 +382,101 @@ module AstSigImpl implements AstSig { index = 2 and result = ifNode.getOrelse() ) or + // BlockStmt (StmtList): indexed statements result = n.(Ast::StmtListNode).getItem(index) or + // ExprStmt: the expression (0) index = 0 and result = n.(Ast::ExprStmtNode).getValue() + or + // WhileStmt: condition (0), body (1) + // Note: Python while/else is not directly supported by the shared library. + exists(Ast::WhileNode w | w = n | + index = 0 and result = w.getTest() + or + index = 1 and result = w.getBody() + ) + or + // ForStmt (mapped as ForeachStmt): collection (0), variable (1), body (2) + exists(Ast::ForNode f | f = n | + index = 0 and result = f.getIter() + or + index = 1 and result = f.getTarget() + or + index = 2 and result = f.getBody() + ) + or + // ReturnStmt: the value (0) + index = 0 and result = n.(Ast::ReturnNode).getValue() + or + // ThrowStmt (raise): the exception (0), the cause (1) + exists(Ast::RaiseNode r | r = n | + index = 0 and result = r.getException() + or + index = 1 and result = r.getCause() + ) + or + // TryStmt: body (0), handlers (1..n), finally (-1) + exists(Ast::TryNode t | t = n | + index = 0 and result = t.getBody() + or + result = t.getHandler(index - 1) and index >= 1 + ) + or + // CatchClause (except handler): type (0), name (1), body (2) + exists(Ast::ExceptionHandlerNode h | h = n | + index = 0 and result = h.getType() + or + index = 1 and result = h.getName() + or + index = 2 and result = h.getBody() + ) + or + // ConditionalExpr (IfExp): condition (0), then (1), else (2) + exists(Ast::IfExpNode ie | ie = n | + index = 0 and result = ie.getTest() + or + index = 1 and result = ie.getBody() + or + index = 2 and result = ie.getOrelse() + ) + or + // Python BinaryExpr (arithmetic, bitwise, matmul, etc.): left (0), right (1) + exists(Ast::BinaryExprNode be | be = n | + index = 0 and result = be.getLeft() + or + index = 1 and result = be.getRight() + ) + or + // Subscript (obj[index]): object (0), index (1) + exists(Ast::SubscriptNode sub | sub = n | + index = 0 and result = sub.getObject() + or + index = 1 and result = sub.getIndex() + ) + or + // LogicalNotExpr: operand (0) + index = 0 and result = n.(Ast::NotExprNode).getOperand() + or + // 2-operand BoolExpr: left (0), right (1) + exists(Ast::BoolExpr2Node be | be = n | + index = 0 and result = be.getLeftOperand() + or + index = 1 and result = be.getRightOperand() + ) + or + // Multi-operand BoolExpr (outermost): left (0), right (1) + exists(Ast::BoolExprOuterNode be | be = n | + index = 0 and result = be.getLeftOperand() + or + index = 1 and result = be.getRightOperand() + ) + or + // Synthetic BoolExpr pair: left (0), right (1) + exists(Ast::BoolExprPairNode bp | bp = n | + index = 0 and result = bp.getLeftOperand() + or + index = 1 and result = bp.getRightOperand() + ) } Callable getEnclosingCallable(AstNode node) { result = node.getEnclosingScope() } @@ -173,8 +496,10 @@ module AstSigImpl implements AstSig { Stmt() { this instanceof Ast::StmtNode or this instanceof Ast::StmtListNode } } - /** An expression. */ - class Expr extends Ast::ExprNode { } + /** An expression. Includes `TExprNode` and synthetic `TBoolExprPair` nodes. */ + class Expr extends AstNode { + Expr() { this instanceof Ast::ExprNode or this instanceof Ast::BoolExprPairNode } + } /** A block of statements, wrapping Python's `StmtList`. */ class BlockStmt extends Stmt, Ast::StmtListNode { @@ -210,113 +535,107 @@ module AstSigImpl implements AstSig { Stmt getElse() { result = this.getOrelse() } } - // ===== Stub types for constructs not yet implemented ===== - /** A loop statement. Not yet implemented for Python. */ + // ===== Loop statements ===== + /** A loop statement. */ class LoopStmt extends Stmt { - LoopStmt() { none() } + LoopStmt() { this instanceof Ast::WhileNode or this instanceof Ast::ForNode } /** Gets the body of this loop statement. */ Stmt getBody() { none() } } - /** A `while` loop statement. Not yet implemented for Python. */ - class WhileStmt extends LoopStmt { + /** A `while` loop statement. */ + class WhileStmt extends LoopStmt instanceof Ast::WhileNode { /** Gets the boolean condition of this `while` loop. */ - Expr getCondition() { none() } + Expr getCondition() { result = this.(Ast::WhileNode).getTest() } + + override Stmt getBody() { result = this.(Ast::WhileNode).getBody() } } /** A `do-while` loop statement. Python has no do-while construct. */ class DoStmt extends LoopStmt { - /** Gets the boolean condition of this `do-while` loop. */ + DoStmt() { none() } + Expr getCondition() { none() } } /** A C-style `for` loop. Python has no C-style for loop. */ class ForStmt extends LoopStmt { - /** Gets the initializer expression at the specified position. */ + ForStmt() { none() } + Expr getInit(int index) { none() } - /** Gets the boolean condition of this `for` loop. */ Expr getCondition() { none() } - /** Gets the update expression at the specified position. */ Expr getUpdate(int index) { none() } } - /** A for-each loop. Not yet implemented for Python. */ + /** A for-each loop (`for x in iterable:`). */ class ForeachStmt extends LoopStmt { + ForeachStmt() { this instanceof Ast::ForNode } + /** Gets the loop variable. */ - Expr getVariable() { none() } + Expr getVariable() { result = this.(Ast::ForNode).getTarget() } /** Gets the collection being iterated. */ - Expr getCollection() { none() } + Expr getCollection() { result = this.(Ast::ForNode).getIter() } + + override Stmt getBody() { result = this.(Ast::ForNode).getBody() } } - /** A `break` statement. Not yet implemented for Python. */ - class BreakStmt extends Stmt { - BreakStmt() { none() } - } + // ===== Abrupt completion statements ===== + /** A `break` statement. */ + class BreakStmt extends Stmt, Ast::BreakNode { } - /** A `continue` statement. Not yet implemented for Python. */ - class ContinueStmt extends Stmt { - ContinueStmt() { none() } - } - - /** A `return` statement. Not yet implemented for Python. */ - class ReturnStmt extends Stmt { - ReturnStmt() { none() } + /** A `continue` statement. */ + class ContinueStmt extends Stmt, Ast::ContinueNode { } + /** A `return` statement. */ + class ReturnStmt extends Stmt, Ast::ReturnNode { /** Gets the expression being returned, if any. */ - Expr getExpr() { none() } + Expr getExpr() { result = this.getValue() } } - /** A `throw`/`raise` statement. Not yet implemented for Python. */ - class ThrowStmt extends Stmt { - ThrowStmt() { none() } - - /** Gets the expression being thrown. */ - Expr getExpr() { none() } + /** A `raise` statement (mapped to `ThrowStmt`). */ + class ThrowStmt extends Stmt, Ast::RaiseNode { + /** Gets the expression being raised. */ + Expr getExpr() { result = this.getException() } } - /** A `try` statement. Not yet implemented for Python. */ + // ===== Try/except ===== + /** A `try` statement. */ class TryStmt extends Stmt { - TryStmt() { none() } + TryStmt() { this instanceof Ast::TryNode } - /** Gets the body of this `try` statement. */ - Stmt getBody() { none() } + Stmt getBody() { result = this.(Ast::TryNode).getBody() } - /** Gets the `catch` clause at the specified position. */ - CatchClause getCatch(int index) { none() } + CatchClause getCatch(int index) { result = this.(Ast::TryNode).getHandler(index) } - /** Gets the `finally` block of this `try` statement, if any. */ - Stmt getFinally() { none() } + Stmt getFinally() { result = this.(Ast::TryNode).getFinalbody() } } - /** A catch clause. Not yet implemented for Python. */ - class CatchClause extends AstNode { - CatchClause() { none() } + AstNode getTryElse(TryStmt try) { result = try.(Ast::TryNode).getOrelse() } - /** Gets the variable declared by this catch clause. */ - AstNode getVariable() { none() } + /** An except clause in a try statement. */ + class CatchClause extends Stmt { + CatchClause() { this instanceof Ast::ExceptionHandlerNode } + + AstNode getVariable() { result = this.(Ast::ExceptionHandlerNode).getName() } - /** Gets the guard condition, if any. */ Expr getCondition() { none() } - /** Gets the body of this catch clause. */ - Stmt getBody() { none() } + Stmt getBody() { result = this.(Ast::ExceptionHandlerNode).getBody() } } + // ===== Switch/match — stubs for now ===== /** A switch/match statement. Not yet implemented for Python. */ class Switch extends AstNode { Switch() { none() } - /** Gets the expression being switched on. */ Expr getExpr() { none() } - /** Gets the case at the specified position. */ Case getCase(int index) { none() } - /** Gets the statement at the specified position. */ Stmt getStmt(int index) { none() } } @@ -324,70 +643,96 @@ module AstSigImpl implements AstSig { class Case extends AstNode { Case() { none() } - /** Gets a pattern being matched. */ AstNode getAPattern() { none() } - /** Gets the guard expression, if any. */ Expr getGuard() { none() } - /** Gets the body of this case. */ AstNode getBody() { none() } } /** A default case. Not yet implemented for Python. */ class DefaultCase extends Case { } - /** A ternary conditional expression. Not yet implemented for Python. */ - class ConditionalExpr extends Expr { - ConditionalExpr() { none() } - + // ===== Expression types ===== + /** A conditional expression (`x if cond else y`). */ + class ConditionalExpr extends Expr, Ast::IfExpNode { /** Gets the condition of this expression. */ - Expr getCondition() { none() } + Expr getCondition() { result = this.getTest() } /** Gets the true branch of this expression. */ - Expr getThen() { none() } + Expr getThen() { result = Ast::IfExpNode.super.getBody() } /** Gets the false branch of this expression. */ - Expr getElse() { none() } + Expr getElse() { result = this.getOrelse() } } - /** A binary expression. Not yet implemented for Python. */ + /** + * A binary expression for the shared CFG. In Python, this covers + * `and`/`or` expressions (both real 2-operand and synthetic pairs). + */ class BinaryExpr extends Expr { - BinaryExpr() { none() } + BinaryExpr() { + this instanceof Ast::BoolExpr2Node or + this instanceof Ast::BoolExprOuterNode or + this instanceof Ast::BoolExprPairNode + } /** Gets the left operand. */ - Expr getLeftOperand() { none() } + Expr getLeftOperand() { + result = this.(Ast::BoolExpr2Node).getLeftOperand() + or + result = this.(Ast::BoolExprOuterNode).getLeftOperand() + or + result = this.(Ast::BoolExprPairNode).getLeftOperand() + } /** Gets the right operand. */ - Expr getRightOperand() { none() } + Expr getRightOperand() { + result = this.(Ast::BoolExpr2Node).getRightOperand() + or + result = this.(Ast::BoolExprOuterNode).getRightOperand() + or + result = this.(Ast::BoolExprPairNode).getRightOperand() + } } - /** A short-circuiting logical AND expression. Not yet implemented for Python. */ - class LogicalAndExpr extends BinaryExpr { } + /** A short-circuiting logical `and` expression. */ + class LogicalAndExpr extends BinaryExpr { + LogicalAndExpr() { + this.(Ast::BoolExpr2Node).isAnd() or + this.(Ast::BoolExprOuterNode).isAnd() or + this.(Ast::BoolExprPairNode).isAnd() + } + } - /** A short-circuiting logical OR expression. Not yet implemented for Python. */ - class LogicalOrExpr extends BinaryExpr { } + /** A short-circuiting logical `or` expression. */ + class LogicalOrExpr extends BinaryExpr { + LogicalOrExpr() { + this.(Ast::BoolExpr2Node).isOr() or + this.(Ast::BoolExprOuterNode).isOr() or + this.(Ast::BoolExprPairNode).isOr() + } + } /** A null-coalescing expression. Python has no null-coalescing operator. */ - class NullCoalescingExpr extends BinaryExpr { } - - /** A unary expression. Not yet implemented for Python. */ - class UnaryExpr extends Expr { - UnaryExpr() { none() } - - /** Gets the operand. */ - Expr getOperand() { none() } + class NullCoalescingExpr extends BinaryExpr { + NullCoalescingExpr() { none() } } - /** A logical NOT expression. Not yet implemented for Python. */ + /** A unary expression. Exists for the `not` subclass. */ + class UnaryExpr extends Expr { + UnaryExpr() { this instanceof Ast::NotExprNode } + + Expr getOperand() { result = this.(Ast::NotExprNode).getOperand() } + } + + /** A logical `not` expression. */ class LogicalNotExpr extends UnaryExpr { } - /** A boolean literal expression. Not yet implemented for Python. */ - class BooleanLiteral extends Expr { - BooleanLiteral() { none() } - + /** A boolean literal expression (`True` or `False`). */ + class BooleanLiteral extends Expr, Ast::BoolLiteralNode { /** Gets the boolean value of this literal. */ - boolean getValue() { none() } + boolean getValue() { result = this.getBoolValue() } } } @@ -427,3 +772,15 @@ private module Input implements InputSig1, InputSig2 { import CfgCachedStage import Public + +/** + * Maps a new-CFG AST wrapper node to the corresponding Python AST node, if any. + * Entry, exit, and synthetic nodes have no corresponding Python AST node. + */ +Py::AstNode astNodeToPyNode(AstSigImpl::AstNode n) { + result = n.(Ast::ExprNode).asExpr() + or + result = n.(Ast::StmtNode).asStmt() + or + result = n.(Ast::ScopeNode).asScope() +}