Merge pull request #19214 from paldepind/rust-ti-associated

Rust: Associated types
This commit is contained in:
Simon Friis Vindum
2025-04-08 13:46:36 +02:00
committed by GitHub
6 changed files with 931 additions and 610 deletions

View File

@@ -26,5 +26,15 @@ module Impl {
*/
class Trait extends Generated::Trait {
override string toStringImpl() { result = "trait " + this.getName().getText() }
/**
* Gets the number of generic parameters of this trait.
*/
int getNumberOfGenericParams() {
result = this.getGenericParamList().getNumberOfGenericParams()
or
not this.hasGenericParamList() and
result = 0
}
}
}

View File

@@ -2,9 +2,10 @@
private import rust
private import PathResolution
private import TypeInference
private import TypeMention
private import codeql.rust.internal.CachedStages
private import codeql.rust.elements.internal.generated.Raw
private import codeql.rust.elements.internal.generated.Synth
cached
newtype TType =
@@ -15,6 +16,7 @@ newtype TType =
TArrayType() or // todo: add size?
TRefType() or // todo: add mut?
TTypeParamTypeParameter(TypeParam t) or
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
TRefTypeParameter() or
TSelfTypeParameter(Trait t)
@@ -144,6 +146,9 @@ class TraitType extends Type, TTrait {
override TypeParameter getTypeParameter(int i) {
result = TTypeParamTypeParameter(trait.getGenericParamList().getTypeParam(i))
or
result =
any(AssociatedTypeTypeParameter param | param.getTrait() = trait and param.getIndex() = i)
}
pragma[nomagic]
@@ -297,6 +302,14 @@ abstract class TypeParameter extends Type {
override TypeParameter getTypeParameter(int i) { none() }
}
private class RawTypeParameter = @type_param or @trait or @type_alias;
private predicate id(RawTypeParameter x, RawTypeParameter y) { x = y }
private predicate idOfRaw(RawTypeParameter x, int y) = equivalenceRelation(id/2)(x, y)
int idOfTypeParameterAstNode(AstNode node) { idOfRaw(Synth::convertAstNodeToRaw(node), result) }
/** A type parameter from source code. */
class TypeParamTypeParameter extends TypeParameter, TTypeParamTypeParameter {
private TypeParam typeParam;
@@ -320,6 +333,59 @@ class TypeParamTypeParameter extends TypeParameter, TTypeParamTypeParameter {
}
}
/**
* Gets the type alias that is the `i`th type parameter of `trait`. Type aliases
* are numbered consecutively but in arbitrary order, starting from the index
* following the last ordinary type parameter.
*/
predicate traitAliasIndex(Trait trait, int i, TypeAlias typeAlias) {
typeAlias =
rank[i + 1 - trait.getNumberOfGenericParams()](TypeAlias alias |
trait.(TraitItemNode).getADescendant() = alias
|
alias order by idOfTypeParameterAstNode(alias)
)
}
/**
* A type parameter corresponding to an associated type in a trait.
*
* We treat associated type declarations in traits as type parameters. E.g., a
* trait such as
* ```rust
* trait ATrait {
* type AssociatedType;
* // ...
* }
* ```
* is treated as if it was
* ```rust
* trait ATrait<AssociatedType> {
* // ...
* }
* ```
*/
class AssociatedTypeTypeParameter extends TypeParameter, TAssociatedTypeTypeParameter {
private TypeAlias typeAlias;
AssociatedTypeTypeParameter() { this = TAssociatedTypeTypeParameter(typeAlias) }
TypeAlias getTypeAlias() { result = typeAlias }
/** Gets the trait that contains this associated type declaration. */
TraitItemNode getTrait() { result.getAnAssocItem() = typeAlias }
int getIndex() { traitAliasIndex(_, result, typeAlias) }
override Function getMethod(string name) { none() }
override string toString() { result = typeAlias.getName().getText() }
override Location getLocation() { result = typeAlias.getLocation() }
override TypeMention getABaseTypeMention() { none() }
}
/** An implicit reference type parameter. */
class RefTypeParameter extends TypeParameter, TRefTypeParameter {
override Function getMethod(string name) { none() }

View File

@@ -40,17 +40,21 @@ private module Input1 implements InputSig1<Location> {
private newtype TTypeParameterPosition =
TTypeParamTypeParameterPosition(TypeParam tp) or
TSelfTypeParameterPosition()
TImplicitTypeParameterPosition()
class TypeParameterPosition extends TTypeParameterPosition {
TypeParam asTypeParam() { this = TTypeParamTypeParameterPosition(result) }
predicate isSelf() { this = TSelfTypeParameterPosition() }
/**
* Holds if this is the implicit type parameter position used to represent
* parameters that are never passed explicitly as arguments.
*/
predicate isImplicit() { this = TImplicitTypeParameterPosition() }
string toString() {
result = this.asTypeParam().toString()
or
result = "Self" and this.isSelf()
result = "Implicit" and this.isImplicit()
}
}
@@ -69,15 +73,6 @@ private module Input1 implements InputSig1<Location> {
apos.asMethodTypeArgumentPosition() = ppos.asTypeParam().getPosition()
}
/** A raw AST node that might correspond to a type parameter. */
private class RawTypeParameter = @type_param or @trait;
private predicate id(RawTypeParameter x, RawTypeParameter y) { x = y }
private predicate idOfRaw(RawTypeParameter x, int y) = equivalenceRelation(id/2)(x, y)
private int idOf(AstNode node) { idOfRaw(Synth::convertAstNodeToRaw(node), result) }
int getTypeParameterId(TypeParameter tp) {
tp =
rank[result](TypeParameter tp0, int kind, int id |
@@ -86,8 +81,9 @@ private module Input1 implements InputSig1<Location> {
id = 0
or
kind = 1 and
exists(AstNode node | id = idOf(node) |
exists(AstNode node | id = idOfTypeParameterAstNode(node) |
node = tp0.(TypeParamTypeParameter).getTypeParam() or
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
node = tp0.(SelfTypeParameter).getTrait()
)
|
@@ -500,7 +496,10 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
exists(TraitItemNode trait | this = trait.getAnAssocItem() |
typeParamMatchPosition(trait.getTypeParam(_), result, ppos)
or
ppos.isSelf() and result = TSelfTypeParameter(trait)
ppos.isImplicit() and result = TSelfTypeParameter(trait)
or
ppos.isImplicit() and
result.(AssociatedTypeTypeParameter).getTrait() = trait
)
}

View File

@@ -95,6 +95,27 @@ class NonAliasPathMention extends PathMention {
this = node.getASelfPath() and
result = node.(ImplItemNode).getSelfPath().getSegment().getGenericArgList().getTypeArg(i)
)
or
// If `this` is the trait of an `impl` block then any associated types
// defined in the `impl` block are type arguments to the trait.
//
// For instance, for a trait implementation like this
// ```rust
// impl MyTrait for MyType {
// ^^^^^^^ this
// type AssociatedType = i64
// ^^^ result
// // ...
// }
// ```
// the rhs. of the type alias is a type argument to the trait.
exists(ImplItemNode impl, AssociatedTypeTypeParameter param, TypeAlias alias |
this = impl.getTraitPath() and
param.getTrait() = resolvePath(this) and
alias = impl.getASuccessor(param.getTypeAlias().getName().getText()) and
result = alias.getTypeRepr() and
param.getIndex() = i
)
}
override Type resolveType() {
@@ -113,7 +134,11 @@ class NonAliasPathMention extends PathMention {
or
result = TTypeParamTypeParameter(i)
or
result = i.(TypeAlias).getTypeRepr().(TypeReprMention).resolveType()
exists(TypeAlias alias | alias = i |
result.(AssociatedTypeTypeParameter).getTypeAlias() = alias
or
result = alias.getTypeRepr().(TypeReprMention).resolveType()
)
)
}
}
@@ -153,6 +178,17 @@ class TypeParamMention extends TypeMention, TypeParam {
override Type resolveType() { result = TTypeParamTypeParameter(this) }
}
// Used to represent implicit type arguments for associated types in traits.
class TypeAliasMention extends TypeMention, TypeAlias {
private Type t;
TypeAliasMention() { t = TAssociatedTypeTypeParameter(this) }
override TypeReprMention getTypeArgument(int i) { none() }
override Type resolveType() { result = t }
}
/**
* Holds if the `i`th type argument of `selfPath`, belonging to `impl`, resolves
* to type parameter `tp`.
@@ -204,7 +240,11 @@ class ImplMention extends TypeMention, ImplItemNode {
}
class TraitMention extends TypeMention, TraitItemNode {
override TypeMention getTypeArgument(int i) { result = this.getTypeParam(i) }
override TypeMention getTypeArgument(int i) {
result = this.getTypeParam(i)
or
traitAliasIndex(this, i, result)
}
override Type resolveType() { result = TTrait(this) }
}

View File

@@ -329,9 +329,21 @@ mod function_trait_bounds {
}
mod trait_associated_type {
#[derive(Debug)]
struct Wrapper<A> {
field: A,
}
impl<A> Wrapper<A> {
fn unwrap(self) -> A {
self.field // $ fieldof=Wrapper
}
}
trait MyTrait {
type AssociatedType;
// MyTrait::m1
fn m1(self) -> Self::AssociatedType;
fn m2(self) -> Self::AssociatedType
@@ -339,28 +351,129 @@ mod trait_associated_type {
Self::AssociatedType: Default,
Self: Sized,
{
self.m1(); // $ method=MyTrait::m1 type=self.m1():AssociatedType
Self::AssociatedType::default()
}
}
trait MyTraitAssoc2 {
type GenericAssociatedType<AssociatedParam>;
// MyTrait::put
fn put<A>(&self, a: A) -> Self::GenericAssociatedType<A>;
fn putTwo<A>(&self, a: A, b: A) -> Self::GenericAssociatedType<A> {
self.put(a); // $ method=MyTrait::put
self.put(b) // $ method=MyTrait::put
}
}
// A generic trait with multiple associated types.
trait TraitMultipleAssoc<TrG> {
type Assoc1;
type Assoc2;
fn get_zero(&self) -> TrG;
fn get_one(&self) -> Self::Assoc1;
fn get_two(&self) -> Self::Assoc2;
}
#[derive(Debug, Default)]
struct S;
#[derive(Debug, Default)]
struct S2;
#[derive(Debug, Default)]
struct AT;
impl MyTrait for S {
type AssociatedType = S;
type AssociatedType = AT;
// S::m1
fn m1(self) -> Self::AssociatedType {
AT
}
}
impl MyTraitAssoc2 for S {
// Associated type with a type parameter
type GenericAssociatedType<AssociatedParam> = Wrapper<AssociatedParam>;
// S::put
fn put<A>(&self, a: A) -> Wrapper<A> {
Wrapper { field: a }
}
}
impl MyTrait for S2 {
// Associated type definition with a type argument
type AssociatedType = Wrapper<S2>;
fn m1(self) -> Self::AssociatedType {
Wrapper { field: self }
}
}
// NOTE: This implementation is just to make it possible to call `m2` on `S2.`
impl Default for Wrapper<S2> {
fn default() -> Self {
Wrapper { field: S2 }
}
}
// Function that returns an associated type from a trait bound
fn g<T: MyTrait>(thing: T) -> <T as MyTrait>::AssociatedType {
thing.m1() // $ method=MyTrait::m1
}
impl TraitMultipleAssoc<AT> for AT {
type Assoc1 = S;
type Assoc2 = S2;
fn get_zero(&self) -> AT {
AT
}
fn get_one(&self) -> Self::Assoc1 {
S
}
fn get_two(&self) -> Self::Assoc2 {
S2
}
}
pub fn f() {
let x = S;
println!("{:?}", x.m1()); // $ method=S::m1
let x1 = S;
// Call to method in `impl` block
println!("{:?}", x1.m1()); // $ method=S::m1 type=x1.m1():AT
let x = S;
println!("{:?}", x.m2()); // $ method=m2
let x2 = S;
// Call to default method in `trait` block
let y = x2.m2(); // $ method=m2 type=y:AT
println!("{:?}", y);
let x3 = S;
// Call to the method in `impl` block
println!("{:?}", x3.put(1).unwrap()); // $ method=S::put method=unwrap
// Call to default implementation in `trait` block
println!("{:?}", x3.putTwo(2, 3).unwrap()); // $ method=putTwo MISSING: method=unwrap
let x4 = g(S); // $ MISSING: type=x4:AT
println!("{:?}", x4);
let x5 = S2;
println!("{:?}", x5.m1()); // $ method=m1 type=x5.m1():A.S2
let x6 = S2;
println!("{:?}", x6.m2()); // $ method=m2 type=x6.m2():A.S2
let assoc_zero = AT.get_zero(); // $ method=get_zero type=assoc_zero:AT
let assoc_one = AT.get_one(); // $ method=get_one type=assoc_one:S
let assoc_two = AT.get_two(); // $ method=get_two type=assoc_two:S2
}
}