diff --git a/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll b/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll index bcf8f4f6470..ccd36356572 100644 --- a/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll +++ b/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll @@ -17,42 +17,6 @@ private import codeql.util.Void /** Provides the Python implementation of the shared CFG `AstSig`. */ module Ast implements AstSig { - /** - * Maps a `(parent, slot)` pair to the `Py::StmtList` that holds the items - * of the `BlockStmt` for that slot. The slot string distinguishes between - * the multiple bodies that some parents have (e.g. `if` has `body` and - * `orelse`). - */ - private Py::StmtList getBodyStmtList(Py::AstNode parent, string slot) { - result = parent.(Py::Scope).getBody() and slot = "body" - or - result = parent.(Py::If).getBody() and slot = "body" - or - result = parent.(Py::If).getOrelse() and slot = "orelse" - or - result = parent.(Py::While).getBody() and slot = "body" - or - result = parent.(Py::While).getOrelse() and slot = "orelse" - or - result = parent.(Py::For).getBody() and slot = "body" - or - result = parent.(Py::For).getOrelse() and slot = "orelse" - or - result = parent.(Py::With).getBody() and slot = "body" - or - result = parent.(Py::Try).getBody() and slot = "body" - or - result = parent.(Py::Try).getOrelse() and slot = "orelse" - or - result = parent.(Py::Try).getFinalbody() and slot = "finally" - or - result = parent.(Py::Case).getBody() and slot = "body" - or - result = parent.(Py::ExceptStmt).getBody() and slot = "body" - or - result = parent.(Py::ExceptGroupStmt).getBody() and slot = "body" - } - private newtype TAstNode = TStmt(Py::Stmt s) or TExpr(Py::Expr e) { not e instanceof Py::BoolExpr } or @@ -68,11 +32,42 @@ module Ast implements AstSig { */ TBoolExprPair(Py::BoolExpr be, int index) { index = [0 .. count(be.getAValue()) - 2] } or /** - * A synthetic block statement, identifying one body slot of the - * `parent` AST node. The `slot` string disambiguates among multiple - * bodies of the same parent (`"body"`, `"orelse"`, `"finally"`). + * A synthetic block statement, wrapping a `Py::StmtList`. Each list of + * statements that represents an imperative block (a function/class/module + * body, an `if`/`while`/`for` branch, a `try`/`except`/`finally` body, + * etc.) becomes one `BlockStmt` node in the CFG. Lists used in other + * roles (e.g. `Try.getHandlers()`, which is iterated as catch clauses) + * are excluded. */ - TBlockStmt(Py::AstNode parent, string slot) { exists(getBodyStmtList(parent, slot)) } + TBlockStmt(Py::StmtList sl) { + sl = any(Py::Scope p).getBody() + or + sl = any(Py::If p).getBody() + or + sl = any(Py::If p).getOrelse() + or + sl = any(Py::While p).getBody() + or + sl = any(Py::While p).getOrelse() + or + sl = any(Py::For p).getBody() + or + sl = any(Py::For p).getOrelse() + or + sl = any(Py::With p).getBody() + or + sl = any(Py::Try p).getBody() + or + sl = any(Py::Try p).getOrelse() + or + sl = any(Py::Try p).getFinalbody() + or + sl = any(Py::Case p).getBody() + or + sl = any(Py::ExceptStmt p).getBody() + or + sl = any(Py::ExceptGroupStmt p).getBody() + } /** An AST node visible to the shared CFG. */ class AstNode extends TAstNode { @@ -135,7 +130,7 @@ module Ast implements AstSig { } /** Gets the body of callable `c`. */ - AstNode callableGetBody(Callable c) { result = TBlockStmt(c.asScope(), "body") } + AstNode callableGetBody(Callable c) { result = TBlockStmt(c.asScope().getBody()) } /** * A parameter of a callable. @@ -195,29 +190,26 @@ module Ast implements AstSig { * sequence of statements. */ class BlockStmt extends Stmt, TBlockStmt { - private Py::AstNode parent; - private string slot; + private Py::StmtList sl; - BlockStmt() { this = TBlockStmt(parent, slot) } + BlockStmt() { this = TBlockStmt(sl) } /** Gets the `n`th (zero-based) statement in this block. */ - Stmt getStmt(int n) { result = TStmt(getBodyStmtList(parent, slot).getItem(n)) } + Stmt getStmt(int n) { result.asStmt() = sl.getItem(n) } /** Gets the last statement in this block. */ - Stmt getLastStmt() { result = TStmt(getBodyStmtList(parent, slot).getLastItem()) } + Stmt getLastStmt() { result.asStmt() = sl.getLastItem() } - override string toString() { result = "block:" + slot } + override string toString() { result = sl.toString() } - // BlockStmt has no native location; approximate with the first + // `Py::StmtList` has no native location; approximate with the first // item's location. - override Py::Location getLocation() { - result = getBodyStmtList(parent, slot).getItem(0).getLocation() - } + override Py::Location getLocation() { result = sl.getItem(0).getLocation() } override Callable getEnclosingCallable() { - result.asScope() = parent.(Py::Scope) + result.asScope() = sl.getParent().(Py::Scope) or - result.asScope() = parent.(Py::Stmt).getScope() + result.asScope() = sl.getParent().(Py::Stmt).getScope() } override AstNode getChild(int index) { result = this.getStmt(index) } @@ -303,10 +295,10 @@ module Ast implements AstSig { Expr getCondition() { result.asExpr() = ifStmt.getTest() } /** Gets the `then` (true) branch of this `if` statement. */ - Stmt getThen() { result = TBlockStmt(ifStmt, "body") } + Stmt getThen() { result = TBlockStmt(ifStmt.getBody()) } /** Gets the `else` (false) branch, if any. */ - Stmt getElse() { result = TBlockStmt(ifStmt, "orelse") } + Stmt getElse() { result = TBlockStmt(ifStmt.getOrelse()) } override AstNode getChild(int index) { index = 0 and result = this.getCondition() @@ -334,10 +326,10 @@ module Ast implements AstSig { /** Gets the boolean condition of this `while` loop. */ Expr getCondition() { result.asExpr() = whileStmt.getTest() } - override Stmt getBody() { result = TBlockStmt(whileStmt, "body") } + override Stmt getBody() { result = TBlockStmt(whileStmt.getBody()) } /** Gets the `else` branch of this `while` loop, if any. */ - Stmt getElse() { result = TBlockStmt(whileStmt, "orelse") } + Stmt getElse() { result = TBlockStmt(whileStmt.getOrelse()) } override AstNode getChild(int index) { index = 0 and result = this.getCondition() @@ -380,10 +372,10 @@ module Ast implements AstSig { /** Gets the collection being iterated. */ Expr getCollection() { result.asExpr() = forStmt.getIter() } - override Stmt getBody() { result = TBlockStmt(forStmt, "body") } + override Stmt getBody() { result = TBlockStmt(forStmt.getBody()) } /** Gets the `else` branch of this `for` loop, if any. */ - Stmt getElse() { result = TBlockStmt(forStmt, "orelse") } + Stmt getElse() { result = TBlockStmt(forStmt.getOrelse()) } override AstNode getChild(int index) { index = 0 and result = this.getCollection() @@ -452,7 +444,7 @@ module Ast implements AstSig { Expr getOptionalVars() { result.asExpr() = withStmt.getOptionalVars() } - Stmt getBody() { result = TBlockStmt(withStmt, "body") } + Stmt getBody() { result = TBlockStmt(withStmt.getBody()) } override AstNode getChild(int index) { index = 0 and result = this.getContextExpr() @@ -497,12 +489,12 @@ module Ast implements AstSig { TryStmt() { tryStmt = this.asStmt() } - Stmt getBody() { result = TBlockStmt(tryStmt, "body") } + Stmt getBody() { result = TBlockStmt(tryStmt.getBody()) } /** Gets the `else` branch of this `try` statement, if any. */ - Stmt getElse() { result = TBlockStmt(tryStmt, "orelse") } + Stmt getElse() { result = TBlockStmt(tryStmt.getOrelse()) } - Stmt getFinally() { result = TBlockStmt(tryStmt, "finally") } + Stmt getFinally() { result = TBlockStmt(tryStmt.getFinalbody()) } CatchClause getCatch(int index) { result = TStmt(tryStmt.getHandler(index)) } @@ -549,9 +541,9 @@ module Ast implements AstSig { /** Gets the body of this exception handler. */ Stmt getBody() { - result = TBlockStmt(handler.(Py::ExceptStmt), "body") + result = TBlockStmt(handler.(Py::ExceptStmt).getBody()) or - result = TBlockStmt(handler.(Py::ExceptGroupStmt), "body") + result = TBlockStmt(handler.(Py::ExceptGroupStmt).getBody()) } override AstNode getChild(int index) { @@ -592,7 +584,7 @@ module Ast implements AstSig { Expr getGuard() { result.asExpr() = caseStmt.getGuard().(Py::Guard).getTest() } - AstNode getBody() { result = TBlockStmt(caseStmt, "body") } + AstNode getBody() { result = TBlockStmt(caseStmt.getBody()) } /** Holds if this case is a wildcard pattern (`case _:`). */ predicate isWildcard() { caseStmt.getPattern() instanceof Py::MatchWildcardPattern }