From 8d245e6bc2efc26d8b1f4e769a3a9ef6a689a65c Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Fri, 28 May 2021 15:00:59 +0200 Subject: [PATCH] Resolve `newtype` constructor calls --- ql/src/codeql_ql/ast/Ast.qll | 2 + ql/src/codeql_ql/ast/internal/Module.qll | 6 +- ql/src/codeql_ql/ast/internal/Predicate.qll | 55 +++++++++++++---- ql/src/codeql_ql/ast/internal/Type.qll | 19 +++++- ql/src/codeql_ql/ast/internal/Variable.qll | 65 +++++++++++---------- ql/src/codeql_ql/printAstAst.qll | 4 +- 6 files changed, 103 insertions(+), 48 deletions(-) diff --git a/ql/src/codeql_ql/ast/Ast.qll b/ql/src/codeql_ql/ast/Ast.qll index 00c12352921..0b71428c6f7 100644 --- a/ql/src/codeql_ql/ast/Ast.qll +++ b/ql/src/codeql_ql/ast/Ast.qll @@ -596,6 +596,8 @@ class NewTypeBranch extends TNewTypeBranch, TypeDeclaration { /** Gets the body of this branch. */ Formula getBody() { toGenerated(result) = branch.getChild(_).(Generated::Body).getChild() } + NewType getNewType() { result.getABranch() = this } + override AstNode getAChild(string pred) { result = super.getAChild(pred) or diff --git a/ql/src/codeql_ql/ast/internal/Module.qll b/ql/src/codeql_ql/ast/internal/Module.qll index 2083cb514a9..12ed38267d1 100644 --- a/ql/src/codeql_ql/ast/internal/Module.qll +++ b/ql/src/codeql_ql/ast/internal/Module.qll @@ -166,7 +166,11 @@ predicate resolveModuleExpr(ModuleExpr me, FileOrModule m) { ) } -boolean getPublicBool(ModuleMember m) { if m.isPrivate() then result = false else result = true } +boolean getPublicBool(AstNode n) { + if n.(ModuleMember).isPrivate() or n.(NewTypeBranch).getNewType().isPrivate() + then result = false + else result = true +} /** * Holds if `container` defines module `m` with name `name`. diff --git a/ql/src/codeql_ql/ast/internal/Predicate.qll b/ql/src/codeql_ql/ast/internal/Predicate.qll index bc835ad9369..473600bfb9f 100644 --- a/ql/src/codeql_ql/ast/internal/Predicate.qll +++ b/ql/src/codeql_ql/ast/internal/Predicate.qll @@ -1,18 +1,30 @@ import ql private import Builtins private import codeql_ql.ast.internal.Module +private import codeql_ql.ast.internal.AstNodes as AstNodes -private predicate definesPredicate(FileOrModule m, string name, ClasslessPredicate p, boolean public) { +private class TClasslessPredicateOrNewTypeBranch = + AstNodes::TClasslessPredicate or AstNodes::TNewTypeBranch; + +string getPredicateName(TClasslessPredicateOrNewTypeBranch p) { + result = p.(ClasslessPredicate).getName() or + result = p.(NewTypeBranch).getName() +} + +private predicate definesPredicate( + FileOrModule m, string name, int arity, TClasslessPredicateOrNewTypeBranch p, boolean public +) { m = getEnclosingModule(p) and - name = p.getName() and - public = getPublicBool(p) + name = getPredicateName(p) and + public = getPublicBool(p) and + arity = [p.(ClasslessPredicate).getArity(), count(p.(NewTypeBranch).getField(_))] or // import X exists(Import imp, FileOrModule m0 | m = getEnclosingModule(imp) and m0 = imp.getResolvedModule() and not exists(imp.importedAs()) and - definesPredicate(m0, name, p, true) and + definesPredicate(m0, name, arity, p, true) and public = getPublicBool(imp) ) or @@ -21,7 +33,8 @@ private predicate definesPredicate(FileOrModule m, string name, ClasslessPredica m = getEnclosingModule(alias) and name = alias.getName() and resolvePredicateExpr(alias.getAlias(), p) and - public = getPublicBool(alias) + public = getPublicBool(alias) and + arity = alias.getArity() ) } @@ -34,8 +47,7 @@ predicate resolvePredicateExpr(PredicateExpr pe, ClasslessPredicate p) { m = pe.getQualifier().getResolvedModule() and public = true | - definesPredicate(m, pe.getName(), p, public) and - count(p.getParameter(_)) = pe.getArity() + definesPredicate(m, pe.getName(), count(p.getParameter(_)), p, public) ) } @@ -54,8 +66,7 @@ private predicate resolvePredicateCall(PredicateCall pc, PredicateOrBuiltin p) { m = pc.getQualifier().getResolvedModule() and public = true | - definesPredicate(m, pc.getPredicateName(), p.getDeclaration(), public) and - p.getArity() = pc.getNumberOfArguments() + definesPredicate(m, pc.getPredicateName(), pc.getNumberOfArguments(), p.getDeclaration(), public) ) } @@ -74,6 +85,7 @@ predicate resolveCall(Call c, PredicateOrBuiltin p) { private newtype TPredOrBuiltin = TPred(Predicate p) or + TNewTypeBranch(NewTypeBranch b) or TBuiltinClassless(string ret, string name, string args) { isBuiltinClassless(ret, name, args) } or TBuiltinMember(string qual, string ret, string name, string args) { isBuiltinMember(qual, ret, name, args) @@ -101,7 +113,7 @@ class PredicateOrBuiltin extends TPredOrBuiltin { ) } - Predicate getDeclaration() { none() } + AstNode getDeclaration() { none() } Type getDeclaringType() { none() } @@ -127,6 +139,7 @@ private class DefinedPredicate extends PredicateOrBuiltin, TPred { override Type getParameterType(int i) { result = decl.getParameter(i).getType() } + // Can be removed when all types can be resolved override int getArity() { result = decl.getArity() } override Type getDeclaringType() { @@ -140,6 +153,27 @@ private class DefinedPredicate extends PredicateOrBuiltin, TPred { } } +private class DefinedNewTypeBranch extends PredicateOrBuiltin, TNewTypeBranch { + NewTypeBranch b; + + DefinedNewTypeBranch() { this = TNewTypeBranch(b) } + + override NewTypeBranch getDeclaration() { result = b } + + override string getName() { result = b.getName() } + + override NewTypeBranchType getReturnType() { result.getDeclaration() = b } + + override Type getParameterType(int i) { result = b.getField(i).getType() } + + // Can be removed when all types can be resolved + override int getArity() { result = count(b.getField(_)) } + + override Type getDeclaringType() { none() } + + override predicate isPrivate() { b.getNewType().isPrivate() } +} + private class TBuiltin = TBuiltinClassless or TBuiltinMember; class BuiltinPredicate extends PredicateOrBuiltin, TBuiltin { } @@ -183,6 +217,7 @@ module PredConsistency { query predicate noResolveCall(Call c) { not resolveCall(c, _) and + not c instanceof NoneCall and not c.getLocation().getFile().getAbsolutePath().regexpMatch(".*/(test|examples)/.*") } diff --git a/ql/src/codeql_ql/ast/internal/Type.qll b/ql/src/codeql_ql/ast/internal/Type.qll index 63a837545cd..804c0eb53bf 100644 --- a/ql/src/codeql_ql/ast/internal/Type.qll +++ b/ql/src/codeql_ql/ast/internal/Type.qll @@ -61,11 +61,24 @@ class Type extends TType { ) } + pragma[nomagic] + private predicate getClassPredicate0(string name, int arity, PredicateOrBuiltin p, Type t) { + p = classPredCandidate(this, name, arity) and + t = p.getDeclaringType().getASuperType+() + } + + pragma[nomagic] + private predicate getClassPredicate1( + string name, int arity, PredicateOrBuiltin p1, PredicateOrBuiltin p2 + ) { + getClassPredicate0(name, arity, p1, p2.getDeclaringType()) and + p2 = classPredCandidate(this, name, arity) + } + + pragma[nomagic] PredicateOrBuiltin getClassPredicate(string name, int arity) { result = classPredCandidate(this, name, arity) and - not exists(PredicateOrBuiltin other | other = classPredCandidate(this, name, arity) | - other.getDeclaringType().getASuperType+() = result.getDeclaringType() - ) + not getClassPredicate1(name, arity, _, result) } } diff --git a/ql/src/codeql_ql/ast/internal/Variable.qll b/ql/src/codeql_ql/ast/internal/Variable.qll index 64ebd70e046..379196d6faa 100644 --- a/ql/src/codeql_ql/ast/internal/Variable.qll +++ b/ql/src/codeql_ql/ast/internal/Variable.qll @@ -10,34 +10,43 @@ class VariableScope extends TScope, AstNode { VariableScope getOuterScope() { result = scopeOf(this) } /** Gets a variable declared directly in this scope. */ - VarDef getADefinition() { result.getParent() = this } + VarDef getADefinition(string name) { + result.getParent() = this and + name = result.getName() + } /** Holds if this scope contains variable `decl`, either directly or inherited. */ - predicate containsVar(VarDef decl) { - not this instanceof Class and - decl = this.getADefinition() - or - decl = this.(Select).getExpr(_).(AsExpr) - or - decl = this.(Aggregate).getExpr(_).(AsExpr) - or - decl = this.(ExprAggregate).getExpr(_).(AsExpr) - or - this.getOuterScope().containsVar(decl) and - not this.getADefinition().getName() = decl.getName() + predicate containsVar(VarDef decl, string name) { + name = decl.getName() and + ( + not this instanceof Class and + decl = this.getADefinition(name) + or + decl = this.(Select).getExpr(_).(AsExpr) + or + decl = this.(Aggregate).getExpr(_).(AsExpr) + or + decl = this.(ExprAggregate).getExpr(_).(AsExpr) + or + this.getOuterScope().containsVar(decl, name) and + not exists(this.getADefinition(name)) + ) } /** Holds if this scope contains field `decl`, either directly or inherited. */ - predicate containsField(VarDef decl) { - decl = this.(Class).getAField() - or - this.getOuterScope().containsField(decl) and - not this.getADefinition().getName() = decl.getName() - or - exists(VariableScope sup | - sup = this.(Class).getASuperType().getResolvedType().(ClassType).getDeclaration() and - sup.containsField(decl) and - not this.(Class).getAField().getName() = decl.getName() + predicate containsField(VarDef decl, string name) { + name = decl.getName() and + ( + decl = this.(Class).getAField() + or + this.getOuterScope().containsField(decl, name) and + not exists(this.getADefinition(name)) + or + exists(VariableScope sup | + sup = this.(Class).getASuperType().getResolvedType().(ClassType).getDeclaration() and + sup.containsField(decl, name) and + not this.(Class).getAField().getName() = name + ) ) } } @@ -56,15 +65,9 @@ private string getName(Identifier i) { ) } -predicate resolveVariable(Identifier i, VarDef decl) { - scopeOf(i).containsVar(decl) and - decl.getName() = getName(i) -} +predicate resolveVariable(Identifier i, VarDef decl) { scopeOf(i).containsVar(decl, getName(i)) } -predicate resolveField(Identifier i, VarDef decl) { - scopeOf(i).containsField(decl) and - decl.getName() = getName(i) -} +predicate resolveField(Identifier i, VarDef decl) { scopeOf(i).containsField(decl, getName(i)) } module VarConsistency { query predicate multipleVarDefs(VarAccess v, VarDef decl) { diff --git a/ql/src/codeql_ql/printAstAst.qll b/ql/src/codeql_ql/printAstAst.qll index 4471a405a44..ff1a47e8e5c 100644 --- a/ql/src/codeql_ql/printAstAst.qll +++ b/ql/src/codeql_ql/printAstAst.qll @@ -72,9 +72,7 @@ class PrintAstNode extends AstNode { /** * Gets the child node that is accessed using the predicate `edgeName`. */ - PrintAstNode getChild(string edgeName) { - result = this.getAChild(edgeName) - } + PrintAstNode getChild(string edgeName) { result = this.getAChild(edgeName) } } private predicate shouldPrintNode(AstNode n) {