Merge pull request #19575 from paldepind/rust/function-call-method

Rust: Resolve function calls to traits methods
This commit is contained in:
Simon Friis Vindum
2025-05-27 09:28:36 +02:00
committed by GitHub
9 changed files with 1793 additions and 1635 deletions

View File

@@ -14,6 +14,7 @@ private import codeql.rust.elements.PathExpr
module Impl {
private import rust
private import codeql.rust.internal.PathResolution as PathResolution
private import codeql.rust.internal.TypeInference as TypeInference
pragma[nomagic]
Path getFunctionPath(CallExpr ce) { result = ce.getFunction().(PathExpr).getPath() }
@@ -36,7 +37,14 @@ module Impl {
class CallExpr extends Generated::CallExpr {
override string toStringImpl() { result = this.getFunction().toAbbreviatedString() + "(...)" }
override Callable getStaticTarget() { result = getResolvedFunction(this) }
override Callable getStaticTarget() {
// If this call is to a trait method, e.g., `Trait::foo(bar)`, then check
// if type inference can resolve it to the correct trait implementation.
result = TypeInference::resolveMethodCallTarget(this)
or
not exists(TypeInference::resolveMethodCallTarget(this)) and
result = getResolvedFunction(this)
}
/** Gets the struct that this call resolves to, if any. */
Struct getStruct() { result = getResolvedFunction(this) }

View File

@@ -14,14 +14,6 @@ private import codeql.rust.internal.TypeInference
* be referenced directly.
*/
module Impl {
private predicate isInherentImplFunction(Function f) {
f = any(Impl impl | not impl.hasTrait()).(ImplItemNode).getAnAssocItem()
}
private predicate isTraitImplFunction(Function f) {
f = any(Impl impl | impl.hasTrait()).(ImplItemNode).getAnAssocItem()
}
// the following QLdoc is generated: if you need to edit it, do it in the schema file
/**
* A method call expression. For example:
@@ -31,38 +23,7 @@ module Impl {
* ```
*/
class MethodCallExpr extends Generated::MethodCallExpr {
private Function getStaticTargetFrom(boolean fromSource) {
result = resolveMethodCallExpr(this) and
(if result.fromSource() then fromSource = true else fromSource = false) and
(
// prioritize inherent implementation methods first
isInherentImplFunction(result)
or
not isInherentImplFunction(resolveMethodCallExpr(this)) and
(
// then trait implementation methods
isTraitImplFunction(result)
or
not isTraitImplFunction(resolveMethodCallExpr(this)) and
(
// then trait methods with default implementations
result.hasBody()
or
// and finally trait methods without default implementations
not resolveMethodCallExpr(this).hasBody()
)
)
)
}
override Function getStaticTarget() {
// Functions in source code also gets extracted as library code, due to
// this duplication we prioritize functions from source code.
result = this.getStaticTargetFrom(true)
or
not exists(this.getStaticTargetFrom(true)) and
result = this.getStaticTargetFrom(false)
}
override Function getStaticTarget() { result = resolveMethodCallTarget(this) }
private string toStringPart(int index) {
index = 0 and

View File

@@ -678,7 +678,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
Declaration getTarget() {
result = CallExprImpl::getResolvedFunction(this)
or
result = resolveMethodCallExpr(this) // mutual recursion; resolving method calls requires resolving types and vice versa
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
}
}
@@ -1000,6 +1000,150 @@ private StructType inferLiteralType(LiteralExpr le) {
)
}
private module MethodCall {
/** An expression that calls a method. */
abstract private class MethodCallImpl extends Expr {
/** Gets the name of the method targeted. */
abstract string getMethodName();
/** Gets the number of arguments _excluding_ the `self` argument. */
abstract int getArity();
/** Gets the trait targeted by this method call, if any. */
Trait getTrait() { none() }
/** Gets the type of the receiver of the method call at `path`. */
abstract Type getTypeAt(TypePath path);
}
final class MethodCall = MethodCallImpl;
private class MethodCallExprMethodCall extends MethodCallImpl instanceof MethodCallExpr {
override string getMethodName() { result = super.getIdentifier().getText() }
override int getArity() { result = super.getArgList().getNumberOfArgs() }
pragma[nomagic]
override Type getTypeAt(TypePath path) {
exists(TypePath path0 | result = inferType(super.getReceiver(), path0) |
path0.isCons(TRefTypeParameter(), path)
or
not path0.isCons(TRefTypeParameter(), _) and
not (path0.isEmpty() and result = TRefType()) and
path = path0
)
}
}
private class CallExprMethodCall extends MethodCallImpl instanceof CallExpr {
TraitItemNode trait;
string methodName;
Expr receiver;
CallExprMethodCall() {
receiver = this.getArgList().getArg(0) and
exists(Path path, Function f |
path = this.getFunction().(PathExpr).getPath() and
f = resolvePath(path) and
f.getParamList().hasSelfParam() and
trait = resolvePath(path.getQualifier()) and
trait.getAnAssocItem() = f and
path.getSegment().getIdentifier().getText() = methodName
)
}
override string getMethodName() { result = methodName }
override int getArity() { result = super.getArgList().getNumberOfArgs() - 1 }
override Trait getTrait() { result = trait }
pragma[nomagic]
override Type getTypeAt(TypePath path) { result = inferType(receiver, path) }
}
}
import MethodCall
/**
* Holds if a method for `type` with the name `name` and the arity `arity`
* exists in `impl`.
*/
private predicate methodCandidate(Type type, string name, int arity, Impl impl) {
type = impl.getSelfTy().(TypeMention).resolveType() and
exists(Function f |
f = impl.(ImplItemNode).getASuccessor(name) and
f.getParamList().hasSelfParam() and
arity = f.getParamList().getNumberOfParams()
)
}
/**
* Holds if a method for `type` for `trait` with the name `name` and the arity
* `arity` exists in `impl`.
*/
pragma[nomagic]
private predicate methodCandidateTrait(Type type, Trait trait, string name, int arity, Impl impl) {
trait = resolvePath(impl.(ImplItemNode).getTraitPath()) and
methodCandidate(type, name, arity, impl)
}
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
pragma[nomagic]
predicate potentialInstantiationOf(MethodCall mc, TypeAbstraction impl, TypeMention constraint) {
exists(Type rootType, string name, int arity |
rootType = mc.getTypeAt(TypePath::nil()) and
name = mc.getMethodName() and
arity = mc.getArity() and
constraint = impl.(ImplTypeAbstraction).getSelfTy()
|
methodCandidateTrait(rootType, mc.getTrait(), name, arity, impl)
or
not exists(mc.getTrait()) and
methodCandidate(rootType, name, arity, impl)
)
}
predicate relevantTypeMention(TypeMention constraint) {
exists(Impl impl | methodCandidate(_, _, _, impl) and constraint = impl.getSelfTy())
}
}
bindingset[item, name]
pragma[inline_late]
private Function getMethodSuccessor(ItemNode item, string name) {
result = item.getASuccessor(name)
}
bindingset[tp, name]
pragma[inline_late]
private Function getTypeParameterMethod(TypeParameter tp, string name) {
result = getMethodSuccessor(tp.(TypeParamTypeParameter).getTypeParam(), name)
or
result = getMethodSuccessor(tp.(SelfTypeParameter).getTrait(), name)
}
/** Gets a method from an `impl` block that matches the method call `mc`. */
private Function getMethodFromImpl(MethodCall mc) {
exists(Impl impl |
IsInstantiationOf<MethodCall, IsInstantiationOfInput>::isInstantiationOf(mc, impl, _) and
result = getMethodSuccessor(impl, mc.getMethodName())
)
}
/**
* Gets a method that the method call `mc` resolves to based on type inference,
* if any.
*/
private Function inferMethodCallTarget(MethodCall mc) {
// The method comes from an `impl` block targeting the type of the receiver.
result = getMethodFromImpl(mc)
or
// The type of the receiver is a type parameter and the method comes from a
// trait bound on the type parameter.
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
}
cached
private module Cached {
private import codeql.rust.internal.CachedStages
@@ -1026,92 +1170,49 @@ private module Cached {
)
}
private class ReceiverExpr extends Expr {
MethodCallExpr mce;
ReceiverExpr() { mce.getReceiver() = this }
string getField() { result = mce.getIdentifier().getText() }
int getNumberOfArgs() { result = mce.getArgList().getNumberOfArgs() }
pragma[nomagic]
Type getTypeAt(TypePath path) {
exists(TypePath path0 | result = inferType(this, path0) |
path0.isCons(TRefTypeParameter(), path)
or
not path0.isCons(TRefTypeParameter(), _) and
not (path0.isEmpty() and result = TRefType()) and
path = path0
)
}
private predicate isInherentImplFunction(Function f) {
f = any(Impl impl | not impl.hasTrait()).(ImplItemNode).getAnAssocItem()
}
/** Holds if a method for `type` with the name `name` and the arity `arity` exists in `impl`. */
pragma[nomagic]
private predicate methodCandidate(Type type, string name, int arity, Impl impl) {
type = impl.getSelfTy().(TypeReprMention).resolveType() and
exists(Function f |
f = impl.(ImplItemNode).getASuccessor(name) and
f.getParamList().hasSelfParam() and
arity = f.getParamList().getNumberOfParams()
)
private predicate isTraitImplFunction(Function f) {
f = any(Impl impl | impl.hasTrait()).(ImplItemNode).getAnAssocItem()
}
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<ReceiverExpr> {
pragma[nomagic]
predicate potentialInstantiationOf(
ReceiverExpr receiver, TypeAbstraction impl, TypeMention constraint
) {
methodCandidate(receiver.getTypeAt(TypePath::nil()), receiver.getField(),
receiver.getNumberOfArgs(), impl) and
constraint = impl.(ImplTypeAbstraction).getSelfTy()
}
predicate relevantTypeMention(TypeMention constraint) {
exists(Impl impl | methodCandidate(_, _, _, impl) and constraint = impl.getSelfTy())
}
}
bindingset[item, name]
pragma[inline_late]
private Function getMethodSuccessor(ItemNode item, string name) {
result = item.getASuccessor(name)
}
bindingset[tp, name]
pragma[inline_late]
private Function getTypeParameterMethod(TypeParameter tp, string name) {
result = getMethodSuccessor(tp.(TypeParamTypeParameter).getTypeParam(), name)
or
result = getMethodSuccessor(tp.(SelfTypeParameter).getTrait(), name)
}
/**
* Gets the method from an `impl` block with an implementing type that matches
* the type of `receiver` and with a name of the method call in which
* `receiver` occurs, if any.
*/
private Function getMethodFromImpl(ReceiverExpr receiver) {
exists(Impl impl |
IsInstantiationOf<ReceiverExpr, IsInstantiationOfInput>::isInstantiationOf(receiver, impl, _) and
result = getMethodSuccessor(impl, receiver.getField())
)
}
/** Gets a method that the method call `mce` resolves to, if any. */
cached
Function resolveMethodCallExpr(MethodCallExpr mce) {
exists(ReceiverExpr receiver | mce.getReceiver() = receiver |
// The method comes from an `impl` block targeting the type of `receiver`.
result = getMethodFromImpl(receiver)
private Function resolveMethodCallTargetFrom(MethodCall mc, boolean fromSource) {
result = inferMethodCallTarget(mc) and
(if result.fromSource() then fromSource = true else fromSource = false) and
(
// prioritize inherent implementation methods first
isInherentImplFunction(result)
or
// The type of `receiver` is a type parameter and the method comes from a
// trait bound on the type parameter.
result = getTypeParameterMethod(receiver.getTypeAt(TypePath::nil()), receiver.getField())
not isInherentImplFunction(inferMethodCallTarget(mc)) and
(
// then trait implementation methods
isTraitImplFunction(result)
or
not isTraitImplFunction(inferMethodCallTarget(mc)) and
(
// then trait methods with default implementations
result.hasBody()
or
// and finally trait methods without default implementations
not inferMethodCallTarget(mc).hasBody()
)
)
)
}
/** Gets a method that the method call `mc` resolves to, if any. */
cached
Function resolveMethodCallTarget(MethodCall mc) {
// Functions in source code also gets extracted as library code, due to
// this duplication we prioritize functions from source code.
result = resolveMethodCallTargetFrom(mc, true)
or
not exists(resolveMethodCallTargetFrom(mc, true)) and
result = resolveMethodCallTargetFrom(mc, false)
}
pragma[inline]
private Type inferRootTypeDeref(AstNode n) {
result = inferType(n) and
@@ -1243,6 +1344,6 @@ private module Debug {
Function debugResolveMethodCallExpr(MethodCallExpr mce) {
mce = getRelevantLocatable() and
result = resolveMethodCallExpr(mce)
result = resolveMethodCallTarget(mce)
}
}

View File

@@ -92,6 +92,24 @@ edges
| main.rs:188:9:188:9 | d [MyInt] | main.rs:189:10:189:10 | d [MyInt] | provenance | |
| main.rs:188:13:188:20 | a.add(...) [MyInt] | main.rs:188:9:188:9 | d [MyInt] | provenance | |
| main.rs:189:10:189:10 | d [MyInt] | main.rs:189:10:189:16 | d.value | provenance | |
| main.rs:201:18:201:21 | SelfParam [MyInt] | main.rs:201:48:203:5 | { ... } [MyInt] | provenance | |
| main.rs:205:26:205:37 | ...: MyInt [MyInt] | main.rs:205:49:207:5 | { ... } [MyInt] | provenance | |
| main.rs:211:9:211:9 | a [MyInt] | main.rs:213:49:213:49 | a [MyInt] | provenance | |
| main.rs:211:13:211:38 | MyInt {...} [MyInt] | main.rs:211:9:211:9 | a [MyInt] | provenance | |
| main.rs:211:28:211:36 | source(...) | main.rs:211:13:211:38 | MyInt {...} [MyInt] | provenance | |
| main.rs:213:9:213:26 | MyInt {...} [MyInt] | main.rs:213:24:213:24 | c | provenance | |
| main.rs:213:24:213:24 | c | main.rs:214:10:214:10 | c | provenance | |
| main.rs:213:30:213:53 | ...::take_self(...) [MyInt] | main.rs:213:9:213:26 | MyInt {...} [MyInt] | provenance | |
| main.rs:213:49:213:49 | a [MyInt] | main.rs:201:18:201:21 | SelfParam [MyInt] | provenance | |
| main.rs:213:49:213:49 | a [MyInt] | main.rs:213:30:213:53 | ...::take_self(...) [MyInt] | provenance | |
| main.rs:217:9:217:9 | b [MyInt] | main.rs:218:54:218:54 | b [MyInt] | provenance | |
| main.rs:217:13:217:39 | MyInt {...} [MyInt] | main.rs:217:9:217:9 | b [MyInt] | provenance | |
| main.rs:217:28:217:37 | source(...) | main.rs:217:13:217:39 | MyInt {...} [MyInt] | provenance | |
| main.rs:218:9:218:26 | MyInt {...} [MyInt] | main.rs:218:24:218:24 | c | provenance | |
| main.rs:218:24:218:24 | c | main.rs:219:10:219:10 | c | provenance | |
| main.rs:218:30:218:55 | ...::take_second(...) [MyInt] | main.rs:218:9:218:26 | MyInt {...} [MyInt] | provenance | |
| main.rs:218:54:218:54 | b [MyInt] | main.rs:205:26:205:37 | ...: MyInt [MyInt] | provenance | |
| main.rs:218:54:218:54 | b [MyInt] | main.rs:218:30:218:55 | ...::take_second(...) [MyInt] | provenance | |
| main.rs:227:32:231:1 | { ... } | main.rs:246:41:246:54 | async_source(...) | provenance | |
| main.rs:228:9:228:9 | a | main.rs:227:32:231:1 | { ... } | provenance | |
| main.rs:228:9:228:9 | a | main.rs:229:10:229:10 | a | provenance | |
@@ -202,6 +220,26 @@ nodes
| main.rs:188:13:188:20 | a.add(...) [MyInt] | semmle.label | a.add(...) [MyInt] |
| main.rs:189:10:189:10 | d [MyInt] | semmle.label | d [MyInt] |
| main.rs:189:10:189:16 | d.value | semmle.label | d.value |
| main.rs:201:18:201:21 | SelfParam [MyInt] | semmle.label | SelfParam [MyInt] |
| main.rs:201:48:203:5 | { ... } [MyInt] | semmle.label | { ... } [MyInt] |
| main.rs:205:26:205:37 | ...: MyInt [MyInt] | semmle.label | ...: MyInt [MyInt] |
| main.rs:205:49:207:5 | { ... } [MyInt] | semmle.label | { ... } [MyInt] |
| main.rs:211:9:211:9 | a [MyInt] | semmle.label | a [MyInt] |
| main.rs:211:13:211:38 | MyInt {...} [MyInt] | semmle.label | MyInt {...} [MyInt] |
| main.rs:211:28:211:36 | source(...) | semmle.label | source(...) |
| main.rs:213:9:213:26 | MyInt {...} [MyInt] | semmle.label | MyInt {...} [MyInt] |
| main.rs:213:24:213:24 | c | semmle.label | c |
| main.rs:213:30:213:53 | ...::take_self(...) [MyInt] | semmle.label | ...::take_self(...) [MyInt] |
| main.rs:213:49:213:49 | a [MyInt] | semmle.label | a [MyInt] |
| main.rs:214:10:214:10 | c | semmle.label | c |
| main.rs:217:9:217:9 | b [MyInt] | semmle.label | b [MyInt] |
| main.rs:217:13:217:39 | MyInt {...} [MyInt] | semmle.label | MyInt {...} [MyInt] |
| main.rs:217:28:217:37 | source(...) | semmle.label | source(...) |
| main.rs:218:9:218:26 | MyInt {...} [MyInt] | semmle.label | MyInt {...} [MyInt] |
| main.rs:218:24:218:24 | c | semmle.label | c |
| main.rs:218:30:218:55 | ...::take_second(...) [MyInt] | semmle.label | ...::take_second(...) [MyInt] |
| main.rs:218:54:218:54 | b [MyInt] | semmle.label | b [MyInt] |
| main.rs:219:10:219:10 | c | semmle.label | c |
| main.rs:227:32:231:1 | { ... } | semmle.label | { ... } |
| main.rs:228:9:228:9 | a | semmle.label | a |
| main.rs:228:13:228:21 | source(...) | semmle.label | source(...) |
@@ -225,6 +263,8 @@ subpaths
| main.rs:143:38:143:38 | a | main.rs:106:27:106:32 | ...: i64 | main.rs:106:42:112:5 | { ... } | main.rs:143:13:143:39 | ...::data_through(...) |
| main.rs:161:24:161:33 | source(...) | main.rs:155:12:155:17 | ...: i64 | main.rs:155:28:157:5 | { ... } [MyInt] | main.rs:161:13:161:34 | ...::new(...) [MyInt] |
| main.rs:186:9:186:9 | a [MyInt] | main.rs:169:12:169:15 | SelfParam [MyInt] | main.rs:169:42:172:5 | { ... } [MyInt] | main.rs:188:13:188:20 | a.add(...) [MyInt] |
| main.rs:213:49:213:49 | a [MyInt] | main.rs:201:18:201:21 | SelfParam [MyInt] | main.rs:201:48:203:5 | { ... } [MyInt] | main.rs:213:30:213:53 | ...::take_self(...) [MyInt] |
| main.rs:218:54:218:54 | b [MyInt] | main.rs:205:26:205:37 | ...: MyInt [MyInt] | main.rs:205:49:207:5 | { ... } [MyInt] | main.rs:218:30:218:55 | ...::take_second(...) [MyInt] |
testFailures
#select
| main.rs:18:10:18:10 | a | main.rs:13:5:13:13 | source(...) | main.rs:18:10:18:10 | a | $@ | main.rs:13:5:13:13 | source(...) | source(...) |
@@ -241,6 +281,8 @@ testFailures
| main.rs:144:10:144:10 | b | main.rs:142:13:142:22 | source(...) | main.rs:144:10:144:10 | b | $@ | main.rs:142:13:142:22 | source(...) | source(...) |
| main.rs:163:10:163:10 | m | main.rs:161:24:161:33 | source(...) | main.rs:163:10:163:10 | m | $@ | main.rs:161:24:161:33 | source(...) | source(...) |
| main.rs:189:10:189:16 | d.value | main.rs:186:28:186:36 | source(...) | main.rs:189:10:189:16 | d.value | $@ | main.rs:186:28:186:36 | source(...) | source(...) |
| main.rs:214:10:214:10 | c | main.rs:211:28:211:36 | source(...) | main.rs:214:10:214:10 | c | $@ | main.rs:211:28:211:36 | source(...) | source(...) |
| main.rs:219:10:219:10 | c | main.rs:217:28:217:37 | source(...) | main.rs:219:10:219:10 | c | $@ | main.rs:217:28:217:37 | source(...) | source(...) |
| main.rs:229:10:229:10 | a | main.rs:228:13:228:21 | source(...) | main.rs:229:10:229:10 | a | $@ | main.rs:228:13:228:21 | source(...) | source(...) |
| main.rs:239:14:239:14 | c | main.rs:238:17:238:25 | source(...) | main.rs:239:14:239:14 | c | $@ | main.rs:238:17:238:25 | source(...) | source(...) |
| main.rs:247:10:247:10 | a | main.rs:228:13:228:21 | source(...) | main.rs:247:10:247:10 | a | $@ | main.rs:228:13:228:21 | source(...) | source(...) |

View File

@@ -211,12 +211,12 @@ fn data_through_trait_method_called_as_function() {
let a = MyInt { value: source(8) };
let b = MyInt { value: 2 };
let MyInt { value: c } = MyTrait::take_self(a, b);
sink(c); // $ MISSING: hasValueFlow=8
sink(c); // $ hasValueFlow=8
let a = MyInt { value: 0 };
let b = MyInt { value: source(37) };
let MyInt { value: c } = MyTrait::take_second(a, b);
sink(c); // $ MISSING: hasValueFlow=37
sink(c); // $ hasValueFlow=37
let a = MyInt { value: 0 };
let b = MyInt { value: source(38) };

View File

@@ -48,10 +48,13 @@
| main.rs:188:13:188:20 | a.add(...) | main.rs:169:5:172:5 | fn add |
| main.rs:189:5:189:17 | sink(...) | main.rs:5:1:7:1 | fn sink |
| main.rs:211:28:211:36 | source(...) | main.rs:1:1:3:1 | fn source |
| main.rs:213:30:213:53 | ...::take_self(...) | main.rs:201:5:203:5 | fn take_self |
| main.rs:214:5:214:11 | sink(...) | main.rs:5:1:7:1 | fn sink |
| main.rs:217:28:217:37 | source(...) | main.rs:1:1:3:1 | fn source |
| main.rs:218:30:218:55 | ...::take_second(...) | main.rs:205:5:207:5 | fn take_second |
| main.rs:219:5:219:11 | sink(...) | main.rs:5:1:7:1 | fn sink |
| main.rs:222:28:222:37 | source(...) | main.rs:1:1:3:1 | fn source |
| main.rs:223:30:223:53 | ...::take_self(...) | main.rs:201:5:203:5 | fn take_self |
| main.rs:224:5:224:11 | sink(...) | main.rs:5:1:7:1 | fn sink |
| main.rs:228:13:228:21 | source(...) | main.rs:1:1:3:1 | fn source |
| main.rs:229:5:229:11 | sink(...) | main.rs:5:1:7:1 | fn sink |

View File

@@ -90,6 +90,32 @@ mod method_impl {
}
}
mod trait_impl {
#[derive(Debug)]
struct MyThing {
field: bool,
}
trait MyTrait<B> {
fn trait_method(self) -> B;
}
impl MyTrait<bool> for MyThing {
// MyThing::trait_method
fn trait_method(self) -> bool {
self.field // $ fieldof=MyThing
}
}
pub fn f() {
let x = MyThing { field: true };
let a = x.trait_method(); // $ type=a:bool method=MyThing::trait_method
let y = MyThing { field: false };
let b = MyTrait::trait_method(y); // $ type=b:bool method=MyThing::trait_method
}
}
mod method_non_parametric_impl {
#[derive(Debug)]
struct MyThing<A> {

View File

@@ -43,7 +43,7 @@ module ResolveTest implements TestSig {
source.fromSource() and
not source.isFromMacroExpansion()
|
target = source.(MethodCallExpr).getStaticTarget() and
target = resolveMethodCallTarget(source) and
functionHasValue(target, value) and
tag = "method"
or