mirror of
https://github.com/github/codeql.git
synced 2025-12-16 16:53:25 +01:00
Merge pull request #19038 from paldepind/rust-type-inference-tweaks
Rust: Small type inference tweaks
This commit is contained in:
@@ -578,14 +578,9 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
|
||||
}
|
||||
|
||||
Declaration getTarget() {
|
||||
result =
|
||||
[
|
||||
CallExprImpl::getResolvedFunction(this).(AstNode),
|
||||
this.(CallExpr).getStruct(),
|
||||
this.(CallExpr).getVariant(),
|
||||
// mutual recursion; resolving method calls requires resolving types and vice versa
|
||||
resolveMethodCallExpr(this)
|
||||
]
|
||||
result = CallExprImpl::getResolvedFunction(this)
|
||||
or
|
||||
result = resolveMethodCallExpr(this) // mutual recursion; resolving method calls requires resolving types and vice versa
|
||||
}
|
||||
}
|
||||
|
||||
@@ -908,7 +903,7 @@ private module Cached {
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a method that the method call `mce` infers to, if any.
|
||||
* Gets a method that the method call `mce` resolves to, if any.
|
||||
*/
|
||||
cached
|
||||
Function resolveMethodCallExpr(MethodCallExpr mce) {
|
||||
@@ -922,7 +917,7 @@ private module Cached {
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the record field that the field expression `fe` infers to, if any.
|
||||
* Gets the record field that the field expression `fe` resolves to, if any.
|
||||
*/
|
||||
cached
|
||||
RecordField resolveRecordFieldExpr(FieldExpr fe) {
|
||||
@@ -938,7 +933,7 @@ private module Cached {
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the tuple field that the field expression `fe` infers to, if any.
|
||||
* Gets the tuple field that the field expression `fe` resolves to, if any.
|
||||
*/
|
||||
cached
|
||||
TupleField resolveTupleFieldExpr(FieldExpr fe) {
|
||||
|
||||
@@ -1,4 +1,69 @@
|
||||
mod m1 {
|
||||
mod field_access {
|
||||
#[derive(Debug)]
|
||||
struct S;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MyThing {
|
||||
a: S,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum MyOption<T> {
|
||||
MyNone(),
|
||||
MySome(T),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct GenericThing<A> {
|
||||
a: A,
|
||||
}
|
||||
|
||||
struct OptionS {
|
||||
a: MyOption<S>,
|
||||
}
|
||||
|
||||
fn simple_field_access() {
|
||||
let x = MyThing { a: S };
|
||||
println!("{:?}", x.a);
|
||||
}
|
||||
|
||||
fn generic_field_access() {
|
||||
// Explicit type argument
|
||||
let x = GenericThing::<S> { a: S };
|
||||
println!("{:?}", x.a);
|
||||
|
||||
// Implicit type argument
|
||||
let y = GenericThing { a: S };
|
||||
println!("{:?}", x.a);
|
||||
|
||||
// The type of the field `a` can only be infered from the concrete type
|
||||
// in the struct declaration.
|
||||
let x = OptionS {
|
||||
a: MyOption::MyNone(),
|
||||
};
|
||||
println!("{:?}", x.a);
|
||||
|
||||
// The type of the field `a` can only be infered from the type argument
|
||||
let x = GenericThing::<MyOption<S>> {
|
||||
a: MyOption::MyNone(),
|
||||
};
|
||||
println!("{:?}", x.a);
|
||||
|
||||
let mut x = GenericThing {
|
||||
a: MyOption::MyNone(),
|
||||
};
|
||||
// Only after this access can we infer the type parameter of `x`
|
||||
let a: MyOption<S> = x.a;
|
||||
println!("{:?}", a);
|
||||
}
|
||||
|
||||
pub fn f() {
|
||||
simple_field_access();
|
||||
generic_field_access();
|
||||
}
|
||||
}
|
||||
|
||||
mod method_impl {
|
||||
pub struct Foo {}
|
||||
|
||||
impl Foo {
|
||||
@@ -25,7 +90,7 @@ mod m1 {
|
||||
}
|
||||
}
|
||||
|
||||
mod m2 {
|
||||
mod method_non_parametric_impl {
|
||||
#[derive(Debug)]
|
||||
struct MyThing<A> {
|
||||
a: A,
|
||||
@@ -58,6 +123,10 @@ mod m2 {
|
||||
let x = MyThing { a: S1 };
|
||||
let y = MyThing { a: S2 };
|
||||
|
||||
// simple field access
|
||||
println!("{:?}", x.a);
|
||||
println!("{:?}", y.a);
|
||||
|
||||
println!("{:?}", x.m1()); // missing call target
|
||||
println!("{:?}", y.m1().a); // missing call target
|
||||
|
||||
@@ -69,7 +138,7 @@ mod m2 {
|
||||
}
|
||||
}
|
||||
|
||||
mod m3 {
|
||||
mod method_non_parametric_trait_impl {
|
||||
#[derive(Debug)]
|
||||
struct MyThing<A> {
|
||||
a: A,
|
||||
@@ -122,7 +191,7 @@ mod m3 {
|
||||
}
|
||||
}
|
||||
|
||||
mod m4 {
|
||||
mod function_trait_bounds {
|
||||
#[derive(Debug)]
|
||||
struct MyThing<A> {
|
||||
a: A,
|
||||
@@ -175,7 +244,7 @@ mod m4 {
|
||||
}
|
||||
}
|
||||
|
||||
mod m5 {
|
||||
mod trait_associated_type {
|
||||
trait MyTrait {
|
||||
type AssociatedType;
|
||||
|
||||
@@ -210,7 +279,7 @@ mod m5 {
|
||||
}
|
||||
}
|
||||
|
||||
mod m6 {
|
||||
mod generic_enum {
|
||||
#[derive(Debug)]
|
||||
enum MyEnum<A> {
|
||||
C1(A),
|
||||
@@ -240,7 +309,7 @@ mod m6 {
|
||||
}
|
||||
}
|
||||
|
||||
mod m7 {
|
||||
mod method_supertraits {
|
||||
#[derive(Debug)]
|
||||
struct MyThing<A> {
|
||||
a: A,
|
||||
@@ -325,7 +394,7 @@ mod m7 {
|
||||
}
|
||||
}
|
||||
|
||||
mod m8 {
|
||||
mod function_trait_bounds_2 {
|
||||
use std::convert::From;
|
||||
use std::fmt::Debug;
|
||||
|
||||
@@ -374,7 +443,7 @@ mod m8 {
|
||||
}
|
||||
}
|
||||
|
||||
mod m9 {
|
||||
mod option_methods {
|
||||
#[derive(Debug)]
|
||||
enum MyOption<T> {
|
||||
MyNone(),
|
||||
@@ -456,7 +525,7 @@ mod m9 {
|
||||
}
|
||||
}
|
||||
|
||||
mod m10 {
|
||||
mod method_call_type_conversion {
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
struct S<T>(T);
|
||||
@@ -508,7 +577,7 @@ mod m10 {
|
||||
}
|
||||
}
|
||||
|
||||
mod m11 {
|
||||
mod trait_implicit_self_borrow {
|
||||
trait MyTrait {
|
||||
fn foo(&self) -> &Self;
|
||||
|
||||
@@ -531,7 +600,7 @@ mod m11 {
|
||||
}
|
||||
}
|
||||
|
||||
mod m12 {
|
||||
mod implicit_self_borrow {
|
||||
struct S;
|
||||
|
||||
struct MyStruct<T>(T);
|
||||
@@ -548,7 +617,7 @@ mod m12 {
|
||||
}
|
||||
}
|
||||
|
||||
mod m13 {
|
||||
mod borrowed_typed {
|
||||
struct S;
|
||||
|
||||
impl S {
|
||||
@@ -578,18 +647,19 @@ mod m13 {
|
||||
}
|
||||
|
||||
fn main() {
|
||||
m1::f();
|
||||
m1::g(m1::Foo {}, m1::Foo {});
|
||||
m2::f();
|
||||
m3::f();
|
||||
m4::f();
|
||||
m5::f();
|
||||
m6::f();
|
||||
m7::f();
|
||||
m8::f();
|
||||
m9::f();
|
||||
m10::f();
|
||||
m11::f();
|
||||
m12::f();
|
||||
m13::f();
|
||||
field_access::f();
|
||||
method_impl::f();
|
||||
method_impl::g(method_impl::Foo {}, method_impl::Foo {});
|
||||
method_non_parametric_impl::f();
|
||||
method_non_parametric_trait_impl::f();
|
||||
function_trait_bounds::f();
|
||||
trait_associated_type::f();
|
||||
generic_enum::f();
|
||||
method_supertraits::f();
|
||||
function_trait_bounds_2::f();
|
||||
option_methods::f();
|
||||
method_call_type_conversion::f();
|
||||
trait_implicit_self_borrow::f();
|
||||
implicit_self_borrow::f();
|
||||
borrowed_typed::f();
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -81,21 +81,24 @@ signature module InputSig1<LocationSig Location> {
|
||||
|
||||
module Make1<LocationSig Location, InputSig1<Location> Input1> {
|
||||
private import Input1
|
||||
private import codeql.util.DenseRank
|
||||
|
||||
private module DenseRankInput implements DenseRankInputSig {
|
||||
class Ranked = TypeParameter;
|
||||
private module TypeParameter {
|
||||
private import codeql.util.DenseRank
|
||||
|
||||
predicate getRank = getTypeParameterId/1;
|
||||
}
|
||||
private module DenseRankInput implements DenseRankInputSig {
|
||||
class Ranked = TypeParameter;
|
||||
|
||||
private int getTypeParameterRank(TypeParameter tp) {
|
||||
tp = DenseRank<DenseRankInput>::denseRank(result)
|
||||
}
|
||||
predicate getRank = getTypeParameterId/1;
|
||||
}
|
||||
|
||||
bindingset[s]
|
||||
private predicate decodeTypePathComponent(string s, TypeParameter tp) {
|
||||
getTypeParameterRank(tp) = s.toInt()
|
||||
private int getTypeParameterRank(TypeParameter tp) {
|
||||
tp = DenseRank<DenseRankInput>::denseRank(result)
|
||||
}
|
||||
|
||||
string encode(TypeParameter tp) { result = getTypeParameterRank(tp).toString() }
|
||||
|
||||
bindingset[s]
|
||||
TypeParameter decode(string s) { encode(result) = s }
|
||||
}
|
||||
|
||||
final private class String = string;
|
||||
@@ -132,10 +135,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
|
||||
|
||||
bindingset[this]
|
||||
private TypeParameter getTypeParameter(int i) {
|
||||
exists(string s |
|
||||
s = this.splitAt(".", i) and
|
||||
decodeTypePathComponent(s, result)
|
||||
)
|
||||
result = TypeParameter::decode(this.splitAt(".", i))
|
||||
}
|
||||
|
||||
/** Gets a textual representation of this type path. */
|
||||
@@ -180,13 +180,13 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
|
||||
/** Holds if this path starts with `tp`, followed by `suffix`. */
|
||||
bindingset[this]
|
||||
predicate isCons(TypeParameter tp, TypePath suffix) {
|
||||
decodeTypePathComponent(this, tp) and
|
||||
tp = TypeParameter::decode(this) and
|
||||
suffix.isEmpty()
|
||||
or
|
||||
exists(int first |
|
||||
first = min(this.indexOf(".")) and
|
||||
suffix = this.suffix(first + 1) and
|
||||
decodeTypePathComponent(this.prefix(first), tp)
|
||||
tp = TypeParameter::decode(this.prefix(first))
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -197,7 +197,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
|
||||
TypePath nil() { result.isEmpty() }
|
||||
|
||||
/** Gets the singleton type path `tp`. */
|
||||
TypePath singleton(TypeParameter tp) { result = getTypeParameterRank(tp).toString() }
|
||||
TypePath singleton(TypeParameter tp) { result = TypeParameter::encode(tp) }
|
||||
|
||||
/**
|
||||
* Gets the type path obtained by appending the singleton type path `tp`
|
||||
@@ -559,11 +559,11 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
|
||||
private predicate directTypeMatch(
|
||||
Access a, Declaration target, TypePath path, Type t, TypeParameter tp
|
||||
) {
|
||||
not exists(getTypeArgument(a, target, tp, _)) and
|
||||
exists(AccessPosition apos, DeclarationPosition dpos, TypePath pathToTypeParam |
|
||||
adjustedAccessType(a, apos, target, pathToTypeParam.append(path), t) and
|
||||
tp = target.getDeclaredType(dpos, pathToTypeParam) and
|
||||
not exists(getTypeArgument(a, target, tp, _)) and
|
||||
accessDeclarationPositionMatch(apos, dpos)
|
||||
accessDeclarationPositionMatch(apos, dpos) and
|
||||
adjustedAccessType(a, apos, target, pathToTypeParam.append(path), t)
|
||||
)
|
||||
}
|
||||
|
||||
@@ -672,7 +672,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
|
||||
*
|
||||
* class Sub<T4> : Mid<C<T4>> { }
|
||||
*
|
||||
* new Sub<int>().Method();
|
||||
* new Sub<int>().Method(); // Note: `Sub<int>` is a subtype of `Base<C<C<int>>>`
|
||||
* // ^^^^^^^^^^^^^^^^^^^^^^^ `a`
|
||||
* ```
|
||||
*
|
||||
@@ -688,14 +688,19 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
|
||||
private predicate baseTypeMatch(
|
||||
Access a, Declaration target, TypePath path, Type t, TypeParameter tp
|
||||
) {
|
||||
not exists(getTypeArgument(a, target, tp, _)) and
|
||||
exists(AccessPosition apos, DeclarationPosition dpos, Type base, TypePath pathToTypeParam |
|
||||
accessBaseType(a, apos, target, base, pathToTypeParam.append(path), t) and
|
||||
declarationBaseType(target, dpos, base, pathToTypeParam, tp) and
|
||||
not exists(getTypeArgument(a, target, tp, _)) and
|
||||
accessDeclarationPositionMatch(apos, dpos)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Holds if for `a` and corresponding `target`, the type parameter `tp` is
|
||||
* matched by a type argument at the access with type `t` and type path
|
||||
* `path`.
|
||||
*/
|
||||
pragma[nomagic]
|
||||
private predicate explicitTypeMatch(
|
||||
Access a, Declaration target, TypePath path, Type t, TypeParameter tp
|
||||
@@ -708,8 +713,10 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
|
||||
private predicate implicitTypeMatch(
|
||||
Access a, Declaration target, TypePath path, Type t, TypeParameter tp
|
||||
) {
|
||||
// We can get the type of `tp` from one of the access positions
|
||||
directTypeMatch(a, target, path, t, tp)
|
||||
or
|
||||
// We can get the type of `tp` by going up the type hiearchy
|
||||
baseTypeMatch(a, target, path, t, tp)
|
||||
}
|
||||
|
||||
@@ -717,8 +724,12 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
|
||||
private predicate typeMatch(
|
||||
Access a, Declaration target, TypePath path, Type t, TypeParameter tp
|
||||
) {
|
||||
// A type given at the access corresponds directly to the type parameter
|
||||
// at the target.
|
||||
explicitTypeMatch(a, target, path, t, tp)
|
||||
or
|
||||
// No explicit type argument, so we deduce the parameter from other
|
||||
// information
|
||||
implicitTypeMatch(a, target, path, t, tp)
|
||||
}
|
||||
|
||||
@@ -763,12 +774,14 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
|
||||
pragma[nomagic]
|
||||
Type inferAccessType(Access a, AccessPosition apos, TypePath path) {
|
||||
exists(DeclarationPosition dpos | accessDeclarationPositionMatch(apos, dpos) |
|
||||
// A suffix of `path` leads to a type parameter in the target
|
||||
exists(Declaration target, TypePath prefix, TypeParameter tp, TypePath suffix |
|
||||
tp = target.getDeclaredType(pragma[only_bind_into](dpos), prefix) and
|
||||
typeMatch(a, target, suffix, result, tp) and
|
||||
path = prefix.append(suffix)
|
||||
path = prefix.append(suffix) and
|
||||
typeMatch(a, target, suffix, result, tp)
|
||||
)
|
||||
or
|
||||
// `path` corresponds directly to a concrete type in the declaration
|
||||
exists(Declaration target |
|
||||
result = target.getDeclaredType(pragma[only_bind_into](dpos), path) and
|
||||
target = a.getTarget() and
|
||||
|
||||
Reference in New Issue
Block a user