Merge pull request #21273 from paldepind/rust/tp-assoc

Rust: Implement support for associated types accessed on type parameters
This commit is contained in:
Simon Friis Vindum
2026-02-11 13:39:55 +01:00
committed by GitHub
7 changed files with 843 additions and 520 deletions

View File

@@ -0,0 +1,64 @@
/**
* Provides classes and helper predicates for associated types.
*/
private import rust
private import codeql.rust.internal.PathResolution
private import TypeMention
private import Type
private import TypeInference
/** An associated type, that is, a type alias in a trait block. */
final class AssocType extends TypeAlias {
Trait trait;
AssocType() { this = trait.getAssocItemList().getAnAssocItem() }
Trait getTrait() { result = trait }
string getText() { result = this.getName().getText() }
}
/** Gets an associated type of `trait` or of a supertrait of `trait`. */
AssocType getTraitAssocType(Trait trait) { result.getTrait() = trait.getSupertrait*() }
/** Holds if `path` is of the form `<type as trait>::name` */
pragma[nomagic]
predicate pathTypeAsTraitAssoc(Path path, TypeRepr typeRepr, Path traitPath, string name) {
exists(PathSegment segment |
segment = path.getQualifier().getSegment() and
typeRepr = segment.getTypeRepr() and
traitPath = segment.getTraitTypeRepr().getPath() and
name = path.getText()
)
}
/**
* Holds if `assoc` is accessed on `tp` in `path`.
*
* That is, this is the case when `path` is of the form `<tp as
* Trait>::AssocType` or `tp::AssocType`; and `AssocType` resolves to `assoc`.
*/
predicate tpAssociatedType(TypeParam tp, AssocType assoc, Path path) {
resolvePath(path.getQualifier()) = tp and
resolvePath(path) = assoc
or
exists(PathTypeRepr typeRepr, Path traitPath, string name |
pathTypeAsTraitAssoc(path, typeRepr, traitPath, name) and
tp = resolvePath(typeRepr.getPath()) and
assoc = resolvePath(traitPath).(TraitItemNode).getAssocItem(name)
)
}
/**
* Holds if `bound` is a type bound for `tp` that gives rise to `assoc` being
* present for `tp`.
*/
predicate tpBoundAssociatedType(
TypeParam tp, TypeBound bound, Path path, TraitItemNode trait, AssocType assoc
) {
bound = tp.getATypeBound() and
path = bound.getTypeRepr().(PathTypeRepr).getPath() and
trait = resolvePath(path) and
assoc = getTraitAssocType(trait)
}

View File

