diff --git a/ql/src/codeql_ql/ast/Ast.qll b/ql/src/codeql_ql/ast/Ast.qll index 06713d30745..e11c43a569a 100644 --- a/ql/src/codeql_ql/ast/Ast.qll +++ b/ql/src/codeql_ql/ast/Ast.qll @@ -819,7 +819,7 @@ class NewType extends TNewType, TypeDeclaration, ModuleDeclaration { * A branch in a `newtype`. * E.g. `Bar()` or `Baz()` in `newtype Foo = Bar() or Baz()`. */ -class NewTypeBranch extends TNewTypeBranch, PredicateOrBuiltin, TypeDeclaration { +class NewTypeBranch extends TNewTypeBranch, Predicate, TypeDeclaration { QL::DatatypeBranch branch; NewTypeBranch() { this = TNewTypeBranch(branch) } @@ -835,7 +835,7 @@ class NewTypeBranch extends TNewTypeBranch, PredicateOrBuiltin, TypeDeclaration } /** Gets the body of this branch. */ - Formula getBody() { toQL(result) = branch.getChild(_).(QL::Body).getChild() } + override Formula getBody() { toQL(result) = branch.getChild(_).(QL::Body).getChild() } override NewTypeBranchType getReturnType() { result.getDeclaration() = this } diff --git a/ql/src/codeql_ql/ast/internal/AstNodes.qll b/ql/src/codeql_ql/ast/internal/AstNodes.qll index 1b0f92d014e..9e5b8cb99ab 100644 --- a/ql/src/codeql_ql/ast/internal/AstNodes.qll +++ b/ql/src/codeql_ql/ast/internal/AstNodes.qll @@ -193,7 +193,8 @@ QL::AstNode toQL(AST::AstNode n) { n = TAnnotationArg(result) } -class TPredicate = TCharPred or TClasslessPredicate or TClassPredicate or TDBRelation; +class TPredicate = + TCharPred or TClasslessPredicate or TClassPredicate or TDBRelation or TNewTypeBranch; class TPredOrBuiltin = TPredicate or TNewTypeBranch or TBuiltin; diff --git a/ql/src/codeql_ql/ast/internal/Predicate.qll b/ql/src/codeql_ql/ast/internal/Predicate.qll index 44715c1b0ff..fa0f55bf893 100644 --- a/ql/src/codeql_ql/ast/internal/Predicate.qll +++ b/ql/src/codeql_ql/ast/internal/Predicate.qll @@ -3,20 +3,13 @@ private import Builtins private import codeql_ql.ast.internal.Module private import codeql_ql.ast.internal.AstNodes -private class TClasslessPredicateOrNewTypeBranch = TClasslessPredicate or TNewTypeBranch; - -private string getPredicateName(TClasslessPredicateOrNewTypeBranch p) { - result = p.(ClasslessPredicate).getName() or - result = p.(NewTypeBranch).getName() -} - private predicate definesPredicate( - FileOrModule m, string name, int arity, TClasslessPredicateOrNewTypeBranch p, boolean public + FileOrModule m, string name, int arity, Predicate p, boolean public ) { m = getEnclosingModule(p) and - name = getPredicateName(p) and + name = p.getName() and public = getPublicBool(p) and - arity = [p.(ClasslessPredicate).getArity(), count(p.(NewTypeBranch).getField(_))] + arity = p.getArity() or // import X exists(Import imp, FileOrModule m0 | @@ -40,7 +33,7 @@ private predicate definesPredicate( cached private module Cached { cached - predicate resolvePredicateExpr(PredicateExpr pe, ClasslessPredicate p) { + predicate resolvePredicateExpr(PredicateExpr pe, Predicate p) { exists(FileOrModule m, boolean public | not exists(pe.getQualifier()) and m = getEnclosingModule(pe).getEnclosing*() and @@ -49,7 +42,7 @@ private module Cached { m = pe.getQualifier().getResolvedModule() and public = true | - definesPredicate(m, pe.getName(), count(p.getParameter(_)), p, public) + definesPredicate(m, pe.getName(), p.getArity(), p, public) ) } diff --git a/ql/test/callgraph/callgraph.expected b/ql/test/callgraph/callgraph.expected index 4011cfb9925..33fb31cef9d 100644 --- a/ql/test/callgraph/callgraph.expected +++ b/ql/test/callgraph/callgraph.expected @@ -25,4 +25,5 @@ dependsOn exprPredicate | Foo.qll:24:22:24:31 | predicate | Foo.qll:22:3:22:32 | ClasslessPredicate myThing0 | | Foo.qll:26:22:26:31 | predicate | Foo.qll:20:3:20:54 | ClasslessPredicate myThing2 | +| Foo.qll:47:55:47:62 | predicate | Foo.qll:42:20:42:27 | NewTypeBranch MkRoot | | Foo.qll:47:65:47:70 | predicate | Foo.qll:44:9:44:56 | ClasslessPredicate edge |