diff --git a/ql/src/codeql_ql/ast/Ast.qll b/ql/src/codeql_ql/ast/Ast.qll index 0cc407e4d68..ea5b4c00753 100644 --- a/ql/src/codeql_ql/ast/Ast.qll +++ b/ql/src/codeql_ql/ast/Ast.qll @@ -4,6 +4,14 @@ private import codeql_ql.ast.internal.Module private import codeql_ql.ast.internal.Predicate private import codeql_ql.ast.internal.Type +bindingset[name, i] +private string indexedMember(string name, int i) { result = name + "(" + i.toString() + ")" } + +bindingset[name, index] +private string stringIndexedMember(string name, string index) { + result = name + "(\"" + index + "\")" +} + /** An AST node of a QL program */ class AstNode extends TAstNode { string toString() { result = getAPrimaryQlClass() } @@ -20,6 +28,13 @@ class AstNode extends TAstNode { not result = this } + /** + * Gets a child of this node, which can also be retrieved using a predicate + * named `pred`. + */ + cached + AstNode getAChild(string pred) { none() } + string getAPrimaryQlClass() { result = "???" } } @@ -30,6 +45,10 @@ class TopLevel extends TTopLevel, AstNode { ModuleMember getAMember() { toGenerated(result) = file.getChild(_).getChild(_) } + override ModuleMember getAChild(string pred) { + pred = "getAMember" and result = this.getAMember() + } + override string getAPrimaryQlClass() { result = "TopLevel" } } @@ -51,6 +70,18 @@ class Select extends TSelect, AstNode { toGenerated(result) = sel.getChild(_).(Generated::OrderBys).getChild(i).getChild(0) } + override AstNode getAChild(string pred) { + pred = "getWhere" and result = this.getWhere() + or + exists(int i | + pred = indexedMember("getVarDecl", i) and result = this.getVarDecl(i) + or + pred = indexedMember("getAsExpr", i) and result = this.getAsExpr(i) + or + pred = indexedMember("getOrderBy", i) and result = this.getOrderBy(i) + ) + } + override string getAPrimaryQlClass() { result = "Select" } // TODO: Getters for VarDecls, Where-clause, selects. } @@ -73,7 +104,15 @@ class Predicate extends TPredicate, AstNode { * Gets the `i`th parameter of the predicate. */ VarDecl getParameter(int i) { none() } + // TODO: ReturnType. + override AstNode getAChild(string pred) { + pred = "getBody" and result = this.getBody() + or + exists(int i | pred = indexedMember("getParameter", i) and result = this.getParameter(i)) + } + + override string getAPrimaryQlClass() { result = "Predicate" } } class PredicateExpr extends TPredicateExpr, AstNode { @@ -111,6 +150,8 @@ class PredicateExpr extends TPredicateExpr, AstNode { this in [result.(ClasslessPredicate).getAlias(), result.(HigherOrderFormula).getInput(_)] } + override AstNode getAChild(string pred) { pred = "getQualifier" and result = this.getQualifier() } + override string getAPrimaryQlClass() { result = "PredicateExpr" } } @@ -144,6 +185,14 @@ class ClasslessPredicate extends TClasslessPredicate, Predicate, ModuleDeclarati toGenerated(result) = rank[i](Generated::VarDecl decl, int index | decl = pred.getChild(index) | decl order by index) } + + override AstNode getAChild(string pred_name) { + pred_name = "getAlias" and result = this.getAlias() + or + pred_name = "getBody" and result = this.getBody() + or + exists(int i | pred_name = indexedMember("getParameter", i) and result = this.getParameter(i)) + } } /** @@ -166,6 +215,12 @@ class ClassPredicate extends TClassPredicate, Predicate { toGenerated(result) = rank[i](Generated::VarDecl decl, int index | decl = pred.getChild(index) | decl order by index) } + + override AstNode getAChild(string pred_name) { + pred_name = "getBody" and result = this.getBody() + or + exists(int i | pred_name = indexedMember("getParameter", i) and result = this.getParameter(i)) + } } /** @@ -183,6 +238,8 @@ class CharPred extends TCharPred, Predicate { override string getName() { result = getParent().getName() } override Class getParent() { result.getCharPred() = this } + + override AstNode getAChild(string pred_name) { pred_name = "getBody" and result = this.getBody() } } /** @@ -211,6 +268,8 @@ class VarDecl extends TVarDecl, AstNode { } TypeExpr getType() { toGenerated(result) = var.getChild(0) } + + override AstNode getAChild(string pred) { pred = "getType" and result = this.getType() } } /** @@ -298,6 +357,12 @@ class Module extends TModule, ModuleDeclaration { ModuleExpr getAlias() { toGenerated(result) = mod.getAFieldOrChild().(Generated::ModuleAliasBody).getChild() } + + override AstNode getAChild(string pred) { + pred = "getAlias" and result = this.getAlias() + or + pred = "getAMember" and result = this.getAMember() + } } /** @@ -384,6 +449,23 @@ class Class extends TClass, TypeDeclaration, ModuleDeclaration { TypeExpr getUnionMember() { toGenerated(result) = cls.getChild(_).(Generated::TypeUnionBody).getChild(_) } + + override AstNode getAChild(string pred) { + pred = "getAliasType" and result = this.getAliasType() + or + pred = "getUnionMember" and result = this.getUnionMember() + or + pred = "getAField" and result = this.getAField() + or + pred = "getCharPred" and result = this.getCharPred() + or + pred = "getASuperType" and result = this.getASuperType() + or + exists(string name | + pred = stringIndexedMember("getClassPredicate", name) and + result = this.getClassPredicate(name) + ) + } } /** @@ -409,6 +491,8 @@ class NewType extends TNewType, TypeDeclaration, ModuleDeclaration { * Gets a branch in this `newtype`. */ NewTypeBranch getABranch() { toGenerated(result) = type.getChild().getChild(_) } + + override AstNode getAChild(string pred) { pred = "getABranch" and result = this.getABranch() } } /** @@ -433,6 +517,12 @@ class NewTypeBranch extends TNewTypeBranch, TypeDeclaration { Formula getBody() { toGenerated(result) = branch.getChild(_).(Generated::Body).getChild() } override NewType getParent() { result.getABranch() = this } + + override AstNode getAChild(string pred) { + pred = "getBody" and result = this.getBody() + or + exists(int i | pred = indexedMember("getField", i) and result = this.getField(i)) + } } class Call extends TCall, AstNode { @@ -466,6 +556,12 @@ class PredicateCall extends TPredicateCall, Call { string getPredicateName() { result = expr.getChild(0).(Generated::AritylessPredicateExpr).getName().getValue() } + + override AstNode getAChild(string pred) { + exists(int i | pred = indexedMember("getArgument", i) and result = this.getArgument(i)) + or + pred = "getQualifier" and result = this.getQualifier() + } } class MemberCall extends TMemberCall, Call { @@ -498,6 +594,14 @@ class MemberCall extends TMemberCall, Call { } Expr getBase() { toGenerated(result) = expr.getChild(0) } + + override AstNode getAChild(string pred) { + pred = "getBase" and result = this.getBase() + or + pred = "getSuperType" and result = this.getSuperType() + or + exists(int i | pred = indexedMember("getArgument", i) and result = this.getArgument(i)) + } } class NoneCall extends TNoneCall, Call, Formula { @@ -528,6 +632,12 @@ class InlineCast extends TInlineCast, Expr { } Expr getBase() { toGenerated(result) = expr.getChild(0) } + + override AstNode getAChild(string pred) { + pred = "getType" and result = this.getType() + or + pred = "getBase" and result = this.getBase() + } } /** An entity that resolves to a module. */ @@ -608,6 +718,8 @@ class Conjunction extends TConjunction, AstNode, Formula { /** Gets an operand to this formula. */ Formula getAnOperand() { toGenerated(result) in [conj.getLeft(), conj.getRight()] } + + override AstNode getAChild(string pred) { pred = "getAnOperand" and result = this.getAnOperand() } } /** An `or` formula, with 2 or more operands. */ @@ -620,6 +732,8 @@ class Disjunction extends TDisjunction, AstNode { /** Gets an operand to this formula. */ Formula getAnOperand() { toGenerated(result) in [disj.getLeft(), disj.getRight()] } + + override AstNode getAChild(string pred) { pred = "getAnOperand" and result = this.getAnOperand() } } /** @@ -700,6 +814,12 @@ class ComparisonFormula extends TComparisonFormula, Formula { ComparisonSymbol getSymbol() { result = this.getOperator().getSymbol() } override string getAPrimaryQlClass() { result = "ComparisonFormula" } + + override AstNode getAChild(string pred) { + pred = "getLeftOperand" and result = this.getLeftOperand() + or + pred = "getRightOperand" and result = this.getRightOperand() + } } /** A quantifier formula, such as `exists` or `forall`. */ @@ -742,6 +862,16 @@ class Quantifier extends TQuantifier, Formula { predicate hasExpr() { exists(getExpr()) } override string getAPrimaryQlClass() { result = "Quantifier" } + + override AstNode getAChild(string pred) { + exists(int i | pred = indexedMember("getArgument", i) and result = this.getArgument(i)) + or + pred = "getRange" and result = this.getRange() + or + pred = "getFormula" and result = this.getFormula() + or + pred = "getExpr" and result = this.getExpr() + } } /** An `exists` quantifier. */ @@ -781,6 +911,14 @@ class IfFormula extends TIfFormula, Formula { Formula getElsePart() { toGenerated(result) = ifterm.getSecond() } override string getAPrimaryQlClass() { result = "IfFormula" } + + override AstNode getAChild(string pred) { + pred = "getCondition" and result = this.getCondition() + or + pred = "getThenPart" and result = this.getThenPart() + or + pred = "getElsePart" and result = this.getElsePart() + } } /** @@ -798,6 +936,12 @@ class Implication extends TImplication, Formula { Formula getRightOperand() { toGenerated(result) = imp.getRight() } override string getAPrimaryQlClass() { result = "Implication" } + + override AstNode getAChild(string pred) { + pred = "getLeftOperand" and result = this.getLeftOperand() + or + pred = "getRightOperand" and result = this.getRightOperand() + } } /** @@ -817,6 +961,12 @@ class InstanceOf extends TInstanceOf, Formula { /** Gets the type being checked. */ //QLTypeExpr getType() { result = getTypeRef().getType() } override string getAPrimaryQlClass() { result = "InstanceOf" } + + override AstNode getAChild(string pred) { + pred = "getExpr" and result = this.getExpr() + or + pred = "getType" and result = this.getType() + } } class InFormula extends TInFormula, Formula { @@ -829,6 +979,12 @@ class InFormula extends TInFormula, Formula { Expr getRange() { toGenerated(result) = inexpr.getRight() } override string getAPrimaryQlClass() { result = "InFormula" } + + override AstNode getAChild(string pred) { + pred = "getExpr" and result = this.getExpr() + or + pred = "getRange" and result = this.getRange() + } } class HigherOrderFormula extends THigherOrderFormula, Formula { @@ -845,6 +1001,14 @@ class HigherOrderFormula extends THigherOrderFormula, Formula { string getName() { result = hop.getName().getValue() } override string getAPrimaryQlClass() { result = "HigherOrderFormula" } + + override AstNode getAChild(string pred) { + exists(int i | + pred = indexedMember("getInput", i) and result = this.getInput(i) + or + pred = indexedMember("getArgument", i) and result = this.getArgument(i) + ) + } } class ExprAggregate extends TExprAggregate, Expr { @@ -878,6 +1042,14 @@ class ExprAggregate extends TExprAggregate, Expr { } override string getAPrimaryQlClass() { result = "ExprAggregate[" + kind + "]" } + + override AstNode getAChild(string pred) { + exists(int i | + pred = indexedMember("getAsExpr", i) and result = this.getAsExpr(i) + or + pred = indexedMember("getOrderBy", i) and result = this.getOrderBy(i) + ) + } } /** An aggregate expression, such as `count` or `sum`. */ @@ -923,6 +1095,16 @@ class Aggregate extends TAggregate, Expr { } override string getAPrimaryQlClass() { result = "Aggregate[" + kind + "]" } + + override AstNode getAChild(string pred) { + exists(int i | + pred = indexedMember("getArgument", i) and result = this.getArgument(i) + or + pred = indexedMember("getAsExpr", i) and result = this.getAsExpr(i) + or + pred = indexedMember("getOrderBy", i) and result = this.getOrderBy(i) + ) + } } /** @@ -937,6 +1119,8 @@ class Rank extends Aggregate { * The `i` in `rank[i]( | | )`. */ Expr getRankExpr() { toGenerated(result) = agg.getChild(1) } + + override AstNode getAChild(string pred) { pred = "getRankExpr" and result = this.getRankExpr() } } /** @@ -970,6 +1154,8 @@ class AsExpr extends TAsExpr, AstNode { or result.(Select).getAsExpr(_) = this } + + override AstNode getAChild(string pred) { pred = "getInnerExpr" and result = this.getInnerExpr() } } class Identifier extends TIdentifier, Expr { @@ -992,6 +1178,8 @@ class Negation extends TNegation, Formula { Formula getFormula() { toGenerated(result) = neg.getChild() } override string getAPrimaryQlClass() { result = "Negation" } + + override AstNode getAChild(string pred) { pred = "getFormula" and result = this.getFormula() } } /** An expression, such as `x+4`. */ @@ -1021,6 +1209,10 @@ class ExprAnnotation extends TExprAnnotation, Expr { Expr getExpression() { toGenerated(result) = expr_anno.getChild() } override string getAPrimaryQlClass() { result = "ExprAnnotation" } + + override AstNode getAChild(string pred) { + pred = "getExpression" and result = this.getExpression() + } } /** A function symbol, such as `+` or `*`. */ @@ -1053,6 +1245,12 @@ class AddSubExpr extends TAddSubExpr, BinOpExpr { /** Gets the operator of the binary expression. */ FunctionSymbol getOperator() { result = operator } + + override AstNode getAChild(string pred) { + pred = "getLeftOperand" and result = this.getLeftOperand() + or + pred = "getRightOperand" and result = this.getRightOperand() + } } /** @@ -1093,6 +1291,12 @@ class MulDivModExpr extends TMulDivModExpr, BinOpExpr { /** Gets the operator of the binary expression. */ FunctionSymbol getOperator() { result = operator } + + override AstNode getAChild(string pred) { + pred = "getLeftOperand" and result = this.getLeftOperand() + or + pred = "getRightOperand" and result = this.getRightOperand() + } } /** @@ -1141,6 +1345,12 @@ class Range extends TRange, Expr { Expr getHighEndpoint() { toGenerated(result) = range.getUpper() } override string getAPrimaryQlClass() { result = "Range" } + + override AstNode getAChild(string pred) { + pred = "getLowEndpoint" and result = this.getLowEndpoint() + or + pred = "getHighEndpoint" and result = this.getHighEndpoint() + } } /** @@ -1157,6 +1367,10 @@ class Set extends TSet, Expr { Expr getElement(int i) { toGenerated(result) = set.getChild(i) } override string getAPrimaryQlClass() { result = "Set" } + + override AstNode getAChild(string pred) { + exists(int i | pred = indexedMember("getElement", i) and result = getElement(i)) + } } /** A unary operation expression, such as `-(x*y)` */ @@ -1170,6 +1384,10 @@ class UnaryExpr extends TUnaryExpr, Expr { /** Gets the operator of the unary expression as a string. */ FunctionSymbol getOperator() { result = unaryexpr.getChild(0).toString() } + + override string getAPrimaryQlClass() { result = "UnaryExpr" } + + override AstNode getAChild(string pred) { pred = "getOperand" and result = this.getOperand() } } /** A "don't care" expression, denoted by `_`. */ @@ -1221,10 +1439,10 @@ class ModuleExpr extends TModuleExpr, ModuleRef { override AstNode getParent() { result = super.getParent() or - result.(PredicateCall).getQualifier() = this - or - result.(PredicateExpr).getQualifier() = this - or + result.(PredicateCall).getQualifier() = this or + result.(PredicateExpr).getQualifier() = this or result.(Module).getAlias() = this } + + override AstNode getAChild(string pred) { pred = "getQualifier" and result = this.getQualifier() } }