diff --git a/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll b/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll index 15ec5dbfa73..d6ac7ceb110 100644 --- a/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll +++ b/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll @@ -29,7 +29,8 @@ private module Ast { * Only created for inner pairs (index >= 1); the outermost pair (index 0) * is represented by the original `BoolExpr` node via `TExprNode`. */ - TBoolExprPair(Py::BoolExpr be, int index) { index = [1 .. count(be.getAValue()) - 2] } + TBoolExprPair(Py::BoolExpr be, int index) { index = [1 .. count(be.getAValue()) - 2] } or + TPatternNode(Py::Pattern p) /** * An AST node for the shared CFG. Each branch of the newtype gets a @@ -122,6 +123,21 @@ private module Ast { } } + class PatternNode extends Node, TPatternNode { + private Py::Pattern pattern; + + PatternNode() { this = TPatternNode(pattern) } + + /** Gets the underlying Python pattern. */ + Py::Pattern asPattern() { result = pattern } + + override string toString() { result = pattern.toString() } + + override Py::Location getLocation() { result = pattern.getLocation() } + + override ScopeNode getEnclosingScope() { result.asScope() = pattern.getScope() } + } + /** An `if` statement. */ class IfNode extends StmtNode { private Py::If ifStmt; @@ -289,6 +305,8 @@ private module Ast { CaseNode() { caseStmt = this.asStmt() } + PatternNode getPattern() { result.asPattern() = caseStmt.getPattern() } + ExprNode getGuard() { result.asExpr() = caseStmt.getGuard().(Py::Guard).getTest() } StmtListNode getBody() { result.asStmtList() = caseStmt.getBody() } @@ -778,9 +796,11 @@ module AstSigImpl implements AstSig { or // Case: guard (0), body (1) exists(Ast::CaseNode c | c = n | - index = 0 and result = c.getGuard() + index = 0 and result = c.getPattern() or - index = 1 and result = c.getBody() + index = 1 and result = c.getGuard() + or + index = 2 and result = c.getBody() ) or // CatchClause (except handler): type (0), name (1), body (2) @@ -1101,7 +1121,7 @@ module AstSigImpl implements AstSig { class Case extends Stmt { Case() { this instanceof Ast::CaseNode } - AstNode getAPattern() { none() } + AstNode getAPattern() { result = this.(Ast::CaseNode).getPattern() } Expr getGuard() { result = this.(Ast::CaseNode).getGuard() }