diff --git a/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll b/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll index 212f1c20053..bcf8f4f6470 100644 --- a/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll +++ b/python/ql/lib/semmle/python/controlflow/internal/AstNodeImpl.qll @@ -55,18 +55,18 @@ module Ast implements AstSig { private newtype TAstNode = TStmt(Py::Stmt s) or - TExpr(Py::Expr e) or + TExpr(Py::Expr e) { not e instanceof Py::BoolExpr } or TScope(Py::Scope sc) or TPattern(Py::Pattern p) or /** - * A synthetic intermediate node in a multi-operand `and`/`or` + * A synthetic node representing an operand pair of an `and`/`or` * expression. For `a and b and c` (operands 0, 1, 2) we model the - * operation as a right-nested tree where the inner pair at index 1 - * represents `b and c` and is the right operand of the outer pair. - * The outermost pair (index 0) is represented by the underlying - * `Py::BoolExpr` itself via `TExpr`. + * operation as a right-nested tree: pair 0 represents the whole + * expression with left=a and right=pair 1; pair 1 represents + * `b and c` with left=b and right=c. Each Python `Py::BoolExpr` + * with `n` operands has `n - 1` such pairs (indices `0 .. n - 2`). */ - TBoolExprPair(Py::BoolExpr be, int index) { index = [1 .. count(be.getAValue()) - 2] } or + TBoolExprPair(Py::BoolExpr be, int index) { index = [0 .. count(be.getAValue()) - 2] } or /** * A synthetic block statement, identifying one body slot of the * `parent` AST node. The `slot` string disambiguates among multiple @@ -88,8 +88,17 @@ module Ast implements AstSig { /** Gets the underlying Python `Stmt`, if this node wraps one. */ Py::Stmt asStmt() { this = TStmt(result) } - /** Gets the underlying Python `Expr`, if this node wraps one. */ - Py::Expr asExpr() { this = TExpr(result) } + /** + * Gets the underlying Python `Expr`, if this node wraps one. Boolean + * expressions are represented by `TBoolExprPair(_, 0)`; this + * predicate also recovers the underlying `Py::BoolExpr` from such a + * representation. + */ + Py::Expr asExpr() { + this = TExpr(result) + or + this = TBoolExprPair(result, 0) + } /** Gets the underlying Python `Scope`, if this node wraps one. */ Py::Scope asScope() { this = TScope(result) } @@ -105,20 +114,6 @@ module Ast implements AstSig { AstNode getChild(int index) { none() } } - /** Implementation of `AstNode` predicates for synthetic `TBoolExprPair` nodes. */ - private class BoolExprPair extends Expr, TBoolExprPair { - private Py::BoolExpr be; - private int index; - - BoolExprPair() { this = TBoolExprPair(be, index) } - - override string toString() { result = be.getOperator() } - - override Py::Location getLocation() { result = be.getValue(index).getLocation() } - - override Callable getEnclosingCallable() { result.asScope() = be.getScope() } - } - /** Gets the immediately enclosing callable that contains `node`. */ Callable getEnclosingCallable(AstNode node) { result = node.getEnclosingCallable() } @@ -174,7 +169,7 @@ module Ast implements AstSig { Expr() { this instanceof TExpr or this instanceof TBoolExprPair } // For `TExpr` instances, delegate to the wrapped Python expression. - // `BoolExprPair` (the only `TBoolExprPair` subclass) provides its own overrides. + // `BinaryExpr` (the only `TBoolExprPair` subclass) provides its own overrides. override string toString() { result = this.asExpr().toString() } override Py::Location getLocation() { result = this.asExpr().getLocation() } @@ -235,7 +230,7 @@ module Ast implements AstSig { ExprStmt() { exprStmt = this.asStmt() } /** Gets the expression in this expression statement. */ - Expr getExpr() { result = TExpr(exprStmt.getValue()) } + Expr getExpr() { result.asExpr() = exprStmt.getValue() } override AstNode getChild(int index) { index = 0 and result = this.getExpr() } } @@ -246,9 +241,9 @@ module Ast implements AstSig { AssignStmt() { assign = this.asStmt() } - Expr getValue() { result = TExpr(assign.getValue()) } + Expr getValue() { result.asExpr() = assign.getValue() } - Expr getTarget(int n) { result = TExpr(assign.getTarget(n)) } + Expr getTarget(int n) { result.asExpr() = assign.getTarget(n) } int getNumberOfTargets() { result = count(assign.getATarget()) } @@ -265,7 +260,7 @@ module Ast implements AstSig { AugAssignStmt() { augAssign = this.asStmt() } - Expr getOperation() { result = TExpr(augAssign.getOperation()) } + Expr getOperation() { result.asExpr() = augAssign.getOperation() } override AstNode getChild(int index) { index = 0 and result = this.getOperation() } } @@ -276,9 +271,9 @@ module Ast implements AstSig { NamedExpr() { assignExpr = this.asExpr() } - Expr getValue() { result = TExpr(assignExpr.getValue()) } + Expr getValue() { result.asExpr() = assignExpr.getValue() } - Expr getTarget() { result = TExpr(assignExpr.getTarget()) } + Expr getTarget() { result.asExpr() = assignExpr.getTarget() } override AstNode getChild(int index) { index = 0 and result = this.getValue() @@ -305,7 +300,7 @@ module Ast implements AstSig { Py::If asIf() { result = ifStmt } /** Gets the condition of this `if` statement. */ - Expr getCondition() { result = TExpr(ifStmt.getTest()) } + Expr getCondition() { result.asExpr() = ifStmt.getTest() } /** Gets the `then` (true) branch of this `if` statement. */ Stmt getThen() { result = TBlockStmt(ifStmt, "body") } @@ -337,7 +332,7 @@ module Ast implements AstSig { WhileStmt() { whileStmt = this.asStmt() } /** Gets the boolean condition of this `while` loop. */ - Expr getCondition() { result = TExpr(whileStmt.getTest()) } + Expr getCondition() { result.asExpr() = whileStmt.getTest() } override Stmt getBody() { result = TBlockStmt(whileStmt, "body") } @@ -380,10 +375,10 @@ module Ast implements AstSig { ForeachStmt() { forStmt = this.asStmt() } /** Gets the loop variable. */ - Expr getVariable() { result = TExpr(forStmt.getTarget()) } + Expr getVariable() { result.asExpr() = forStmt.getTarget() } /** Gets the collection being iterated. */ - Expr getCollection() { result = TExpr(forStmt.getIter()) } + Expr getCollection() { result.asExpr() = forStmt.getIter() } override Stmt getBody() { result = TBlockStmt(forStmt, "body") } @@ -423,7 +418,7 @@ module Ast implements AstSig { ReturnStmt() { ret = this.asStmt() } /** Gets the expression being returned, if any. */ - Expr getExpr() { result = TExpr(ret.getValue()) } + Expr getExpr() { result.asExpr() = ret.getValue() } override AstNode getChild(int index) { index = 0 and result = this.getExpr() } } @@ -435,10 +430,10 @@ module Ast implements AstSig { Throw() { raise = this.asStmt() } /** Gets the expression being raised. */ - Expr getExpr() { result = TExpr(raise.getException()) } + Expr getExpr() { result.asExpr() = raise.getException() } /** Gets the cause of this `raise`, if any. */ - Expr getCause() { result = TExpr(raise.getCause()) } + Expr getCause() { result.asExpr() = raise.getCause() } override AstNode getChild(int index) { index = 0 and result = this.getExpr() @@ -453,9 +448,9 @@ module Ast implements AstSig { WithStmt() { withStmt = this.asStmt() } - Expr getContextExpr() { result = TExpr(withStmt.getContextExpr()) } + Expr getContextExpr() { result.asExpr() = withStmt.getContextExpr() } - Expr getOptionalVars() { result = TExpr(withStmt.getOptionalVars()) } + Expr getOptionalVars() { result.asExpr() = withStmt.getOptionalVars() } Stmt getBody() { result = TBlockStmt(withStmt, "body") } @@ -474,9 +469,9 @@ module Ast implements AstSig { AssertStmt() { assertStmt = this.asStmt() } - Expr getTest() { result = TExpr(assertStmt.getTest()) } + Expr getTest() { result.asExpr() = assertStmt.getTest() } - Expr getMsg() { result = TExpr(assertStmt.getMsg()) } + Expr getMsg() { result.asExpr() = assertStmt.getMsg() } override AstNode getChild(int index) { index = 0 and result = this.getTest() @@ -491,7 +486,7 @@ module Ast implements AstSig { DeleteStmt() { del = this.asStmt() } - Expr getTarget(int n) { result = TExpr(del.getTarget(n)) } + Expr getTarget(int n) { result.asExpr() = del.getTarget(n) } override AstNode getChild(int index) { result = this.getTarget(index) } } @@ -544,10 +539,10 @@ module Ast implements AstSig { CatchClause() { handler = this.asStmt() } /** Gets the type expression of this exception handler. */ - Expr getType() { result = TExpr(handler.getType()) } + Expr getType() { result.asExpr() = handler.getType() } /** Gets the variable name of this exception handler, if any. */ - AstNode getVariable() { result = TExpr(handler.getName()) } + AstNode getVariable() { result.asExpr() = handler.getName() } /** Holds: catch clauses do not have a `Condition` in Python's model. */ Expr getCondition() { none() } @@ -574,7 +569,7 @@ module Ast implements AstSig { Switch() { matchStmt = this.asStmt() } - Expr getExpr() { result = TExpr(matchStmt.getSubject()) } + Expr getExpr() { result.asExpr() = matchStmt.getSubject() } Case getCase(int index) { result = TStmt(matchStmt.getCase(index)) } @@ -595,7 +590,7 @@ module Ast implements AstSig { AstNode getAPattern() { result = TPattern(caseStmt.getPattern()) } - Expr getGuard() { result = TExpr(caseStmt.getGuard().(Py::Guard).getTest()) } + Expr getGuard() { result.asExpr() = caseStmt.getGuard().(Py::Guard).getTest() } AstNode getBody() { result = TBlockStmt(caseStmt, "body") } @@ -623,13 +618,13 @@ module Ast implements AstSig { ConditionalExpr() { ifExp = this.asExpr() } /** Gets the condition of this expression. */ - Expr getCondition() { result = TExpr(ifExp.getTest()) } + Expr getCondition() { result.asExpr() = ifExp.getTest() } /** Gets the true branch of this expression. */ - Expr getThen() { result = TExpr(ifExp.getBody()) } + Expr getThen() { result.asExpr() = ifExp.getBody() } /** Gets the false branch of this expression. */ - Expr getElse() { result = TExpr(ifExp.getOrelse()) } + Expr getElse() { result.asExpr() = ifExp.getOrelse() } override AstNode getChild(int index) { index = 0 and result = this.getCondition() @@ -641,84 +636,51 @@ module Ast implements AstSig { } /** - * A binary expression for the shared CFG. In Python, this covers - * `and`/`or` expressions (both real 2-operand and synthetic pairs). + * A binary expression for the shared CFG. In Python, this covers all + * `and`/`or` expression operand pairs. */ - class BinaryExpr extends Expr { - BinaryExpr() { - exists(Py::BoolExpr be | this = TExpr(be) and count(be.getAValue()) >= 2) - or - this instanceof TBoolExprPair - } + class BinaryExpr extends Expr, TBoolExprPair { + private Py::BoolExpr be; + private int index; + + BinaryExpr() { this = TBoolExprPair(be, index) } + + /** Gets the underlying Python `BoolExpr`. */ + Py::BoolExpr getBoolExpr() { result = be } + + override string toString() { result = be.getOperator() } + + override Py::Location getLocation() { result = be.getValue(index).getLocation() } + + override Callable getEnclosingCallable() { result.asScope() = be.getScope() } /** Gets the left operand of this binary expression. */ - Expr getLeftOperand() { - exists(Py::BoolExpr be | this = TExpr(be) and result = TExpr(be.getValue(0))) - or - exists(Py::BoolExpr be, int i | - this = TBoolExprPair(be, i) and result = TExpr(be.getValue(i)) - ) - } + Expr getLeftOperand() { result.asExpr() = be.getValue(index) } /** Gets the right operand of this binary expression. */ Expr getRightOperand() { - // 2-operand BoolExpr: right operand is value(1). - exists(Py::BoolExpr be | - this = TExpr(be) and - count(be.getAValue()) = 2 and - result = TExpr(be.getValue(1)) - ) + // Last pair: right operand is the final value. + index = count(be.getAValue()) - 2 and result.asExpr() = be.getValue(index + 1) or - // 3+ operand BoolExpr (outermost): right operand is the synthetic - // pair at index 1. - exists(Py::BoolExpr be | - this = TExpr(be) and - count(be.getAValue()) > 2 and - result = TBoolExprPair(be, 1) - ) - or - // Last synthetic pair: right operand is the final value. - exists(Py::BoolExpr be, int i, int n | - this = TBoolExprPair(be, i) and - n = count(be.getAValue()) and - i = n - 2 and - result = TExpr(be.getValue(i + 1)) - ) - or - // Non-last synthetic pair: right operand is the next pair. - exists(Py::BoolExpr be, int i, int n | - this = TBoolExprPair(be, i) and - n = count(be.getAValue()) and - i < n - 2 and - result = TBoolExprPair(be, i + 1) - ) + // Non-last pair: right operand is the next synthetic pair. + index < count(be.getAValue()) - 2 and result = TBoolExprPair(be, index + 1) } - override AstNode getChild(int index) { - index = 0 and result = this.getLeftOperand() + override AstNode getChild(int childIndex) { + childIndex = 0 and result = this.getLeftOperand() or - index = 1 and result = this.getRightOperand() + childIndex = 1 and result = this.getRightOperand() } } /** A short-circuiting logical `and` expression. */ class LogicalAndExpr extends BinaryExpr { - LogicalAndExpr() { - exists(Py::BoolExpr be | - be.getOp() instanceof Py::And and - (this = TExpr(be) or this = TBoolExprPair(be, _)) - ) - } + LogicalAndExpr() { this.getBoolExpr().getOp() instanceof Py::And } } /** A short-circuiting logical `or` expression. */ class LogicalOrExpr extends BinaryExpr { - LogicalOrExpr() { - exists(Py::BoolExpr be | - be.getOp() instanceof Py::Or and - (this = TExpr(be) or this = TBoolExprPair(be, _)) - ) - } + LogicalOrExpr() { this.getBoolExpr().getOp() instanceof Py::Or } } /** A null-coalescing expression. Python has no null-coalescing operator. */ @@ -733,7 +695,7 @@ module Ast implements AstSig { UnaryExpr() { this.asExpr().(Py::UnaryExpr).getOp() instanceof Py::Not } /** Gets the operand of this unary expression. */ - Expr getOperand() { result = TExpr(this.asExpr().(Py::UnaryExpr).getOperand()) } + Expr getOperand() { result.asExpr() = this.asExpr().(Py::UnaryExpr).getOperand() } override AstNode getChild(int index) { index = 0 and result = this.getOperand() } } @@ -784,9 +746,9 @@ module Ast implements AstSig { ArithBinaryExpr() { binExpr = this.asExpr() } - Expr getLeft() { result = TExpr(binExpr.getLeft()) } + Expr getLeft() { result.asExpr() = binExpr.getLeft() } - Expr getRight() { result = TExpr(binExpr.getRight()) } + Expr getRight() { result.asExpr() = binExpr.getRight() } override AstNode getChild(int index) { index = 0 and result = this.getLeft() @@ -801,16 +763,16 @@ module Ast implements AstSig { CallExpr() { call = this.asExpr() } - Expr getFunc() { result = TExpr(call.getFunc()) } + Expr getFunc() { result.asExpr() = call.getFunc() } - Expr getPositionalArg(int n) { result = TExpr(call.getPositionalArg(n)) } + Expr getPositionalArg(int n) { result.asExpr() = call.getPositionalArg(n) } int getNumberOfPositionalArgs() { result = count(call.getAPositionalArg()) } Expr getKeywordValue(int n) { - result = TExpr(call.getNamedArg(n).(Py::Keyword).getValue()) + result.asExpr() = call.getNamedArg(n).(Py::Keyword).getValue() or - result = TExpr(call.getNamedArg(n).(Py::DictUnpacking).getValue()) + result.asExpr() = call.getNamedArg(n).(Py::DictUnpacking).getValue() } int getNumberOfNamedArgs() { result = count(call.getANamedArg()) } @@ -831,9 +793,9 @@ module Ast implements AstSig { SubscriptExpr() { sub = this.asExpr() } - Expr getObject() { result = TExpr(sub.getObject()) } + Expr getObject() { result.asExpr() = sub.getObject() } - Expr getIndex() { result = TExpr(sub.getIndex()) } + Expr getIndex() { result.asExpr() = sub.getIndex() } override AstNode getChild(int index) { index = 0 and result = this.getObject() @@ -848,7 +810,7 @@ module Ast implements AstSig { AttributeExpr() { attr = this.asExpr() } - Expr getObject() { result = TExpr(attr.getObject()) } + Expr getObject() { result.asExpr() = attr.getObject() } override AstNode getChild(int index) { index = 0 and result = this.getObject() } } @@ -859,7 +821,7 @@ module Ast implements AstSig { TupleExpr() { tuple = this.asExpr() } - Expr getElt(int n) { result = TExpr(tuple.getElt(n)) } + Expr getElt(int n) { result.asExpr() = tuple.getElt(n) } override AstNode getChild(int index) { result = this.getElt(index) } } @@ -870,7 +832,7 @@ module Ast implements AstSig { ListExpr() { list = this.asExpr() } - Expr getElt(int n) { result = TExpr(list.getElt(n)) } + Expr getElt(int n) { result.asExpr() = list.getElt(n) } override AstNode getChild(int index) { result = this.getElt(index) } } @@ -881,7 +843,7 @@ module Ast implements AstSig { SetExpr() { set = this.asExpr() } - Expr getElt(int n) { result = TExpr(set.getElt(n)) } + Expr getElt(int n) { result.asExpr() = set.getElt(n) } override AstNode getChild(int index) { result = this.getElt(index) } } @@ -896,9 +858,9 @@ module Ast implements AstSig { * Gets the key of the `n`th item (at child index `2*n`); the value is * at child index `2*n + 1`. */ - Expr getKey(int n) { result = TExpr(dict.getItem(n).(Py::KeyValuePair).getKey()) } + Expr getKey(int n) { result.asExpr() = dict.getItem(n).(Py::KeyValuePair).getKey() } - Expr getValue(int n) { result = TExpr(dict.getItem(n).(Py::KeyValuePair).getValue()) } + Expr getValue(int n) { result.asExpr() = dict.getItem(n).(Py::KeyValuePair).getValue() } int getNumberOfItems() { result = count(dict.getAnItem()) } @@ -917,7 +879,7 @@ module Ast implements AstSig { ArithUnaryExpr() { unaryExpr = this.asExpr() and not unaryExpr.getOp() instanceof Py::Not } - Expr getOperand() { result = TExpr(unaryExpr.getOperand()) } + Expr getOperand() { result.asExpr() = unaryExpr.getOperand() } override AstNode getChild(int index) { index = 0 and result = this.getOperand() } } @@ -940,7 +902,7 @@ module Ast implements AstSig { iterable = this.asExpr().(Py::GeneratorExp).getIterable() } - Expr getIterable() { result = TExpr(iterable) } + Expr getIterable() { result.asExpr() = iterable } override AstNode getChild(int index) { index = 0 and result = this.getIterable() } } @@ -951,9 +913,9 @@ module Ast implements AstSig { CompareExpr() { cmp = this.asExpr() } - Expr getLeft() { result = TExpr(cmp.getLeft()) } + Expr getLeft() { result.asExpr() = cmp.getLeft() } - Expr getComparator(int n) { result = TExpr(cmp.getComparator(n)) } + Expr getComparator(int n) { result.asExpr() = cmp.getComparator(n) } override AstNode getChild(int index) { index = 0 and result = this.getLeft() @@ -968,11 +930,11 @@ module Ast implements AstSig { SliceExpr() { slice = this.asExpr() } - Expr getStart() { result = TExpr(slice.getStart()) } + Expr getStart() { result.asExpr() = slice.getStart() } - Expr getStop() { result = TExpr(slice.getStop()) } + Expr getStop() { result.asExpr() = slice.getStop() } - Expr getStep() { result = TExpr(slice.getStep()) } + Expr getStep() { result.asExpr() = slice.getStep() } override AstNode getChild(int index) { index = 0 and result = this.getStart() @@ -989,7 +951,7 @@ module Ast implements AstSig { StarredExpr() { starred = this.asExpr() } - Expr getValue() { result = TExpr(starred.getValue()) } + Expr getValue() { result.asExpr() = starred.getValue() } override AstNode getChild(int index) { index = 0 and result = this.getValue() } } @@ -1000,7 +962,7 @@ module Ast implements AstSig { FstringExpr() { fstring = this.asExpr() } - Expr getValue(int n) { result = TExpr(fstring.getValue(n)) } + Expr getValue(int n) { result.asExpr() = fstring.getValue(n) } override AstNode getChild(int index) { result = this.getValue(index) } } @@ -1011,9 +973,9 @@ module Ast implements AstSig { FormattedValueExpr() { fv = this.asExpr() } - Expr getValue() { result = TExpr(fv.getValue()) } + Expr getValue() { result.asExpr() = fv.getValue() } - Expr getFormatSpec() { result = TExpr(fv.getFormatSpec()) } + Expr getFormatSpec() { result.asExpr() = fv.getFormatSpec() } override AstNode getChild(int index) { index = 0 and result = this.getValue() @@ -1028,7 +990,7 @@ module Ast implements AstSig { YieldExpr() { yield = this.asExpr() } - Expr getValue() { result = TExpr(yield.getValue()) } + Expr getValue() { result.asExpr() = yield.getValue() } override AstNode getChild(int index) { index = 0 and result = this.getValue() } } @@ -1039,7 +1001,7 @@ module Ast implements AstSig { YieldFromExpr() { yieldFrom = this.asExpr() } - Expr getValue() { result = TExpr(yieldFrom.getValue()) } + Expr getValue() { result.asExpr() = yieldFrom.getValue() } override AstNode getChild(int index) { index = 0 and result = this.getValue() } } @@ -1050,7 +1012,7 @@ module Ast implements AstSig { AwaitExpr() { await = this.asExpr() } - Expr getValue() { result = TExpr(await.getValue()) } + Expr getValue() { result.asExpr() = await.getValue() } override AstNode getChild(int index) { index = 0 and result = this.getValue() } } @@ -1061,7 +1023,7 @@ module Ast implements AstSig { ClassDefExpr() { classExpr = this.asExpr() } - Expr getBase(int n) { result = TExpr(classExpr.getBase(n)) } + Expr getBase(int n) { result.asExpr() = classExpr.getBase(n) } override AstNode getChild(int index) { result = this.getBase(index) } } @@ -1079,14 +1041,14 @@ module Ast implements AstSig { * renumber here to obtain contiguous indices. */ Expr getDefault(int n) { - result = - TExpr(rank[n + 1](Py::Expr d, int i | d = funcExpr.getArgs().getDefault(i) | d order by i)) + result.asExpr() = + rank[n + 1](Py::Expr d, int i | d = funcExpr.getArgs().getDefault(i) | d order by i) } /** Gets the `n`th default for a keyword-only argument, in evaluation order. */ Expr getKwDefault(int n) { - result = - TExpr(rank[n + 1](Py::Expr d, int i | d = funcExpr.getArgs().getKwDefault(i) | d order by i)) + result.asExpr() = + rank[n + 1](Py::Expr d, int i | d = funcExpr.getArgs().getKwDefault(i) | d order by i) } int getNumberOfDefaults() { result = count(funcExpr.getArgs().getADefault()) } @@ -1106,14 +1068,14 @@ module Ast implements AstSig { /** Gets the `n`th default for a positional argument, in evaluation order. */ Expr getDefault(int n) { - result = - TExpr(rank[n + 1](Py::Expr d, int i | d = lambda.getArgs().getDefault(i) | d order by i)) + result.asExpr() = + rank[n + 1](Py::Expr d, int i | d = lambda.getArgs().getDefault(i) | d order by i) } /** Gets the `n`th default for a keyword-only argument, in evaluation order. */ Expr getKwDefault(int n) { - result = - TExpr(rank[n + 1](Py::Expr d, int i | d = lambda.getArgs().getKwDefault(i) | d order by i)) + result.asExpr() = + rank[n + 1](Py::Expr d, int i | d = lambda.getArgs().getKwDefault(i) | d order by i) } int getNumberOfDefaults() { result = count(lambda.getArgs().getADefault()) }