Merge pull request #20122 from paldepind/rust/type-inference-dyn-assoc

Rust: Fix type inference for trait objects for traits with associated types
This commit is contained in:
Simon Friis Vindum
2025-07-26 12:40:09 +02:00
committed by GitHub
7 changed files with 3457 additions and 3191 deletions

View File

@@ -7,6 +7,27 @@ private import codeql.rust.internal.CachedStages
private import codeql.rust.elements.internal.generated.Raw
private import codeql.rust.elements.internal.generated.Synth
/**
* Holds if a dyn trait type should have a type parameter associated with `n`. A
* dyn trait type inherits the type parameters of the trait it implements. That
* includes the type parameters corresponding to associated types.
*
* For instance in
* ```rust
* trait SomeTrait<A> {
* type AssociatedType;
* }
* ```
* this predicate holds for the nodes `A` and `type AssociatedType`.
*/
private predicate dynTraitTypeParameter(Trait trait, AstNode n) {
trait = any(DynTraitTypeRepr dt).getTrait() and
(
n = trait.getGenericParamList().getATypeParam() or
n = trait.(TraitItemNode).getAnAssocItem().(TypeAlias)
)
}
cached
newtype TType =
TTuple(int arity) {
@@ -30,9 +51,7 @@ newtype TType =
TTypeParamTypeParameter(TypeParam t) or
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
TArrayTypeParameter() or
TDynTraitTypeParameter(TypeParam tp) {
tp = any(DynTraitTypeRepr dt).getTrait().getGenericParamList().getATypeParam()
} or
TDynTraitTypeParameter(AstNode n) { dynTraitTypeParameter(_, n) } or
TRefTypeParameter() or
TSelfTypeParameter(Trait t) or
TSliceTypeParameter()
@@ -406,15 +425,35 @@ class ArrayTypeParameter extends TypeParameter, TArrayTypeParameter {
}
class DynTraitTypeParameter extends TypeParameter, TDynTraitTypeParameter {
private TypeParam typeParam;
private AstNode n;
DynTraitTypeParameter() { this = TDynTraitTypeParameter(typeParam) }
DynTraitTypeParameter() { this = TDynTraitTypeParameter(n) }
TypeParam getTypeParam() { result = typeParam }
Trait getTrait() { dynTraitTypeParameter(result, n) }
override string toString() { result = "dyn(" + typeParam.toString() + ")" }
/** Gets the dyn trait type that this type parameter belongs to. */
DynTraitType getDynTraitType() { result.getTrait() = this.getTrait() }
override Location getLocation() { result = typeParam.getLocation() }
/** Gets the `TypeParam` of this dyn trait type parameter, if any. */
TypeParam getTypeParam() { result = n }
/** Gets the `TypeAlias` of this dyn trait type parameter, if any. */
TypeAlias getTypeAlias() { result = n }
/** Gets the trait type parameter that this dyn trait type parameter corresponds to. */
TypeParameter getTraitTypeParameter() {
result.(TypeParamTypeParameter).getTypeParam() = n
or
result.(AssociatedTypeTypeParameter).getTypeAlias() = n
}
private string toStringInner() {
result = [this.getTypeParam().toString(), this.getTypeAlias().getName().toString()]
}
override string toString() { result = "dyn(" + this.toStringInner() + ")" }
override Location getLocation() { result = n.getLocation() }
}
/** An implicit reference type parameter. */
@@ -503,8 +542,7 @@ final class ImplTypeAbstraction extends TypeAbstraction, Impl {
final class DynTypeAbstraction extends TypeAbstraction, DynTraitTypeRepr {
override TypeParameter getATypeParameter() {
result.(TypeParamTypeParameter).getTypeParam() =
this.getTrait().getGenericParamList().getATypeParam()
result = any(DynTraitTypeParameter tp | tp.getTrait() = this.getTrait()).getTraitTypeParameter()
}
}

View File

@@ -97,7 +97,11 @@ private module Input1 implements InputSig1<Location> {
id = 2
or
kind = 1 and
id = idOfTypeParameterAstNode(tp0.(DynTraitTypeParameter).getTypeParam())
id =
idOfTypeParameterAstNode([
tp0.(DynTraitTypeParameter).getTypeParam().(AstNode),
tp0.(DynTraitTypeParameter).getTypeAlias()
])
or
kind = 2 and
exists(AstNode node | id = idOfTypeParameterAstNode(node) |

View File

@@ -324,10 +324,10 @@ class DynTraitTypeReprMention extends TypeMention instanceof DynTraitTypeRepr {
result = dynType
or
exists(DynTraitTypeParameter tp, TypePath path0, TypePath suffix |
tp = dynType.getTypeParameter(_) and
dynType = tp.getDynTraitType() and
path = TypePath::cons(tp, suffix) and
result = super.getTypeBoundList().getBound(0).getTypeRepr().(TypeMention).resolveTypeAt(path0) and
path0.isCons(TTypeParamTypeParameter(tp.getTypeParam()), suffix)
path0.isCons(tp.getTraitTypeParameter(), suffix)
)
}
}
@@ -363,10 +363,10 @@ class DynTypeBoundListMention extends TypeMention instanceof TypeBoundList {
path.isEmpty() and
result.(DynTraitType).getTrait() = trait
or
exists(TypeParam param |
param = trait.getGenericParamList().getATypeParam() and
path = TypePath::singleton(TDynTraitTypeParameter(param)) and
result = TTypeParamTypeParameter(param)
exists(DynTraitTypeParameter tp |
trait = tp.getTrait() and
path = TypePath::singleton(tp) and
result = tp.getTraitTypeParameter()
)
}
}

View File

@@ -1,8 +1,8 @@
multipleCallTargets
| dereference.rs:61:15:61:24 | e1.deref() |
| main.rs:2213:13:2213:31 | ...::from(...) |
| main.rs:2214:13:2214:31 | ...::from(...) |
| main.rs:2215:13:2215:31 | ...::from(...) |
| main.rs:2221:13:2221:31 | ...::from(...) |
| main.rs:2222:13:2222:31 | ...::from(...) |
| main.rs:2223:13:2223:31 | ...::from(...) |
| main.rs:2253:13:2253:31 | ...::from(...) |
| main.rs:2254:13:2254:31 | ...::from(...) |
| main.rs:2255:13:2255:31 | ...::from(...) |
| main.rs:2261:13:2261:31 | ...::from(...) |
| main.rs:2262:13:2262:31 | ...::from(...) |
| main.rs:2263:13:2263:31 | ...::from(...) |

View File

@@ -12,6 +12,12 @@ trait GenericGet<A> {
fn get(&self) -> A;
}
trait AssocTrait<GP> {
type AP;
// AssocTrait::get
fn get(&self) -> (GP, Self::AP);
}
#[derive(Clone, Debug)]
struct MyStruct {
value: i32,
@@ -36,6 +42,17 @@ impl<A: Clone + Debug> GenericGet<A> for GenStruct<A> {
}
}
impl<GGP> AssocTrait<GGP> for GenStruct<GGP>
where
GGP: Clone + Debug,
{
type AP = bool;
// GenStruct<GGP>::get
fn get(&self) -> (GGP, bool) {
(self.value.clone(), true) // $ fieldof=GenStruct target=clone
}
}
fn get_a<A, G: GenericGet<A> + ?Sized>(a: &G) -> A {
a.get() // $ target=GenericGet::get
}
@@ -58,10 +75,34 @@ fn test_poly_dyn_trait() {
let _result = (*obj).get(); // $ target=deref target=GenericGet::get type=_result:bool
}
fn assoc_dyn_get<A, B>(a: &dyn AssocTrait<A, AP = B>) -> (A, B) {
a.get() // $ target=AssocTrait::get
}
fn assoc_get<A, B, T: AssocTrait<A, AP = B> + ?Sized>(a: &T) -> (A, B) {
a.get() // $ target=AssocTrait::get
}
fn test_assoc_type(obj: &dyn AssocTrait<i64, AP = bool>) {
let (
_gp, // $ type=_gp:i64
_ap, // $ type=_ap:bool
) = (*obj).get(); // $ target=deref target=AssocTrait::get
let (
_gp, // $ type=_gp:i64
_ap, // $ type=_ap:bool
) = assoc_dyn_get(obj); // $ target=assoc_dyn_get
let (
_gp, // $ type=_gp:i64
_ap, // $ type=_ap:bool
) = assoc_get(obj); // $ target=assoc_get
}
pub fn test() {
test_basic_dyn_trait(&MyStruct { value: 42 }); // $ target=test_basic_dyn_trait
test_generic_dyn_trait(&GenStruct {
value: "".to_string(),
}); // $ target=test_generic_dyn_trait
test_poly_dyn_trait(); // $ target=test_poly_dyn_trait
test_assoc_type(&GenStruct { value: 100 }); // $ target=test_assoc_type
}

View File

@@ -653,7 +653,7 @@ mod function_trait_bounds {
}
}
mod trait_associated_type {
mod associated_type_in_trait {
#[derive(Debug)]
struct Wrapper<A> {
field: A,
@@ -803,6 +803,46 @@ mod trait_associated_type {
}
}
mod associated_type_in_supertrait {
trait Supertrait {
type Content;
fn insert(content: Self::Content);
}
trait Subtrait: Supertrait {
// Subtrait::get_content
fn get_content(&self) -> Self::Content;
}
struct MyType<T>(T);
impl<T> Supertrait for MyType<T> {
type Content = T;
fn insert(_content: Self::Content) {
println!("Inserting content: ");
}
}
impl<T: Clone> Subtrait for MyType<T> {
// MyType::get_content
fn get_content(&self) -> Self::Content {
(*self).0.clone() // $ fieldof=MyType target=clone target=deref
}
}
fn get_content<T: Subtrait>(item: &T) -> T::Content {
item.get_content() // $ target=Subtrait::get_content
}
fn test() {
let item1 = MyType(42i64);
let _content1 = item1.get_content(); // $ target=MyType::get_content MISSING: type=_content1:i64
let item2 = MyType(true);
let _content2 = get_content(&item2); // $ target=get_content MISSING: type=_content2:bool
}
}
mod generic_enum {
#[derive(Debug)]
enum MyEnum<A> {
@@ -2469,7 +2509,7 @@ fn main() {
method_non_parametric_impl::f(); // $ target=f
method_non_parametric_trait_impl::f(); // $ target=f
function_trait_bounds::f(); // $ target=f
trait_associated_type::f(); // $ target=f
associated_type_in_trait::f(); // $ target=f
generic_enum::f(); // $ target=f
method_supertraits::f(); // $ target=f
function_trait_bounds_2::f(); // $ target=f