Merge pull request #19038 from paldepind/rust-type-inference-tweaks

Rust: Small type inference tweaks
This commit is contained in:
Simon Friis Vindum
2025-03-17 14:09:08 +01:00
committed by GitHub
4 changed files with 1048 additions and 898 deletions

View File

@@ -578,14 +578,9 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}
Declaration getTarget() {
result =
[
CallExprImpl::getResolvedFunction(this).(AstNode),
this.(CallExpr).getStruct(),
this.(CallExpr).getVariant(),
// mutual recursion; resolving method calls requires resolving types and vice versa
resolveMethodCallExpr(this)
]
result = CallExprImpl::getResolvedFunction(this)
or
result = resolveMethodCallExpr(this) // mutual recursion; resolving method calls requires resolving types and vice versa
}
}
@@ -908,7 +903,7 @@ private module Cached {
}
/**
* Gets a method that the method call `mce` infers to, if any.
* Gets a method that the method call `mce` resolves to, if any.
*/
cached
Function resolveMethodCallExpr(MethodCallExpr mce) {
@@ -922,7 +917,7 @@ private module Cached {
}
/**
* Gets the record field that the field expression `fe` infers to, if any.
* Gets the record field that the field expression `fe` resolves to, if any.
*/
cached
RecordField resolveRecordFieldExpr(FieldExpr fe) {
@@ -938,7 +933,7 @@ private module Cached {
}
/**
* Gets the tuple field that the field expression `fe` infers to, if any.
* Gets the tuple field that the field expression `fe` resolves to, if any.
*/
cached
TupleField resolveTupleFieldExpr(FieldExpr fe) {

View File

@@ -1,4 +1,69 @@
mod m1 {
mod field_access {
#[derive(Debug)]
struct S;
#[derive(Debug)]
struct MyThing {
a: S,
}
#[derive(Debug)]
enum MyOption<T> {
MyNone(),
MySome(T),
}
#[derive(Debug)]
struct GenericThing<A> {
a: A,
}
struct OptionS {
a: MyOption<S>,
}
fn simple_field_access() {
let x = MyThing { a: S };
println!("{:?}", x.a);
}
fn generic_field_access() {
// Explicit type argument
let x = GenericThing::<S> { a: S };
println!("{:?}", x.a);
// Implicit type argument
let y = GenericThing { a: S };
println!("{:?}", x.a);
// The type of the field `a` can only be infered from the concrete type
// in the struct declaration.
let x = OptionS {
a: MyOption::MyNone(),
};
println!("{:?}", x.a);
// The type of the field `a` can only be infered from the type argument
let x = GenericThing::<MyOption<S>> {
a: MyOption::MyNone(),
};
println!("{:?}", x.a);
let mut x = GenericThing {
a: MyOption::MyNone(),
};
// Only after this access can we infer the type parameter of `x`
let a: MyOption<S> = x.a;
println!("{:?}", a);
}
pub fn f() {
simple_field_access();
generic_field_access();
}
}
mod method_impl {
pub struct Foo {}
impl Foo {
@@ -25,7 +90,7 @@ mod m1 {
}
}
mod m2 {
mod method_non_parametric_impl {
#[derive(Debug)]
struct MyThing<A> {
a: A,
@@ -58,6 +123,10 @@ mod m2 {
let x = MyThing { a: S1 };
let y = MyThing { a: S2 };
// simple field access
println!("{:?}", x.a);
println!("{:?}", y.a);
println!("{:?}", x.m1()); // missing call target
println!("{:?}", y.m1().a); // missing call target
@@ -69,7 +138,7 @@ mod m2 {
}
}
mod m3 {
mod method_non_parametric_trait_impl {
#[derive(Debug)]
struct MyThing<A> {
a: A,
@@ -122,7 +191,7 @@ mod m3 {
}
}
mod m4 {
mod function_trait_bounds {
#[derive(Debug)]
struct MyThing<A> {
a: A,
@@ -175,7 +244,7 @@ mod m4 {
}
}
mod m5 {
mod trait_associated_type {
trait MyTrait {
type AssociatedType;
@@ -210,7 +279,7 @@ mod m5 {
}
}
mod m6 {
mod generic_enum {
#[derive(Debug)]
enum MyEnum<A> {
C1(A),
@@ -240,7 +309,7 @@ mod m6 {
}
}
mod m7 {
mod method_supertraits {
#[derive(Debug)]
struct MyThing<A> {
a: A,
@@ -325,7 +394,7 @@ mod m7 {
}
}
mod m8 {
mod function_trait_bounds_2 {
use std::convert::From;
use std::fmt::Debug;
@@ -374,7 +443,7 @@ mod m8 {
}
}
mod m9 {
mod option_methods {
#[derive(Debug)]
enum MyOption<T> {
MyNone(),
@@ -456,7 +525,7 @@ mod m9 {
}
}
mod m10 {
mod method_call_type_conversion {
#[derive(Debug, Copy, Clone)]
struct S<T>(T);
@@ -508,7 +577,7 @@ mod m10 {
}
}
mod m11 {
mod trait_implicit_self_borrow {
trait MyTrait {
fn foo(&self) -> &Self;
@@ -531,7 +600,7 @@ mod m11 {
}
}
mod m12 {
mod implicit_self_borrow {
struct S;
struct MyStruct<T>(T);
@@ -548,7 +617,7 @@ mod m12 {
}
}
mod m13 {
mod borrowed_typed {
struct S;
impl S {
@@ -578,18 +647,19 @@ mod m13 {
}
fn main() {
m1::f();
m1::g(m1::Foo {}, m1::Foo {});
m2::f();
m3::f();
m4::f();
m5::f();
m6::f();
m7::f();
m8::f();
m9::f();
m10::f();
m11::f();
m12::f();
m13::f();
field_access::f();
method_impl::f();
method_impl::g(method_impl::Foo {}, method_impl::Foo {});
method_non_parametric_impl::f();
method_non_parametric_trait_impl::f();
function_trait_bounds::f();
trait_associated_type::f();
generic_enum::f();
method_supertraits::f();
function_trait_bounds_2::f();
option_methods::f();
method_call_type_conversion::f();
trait_implicit_self_borrow::f();
implicit_self_borrow::f();
borrowed_typed::f();
}

View File

@@ -81,21 +81,24 @@ signature module InputSig1<LocationSig Location> {
module Make1<LocationSig Location, InputSig1<Location> Input1> {
private import Input1
private import codeql.util.DenseRank
private module DenseRankInput implements DenseRankInputSig {
class Ranked = TypeParameter;
private module TypeParameter {
private import codeql.util.DenseRank
predicate getRank = getTypeParameterId/1;
}
private module DenseRankInput implements DenseRankInputSig {
class Ranked = TypeParameter;
private int getTypeParameterRank(TypeParameter tp) {
tp = DenseRank<DenseRankInput>::denseRank(result)
}
predicate getRank = getTypeParameterId/1;
}
bindingset[s]
private predicate decodeTypePathComponent(string s, TypeParameter tp) {
getTypeParameterRank(tp) = s.toInt()
private int getTypeParameterRank(TypeParameter tp) {
tp = DenseRank<DenseRankInput>::denseRank(result)
}
string encode(TypeParameter tp) { result = getTypeParameterRank(tp).toString() }
bindingset[s]
TypeParameter decode(string s) { encode(result) = s }
}
final private class String = string;
@@ -132,10 +135,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
bindingset[this]
private TypeParameter getTypeParameter(int i) {
exists(string s |
s = this.splitAt(".", i) and
decodeTypePathComponent(s, result)
)
result = TypeParameter::decode(this.splitAt(".", i))
}
/** Gets a textual representation of this type path. */
@@ -180,13 +180,13 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
/** Holds if this path starts with `tp`, followed by `suffix`. */
bindingset[this]
predicate isCons(TypeParameter tp, TypePath suffix) {
decodeTypePathComponent(this, tp) and
tp = TypeParameter::decode(this) and
suffix.isEmpty()
or
exists(int first |
first = min(this.indexOf(".")) and
suffix = this.suffix(first + 1) and
decodeTypePathComponent(this.prefix(first), tp)
tp = TypeParameter::decode(this.prefix(first))
)
}
}
@@ -197,7 +197,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
TypePath nil() { result.isEmpty() }
/** Gets the singleton type path `tp`. */
TypePath singleton(TypeParameter tp) { result = getTypeParameterRank(tp).toString() }
TypePath singleton(TypeParameter tp) { result = TypeParameter::encode(tp) }
/**
* Gets the type path obtained by appending the singleton type path `tp`
@@ -559,11 +559,11 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
private predicate directTypeMatch(
Access a, Declaration target, TypePath path, Type t, TypeParameter tp
) {
not exists(getTypeArgument(a, target, tp, _)) and
exists(AccessPosition apos, DeclarationPosition dpos, TypePath pathToTypeParam |
adjustedAccessType(a, apos, target, pathToTypeParam.append(path), t) and
tp = target.getDeclaredType(dpos, pathToTypeParam) and
not exists(getTypeArgument(a, target, tp, _)) and
accessDeclarationPositionMatch(apos, dpos)
accessDeclarationPositionMatch(apos, dpos) and
adjustedAccessType(a, apos, target, pathToTypeParam.append(path), t)
)
}
@@ -672,7 +672,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
*
* class Sub<T4> : Mid<C<T4>> { }
*
* new Sub<int>().Method();
* new Sub<int>().Method(); // Note: `Sub<int>` is a subtype of `Base<C<C<int>>>`
* // ^^^^^^^^^^^^^^^^^^^^^^^ `a`
* ```
*
@@ -688,14 +688,19 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
private predicate baseTypeMatch(
Access a, Declaration target, TypePath path, Type t, TypeParameter tp
) {
not exists(getTypeArgument(a, target, tp, _)) and
exists(AccessPosition apos, DeclarationPosition dpos, Type base, TypePath pathToTypeParam |
accessBaseType(a, apos, target, base, pathToTypeParam.append(path), t) and
declarationBaseType(target, dpos, base, pathToTypeParam, tp) and
not exists(getTypeArgument(a, target, tp, _)) and
accessDeclarationPositionMatch(apos, dpos)
)
}
/**
* Holds if for `a` and corresponding `target`, the type parameter `tp` is
* matched by a type argument at the access with type `t` and type path
* `path`.
*/
pragma[nomagic]
private predicate explicitTypeMatch(
Access a, Declaration target, TypePath path, Type t, TypeParameter tp
@@ -708,8 +713,10 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
private predicate implicitTypeMatch(
Access a, Declaration target, TypePath path, Type t, TypeParameter tp
) {
// We can get the type of `tp` from one of the access positions
directTypeMatch(a, target, path, t, tp)
or
// We can get the type of `tp` by going up the type hiearchy
baseTypeMatch(a, target, path, t, tp)
}
@@ -717,8 +724,12 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
private predicate typeMatch(
Access a, Declaration target, TypePath path, Type t, TypeParameter tp
) {
// A type given at the access corresponds directly to the type parameter
// at the target.
explicitTypeMatch(a, target, path, t, tp)
or
// No explicit type argument, so we deduce the parameter from other
// information
implicitTypeMatch(a, target, path, t, tp)
}
@@ -763,12 +774,14 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
pragma[nomagic]
Type inferAccessType(Access a, AccessPosition apos, TypePath path) {
exists(DeclarationPosition dpos | accessDeclarationPositionMatch(apos, dpos) |
// A suffix of `path` leads to a type parameter in the target
exists(Declaration target, TypePath prefix, TypeParameter tp, TypePath suffix |
tp = target.getDeclaredType(pragma[only_bind_into](dpos), prefix) and
typeMatch(a, target, suffix, result, tp) and
path = prefix.append(suffix)
path = prefix.append(suffix) and
typeMatch(a, target, suffix, result, tp)
)
or
// `path` corresponds directly to a concrete type in the declaration
exists(Declaration target |
result = target.getDeclaredType(pragma[only_bind_into](dpos), path) and
target = a.getTarget() and