Merge pull request #21067 from paldepind/rust/type-inference-use-type-item

Rust: Refactor type inference to use new `TypeItem` class
This commit is contained in:
Simon Friis Vindum
2025-12-19 14:47:33 +01:00
committed by GitHub
5 changed files with 92 additions and 147 deletions

View File

@@ -30,7 +30,8 @@ class StreamCipherInit extends Cryptography::CryptographicOperation::Range {
// extract the algorithm name from the type of `ce` or its receiver.
exists(Type t, TypePath tp |
t = inferType([call, call.(MethodCall).getReceiver()], tp) and
rawAlgorithmName = t.(StructType).getStruct().(Addressable).getCanonicalPath().splitAt("::")
rawAlgorithmName =
t.(StructType).getTypeItem().(Addressable).getCanonicalPath().splitAt("::")
) and
algorithmName = simplifyAlgorithmName(rawAlgorithmName) and
// only match a known cryptographic algorithm

View File

@@ -32,10 +32,8 @@ private predicate dynTraitTypeParameter(Trait trait, AstNode n) {
cached
newtype TType =
TStruct(Struct s) { Stages::TypeInferenceStage::ref() } or
TEnum(Enum e) or
TDataType(TypeItem ti) { Stages::TypeInferenceStage::ref() } or
TTrait(Trait t) or
TUnion(Union u) or
TImplTraitType(ImplTraitTypeRepr impl) or
TDynTraitType(Trait t) { t = any(DynTraitTypeRepr dt).getTrait() } or
TNeverType() or
@@ -92,7 +90,7 @@ abstract class Type extends TType {
class TupleType extends StructType {
private int arity;
TupleType() { arity = this.getStruct().(Builtins::TupleType).getArity() }
TupleType() { arity = this.getTypeItem().(Builtins::TupleType).getArity() }
/** Gets the arity of this tuple type. */
int getArity() { result = arity }
@@ -112,48 +110,55 @@ class UnitType extends TupleType {
override string toString() { result = "()" }
}
/** A struct type. */
class StructType extends Type, TStruct {
private Struct struct;
class DataType extends Type, TDataType {
private TypeItem typeItem;
StructType() { this = TStruct(struct) }
DataType() { this = TDataType(typeItem) }
/** Gets the struct that this struct type represents. */
Struct getStruct() { result = struct }
/** Gets the type item that this data type represents. */
TypeItem getTypeItem() { result = typeItem }
override TypeParameter getPositionalTypeParameter(int i) {
result = TTypeParamTypeParameter(struct.getGenericParamList().getTypeParam(i))
result = TTypeParamTypeParameter(typeItem.getGenericParamList().getTypeParam(i))
}
override TypeMention getTypeParameterDefault(int i) {
result = struct.getGenericParamList().getTypeParam(i).getDefaultType()
result = typeItem.getGenericParamList().getTypeParam(i).getDefaultType()
}
override string toString() { result = struct.getName().getText() }
override string toString() { result = typeItem.getName().getText() }
override Location getLocation() { result = struct.getLocation() }
override Location getLocation() { result = typeItem.getLocation() }
}
/** A struct type. */
class StructType extends DataType {
private Struct struct;
StructType() { struct = super.getTypeItem() }
/** Gets the struct that this struct type represents. */
override Struct getTypeItem() { result = struct }
}
/** An enum type. */
class EnumType extends Type, TEnum {
class EnumType extends DataType {
private Enum enum;
EnumType() { this = TEnum(enum) }
EnumType() { enum = super.getTypeItem() }
/** Gets the enum that this enum type represents. */
Enum getEnum() { result = enum }
override Enum getTypeItem() { result = enum }
}
override TypeParameter getPositionalTypeParameter(int i) {
result = TTypeParamTypeParameter(enum.getGenericParamList().getTypeParam(i))
}
/** A union type. */
class UnionType extends DataType {
private Union union;
override TypeMention getTypeParameterDefault(int i) {
result = enum.getGenericParamList().getTypeParam(i).getDefaultType()
}
UnionType() { union = super.getTypeItem() }
override string toString() { result = enum.getName().getText() }
override Location getLocation() { result = enum.getLocation() }
/** Gets the union that this union type represents. */
override Union getTypeItem() { result = union }
}
/** A trait type. */
@@ -186,35 +191,13 @@ class TraitType extends Type, TTrait {
override Location getLocation() { result = trait.getLocation() }
}
/** A union type. */
class UnionType extends Type, TUnion {
private Union union;
UnionType() { this = TUnion(union) }
/** Gets the union that this union type represents. */
Union getUnion() { result = union }
override TypeParameter getPositionalTypeParameter(int i) {
result = TTypeParamTypeParameter(union.getGenericParamList().getTypeParam(i))
}
override TypeMention getTypeParameterDefault(int i) {
result = union.getGenericParamList().getTypeParam(i).getDefaultType()
}
override string toString() { result = union.getName().getText() }
override Location getLocation() { result = union.getLocation() }
}
/**
* An array type.
*
* Array types like `[i64; 5]` are modeled as normal generic types.
*/
class ArrayType extends StructType {
ArrayType() { this.getStruct() instanceof Builtins::ArrayType }
ArrayType() { this.getTypeItem() instanceof Builtins::ArrayType }
override string toString() { result = "[;]" }
}
@@ -227,13 +210,13 @@ TypeParamTypeParameter getArrayTypeParameter() {
abstract class RefType extends StructType { }
class RefMutType extends RefType {
RefMutType() { this.getStruct() instanceof Builtins::RefMutType }
RefMutType() { this.getTypeItem() instanceof Builtins::RefMutType }
override string toString() { result = "&mut" }
}
class RefSharedType extends RefType {
RefSharedType() { this.getStruct() instanceof Builtins::RefSharedType }
RefSharedType() { this.getTypeItem() instanceof Builtins::RefSharedType }
override string toString() { result = "&" }
}
@@ -330,7 +313,7 @@ class ImplTraitReturnType extends ImplTraitType {
* with a single type argument.
*/
class SliceType extends StructType {
SliceType() { this.getStruct() instanceof Builtins::SliceType }
SliceType() { this.getTypeItem() instanceof Builtins::SliceType }
override string toString() { result = "[]" }
}
@@ -356,13 +339,13 @@ TypeParamTypeParameter getPtrTypeParameter() {
}
class PtrMutType extends PtrType {
PtrMutType() { this.getStruct() instanceof Builtins::PtrMutType }
PtrMutType() { this.getTypeItem() instanceof Builtins::PtrMutType }
override string toString() { result = "*mut" }
}
class PtrConstType extends PtrType {
PtrConstType() { this.getStruct() instanceof Builtins::PtrConstType }
PtrConstType() { this.getTypeItem() instanceof Builtins::PtrConstType }
override string toString() { result = "*const" }
}
@@ -624,7 +607,7 @@ pragma[nomagic]
predicate validSelfType(Type t) {
t instanceof RefType
or
exists(Struct s | t = TStruct(s) |
exists(Struct s | t = TDataType(s) |
s instanceof BoxStruct or
s instanceof RcStruct or
s instanceof ArcStruct or

View File

@@ -619,7 +619,7 @@ private Type inferLogicalOperationType(AstNode n, TypePath path) {
exists(Builtins::Bool t, BinaryLogicalOperation be |
n = [be, be.getLhs(), be.getRhs()] and
path.isEmpty() and
result = TStruct(t)
result = TDataType(t)
)
}
@@ -887,14 +887,14 @@ private module StructExprMatchingInput implements MatchingInputSig {
}
abstract class Declaration extends AstNode {
abstract TypeParam getATypeParam();
final TypeParameter getTypeParameter(TypeParameterPosition ppos) {
typeParamMatchPosition(this.getATypeParam(), result, ppos)
typeParamMatchPosition(this.getTypeItem().getGenericParamList().getATypeParam(), result, ppos)
}
abstract StructField getField(string name);
abstract TypeItem getTypeItem();
Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
// type of a field
exists(TypeMention tp |
@@ -906,45 +906,28 @@ private module StructExprMatchingInput implements MatchingInputSig {
dpos.isStructPos() and
result = this.getTypeParameter(_) and
path = TypePath::singleton(result)
or
// type of the struct or enum itself
dpos.isStructPos() and
path.isEmpty() and
result = TDataType(this.getTypeItem())
}
}
private class StructDecl extends Declaration, Struct {
StructDecl() { this.isStruct() or this.isUnit() }
override TypeParam getATypeParam() { result = this.getGenericParamList().getATypeParam() }
override StructField getField(string name) { result = this.getStructField(name) }
override Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
result = super.getDeclaredType(dpos, path)
or
// type of the struct itself
dpos.isStructPos() and
path.isEmpty() and
result = TStruct(this)
}
override TypeItem getTypeItem() { result = this }
}
private class StructVariantDecl extends Declaration, Variant {
StructVariantDecl() { this.isStruct() or this.isUnit() }
Enum getEnum() { result.getVariantList().getAVariant() = this }
override TypeParam getATypeParam() {
result = this.getEnum().getGenericParamList().getATypeParam()
}
override StructField getField(string name) { result = this.getStructField(name) }
override Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
result = super.getDeclaredType(dpos, path)
or
// type of the enum itself
dpos.isStructPos() and
path.isEmpty() and
result = TEnum(this.getEnum())
}
override TypeItem getTypeItem() { result = this.getEnum() }
}
class AccessPosition = DeclarationPosition;
@@ -2841,11 +2824,21 @@ private module NonMethodResolution {
}
abstract private class TupleLikeConstructor extends Addressable {
abstract TypeParameter getTypeParameter(TypeParameterPosition ppos);
final TypeParameter getTypeParameter(TypeParameterPosition ppos) {
typeParamMatchPosition(this.getTypeItem().getGenericParamList().getATypeParam(), result, ppos)
}
abstract Type getParameterType(FunctionPosition pos, TypePath path);
abstract TypeItem getTypeItem();
abstract Type getReturnType(TypePath path);
abstract TupleField getTupleField(int i);
Type getReturnType(TypePath path) {
result = TDataType(this.getTypeItem()) and
path.isEmpty()
or
result = TTypeParamTypeParameter(this.getTypeItem().getGenericParamList().getATypeParam()) and
path = TypePath::singleton(result)
}
Type getDeclaredType(FunctionPosition pos, TypePath path) {
result = this.getParameterType(pos, path)
@@ -2856,54 +2849,26 @@ abstract private class TupleLikeConstructor extends Addressable {
pos.isSelf() and
result = this.getReturnType(path)
}
}
private class TupleStruct extends TupleLikeConstructor, Struct {
TupleStruct() { this.isTuple() }
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos)
}
override Type getParameterType(FunctionPosition pos, TypePath path) {
exists(int i |
result = this.getTupleField(i).getTypeRepr().(TypeMention).resolveTypeAt(path) and
i = pos.asPosition()
)
}
override Type getReturnType(TypePath path) {
result = TStruct(this) and
path.isEmpty()
or
result = TTypeParamTypeParameter(this.getGenericParamList().getATypeParam()) and
path = TypePath::singleton(result)
Type getParameterType(FunctionPosition pos, TypePath path) {
result = this.getTupleField(pos.asPosition()).getTypeRepr().(TypeMention).resolveTypeAt(path)
}
}
private class TupleVariant extends TupleLikeConstructor, Variant {
TupleVariant() { this.isTuple() }
private class TupleLikeStruct extends TupleLikeConstructor instanceof Struct {
TupleLikeStruct() { this.isTuple() }
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
typeParamMatchPosition(this.getEnum().getGenericParamList().getATypeParam(), result, ppos)
}
override TypeItem getTypeItem() { result = this }
override Type getParameterType(FunctionPosition pos, TypePath path) {
exists(int i |
result = this.getTupleField(i).getTypeRepr().(TypeMention).resolveTypeAt(path) and
i = pos.asPosition()
)
}
override TupleField getTupleField(int i) { result = Struct.super.getTupleField(i) }
}
override Type getReturnType(TypePath path) {
exists(Enum enum | enum = this.getEnum() |
result = TEnum(enum) and
path.isEmpty()
or
result = TTypeParamTypeParameter(enum.getGenericParamList().getATypeParam()) and
path = TypePath::singleton(result)
)
}
private class TupleLikeVariant extends TupleLikeConstructor instanceof Variant {
TupleLikeVariant() { this.isTuple() }
override TypeItem getTypeItem() { result = super.getEnum() }
override TupleField getTupleField(int i) { result = Variant.super.getTupleField(i) }
}
/**
@@ -3224,7 +3189,7 @@ private module FieldExprMatchingInput implements MatchingInputSig {
dpos.isSelf() and
// no case for variants as those can only be destructured using pattern matching
exists(Struct s | this.getAstNode() = [s.getStructField(_).(AstNode), s.getTupleField(_)] |
result = TStruct(s) and
result = TDataType(s) and
path.isEmpty()
or
result = TTypeParamTypeParameter(s.getGenericParamList().getATypeParam()) and
@@ -3374,15 +3339,15 @@ private Type inferTryExprType(TryExpr te, TypePath path) {
}
pragma[nomagic]
private StructType getStrStruct() { result = TStruct(any(Builtins::Str s)) }
private StructType getStrStruct() { result = TDataType(any(Builtins::Str s)) }
pragma[nomagic]
private StructType getStringStruct() { result = TStruct(any(StringStruct s)) }
private StructType getStringStruct() { result = TDataType(any(StringStruct s)) }
pragma[nomagic]
private Type inferLiteralType(LiteralExpr le, TypePath path, boolean certain) {
path.isEmpty() and
exists(Builtins::BuiltinType t | result = TStruct(t) |
exists(Builtins::BuiltinType t | result = TDataType(t) |
le instanceof CharLiteralExpr and
t instanceof Builtins::Char and
certain = true
@@ -3502,7 +3467,7 @@ private Type inferArrayExprType(ArrayExpr ae) { exists(ae) and result instanceof
* Gets the root type of the range expression `re`.
*/
pragma[nomagic]
private Type inferRangeExprType(RangeExpr re) { result = TStruct(getRangeType(re)) }
private Type inferRangeExprType(RangeExpr re) { result = TDataType(getRangeType(re)) }
/**
* According to [the Rust reference][1]: _"array and slice-typed expressions
@@ -3519,7 +3484,7 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
// TODO: Method resolution to the `std::ops::Index` trait can handle the
// `Index` instances for slices and arrays.
exists(TypePath exprPath, Builtins::BuiltinType t |
TStruct(t) = inferType(ie.getIndex()) and
TDataType(t) = inferType(ie.getIndex()) and
(
// also allow `i32`, since that is currently the type that we infer for
// integer literals like `0`
@@ -3879,11 +3844,11 @@ private module Cached {
*/
cached
StructField resolveStructFieldExpr(FieldExpr fe, boolean isDereferenced) {
exists(string name, Type ty |
exists(string name, DataType ty |
ty = getFieldExprLookupType(fe, pragma[only_bind_into](name), isDereferenced)
|
result = ty.(StructType).getStruct().getStructField(pragma[only_bind_into](name)) or
result = ty.(UnionType).getUnion().getStructField(pragma[only_bind_into](name))
result = ty.(StructType).getTypeItem().getStructField(pragma[only_bind_into](name)) or
result = ty.(UnionType).getTypeItem().getStructField(pragma[only_bind_into](name))
)
}
@@ -3896,7 +3861,7 @@ private module Cached {
result =
getTupleFieldExprLookupType(fe, pragma[only_bind_into](i), isDereferenced)
.(StructType)
.getStruct()
.getTypeItem()
.getTupleField(pragma[only_bind_into](i))
)
}

View File

@@ -271,9 +271,7 @@ class NonAliasPathTypeMention extends PathTypeMention {
pragma[nomagic]
private Type resolveRootType() {
result = TStruct(resolved)
or
result = TEnum(resolved)
result = TDataType(resolved)
or
exists(TraitItemNode trait | trait = resolved |
// If this is a `Self` path, then it resolves to the implicit `Self`
@@ -283,8 +281,6 @@ class NonAliasPathTypeMention extends PathTypeMention {
else result = TTrait(trait)
)
or
result = TUnion(resolved)
or
result = TTypeParamTypeParameter(resolved)
or
result = TAssociatedTypeTypeParameter(resolved)

View File

@@ -14,7 +14,7 @@ private import codeql.rust.frameworks.stdlib.Builtins as Builtins
/** A node whose type is a numeric type. */
class NumericTypeBarrier extends DataFlow::Node {
NumericTypeBarrier() {
TypeInference::inferType(this.asExpr()).(StructType).getStruct() instanceof
TypeInference::inferType(this.asExpr()).(StructType).getTypeItem() instanceof
Builtins::NumericType
}
}
@@ -22,14 +22,14 @@ class NumericTypeBarrier extends DataFlow::Node {
/** A node whose type is `bool`. */
class BooleanTypeBarrier extends DataFlow::Node {
BooleanTypeBarrier() {
TypeInference::inferType(this.asExpr()).(StructType).getStruct() instanceof Builtins::Bool
TypeInference::inferType(this.asExpr()).(StructType).getTypeItem() instanceof Builtins::Bool
}
}
/** A node whose type is an integral (integer). */
class IntegralTypeBarrier extends DataFlow::Node {
IntegralTypeBarrier() {
TypeInference::inferType(this.asExpr()).(StructType).getStruct() instanceof
TypeInference::inferType(this.asExpr()).(StructType).getTypeItem() instanceof
Builtins::IntegralType
}
}
@@ -37,7 +37,7 @@ class IntegralTypeBarrier extends DataFlow::Node {
/** A node whose type is a fieldless enum. */
class FieldlessEnumTypeBarrier extends DataFlow::Node {
FieldlessEnumTypeBarrier() {
TypeInference::inferType(this.asExpr()).(EnumType).getEnum().isFieldless()
TypeInference::inferType(this.asExpr()).(EnumType).getTypeItem().isFieldless()
}
}