diff --git a/ql/src/codeql_ql/ast/Ast.qll b/ql/src/codeql_ql/ast/Ast.qll index 21e443fa73b..ea0b33e6b5d 100644 --- a/ql/src/codeql_ql/ast/Ast.qll +++ b/ql/src/codeql_ql/ast/Ast.qll @@ -1392,11 +1392,17 @@ class HigherOrderFormula extends THigherOrderFormula, Formula { } } +class Aggregate extends TAggregate, Expr { + string getKind() { none() } + + Generated::Aggregate getAggregate() { none() } +} + /** * An aggregate containing an expression. * E.g. `min(getAPredicate().getArity())`. */ -class ExprAggregate extends TExprAggregate, Expr { +class ExprAggregate extends TExprAggregate, Aggregate { Generated::Aggregate agg; Generated::ExprAggregateBody body; string kind; @@ -1411,7 +1417,9 @@ class ExprAggregate extends TExprAggregate, Expr { * Gets the kind of aggregate. * E.g. for `min(foo())` the result is "min". */ - string getKind() { result = kind } + override string getKind() { result = kind } + + override Generated::Aggregate getAggregate() { result = agg } /** * Gets the ith "as" expression of this aggregate, if any. @@ -1457,13 +1465,13 @@ class ExprAggregate extends TExprAggregate, Expr { } /** An aggregate expression, such as `count` or `sum`. */ -class Aggregate extends TAggregate, Expr { +class FullAggregate extends TFullAggregate, Aggregate { Generated::Aggregate agg; string kind; Generated::FullAggregateBody body; - Aggregate() { - this = TAggregate(agg) and + FullAggregate() { + this = TFullAggregate(agg) and kind = agg.getChild(0).(Generated::AggId).getValue() and body = agg.getChild(_) } @@ -1472,7 +1480,9 @@ class Aggregate extends TAggregate, Expr { * Gets the kind of aggregate. * E.g. for `min(int i | foo(i))` the result is "foo". */ - string getKind() { result = kind } + override string getKind() { result = kind } + + override Generated::Aggregate getAggregate() { result = agg } /** Gets the ith declared argument of this quantifier. */ VarDecl getArgument(int i) { toGenerated(result) = body.getChild(i) } @@ -1502,7 +1512,7 @@ class Aggregate extends TAggregate, Expr { result = body.getOrderBys().getChild(i).getChild(1).(Generated::Direction).getValue() } - override string getAPrimaryQlClass() { result = "Aggregate[" + kind + "]" } + override string getAPrimaryQlClass() { kind != "rank" and result = "FullAggregate[" + kind + "]" } override Type getType() { exists(PrimitiveType prim | prim = result | @@ -1540,14 +1550,14 @@ class Aggregate extends TAggregate, Expr { * A "rank" expression, such as `rank[4](int i | i = [5 .. 15] | i)`. */ class Rank extends Aggregate { - Rank() { kind = "rank" } + Rank() { this.getKind() = "rank" } override string getAPrimaryQlClass() { result = "Rank" } /** * The `i` in `rank[i]( | | )`. */ - Expr getRankExpr() { toGenerated(result) = agg.getChild(1) } + Expr getRankExpr() { toGenerated(result) = this.getAggregate().getChild(1) } override AstNode getAChild(string pred) { result = super.getAChild(pred) diff --git a/ql/src/codeql_ql/ast/internal/AstNodes.qll b/ql/src/codeql_ql/ast/internal/AstNodes.qll index fa1ee013e50..509c11d1b79 100644 --- a/ql/src/codeql_ql/ast/internal/AstNodes.qll +++ b/ql/src/codeql_ql/ast/internal/AstNodes.qll @@ -21,7 +21,9 @@ newtype TAstNode = TComparisonFormula(Generated::CompTerm comp) or TComparisonOp(Generated::Compop op) or TQuantifier(Generated::Quantified quant) or - TAggregate(Generated::Aggregate agg) { agg.getChild(_) instanceof Generated::FullAggregateBody } or + TFullAggregate(Generated::Aggregate agg) { + agg.getChild(_) instanceof Generated::FullAggregateBody + } or TExprAggregate(Generated::Aggregate agg) { agg.getChild(_) instanceof Generated::ExprAggregateBody } or @@ -63,9 +65,11 @@ class TFormula = class TBinOpExpr = TAddSubExpr or TMulDivModExpr; +class TAggregate = TFullAggregate or TExprAggregate; + class TExpr = - TBinOpExpr or TLiteral or TAggregate or TExprAggregate or TIdentifier or TInlineCast or TCall or - TUnaryExpr or TExprAnnotation or TDontCare or TRange or TSet or TAsExpr or TSuper; + TBinOpExpr or TLiteral or TAggregate or TIdentifier or TInlineCast or + TCall or TUnaryExpr or TExprAnnotation or TDontCare or TRange or TSet or TAsExpr or TSuper; class TCall = TPredicateCall or TMemberCall or TNoneCall or TAnyCall; @@ -77,7 +81,7 @@ private Generated::AstNode toGeneratedFormula(AST::AstNode n) { n = TComparisonFormula(result) or n = TComparisonOp(result) or n = TQuantifier(result) or - n = TAggregate(result) or + n = TFullAggregate(result) or n = TIdentifier(result) or n = TNegation(result) or n = TIfFormula(result) or @@ -94,7 +98,7 @@ private Generated::AstNode toGeneratedExpr(AST::AstNode n) { n = TSet(result) or n = TExprAnnotation(result) or n = TLiteral(result) or - n = TAggregate(result) or + n = TFullAggregate(result) or n = TExprAggregate(result) or n = TIdentifier(result) or n = TUnaryExpr(result) or diff --git a/ql/src/codeql_ql/ast/internal/Variable.qll b/ql/src/codeql_ql/ast/internal/Variable.qll index b5791f457ae..af3d7a9eb4c 100644 --- a/ql/src/codeql_ql/ast/internal/Variable.qll +++ b/ql/src/codeql_ql/ast/internal/Variable.qll @@ -2,7 +2,7 @@ import ql import codeql_ql.ast.internal.AstNodes private class TScope = - TClass or TAggregate or TExprAggregate or TQuantifier or TSelect or TPredicate or TNewTypeBranch; + TClass or TFullAggregate or TExprAggregate or TQuantifier or TSelect or TPredicate or TNewTypeBranch; /** A variable scope. */ class VariableScope extends TScope, AstNode { @@ -24,7 +24,7 @@ class VariableScope extends TScope, AstNode { or decl = this.(Select).getExpr(_).(AsExpr) or - decl = this.(Aggregate).getExpr(_).(AsExpr) + decl = this.(FullAggregate).getExpr(_).(AsExpr) or decl = this.(ExprAggregate).getExpr(_).(AsExpr) or