Rust: Disambiguate calls to associated functions

This commit is contained in:
Tom Hvitved
2025-07-08 13:15:27 +02:00
parent 95c2b9f8f7
commit ebde0bdc47
6 changed files with 282 additions and 92 deletions

View File

@@ -13,6 +13,7 @@ private import codeql.rust.elements.Resolvable
*/
module Impl {
private import rust
private import codeql.rust.internal.TypeInference as TypeInference
pragma[nomagic]
Resolvable getCallResolvable(CallExprBase call) {
@@ -27,7 +28,7 @@ module Impl {
*/
class CallExprBase extends Generated::CallExprBase {
/** Gets the static target of this call, if any. */
Callable getStaticTarget() { none() } // overridden by subclasses, but cannot be made abstract
final Function getStaticTarget() { result = TypeInference::resolveCallTarget(this) }
override Expr getArg(int index) { result = this.getArgList().getArg(index) }
}

View File

@@ -14,7 +14,6 @@ private import codeql.rust.elements.PathExpr
module Impl {
private import rust
private import codeql.rust.internal.PathResolution as PathResolution
private import codeql.rust.internal.TypeInference as TypeInference
pragma[nomagic]
Path getFunctionPath(CallExpr ce) { result = ce.getFunction().(PathExpr).getPath() }
@@ -37,15 +36,6 @@ module Impl {
class CallExpr extends Generated::CallExpr {
override string toStringImpl() { result = this.getFunction().toAbbreviatedString() + "(...)" }
override Callable getStaticTarget() {
// If this call is to a trait method, e.g., `Trait::foo(bar)`, then check
// if type inference can resolve it to the correct trait implementation.
result = TypeInference::resolveMethodCallTarget(this)
or
not exists(TypeInference::resolveMethodCallTarget(this)) and
result = getResolvedFunction(this)
}
/** Gets the struct that this call resolves to, if any. */
Struct getStruct() { result = getResolvedFunction(this) }

View File

@@ -40,6 +40,9 @@ module Impl {
/** Gets the trait targeted by this call, if any. */
abstract Trait getTrait();
/** Holds if this call targets a trait. */
predicate hasTrait() { exists(this.getTrait()) }
/** Gets the name of the method called if this call is a method call. */
abstract string getMethodName();
@@ -59,12 +62,7 @@ module Impl {
Expr getReceiver() { result = this.getArgument(TSelfArgumentPosition()) }
/** Gets the static target of this call, if any. */
Function getStaticTarget() {
result = TypeInference::resolveMethodCallTarget(this)
or
not exists(TypeInference::resolveMethodCallTarget(this)) and
result = this.(CallExpr).getStaticTarget()
}
Function getStaticTarget() { result = TypeInference::resolveCallTarget(this) }
/** Gets a runtime target of this call, if any. */
pragma[nomagic]
@@ -78,23 +76,44 @@ module Impl {
}
}
/** Holds if the call expression dispatches to a method. */
private predicate callIsMethodCall(CallExpr call, Path qualifier, string methodName) {
exists(Path path, Function f |
path = call.getFunction().(PathExpr).getPath() and
f = resolvePath(path) and
f.getParamList().hasSelfParam() and
qualifier = path.getQualifier() and
path.getSegment().getIdentifier().getText() = methodName
private predicate callHasQualifier(CallExpr call, Path path, Path qualifier) {
path = call.getFunction().(PathExpr).getPath() and
qualifier = path.getQualifier()
}
private predicate callHasTraitQualifier(CallExpr call, Trait qualifier) {
exists(RelevantPath qualifierPath |
callHasQualifier(call, _, qualifierPath) and
qualifier = resolvePath(qualifierPath) and
// When the qualifier is `Self` and resolves to a trait, it's inside a
// trait method's default implementation. This is not a dispatch whose
// target is inferred from the type of the receiver, but should always
// resolve to the function in the trait block as path resolution does.
not qualifierPath.isUnqualified("Self")
)
}
private class CallExprCall extends Call instanceof CallExpr {
CallExprCall() { not callIsMethodCall(this, _, _) }
/** Holds if the call expression dispatches to a method. */
private predicate callIsMethodCall(
CallExpr call, Path qualifier, string methodName, boolean selfIsRef
) {
exists(Path path, Function f |
callHasQualifier(call, path, qualifier) and
f = resolvePath(path) and
path.getSegment().getIdentifier().getText() = methodName and
exists(SelfParam self |
self = f.getParamList().getSelfParam() and
if self.isRef() then selfIsRef = true else selfIsRef = false
)
)
}
class CallExprCall extends Call instanceof CallExpr {
CallExprCall() { not callIsMethodCall(this, _, _, _) }
override string getMethodName() { none() }
override Trait getTrait() { none() }
override Trait getTrait() { callHasTraitQualifier(this, result) }
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) { none() }
@@ -103,22 +122,23 @@ module Impl {
}
}
private class CallExprMethodCall extends Call instanceof CallExpr {
class CallExprMethodCall extends Call instanceof CallExpr {
Path qualifier;
string methodName;
boolean selfIsRef;
CallExprMethodCall() { callIsMethodCall(this, qualifier, methodName) }
CallExprMethodCall() { callIsMethodCall(this, qualifier, methodName, selfIsRef) }
/**
* Holds if this call must have an explicit borrow for the `self` argument,
* because the corresponding parameter is `&self`. Explicit borrows are not
* needed when using method call syntax.
*/
predicate hasExplicitSelfBorrow() { selfIsRef = true }
override string getMethodName() { result = methodName }
override Trait getTrait() {
result = resolvePath(qualifier) and
// When the qualifier is `Self` and resolves to a trait, it's inside a
// trait method's default implementation. This is not a dispatch whose
// target is inferred from the type of the receiver, but should always
// resolve to the function in the trait block as path resolution does.
qualifier.toString() != "Self"
}
override Trait getTrait() { callHasTraitQualifier(this, result) }
override predicate implicitBorrowAt(ArgumentPosition pos, boolean certain) { none() }

View File

@@ -6,8 +6,6 @@
private import rust
private import codeql.rust.elements.internal.generated.MethodCallExpr
private import codeql.rust.internal.PathResolution
private import codeql.rust.internal.TypeInference
/**
* INTERNAL: This module contains the customizable definition of `MethodCallExpr` and should not
@@ -23,8 +21,6 @@ module Impl {
* ```
*/
class MethodCallExpr extends Generated::MethodCallExpr {
override Function getStaticTarget() { result = resolveMethodCallTarget(this) }
private string toStringPart(int index) {
index = 0 and
result = this.getReceiver().toAbbreviatedString()

View File

@@ -425,6 +425,8 @@ final class TraitTypeAbstraction extends TypeAbstraction, Trait {
result.(TypeParamTypeParameter).getTypeParam() = this.getGenericParamList().getATypeParam()
or
result.(AssociatedTypeTypeParameter).getTrait() = this
or
result.(SelfTypeParameter).getTrait() = this
}
}

View File

@@ -11,6 +11,7 @@ private import codeql.rust.frameworks.stdlib.Stdlib
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
private import codeql.rust.elements.Call
private import codeql.rust.elements.internal.CallImpl::Impl as CallImpl
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl
class Type = T::Type;
@@ -522,6 +523,15 @@ private Type inferPathExprType(PathExpr pe, TypePath path) {
)
}
/** Gets the explicit type qualifier of the call `ce`, if any. */
private Type getTypeQualifier(CallExpr ce, TypePath path) {
exists(PathExpr pe, TypeMention tm |
pe = ce.getFunction() and
tm = pe.getPath().getQualifier() and
result = tm.resolveTypeAt(path)
)
}
/**
* A matching configuration for resolving types of call expressions
* like `foo::bar(baz)` and `foo.bar(baz)`.
@@ -724,8 +734,6 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}
}
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl
final class Access extends Call {
pragma[nomagic]
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
@@ -761,17 +769,13 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
or
// The `Self` type is supplied explicitly as a type qualifier, e.g. `Foo::<Bar>::baz()`
apos = TArgumentAccessPosition(CallImpl::TSelfArgumentPosition(), false, false) and
exists(PathExpr pe, TypeMention tm |
pe = this.(CallExpr).getFunction() and
tm = pe.getPath().getQualifier() and
result = tm.resolveTypeAt(path)
)
result = getTypeQualifier(this, path)
}
Declaration getTarget() {
result = resolveMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
or
result = CallExprImpl::getResolvedFunction(this)
result = resolveFunctionCallTarget(this) // potential mutual recursion; resolving some associated function calls requires resolving types
}
}
@@ -1220,15 +1224,28 @@ private Type inferForLoopExprType(AstNode n, TypePath path) {
)
}
pragma[nomagic]
private Type inferCastExprType(CastExpr ce, TypePath path) {
result = ce.getTypeRepr().(TypeMention).resolveTypeAt(path)
}
final class MethodCall extends Call {
MethodCall() { exists(this.getReceiver()) }
private Type getReceiverTypeAt(TypePath path) {
result = inferType(super.getReceiver(), path)
or
result = getTypeQualifier(this, path)
}
/** Gets the type of the receiver of the method call at `path`. */
Type getTypeAt(TypePath path) {
if this.receiverImplicitlyBorrowed()
if
this.receiverImplicitlyBorrowed() or
this.(CallImpl::CallExprMethodCall).hasExplicitSelfBorrow()
then
exists(TypePath path0, Type t0 |
t0 = inferType(super.getReceiver(), path0) and
t0 = this.getReceiverTypeAt(path0) and
(
path0.isCons(TRefTypeParameter(), path)
or
@@ -1256,7 +1273,7 @@ final class MethodCall extends Call {
t0.(StructType).asItemNode() instanceof StringStruct and
result.(StructType).asItemNode() instanceof Builtins::Str
)
else result = inferType(super.getReceiver(), path)
else result = this.getReceiverTypeAt(path)
}
}
@@ -1349,8 +1366,6 @@ private predicate implSiblingCandidate(
// contains the same `impl` block so considering both would give spurious
// siblings).
not exists(impl.getAttributeMacroExpansion()) and
// We use this for resolving methods, so exclude traits that do not have methods.
exists(Function f | f = trait.getASuccessor(_) and f.getParamList().hasSelfParam()) and
selfTy = impl.getSelfTy() and
rootType = selfTy.resolveType()
}
@@ -1385,42 +1400,49 @@ private predicate implSiblings(TraitItemNode trait, Impl impl1, Impl impl2) {
pragma[nomagic]
private predicate implHasSibling(Impl impl, Trait trait) { implSiblings(trait, impl, _) }
pragma[nomagic]
private predicate functionTypeAtPath(Function f, int pos, TypePath path, Type type) {
exists(TypeMention tm | type = tm.resolveTypeAt(path) |
tm = f.getParam(pos).getTypeRepr()
or
pos = -1 and
tm = f.getRetType().getTypeRepr()
)
}
/**
* Holds if a type parameter of `trait` occurs in the method with the name
* `methodName` at the `pos`th parameter at `path`.
* Holds if type parameter `tp` of `trait` occurs in the function with the name
* `functionName` at the `pos`th parameter at `path`.
*
* The special position `-1` refers to the return type of the function, which
* is sometimes needed to disambiguate associated function calls like
* `Default::default()` (in this case, `tp` is the special `Self` type parameter).
*/
bindingset[trait]
pragma[inline_late]
private predicate traitTypeParameterOccurrence(
TraitItemNode trait, string methodName, int pos, TypePath path
TraitItemNode trait, Function f, string functionName, int pos, TypePath path, TypeParameter tp
) {
exists(Function f | f = trait.getASuccessor(methodName) |
f.getParam(pos).getTypeRepr().(TypeMention).resolveTypeAt(path) =
trait.(TraitTypeAbstraction).getATypeParameter()
)
}
bindingset[f, pos, path]
pragma[inline_late]
private predicate methodTypeAtPath(Function f, int pos, TypePath path, Type type) {
f.getParam(pos).getTypeRepr().(TypeMention).resolveTypeAt(path) = type
f = trait.getAssocItem(functionName) and
functionTypeAtPath(f, pos, path, tp) and
tp = trait.(TraitTypeAbstraction).getATypeParameter()
}
/**
* Holds if resolving the method `f` in `impl` with the name `methodName`
* Holds if resolving the function `f` in `impl` with the name `functionName`
* requires inspecting the types of applied _arguments_ in order to determine
* whether it is the correct resolution.
*/
pragma[nomagic]
private predicate methodResolutionDependsOnArgument(
Impl impl, string methodName, Function f, int pos, TypePath path, Type type
private predicate functionResolutionDependsOnArgument(
ImplItemNode impl, string functionName, Function f, int pos, TypePath path, Type type
) {
/*
* As seen in the example below, when an implementation has a sibling for a
* trait we find occurrences of a type parameter of the trait in a method
* trait we find occurrences of a type parameter of the trait in a function
* signature in the trait. We then find the type given in the implementation
* at the same position, which is a position that might disambiguate the
* method from its siblings.
* function from its siblings.
*
* ```rust
* trait MyTrait<T> {
@@ -1442,9 +1464,10 @@ private predicate methodResolutionDependsOnArgument(
exists(TraitItemNode trait |
implHasSibling(impl, trait) and
traitTypeParameterOccurrence(trait, methodName, pos, path) and
methodTypeAtPath(getMethodSuccessor(impl, methodName), pos, path, type) and
f = getMethodSuccessor(impl, methodName)
traitTypeParameterOccurrence(trait, _, functionName, pos, path, _) and
functionTypeAtPath(f, pos, path, type) and
f = impl.getAssocItem(functionName) and
pos >= 0
)
}
@@ -1484,11 +1507,12 @@ private Function getMethodFromImpl(MethodCall mc) {
name = mc.getMethodName() and
result = getMethodSuccessor(impl, name)
|
not methodResolutionDependsOnArgument(impl, _, _, _, _, _)
not functionResolutionDependsOnArgument(impl, name, _, _, _, _)
or
exists(int pos, TypePath path, Type type |
methodResolutionDependsOnArgument(impl, name, result, pos, path, type) and
inferType(mc.getPositionalArgument(pos), path) = type
functionResolutionDependsOnArgument(impl, name, result, pos, pragma[only_bind_into](path),
type) and
inferType(mc.getPositionalArgument(pos), pragma[only_bind_into](path)) = type
)
)
}
@@ -1499,6 +1523,162 @@ private Function getTraitMethod(ImplTraitReturnType trait, string name) {
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
}
pragma[nomagic]
private Function resolveMethodCallTarget(MethodCall mc) {
// The method comes from an `impl` block targeting the type of the receiver.
result = getMethodFromImpl(mc)
or
// The type of the receiver is a type parameter and the method comes from a
// trait bound on the type parameter.
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
or
// The type of the receiver is an `impl Trait` type.
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
}
pragma[nomagic]
private predicate assocFuncResolutionDependsOnArgument(Function f, Impl impl, int pos) {
functionResolutionDependsOnArgument(impl, _, f, pos, _, _) and
not f.getParamList().hasSelfParam()
}
private class FunctionCallExpr extends CallImpl::CallExprCall {
ItemNode getResolvedFunction() { result = CallExprImpl::getResolvedFunction(this) }
/**
* Holds if the target of this call is ambigous, and type information is required
* to disambiguate.
*/
predicate isAmbigous() {
this.hasTrait()
or
assocFuncResolutionDependsOnArgument(this.getResolvedFunction(), _, _)
}
/**
* Gets a target candidate of this ambigous call, which belongs to `impl`.
*
* In order for the candidate to be a match, the argument type at `pos` must be
* checked against the type of the function at the same position.
*
* `resolved` is the corresponding function resolved through path resolution.
*/
pragma[nomagic]
Function getAnAmbigousCandidate(ImplItemNode impl, int pos, Function resolved) {
resolved = this.getResolvedFunction() and
(
exists(TraitItemNode trait |
trait = this.getTrait() and
result.implements(resolved) and
result = impl.getAnAssocItem()
|
assocFuncResolutionDependsOnArgument(result, impl, pos)
or
exists(TypeParameter tp | traitTypeParameterOccurrence(trait, resolved, _, pos, _, tp) |
pos >= 0
or
// We only check that the context of the call provides relevant type information
// when no argument can
not traitTypeParameterOccurrence(trait, resolved, _, any(int pos0 | pos0 >= 0), _, tp)
)
)
or
result = resolved and
assocFuncResolutionDependsOnArgument(result, impl, pos)
)
}
/**
* Same as `getAnAmbigousCandidate`, ranks the positions to be checked.
*/
Function getAnAmbigousCandidateRanked(ImplItemNode impl, int pos, Function f, int rnk) {
pos = rank[rnk + 1](int pos0 | result = this.getAnAmbigousCandidate(impl, pos0, f) | pos0)
}
}
private newtype TAmbigousAssocFunctionCallExpr =
MkAmbigousAssocFunctionCallExpr(FunctionCallExpr call, Function resolved, int pos) {
exists(call.getAnAmbigousCandidate(_, pos, resolved))
}
private class AmbigousAssocFunctionCallExpr extends MkAmbigousAssocFunctionCallExpr {
FunctionCallExpr call;
Function resolved;
int pos;
AmbigousAssocFunctionCallExpr() { this = MkAmbigousAssocFunctionCallExpr(call, resolved, pos) }
pragma[nomagic]
Type getTypeAt(TypePath path) {
result = inferType(call.(CallExpr).getArg(pos), path)
or
pos = -1 and
result = inferType(call, path)
}
string toString() { result = call.toString() }
Location getLocation() { result = call.getLocation() }
}
private module AmbigousAssocFuncIsInstantiationOfInput implements
IsInstantiationOfInputSig<AmbigousAssocFunctionCallExpr>
{
pragma[nomagic]
predicate potentialInstantiationOf(
AmbigousAssocFunctionCallExpr ce, TypeAbstraction impl, TypeMention constraint
) {
exists(FunctionCallExpr call, Function resolved, Function cand, int pos |
ce = MkAmbigousAssocFunctionCallExpr(call, resolved, pos) and
cand = call.getAnAmbigousCandidate(impl, pos, resolved)
|
constraint = cand.getParam(pos).getTypeRepr()
or
pos = -1 and
constraint = cand.getRetType().getTypeRepr()
)
}
}
/**
* Gets the target of `call`, where resolution does not rely on type inference.
*/
pragma[nomagic]
private ItemNode resolveUnambigousFunctionCallTarget(FunctionCallExpr call) {
result = call.getResolvedFunction() and
not call.isAmbigous()
}
pragma[nomagic]
private Function resolveAmbigousFunctionCallTargetFromIndex(FunctionCallExpr call, int index) {
exists(Impl impl, int pos, Function resolved |
IsInstantiationOf<AmbigousAssocFunctionCallExpr, AmbigousAssocFuncIsInstantiationOfInput>::isInstantiationOf(MkAmbigousAssocFunctionCallExpr(call,
resolved, pos), impl, _) and
result = call.getAnAmbigousCandidateRanked(impl, pos, resolved, index)
|
index = 0
or
result = resolveAmbigousFunctionCallTargetFromIndex(call, index - 1)
)
}
/**
* Gets the target of `call`, where resolution relies on type inference.
*/
pragma[nomagic]
private Function resolveAmbigousFunctionCallTarget(FunctionCallExpr call) {
result =
resolveAmbigousFunctionCallTargetFromIndex(call,
max(int index | result = call.getAnAmbigousCandidateRanked(_, _, _, index)))
}
pragma[inline]
private ItemNode resolveFunctionCallTarget(FunctionCallExpr call) {
result = resolveUnambigousFunctionCallTarget(call)
or
result = resolveAmbigousFunctionCallTarget(call)
}
cached
private module Cached {
private import codeql.rust.internal.CachedStages
@@ -1527,18 +1707,12 @@ private module Cached {
)
}
/** Gets a method that the method call `mc` resolves to, if any. */
/** Gets a function that `call` resolves to, if any. */
cached
Function resolveMethodCallTarget(MethodCall mc) {
// The method comes from an `impl` block targeting the type of the receiver.
result = getMethodFromImpl(mc)
Function resolveCallTarget(Call call) {
result = resolveMethodCallTarget(call)
or
// The type of the receiver is a type parameter and the method comes from a
// trait bound on the type parameter.
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
or
// The type of the receiver is an `impl Trait` type.
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
result = resolveFunctionCallTarget(call)
}
pragma[inline]
@@ -1660,6 +1834,8 @@ private module Cached {
result = inferIndexExprType(n, path)
or
result = inferForLoopExprType(n, path)
or
result = inferCastExprType(n, path)
}
}
@@ -1685,9 +1861,9 @@ private module Debug {
result = inferType(n, path)
}
Function debugResolveMethodCallTarget(Call mce) {
mce = getRelevantLocatable() and
result = resolveMethodCallTarget(mce)
Function debugResolveCallTarget(Call c) {
c = getRelevantLocatable() and
result = resolveCallTarget(c)
}
predicate debugInferImplicitSelfType(SelfParam self, TypePath path, Type t) {
@@ -1705,6 +1881,11 @@ private module Debug {
tm.resolveTypeAt(path) = type
}
Type debugInferAnnotatedType(AstNode n, TypePath path) {
n = getRelevantLocatable() and
result = inferAnnotatedType(n, path)
}
pragma[nomagic]
private int countTypesAtPath(AstNode n, TypePath path, Type t) {
t = inferType(n, path) and