@@ -8,11 +8,7 @@ private import codeql.rust.elements.internal.generated.Raw
private import codeql.rust.elements.internal.generated.Synth
private import codeql.rust.frameworks.stdlib.Stdlib
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
/** Gets a type alias of `trait` or of a supertrait of `trait`. */
private TypeAlias getTraitTypeAlias(Trait trait) {
result = trait.getSupertrait*().getAssocItemList().getAnAssocItem()
}
private import AssociatedType
/**
* Holds if a dyn trait type for the trait `trait` should have a type parameter
@@ -31,7 +27,7 @@ private TypeAlias getTraitTypeAlias(Trait trait) {
*/
private predicate dynTraitTypeParameter(Trait trait, AstNode n) {
trait = any(DynTraitTypeRepr dt).getTrait() and
n = [trait.getGenericParamList().getATypeParam().(AstNode), getTraitTypeAlias(trait)]
n = [trait.getGenericParamList().getATypeParam().(AstNode), getTraitAssocType(trait)]
}
cached
@@ -43,8 +39,11 @@ newtype TType =
TNeverType() or
TUnknownType() or
TTypeParamTypeParameter(TypeParam t) or
TAssociatedTypeTypeParameter(Trait trait, TypeAlias typeAlias) {
getTraitTypeAlias(trait) = typeAlias
TAssociatedTypeTypeParameter(Trait trait, AssocType typeAlias) {
getTraitAssocType(trait) = typeAlias
} or
TTypeParamAssociatedTypeTypeParameter(TypeParam tp, AssocType assoc) {
tpAssociatedType(tp, assoc, _)
} or
TDynTraitTypeParameter(Trait trait, AstNode n) { dynTraitTypeParameter(trait, n) } or
TImplTraitTypeParameter(ImplTraitTypeRepr implTrait, TypeParam tp) {
@@ -464,6 +463,52 @@ class AssociatedTypeTypeParameter extends TypeParameter, TAssociatedTypeTypePara
override Location getLocation() { result = typeAlias.getLocation() }
}
/**
* A type parameter corresponding to an associated type accessed on a type
* parameter, for example `T::AssociatedType` where `T` is a type parameter.
*
* These type parameters are created when a function signature accesses an
* associated type on a type parameter. For example, in
* ```rust
* fn foo<T: SomeTrait>(arg: T::Assoc) { }
* ```
* we create a `TypeParamAssociatedTypeTypeParameter` for `Assoc` on `T` and the
* mention `T::Assoc` resolves to this type parameter. If denoting the type
* parameter by `T_Assoc` then the above function is treated as if it was
* ```rust
* fn foo<T: SomeTrait<Assoc = T_Assoc>, T_Assoc>(arg: T_Assoc) { }
* ```
*/
class TypeParamAssociatedTypeTypeParameter extends TypeParameter,
TTypeParamAssociatedTypeTypeParameter
{
private TypeParam typeParam;
private AssocType assoc;
TypeParamAssociatedTypeTypeParameter() {
this = TTypeParamAssociatedTypeTypeParameter(typeParam, assoc)
}
/** Gets the type parameter that this associated type is accessed on. */
TypeParam getTypeParam() { result = typeParam }
/** Gets the associated type alias. */
AssocType getTypeAlias() { result = assoc }
/** Gets a path that accesses this type parameter. */
Path getAPath() { tpAssociatedType(typeParam, assoc, result) }
override ItemNode getDeclaringItem() { result.getTypeParam(_) = typeParam }
override string toString() {
result =
typeParam.toString() + "::" + assoc.getName().getText() + "[" +
assoc.getTrait().getName().getText() + "]"
}
override Location getLocation() { result = typeParam.getLocation() }
}
/** Gets the associated type type-parameter corresponding directly to `typeAlias`. */
AssociatedTypeTypeParameter getAssociatedTypeTypeParameter(TypeAlias typeAlias) {
result.isDirect() and result.getTypeAlias() = typeAlias

View File

@@ -108,6 +108,10 @@ private module Input implements InputSig1<Location>, InputSig2<PreTypeMention> {
id2 = idOfTypeParameterAstNode(tp0.(AssociatedTypeTypeParameter).getTypeAlias())
or
kind = 4 and
id1 = idOfTypeParameterAstNode(tp0.(TypeParamAssociatedTypeTypeParameter).getTypeParam()) and
id2 = idOfTypeParameterAstNode(tp0.(TypeParamAssociatedTypeTypeParameter).getTypeAlias())
or
kind = 5 and
id1 = 0 and
exists(AstNode node | id2 = idOfTypeParameterAstNode(node) |
node = tp0.(TypeParamTypeParameter).getTypeParam() or
@@ -270,13 +274,21 @@ private class FunctionDeclaration extends Function {
this = i.asSome().getAnAssocItem()
}
TypeParam getTypeParam(ImplOrTraitItemNodeOption i) {
i = parent and
result = [this.getGenericParamList().getATypeParam(), i.asSome().getTypeParam(_)]
}
TypeParameter getTypeParameter(ImplOrTraitItemNodeOption i, TypeParameterPosition ppos) {
typeParamMatchPosition(this.getTypeParam(i), result, ppos)
or
// For every `TypeParam` of this function, any associated types accessed on
// the type parameter are also type parameters.
ppos.isImplicit() and
result.(TypeParamAssociatedTypeTypeParameter).getTypeParam() = this.getTypeParam(i)
or
i = parent and
(
typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos)
or
typeParamMatchPosition(i.asSome().getTypeParam(_), result, ppos)
or
ppos.isImplicit() and result = TSelfTypeParameter(i.asSome())
or
ppos.isImplicit() and result.(AssociatedTypeTypeParameter).getTrait() = i.asSome()

View File

@@ -6,6 +6,7 @@ private import codeql.rust.frameworks.stdlib.Stdlib
private import Type
private import TypeAbstraction
private import TypeInference
private import AssociatedType
bindingset[trait, name]
pragma[inline_late]
@@ -319,6 +320,22 @@ private module MkTypeMention<getAdditionalPathTypeAtSig/2 getAdditionalPathTypeA
tp = TAssociatedTypeTypeParameter(resolved, alias) and
path.isEmpty()
)
or
// If this path is a type parameter bound, then any associated types
// accessed on the type parameter, which originate from this bound, should
// be instantiated into the bound, as explained in the comment for
// `TypeParamAssociatedTypeTypeParameter`.
// ```rust
// fn foo<T: SomeTrait<Assoc = T_Assoc>, T_Assoc>(arg: T_Assoc) { }
// ^^^^^^^^^ ^^^^^ ^^^^^^^
// this path result
// ```
exists(TypeParam typeParam, Trait trait, AssocType assoc |
tpBoundAssociatedType(typeParam, _, this, trait, assoc) and
tp = TAssociatedTypeTypeParameter(resolved, assoc) and
result = TTypeParamAssociatedTypeTypeParameter(typeParam, assoc) and
path.isEmpty()
)
}
bindingset[name]
@@ -372,6 +389,8 @@ private module MkTypeMention<getAdditionalPathTypeAtSig/2 getAdditionalPathTypeA
or
// Handles paths of the form `Self::AssocType` within a trait block
result = TAssociatedTypeTypeParameter(resolvePath(this.getQualifier()), resolved)
or
result.(TypeParamAssociatedTypeTypeParameter).getAPath() = this
}
override Type resolvePathTypeAt(TypePath typePath) {
@@ -690,11 +709,10 @@ private predicate pathConcreteTypeAssocType(
|
// path of the form `<Type as Trait>::AssocType`
// ^^^ tm ^^^^^^^^^ name
exists(string name |
name = path.getText() and
trait = resolvePath(qualifier.getSegment().getTraitTypeRepr().getPath()) and
getTraitAssocType(trait, name) = alias and
tm = qualifier.getSegment().getTypeRepr()
exists(string name, Path traitPath |
pathTypeAsTraitAssoc(path, tm, traitPath, name) and
trait = resolvePath(traitPath) and
getTraitAssocType(trait, name) = alias
)
or
// path of the form `Self::AssocType` within an `impl` block

View File

@@ -7,7 +7,7 @@ impl<A> Wrapper<A> {
}
}
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone, Copy)]
struct S;
#[derive(Debug, Default)]
@@ -260,13 +260,65 @@ mod type_param_access_associated_type {
)
}
// Associated type accessed on a type parameter of an impl block
impl<TI> Wrapper<TI>
where
TI: GetSet,
{
fn extract(&self) -> TI::Output {
self.0.get() // $ fieldof=Wrapper target=GetSet::get
}
}
// Associated type accessed on another associated type
fn tp_nested_assoc_type<T: GetSet>(thing: T) -> <<T as GetSet>::Output as GetSet>::Output
where
<T as GetSet>::Output: GetSet,
{
thing.get().get() // $ target=GetSet::get target=GetSet::get
}
pub trait GetSetWrap {
type Assoc: GetSet;
// GetSetWrap::get_wrap
fn get_wrap(&self) -> Self::Assoc;
}
impl GetSetWrap for S {
type Assoc = S;
// S::get_wrap
fn get_wrap(&self) -> Self::Assoc {
S
}
}
// Nested associated type accessed on a type parameter of an impl block
impl<TI> Wrapper<TI>
where
TI: GetSetWrap,
{
fn extract2(&self) -> <<TI as GetSetWrap>::Assoc as GetSet>::Output {
self.0.get_wrap().get() // $ fieldof=Wrapper target=GetSetWrap::get_wrap $ MISSING: target=GetSet::get
}
}
pub fn test() {
let _o1 = tp_with_as(S); // $ target=tp_with_as MISSING: type=_o1:S3
let _o2 = tp_without_as(S); // $ target=tp_without_as MISSING: type=_o2:S3
let _o1 = tp_with_as(S); // $ target=tp_with_as type=_o1:S3
let _o2 = tp_without_as(S); // $ target=tp_without_as type=_o2:S3
let (
_o3, // $ MISSING: type=_o3:S3
_o4, // $ MISSING: type=_o4:bool
_o4, // $ type=_o4:bool
) = tp_assoc_from_supertrait(S); // $ target=tp_assoc_from_supertrait
let _o5 = tp_nested_assoc_type(Wrapper(S)); // $ target=tp_nested_assoc_type MISSING: type=_o5:S3
let w = Wrapper(S);
let _extracted = w.extract(); // $ target=extract type=_extracted:S3
let _extracted2 = w.extract2(); // $ target=extract2 MISSING: type=_extracted2:S3
}
}

View File

@@ -1748,7 +1748,7 @@ mod overloadable_operators {
let i64_mul = 17i64 * 18i64; // $ type=i64_mul:i64 target=mul
let i64_div = 19i64 / 20i64; // $ type=i64_div:i64 target=div
let i64_rem = 21i64 % 22i64; // $ type=i64_rem:i64 target=rem
let i64_param_add = param_add(1i64, 2i64); // $ target=param_add $ MISSING: type=i64_param_add:i64
let i64_param_add = param_add(1i64, 2i64); // $ target=param_add $ type=i64_param_add:i64
// Arithmetic assignment operators
let mut i64_add_assign = 23i64;
@@ -2053,7 +2053,7 @@ mod indexers {
let xs: [S; 1] = [S];
let x = xs[0].foo(); // $ target=foo type=x:S target=index
let y = param_index(vec, 0); // $ target=param_index $ MISSING: type=y:S
let y = param_index(vec, 0); // $ target=param_index $ type=y:S
analyze_slice(&xs); // $ target=analyze_slice
}