Merge pull request #21027 from hvitved/rust/type-inference-matching-specialization

Rust: Also use specialized types when inferring types for calls
This commit is contained in:
Tom Hvitved
2025-12-17 11:03:44 +01:00
committed by GitHub
4 changed files with 5290 additions and 5102 deletions

View File

@@ -1,6 +1,7 @@
/** Provides functionality for inferring types. */
private import codeql.util.Boolean
private import codeql.util.Option
private import rust
private import PathResolution
private import Type
@@ -234,6 +235,107 @@ private class NonMethodFunction extends Function {
NonMethodFunction() { not this.hasSelfParam() }
}
private module ImplOrTraitItemNodeOption = Option<ImplOrTraitItemNode>;
private class ImplOrTraitItemNodeOption = ImplOrTraitItemNodeOption::Option;
private class FunctionDeclaration extends Function {
private ImplOrTraitItemNodeOption parent;
FunctionDeclaration() {
not this = any(ImplOrTraitItemNode i).getAnAssocItem() and parent.isNone()
or
this = parent.asSome().getASuccessor(_)
}
/** Holds if this function is associated with `i`. */
predicate isAssoc(ImplOrTraitItemNode i) { i = parent.asSome() }
/** Holds if this is a free function. */
predicate isFree() { parent.isNone() }
/** Holds if this function is valid for `i`. */
predicate isFor(ImplOrTraitItemNodeOption i) { i = parent }
/**
* Holds if this function is valid for `i`. If `i` is a trait or `impl` block then
* this function must be declared directly inside `i`.
*/
predicate isDirectlyFor(ImplOrTraitItemNodeOption i) {
i.isNone() and
this.isFree()
or
this = i.asSome().getAnAssocItem()
}
TypeParameter getTypeParameter(ImplOrTraitItemNodeOption i, TypeParameterPosition ppos) {
i = parent and
(
typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos)
or
typeParamMatchPosition(i.asSome().getTypeParam(_), result, ppos)
or
ppos.isImplicit() and result = TSelfTypeParameter(i.asSome())
or
ppos.isImplicit() and result.(AssociatedTypeTypeParameter).getTrait() = i.asSome()
or
ppos.isImplicit() and this = result.(ImplTraitTypeTypeParameter).getFunction()
)
}
pragma[nomagic]
Type getParameterType(ImplOrTraitItemNodeOption i, FunctionPosition pos, TypePath path) {
i = parent and
(
not pos.isReturn() and
result = getAssocFunctionTypeAt(this, i.asSome(), pos, path)
or
i.isNone() and
result = this.getParam(pos.asPosition()).getTypeRepr().(TypeMention).resolveTypeAt(path)
)
}
private Type resolveRetType(ImplOrTraitItemNodeOption i, TypePath path) {
i = parent and
(
result =
getAssocFunctionTypeAt(this, i.asSome(), any(FunctionPosition pos | pos.isReturn()), path)
or
i.isNone() and
result = getReturnTypeMention(this).resolveTypeAt(path)
)
}
Type getReturnType(ImplOrTraitItemNodeOption i, TypePath path) {
if this.isAsync()
then
i = parent and
path.isEmpty() and
result = getFutureTraitType()
or
exists(TypePath suffix |
result = this.resolveRetType(i, suffix) and
path = TypePath::cons(getDynFutureOutputTypeParameter(), suffix)
)
else result = this.resolveRetType(i, path)
}
Type getDeclaredType(ImplOrTraitItemNodeOption i, FunctionPosition pos, TypePath path) {
result = this.getParameterType(i, pos, path)
or
pos.isReturn() and
result = this.getReturnType(i, path)
}
string toStringExt(ImplOrTraitItemNode i) {
i = parent.asSome() and
if this = i.getAnAssocItem()
then result = this.toString()
else
result = this + " [" + [i.(Impl).getSelfTy().toString(), i.(Trait).getName().toString()] + "]"
}
}
pragma[nomagic]
private TypeMention getCallExprTypeMentionArgument(CallExpr ce, TypeArgumentPosition apos) {
exists(Path p, int i | p = CallExprImpl::getFunctionPath(ce) |
@@ -308,13 +410,12 @@ module CertainTypeInference {
}
pragma[nomagic]
private Type getCallExprType(CallExpr ce, Path p, Function f, TypePath tp) {
callResolvesTo(ce, p, f) and
result =
[
f.(MethodCallMatchingInput::Declaration).getReturnType(tp),
f.(NonMethodCallMatchingInput::Declaration).getReturnType(tp)
]
private Type getCallExprType(CallExpr ce, Path p, FunctionDeclaration f, TypePath path) {
exists(ImplOrTraitItemNodeOption i |
callResolvesTo(ce, p, f) and
result = f.getReturnType(i, path) and
f.isDirectlyFor(i)
)
}
pragma[nomagic]
@@ -2084,62 +2185,36 @@ private module MethodResolution {
private module MethodCallMatchingInput implements MatchingWithEnvironmentInputSig {
import FunctionPositionMatchingInput
final class Declaration extends Function {
private class MethodDeclaration extends Method, FunctionDeclaration { }
private newtype TDeclaration =
TMethodFunctionDeclaration(ImplOrTraitItemNode i, MethodDeclaration m) { m.isAssoc(i) }
final class Declaration extends TMethodFunctionDeclaration {
ImplOrTraitItemNode parent;
ImplOrTraitItemNodeOption someParent;
MethodDeclaration m;
Declaration() {
this = TMethodFunctionDeclaration(parent, m) and
someParent.asSome() = parent
}
predicate isMethod(ImplOrTraitItemNode i, Method method) {
this = TMethodFunctionDeclaration(i, method)
}
TypeParameter getTypeParameter(TypeParameterPosition ppos) {
typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos)
or
exists(ImplOrTraitItemNode i | this = i.getAnAssocItem() |
typeParamMatchPosition(i.getTypeParam(_), result, ppos)
or
ppos.isImplicit() and result = TSelfTypeParameter(i)
or
ppos.isImplicit() and
result.(AssociatedTypeTypeParameter).getTrait() = i
)
or
ppos.isImplicit() and
this = result.(ImplTraitTypeTypeParameter).getFunction()
}
pragma[nomagic]
Type getParameterType(DeclarationPosition dpos, TypePath path) {
exists(Param p, int i |
p = this.getParam(i) and
i = dpos.asPosition() and
result = p.getTypeRepr().(TypeMention).resolveTypeAt(path)
)
or
dpos.isSelf() and
exists(SelfParam self |
self = pragma[only_bind_into](this.getSelfParam()) and
result = getSelfParamTypeMention(self).resolveTypeAt(path)
)
}
private Type resolveRetType(TypePath path) {
result = getReturnTypeMention(this).resolveTypeAt(path)
}
pragma[nomagic]
Type getReturnType(TypePath path) {
if this.isAsync()
then
path.isEmpty() and
result = getFutureTraitType()
or
exists(TypePath suffix |
result = this.resolveRetType(suffix) and
path = TypePath::cons(getDynFutureOutputTypeParameter(), suffix)
)
else result = this.resolveRetType(path)
result = m.getTypeParameter(someParent, ppos)
}
Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
result = this.getParameterType(dpos, path)
or
dpos.isReturn() and
result = this.getReturnType(path)
result = m.getDeclaredType(someParent, dpos, path)
}
string toString() { result = m.toStringExt(parent) }
Location getLocation() { result = m.getLocation() }
}
class AccessEnvironment = string;
@@ -2208,14 +2283,19 @@ private module MethodCallMatchingInput implements MatchingWithEnvironmentInputSi
result = this.getInferredNonSelfType(apos, path)
}
Declaration getTarget(ImplOrTraitItemNode i, string derefChainBorrow) {
Method getTarget(ImplOrTraitItemNode i, string derefChainBorrow) {
exists(string derefChain, boolean borrow |
derefChainBorrow = encodeDerefChainBorrow(derefChain, borrow) and
result = this.resolveCallTarget(i, derefChain, borrow) // mutual recursion; resolving method calls requires resolving types and vice versa
)
}
Declaration getTarget(string derefChainBorrow) { result = this.getTarget(_, derefChainBorrow) }
Declaration getTarget(string derefChainBorrow) {
exists(ImplOrTraitItemNode i, Method m |
m = this.getTarget(i, derefChainBorrow) and
result = TMethodFunctionDeclaration(i, m)
)
}
/**
* Holds if the return type of this call at `path` may have to be inferred
@@ -2467,13 +2547,6 @@ private module NonMethodResolution {
NonMethodArgsAreInstantiationsOf::argsAreInstantiationsOf(this, i, result)
}
pragma[inline]
ItemNode resolveCallTarget() {
result = this.resolveCallTargetViaPathResolution()
or
result = this.resolveCallTargetViaTypeInference(_)
}
pragma[nomagic]
NonMethodFunction resolveTraitFunctionViaPathResolution(TraitItemNode trait) {
this.hasTrait() and
@@ -2594,6 +2667,72 @@ private module NonMethodResolution {
ArgsAreInstantiationsOf<NonMethodArgsAreInstantiationsOfInput>;
}
abstract private class TupleLikeConstructor extends Addressable {
abstract TypeParameter getTypeParameter(TypeParameterPosition ppos);
abstract Type getParameterType(FunctionPosition pos, TypePath path);
abstract Type getReturnType(TypePath path);
Type getDeclaredType(FunctionPosition pos, TypePath path) {
result = this.getParameterType(pos, path)
or
pos.isReturn() and
result = this.getReturnType(path)
or
pos.isSelf() and
result = this.getReturnType(path)
}
}
private class TupleStruct extends TupleLikeConstructor, Struct {
TupleStruct() { this.isTuple() }
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos)
}
override Type getParameterType(FunctionPosition pos, TypePath path) {
exists(int i |
result = this.getTupleField(i).getTypeRepr().(TypeMention).resolveTypeAt(path) and
i = pos.asPosition()
)
}
override Type getReturnType(TypePath path) {
result = TStruct(this) and
path.isEmpty()
or
result = TTypeParamTypeParameter(this.getGenericParamList().getATypeParam()) and
path = TypePath::singleton(result)
}
}
private class TupleVariant extends TupleLikeConstructor, Variant {
TupleVariant() { this.isTuple() }
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
typeParamMatchPosition(this.getEnum().getGenericParamList().getATypeParam(), result, ppos)
}
override Type getParameterType(FunctionPosition pos, TypePath path) {
exists(int i |
result = this.getTupleField(i).getTypeRepr().(TypeMention).resolveTypeAt(path) and
i = pos.asPosition()
)
}
override Type getReturnType(TypePath path) {
exists(Enum enum | enum = this.getEnum() |
result = TEnum(enum) and
path.isEmpty()
or
result = TTypeParamTypeParameter(enum.getGenericParamList().getATypeParam()) and
path = TypePath::singleton(result)
)
}
}
/**
* A matching configuration for resolving types of calls like
* `foo::bar(baz)` where the target is not a method.
@@ -2604,7 +2743,15 @@ private module NonMethodResolution {
private module NonMethodCallMatchingInput implements MatchingInputSig {
import FunctionPositionMatchingInput
abstract class Declaration extends AstNode {
private class NonMethodFunctionDeclaration extends NonMethodFunction, FunctionDeclaration { }
private newtype TDeclaration =
TNonMethodFunctionDeclaration(ImplOrTraitItemNodeOption i, NonMethodFunctionDeclaration f) {
f.isFor(i)
} or
TTupleLikeConstructorDeclaration(TupleLikeConstructor tc)
abstract class Declaration extends TDeclaration {
abstract TypeParameter getTypeParameter(TypeParameterPosition ppos);
pragma[nomagic]
@@ -2618,69 +2765,20 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
dpos.isReturn() and
result = this.getReturnType(path)
}
abstract string toString();
abstract Location getLocation();
}
abstract additional class TupleDeclaration extends Declaration {
override Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
result = super.getDeclaredType(dpos, path)
or
dpos.isSelf() and
result = this.getReturnType(path)
}
}
private class NonMethodFunctionDecl extends Declaration, TNonMethodFunctionDeclaration {
private ImplOrTraitItemNodeOption i;
private NonMethodFunctionDeclaration f;
private class TupleStructDecl extends TupleDeclaration, Struct {
TupleStructDecl() { this.isTuple() }
NonMethodFunctionDecl() { this = TNonMethodFunctionDeclaration(i, f) }
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos)
}
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
exists(int pos |
result = this.getTupleField(pos).getTypeRepr().(TypeMention).resolveTypeAt(path) and
pos = dpos.asPosition()
)
}
override Type getReturnType(TypePath path) {
result = TStruct(this) and
path.isEmpty()
or
result = TTypeParamTypeParameter(this.getGenericParamList().getATypeParam()) and
path = TypePath::singleton(result)
}
}
private class TupleVariantDecl extends TupleDeclaration, Variant {
TupleVariantDecl() { this.isTuple() }
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
typeParamMatchPosition(this.getEnum().getGenericParamList().getATypeParam(), result, ppos)
}
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
exists(int pos |
result = this.getTupleField(pos).getTypeRepr().(TypeMention).resolveTypeAt(path) and
pos = dpos.asPosition()
)
}
override Type getReturnType(TypePath path) {
exists(Enum enum | enum = this.getEnum() |
result = TEnum(enum) and
path.isEmpty()
or
result = TTypeParamTypeParameter(enum.getGenericParamList().getATypeParam()) and
path = TypePath::singleton(result)
)
}
}
private class NonMethodFunctionDecl extends Declaration, NonMethodFunction instanceof MethodCallMatchingInput::Declaration
{
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
result = MethodCallMatchingInput::Declaration.super.getTypeParameter(ppos)
result = f.getTypeParameter(i, ppos)
}
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
@@ -2701,20 +2799,42 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
//
// we need to match `i32` against the type parameter `T` of the `impl` block.
dpos.isSelf() and
exists(ImplOrTraitItemNode i |
this = i.getAnAssocItem() and
result = resolveImplOrTraitType(i, path)
)
result = resolveImplOrTraitType(i.asSome(), path)
or
exists(FunctionPosition fpos |
result = MethodCallMatchingInput::Declaration.super.getParameterType(fpos, path) and
dpos = fpos.getFunctionCallAdjusted(this)
)
result = f.getParameterType(i, dpos, path)
}
override Type getReturnType(TypePath path) {
result = MethodCallMatchingInput::Declaration.super.getReturnType(path)
override Type getReturnType(TypePath path) { result = f.getReturnType(i, path) }
override string toString() {
i.isNone() and result = f.toString()
or
result = f.toStringExt(i.asSome())
}
override Location getLocation() { result = f.getLocation() }
}
private class TupleLikeConstructorDeclaration extends Declaration,
TTupleLikeConstructorDeclaration
{
TupleLikeConstructor tc;
TupleLikeConstructorDeclaration() { this = TTupleLikeConstructorDeclaration(tc) }
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
result = tc.getTypeParameter(ppos)
}
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
result = tc.getParameterType(dpos, path)
}
override Type getReturnType(TypePath path) { result = tc.getReturnType(path) }
override string toString() { result = tc.toString() }
override Location getLocation() { result = tc.getLocation() }
}
class Access extends NonMethodResolution::NonMethodCall, ContextTyping::ContextTypedCallCand {
@@ -2731,8 +2851,22 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
result = inferType(this.getNodeAt(apos), path)
}
pragma[inline]
Declaration getTarget() {
result = this.resolveCallTarget() // potential mutual recursion; resolving some associated function calls requires resolving types
exists(ImplOrTraitItemNodeOption i, NonMethodFunctionDeclaration f |
result = TNonMethodFunctionDeclaration(i, f)
|
f = this.resolveCallTargetViaTypeInference(i.asSome()) // mutual recursion; resolving some associated function calls requires resolving types
or
f = this.resolveTraitFunctionViaPathResolution(i.asSome())
or
f = this.resolveCallTargetViaPathResolution() and
f.isDirectlyFor(i)
)
or
exists(ItemNode i | i = this.resolveCallTargetViaPathResolution() |
result = TTupleLikeConstructorDeclaration(i)
)
}
/**
@@ -2741,21 +2875,17 @@ private module NonMethodCallMatchingInput implements MatchingInputSig {
*/
pragma[nomagic]
predicate hasUnknownTypeAt(FunctionPosition pos, TypePath path) {
exists(ImplOrTraitItemNode i |
this.hasUnknownTypeAt(i,
[
this.resolveCallTargetViaPathResolution().(NonMethodFunction),
this.resolveCallTargetViaTypeInference(i),
this.resolveTraitFunctionViaPathResolution(i)
], pos, path)
exists(ImplOrTraitItemNodeOption i, NonMethodFunctionDeclaration f |
TNonMethodFunctionDeclaration(i, f) = this.getTarget() and
this.hasUnknownTypeAt(i.asSome(), f, pos, path)
)
or
// Tuple declarations, such as `Result::Ok(...)`, may also be context typed
exists(TupleDeclaration td, TypeParameter tp |
td = this.resolveCallTargetViaPathResolution() and
exists(TupleLikeConstructor tc, TypeParameter tp |
tc = this.resolveCallTargetViaPathResolution() and
pos.isReturn() and
tp = td.getReturnType(path) and
not tp = td.getParameterType(_, _) and
tp = tc.getReturnType(path) and
not tp = tc.getParameterType(_, _) and
// check that no explicit type arguments have been supplied for `tp`
not exists(TypeArgumentPosition tapos |
exists(this.getTypeArgument(tapos, _)) and
@@ -2793,9 +2923,9 @@ private module OperationMatchingInput implements MatchingInputSig {
class Declaration extends MethodCallMatchingInput::Declaration {
private Method getSelfOrImpl() {
result = this
result = m
or
this.implements(result)
m.implements(result)
}
pragma[nomagic]
@@ -2812,30 +2942,19 @@ private module OperationMatchingInput implements MatchingInputSig {
)
}
pragma[nomagic]
private Type getParameterType(DeclarationPosition dpos, TypePath path) {
exists(TypePath path0 |
result = super.getParameterType(dpos, path0) and
if this.borrowsAt(dpos) then path0.isCons(getRefTypeParameter(), path) else path0 = path
)
}
pragma[nomagic]
private predicate derefsReturn() { this.getSelfOrImpl() = any(DerefTrait t).getDerefFunction() }
pragma[nomagic]
private Type getReturnType(TypePath path) {
exists(TypePath path0 |
result = super.getReturnType(path0) and
if this.derefsReturn() then path0.isCons(getRefTypeParameter(), path) else path0 = path
)
}
Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
result = this.getParameterType(dpos, path)
or
dpos.isReturn() and
result = this.getReturnType(path)
exists(TypePath path0 |
result = super.getDeclaredType(dpos, path0) and
if
this.borrowsAt(dpos)
or
dpos.isReturn() and this.derefsReturn()
then path0.isCons(getRefTypeParameter(), path)
else path0 = path
)
}
}
@@ -2848,7 +2967,9 @@ private module OperationMatchingInput implements MatchingInputSig {
}
Declaration getTarget() {
result = this.resolveCallTarget(_, _, _) // mutual recursion
exists(ImplOrTraitItemNode i |
result.isMethod(i, this.resolveCallTarget(i, _, _)) // mutual recursion
)
}
}
}
@@ -3315,7 +3436,7 @@ private Type inferStructPatType(AstNode n, TypePath path) {
private module TupleStructPatMatchingInput implements MatchingInputSig {
import FunctionPositionMatchingInput
class Declaration = NonMethodCallMatchingInput::TupleDeclaration;
class Declaration = TupleLikeConstructor;
class Access extends TupleStructPat {
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }
@@ -3408,12 +3529,9 @@ private Type inferForLoopExprType(AstNode n, TypePath path) {
* first-class function.
*/
final private class InvokedClosureExpr extends Expr {
private CallExpr call;
private CallExprImpl::DynamicCallExpr call;
InvokedClosureExpr() {
call.getFunction() = this and
(not this instanceof PathExpr or this = any(Variable v).getAnAccess())
}
InvokedClosureExpr() { call.getFunction() = this }
Type getTypeAt(TypePath path) { result = inferType(this, path) }

View File

@@ -15,21 +15,21 @@ multipleResolvedTargets
| invalid/main.rs:76:13:76:17 | * ... |
| main.rs:1077:14:1077:18 | * ... |
| main.rs:1159:26:1159:30 | * ... |
| main.rs:1504:14:1504:21 | * ... |
| main.rs:1504:16:1504:20 | * ... |
| main.rs:1509:14:1509:18 | * ... |
| main.rs:1540:27:1540:29 | * ... |
| main.rs:1654:17:1654:24 | * ... |
| main.rs:1654:18:1654:24 | * ... |
| main.rs:1792:17:1792:21 | * ... |
| main.rs:1807:28:1807:32 | * ... |
| main.rs:2440:13:2440:18 | * ... |
| main.rs:1503:14:1503:21 | * ... |
| main.rs:1503:16:1503:20 | * ... |
| main.rs:1508:14:1508:18 | * ... |
| main.rs:1539:27:1539:29 | * ... |
| main.rs:1653:17:1653:24 | * ... |
| main.rs:1653:18:1653:24 | * ... |
| main.rs:1791:17:1791:21 | * ... |
| main.rs:1806:28:1806:32 | * ... |
| main.rs:2439:13:2439:18 | * ... |
| main.rs:2633:13:2633:31 | ...::from(...) |
| main.rs:2634:13:2634:31 | ...::from(...) |
| main.rs:2635:13:2635:31 | ...::from(...) |
| main.rs:2636:13:2636:31 | ...::from(...) |
| main.rs:2641:13:2641:31 | ...::from(...) |
| main.rs:2642:13:2642:31 | ...::from(...) |
| main.rs:2643:13:2643:31 | ...::from(...) |
| main.rs:2644:13:2644:31 | ...::from(...) |
| main.rs:3067:13:3067:17 | x.f() |
| main.rs:3072:13:3072:17 | x.f() |
| pattern_matching.rs:273:13:273:27 | * ... |
| pattern_matching.rs:273:14:273:27 | * ... |

View File

@@ -1424,7 +1424,6 @@ mod option_methods {
x2.set(S); // $ target=MyOption::set
println!("{:?}", x2);
// missing type `S` from `MyOption<S>` (but can resolve `MyTrait<S>`)
let mut x3 = MyOption::new(); // $ target=new
x3.call_set(S); // $ target=call_set
println!("{:?}", x3);
@@ -3038,7 +3037,13 @@ mod context_typed {
mod literal_overlap {
trait MyTrait {
// MyTrait::f
fn f(self) -> Self;
// MyTrait::g
fn g(&self, other: &Self) -> &Self {
self.f() // $ target=Reff
}
}
impl MyTrait for i32 {
@@ -3067,6 +3072,16 @@ mod literal_overlap {
x = x.f(); // $ target=usizef $ SPURIOUS: target=i32f
x
}
fn g() {
let x: usize = 0;
let y = &1;
let z = x.g(y); // $ target=MyTrait::g
let x = 0; // $ SPURIOUS: type=x:i32 $ MISSING: type=x:usize
let y: usize = 1;
let z = x.max(y); // $ target=max
}
}
mod blanket_impl;