Rust: Unify type inference for tuple indexing expressions

This commit is contained in:
Tom Hvitved
2025-08-07 15:18:27 +02:00
parent b2343f94c1
commit 454ab4db8c
4 changed files with 102 additions and 72 deletions

View File

@@ -119,7 +119,7 @@ class TupleType extends Type, TTuple {
}
/** The unit type `()`. */
class UnitType extends TupleType, TTuple {
class UnitType extends TupleType {
UnitType() { this = TTuple(0) }
override string toString() { result = "()" }

View File

@@ -1135,6 +1135,36 @@ private Type inferCallExprBaseType(AstNode n, TypePath path) {
)
}
pragma[inline]
private Type inferRootTypeDeref(AstNode n) {
result = inferType(n) and
result != TRefType()
or
// for reference types, lookup members in the type being referenced
result = inferType(n, TypePath::singleton(TRefTypeParameter()))
}
pragma[nomagic]
private Type getFieldExprLookupType(FieldExpr fe, string name) {
result = inferRootTypeDeref(fe.getContainer()) and name = fe.getIdentifier().getText()
}
pragma[nomagic]
private Type getTupleFieldExprLookupType(FieldExpr fe, int pos) {
exists(string name |
result = getFieldExprLookupType(fe, name) and
pos = name.toInt()
)
}
pragma[nomagic]
private TupleTypeParameter resolveTupleTypeFieldExpr(FieldExpr fe) {
exists(int arity, int i |
TTuple(arity) = getTupleFieldExprLookupType(fe, i) and
result = TTupleTypeParameter(arity, i)
)
}
/**
* A matching configuration for resolving types of field expressions
* like `x.field`.
@@ -1158,15 +1188,30 @@ private module FieldExprMatchingInput implements MatchingInputSig {
}
}
abstract class Declaration extends AstNode {
private newtype TDeclaration =
TStructFieldDecl(StructField sf) or
TTupleFieldDecl(TupleField tf) or
TTupleTypeParameterDecl(TupleTypeParameter ttp)
abstract class Declaration extends TDeclaration {
TypeParameter getTypeParameter(TypeParameterPosition ppos) { none() }
abstract Type getDeclaredType(DeclarationPosition dpos, TypePath path);
abstract string toString();
abstract Location getLocation();
}
abstract private class StructOrTupleFieldDecl extends Declaration {
abstract AstNode getAstNode();
abstract TypeRepr getTypeRepr();
Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
override Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
dpos.isSelf() and
// no case for variants as those can only be destructured using pattern matching
exists(Struct s | s.getStructField(_) = this or s.getTupleField(_) = this |
exists(Struct s | this.getAstNode() = [s.getStructField(_).(AstNode), s.getTupleField(_)] |
result = TStruct(s) and
path.isEmpty()
or
@@ -1177,14 +1222,55 @@ private module FieldExprMatchingInput implements MatchingInputSig {
dpos.isField() and
result = this.getTypeRepr().(TypeMention).resolveTypeAt(path)
}
override string toString() { result = this.getAstNode().toString() }
override Location getLocation() { result = this.getAstNode().getLocation() }
}
private class StructFieldDecl extends Declaration instanceof StructField {
override TypeRepr getTypeRepr() { result = StructField.super.getTypeRepr() }
private class StructFieldDecl extends StructOrTupleFieldDecl, TStructFieldDecl {
private StructField sf;
StructFieldDecl() { this = TStructFieldDecl(sf) }
override AstNode getAstNode() { result = sf }
override TypeRepr getTypeRepr() { result = sf.getTypeRepr() }
}
private class TupleFieldDecl extends Declaration instanceof TupleField {
override TypeRepr getTypeRepr() { result = TupleField.super.getTypeRepr() }
private class TupleFieldDecl extends StructOrTupleFieldDecl, TTupleFieldDecl {
private TupleField tf;
TupleFieldDecl() { this = TTupleFieldDecl(tf) }
override AstNode getAstNode() { result = tf }
override TypeRepr getTypeRepr() { result = tf.getTypeRepr() }
}
private class TupleTypeParameterDecl extends Declaration, TTupleTypeParameterDecl {
private TupleTypeParameter ttp;
TupleTypeParameterDecl() { this = TTupleTypeParameterDecl(ttp) }
override Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
dpos.isSelf() and
(
result = ttp.getTupleType() and
path.isEmpty()
or
result = ttp and
path = TypePath::singleton(ttp)
)
or
dpos.isField() and
result = ttp and
path.isEmpty()
}
override string toString() { result = ttp.toString() }
override Location getLocation() { result = ttp.getLocation() }
}
class AccessPosition = DeclarationPosition;
@@ -1206,7 +1292,12 @@ private module FieldExprMatchingInput implements MatchingInputSig {
Declaration getTarget() {
// mutual recursion; resolving fields requires resolving types and vice versa
result = [resolveStructFieldExpr(this).(AstNode), resolveTupleFieldExpr(this)]
result =
[
TStructFieldDecl(resolveStructFieldExpr(this)).(TDeclaration),
TTupleFieldDecl(resolveTupleFieldExpr(this)),
TTupleTypeParameterDecl(resolveTupleTypeFieldExpr(this))
]
}
}
@@ -1266,42 +1357,6 @@ private Type inferFieldExprType(AstNode n, TypePath path) {
)
}
pragma[nomagic]
private Type inferTupleIndexExprType(FieldExpr fe, TypePath path) {
exists(int i, TypePath path0 |
fe.getIdentifier().getText() = i.toString() and
result = inferType(fe.getContainer(), path0) and
path0.isCons(TTupleTypeParameter(_, i), path) and
fe.getIdentifier().getText() = i.toString()
)
}
/** Infers the type of `t` in `t.n` when `t` is a tuple. */
private Type inferTupleContainerExprType(Expr e, TypePath path) {
// NOTE: For a field expression `t.n` where `n` is a number `t` might be a
// tuple as in:
// ```rust
// let t = (Default::default(), 2);
// let s: String = t.0;
// ```
// But it could also be a tuple struct as in:
// ```rust
// struct T(String, u32);
// let t = T(Default::default(), 2);
// let s: String = t.0;
// ```
// We need type information to flow from `t.n` to tuple type parameters of `t`
// in the former case but not the latter case. Hence we include the condition
// that the root type of `t` must be a tuple type.
exists(int i, TypePath path0, FieldExpr fe, int arity |
e = fe.getContainer() and
fe.getIdentifier().getText() = i.toString() and
arity = inferType(fe.getContainer()).(TupleType).getArity() and
result = inferType(fe, path0) and
path = TypePath::cons(TTupleTypeParameter(arity, i), path0)
)
}
/** Gets the root type of the reference node `ref`. */
pragma[nomagic]
private Type inferRefNodeType(AstNode ref) {
@@ -2230,20 +2285,6 @@ private module Cached {
result = resolveFunctionCallTarget(call)
}
pragma[inline]
private Type inferRootTypeDeref(AstNode n) {
result = inferType(n) and
result != TRefType()
or
// for reference types, lookup members in the type being referenced
result = inferType(n, TypePath::singleton(TRefTypeParameter()))
}
pragma[nomagic]
private Type getFieldExprLookupType(FieldExpr fe, string name) {
result = inferRootTypeDeref(fe.getContainer()) and name = fe.getIdentifier().getText()
}
/**
* Gets the struct field that the field expression `fe` resolves to, if any.
*/
@@ -2252,14 +2293,6 @@ private module Cached {
exists(string name | result = getFieldExprLookupType(fe, name).getStructField(name))
}
pragma[nomagic]
private Type getTupleFieldExprLookupType(FieldExpr fe, int pos) {
exists(string name |
result = getFieldExprLookupType(fe, name) and
pos = name.toInt()
)
}
/**
* Gets the tuple field that the field expression `fe` resolves to, if any.
*/
@@ -2341,10 +2374,6 @@ private module Cached {
or
result = inferFieldExprType(n, path)
or
result = inferTupleIndexExprType(n, path)
or
result = inferTupleContainerExprType(n, path)
or
result = inferRefNodeType(n) and
path.isEmpty()
or

View File

@@ -2487,7 +2487,7 @@ mod tuples {
let x = pair.0; // $ type=x:i32
let y = &S1::get_pair(); // $ target=get_pair
y.0.foo(); // $ MISSING: target=foo
y.0.foo(); // $ target=foo
}
}

View File

@@ -4856,6 +4856,7 @@ inferType
| main.rs:2490:9:2490:9 | y | &T | file://:0:0:0:0 | (T_2) |
| main.rs:2490:9:2490:9 | y | &T.0(2) | main.rs:2447:5:2448:16 | S1 |
| main.rs:2490:9:2490:9 | y | &T.1(2) | main.rs:2447:5:2448:16 | S1 |
| main.rs:2490:9:2490:11 | y.0 | | main.rs:2447:5:2448:16 | S1 |
| main.rs:2497:13:2497:23 | boxed_value | | {EXTERNAL LOCATION} | Box |
| main.rs:2497:13:2497:23 | boxed_value | A | {EXTERNAL LOCATION} | Global |
| main.rs:2497:13:2497:23 | boxed_value | T | {EXTERNAL LOCATION} | i32 |