Rust: Type inference for .await expressions

This commit is contained in:
Tom Hvitved
2025-05-26 21:25:07 +02:00
parent e6109cfcf1
commit 821f2fd681
6 changed files with 229 additions and 14 deletions

View File

@@ -49,3 +49,19 @@ class ResultEnum extends Enum {
/** Gets the `Err` variant. */
Variant getErr() { result = this.getVariant("Err") }
}
/**
* The [`Future` trait][1].
*
* [1]: https://doc.rust-lang.org/std/future/trait.Future.html
*/
class FutureTrait extends Trait {
FutureTrait() { this.getCanonicalPath() = "core::future::future::Future" }
/** Gets the `Output` associated type. */
pragma[nomagic]
TypeAlias getOutputType() {
result = this.getAssocItemList().getAnAssocItem() and
result.getName().getText() = "Output"
}
}

View File

@@ -15,10 +15,14 @@ newtype TType =
TTrait(Trait t) or
TArrayType() or // todo: add size?
TRefType() or // todo: add mut?
TImplTraitType(int bounds) {
bounds = any(ImplTraitTypeRepr impl).getTypeBoundList().getNumberOfBounds()
} or
TTypeParamTypeParameter(TypeParam t) or
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
TRefTypeParameter() or
TSelfTypeParameter(Trait t)
TSelfTypeParameter(Trait t) or
TImplTraitTypeParameter(ImplTraitType t, int i) { i in [0 .. t.getNumberOfBounds() - 1] }
/**
* A type without type arguments.
@@ -115,6 +119,9 @@ class TraitType extends Type, TTrait {
TraitType() { this = TTrait(trait) }
/** Gets the underlying trait. */
Trait getTrait() { result = trait }
override StructField getStructField(string name) { none() }
override TupleField getTupleField(int i) { none() }
@@ -176,6 +183,33 @@ class RefType extends Type, TRefType {
override Location getLocation() { result instanceof EmptyLocation }
}
/**
* An [`impl Trait`][1] type.
*
* We represent `impl Trait` types as generic types with as many type parameters
* as there are bounds.
*
* [1] https://doc.rust-lang.org/book/ch10-02-traits.html#traits-as-parameters
*/
class ImplTraitType extends Type, TImplTraitType {
private int bounds;
ImplTraitType() { this = TImplTraitType(bounds) }
/** Gets the number of bounds of this `impl Trait` type. */
int getNumberOfBounds() { result = bounds }
override StructField getStructField(string name) { none() }
override TupleField getTupleField(int i) { none() }
override TypeParameter getTypeParameter(int i) { result = TImplTraitTypeParameter(this, i) }
override string toString() { result = "impl Trait ..." }
override Location getLocation() { result instanceof EmptyLocation }
}
/** A type parameter. */
abstract class TypeParameter extends Type {
override StructField getStructField(string name) { none() }
@@ -281,6 +315,26 @@ class SelfTypeParameter extends TypeParameter, TSelfTypeParameter {
override Location getLocation() { result = trait.getLocation() }
}
/**
* An `impl Trait` type parameter.
*/
class ImplTraitTypeParameter extends TypeParameter, TImplTraitTypeParameter {
private ImplTraitType implTraitType;
private int i;
ImplTraitTypeParameter() { this = TImplTraitTypeParameter(implTraitType, i) }
/** Gets the `impl Trait` type that this parameter belongs to. */
ImplTraitType getImplTraitType() { result = implTraitType }
/** Gets the index of this type parameter. */
int getIndex() { result = i }
override string toString() { result = "impl Trait<" + i.toString() + ">" }
override Location getLocation() { result instanceof EmptyLocation }
}
/**
* A type abstraction. I.e., a place in the program where type variables are
* introduced.

View File

@@ -77,6 +77,16 @@ private module Input1 implements InputSig1<Location> {
apos.asMethodTypeArgumentPosition() = ppos.asTypeParam().getPosition()
}
private int getImplTraitTypeParameterId(ImplTraitTypeParameter tp) {
tp =
rank[result](ImplTraitTypeParameter tp0, int bounds, int i |
bounds = tp0.getImplTraitType().getNumberOfBounds() and
i = tp0.getIndex()
|
tp0 order by bounds, i
)
}
int getTypeParameterId(TypeParameter tp) {
tp =
rank[result](TypeParameter tp0, int kind, int id |
@@ -90,6 +100,9 @@ private module Input1 implements InputSig1<Location> {
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
node = tp0.(SelfTypeParameter).getTrait()
)
or
kind = 2 and
id = getImplTraitTypeParameterId(tp0)
|
tp0 order by kind, id
)
@@ -228,7 +241,11 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
or
n1 = n2.(ParenExpr).getExpr()
or
n1 = n2.(BlockExpr).getStmtList().getTailExpr()
n2 =
any(BlockExpr be |
not be.isAsync() and
n1 = be.getStmtList().getTailExpr()
)
or
n1 = n2.(IfExpr).getABranch()
or
@@ -1010,6 +1027,29 @@ private StructType inferLiteralType(LiteralExpr le) {
)
}
pragma[nomagic]
private AssociatedTypeTypeParameter getFutureOutputTypeParameter() {
result.getTypeAlias() = any(FutureTrait ft).getOutputType()
}
pragma[nomagic]
private Type inferAwaitExprType(AwaitExpr ae, TypePath path) {
exists(TypePath exprPath | result = inferType(ae.getExpr(), exprPath) |
exprPath
.isCons(TImplTraitTypeParameter(_, _),
any(TypePath path0 | path0.isCons(getFutureOutputTypeParameter(), path)))
or
path = exprPath and
not (
exprPath = TypePath::singleton(TImplTraitTypeParameter(_, _)) and
result.(TraitType).getTrait() instanceof FutureTrait
) and
not exprPath
.isCons(TImplTraitTypeParameter(_, _),
any(TypePath path0 | path0.isCons(getFutureOutputTypeParameter(), _)))
)
}
private module MethodCall {
/** An expression that calls a method. */
abstract private class MethodCallImpl extends Expr {
@@ -1119,12 +1159,17 @@ private predicate methodCandidateTrait(Type type, Trait trait, string name, int
}
private module IsInstantiationOfInput implements IsInstantiationOfInputSig<MethodCall> {
pragma[nomagic]
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
rootType = mc.getTypeAt(TypePath::nil()) and
name = mc.getMethodName() and
arity = mc.getArity()
}
pragma[nomagic]
predicate potentialInstantiationOf(MethodCall mc, TypeAbstraction impl, TypeMention constraint) {
exists(Type rootType, string name, int arity |
rootType = mc.getTypeAt(TypePath::nil()) and
name = mc.getMethodName() and
arity = mc.getArity() and
isMethodCall(mc, rootType, name, arity) and
constraint = impl.(ImplTypeAbstraction).getSelfTy()
|
methodCandidateTrait(rootType, mc.getTrait(), name, arity, impl)
@@ -1161,6 +1206,12 @@ private Function getMethodFromImpl(MethodCall mc) {
)
}
bindingset[trait, name]
pragma[inline_late]
private Function getTraitMethod(TraitType trait, string name) {
result = getMethodSuccessor(trait.getTrait(), name)
}
/**
* Gets a method that the method call `mc` resolves to based on type inference,
* if any.
@@ -1172,6 +1223,11 @@ private Function inferMethodCallTarget(MethodCall mc) {
// The type of the receiver is a type parameter and the method comes from a
// trait bound on the type parameter.
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
or
// The type of the receiver is an `impl Trait` type.
result =
getTraitMethod(mc.getTypeAt(TypePath::singleton(TImplTraitTypeParameter(_, _))),
mc.getMethodName())
}
cached
@@ -1347,6 +1403,8 @@ private module Cached {
or
result = inferLiteralType(n) and
path.isEmpty()
or
result = inferAwaitExprType(n, path)
}
}
@@ -1363,7 +1421,7 @@ private module Debug {
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
filepath.matches("%/main.rs") and
startline = 948
startline = 1334
)
}

View File

@@ -15,7 +15,7 @@ abstract class TypeMention extends AstNode {
/** Gets the sub mention at `path`. */
pragma[nomagic]
private TypeMention getMentionAt(TypePath path) {
TypeMention getMentionAt(TypePath path) {
path.isEmpty() and
result = this
or
@@ -150,6 +150,54 @@ class PathTypeReprMention extends TypeMention instanceof PathTypeRepr {
not exists(resolved.(TypeAlias).getTypeRepr()) and
result = super.resolveTypeAt(typePath)
}
pragma[nomagic]
private TypeAlias getResolvedTraitAlias(string name) {
exists(TraitItemNode trait |
trait = resolvePath(path) and
result = trait.getAnAssocItem() and
name = result.getName().getText()
)
}
pragma[nomagic]
private TypeRepr getAssocTypeArg(string name) {
exists(AssocTypeArg arg |
arg = path.getSegment().getGenericArgList().getAGenericArg() and
result = arg.getTypeRepr() and
name = arg.getIdentifier().getText()
)
}
/** Gets the type argument for the associated type `alias`, if any. */
pragma[nomagic]
private TypeRepr getAnAssocTypeArgument(TypeAlias alias) {
exists(string name |
alias = this.getResolvedTraitAlias(name) and
result = this.getAssocTypeArg(name)
)
}
override TypeMention getMentionAt(TypePath tp) {
result = super.getMentionAt(tp)
or
exists(TypeAlias alias, AssociatedTypeTypeParameter attp, TypeMention arg, TypePath suffix |
arg = this.getAnAssocTypeArgument(alias) and
result = arg.getMentionAt(suffix) and
tp = TypePath::cons(attp, suffix) and
attp.getTypeAlias() = alias
)
}
}
class ImplTraitTypeReprMention extends TypeMention instanceof ImplTraitTypeRepr {
override TypeMention getTypeArgument(int i) {
result = super.getTypeBoundList().getBound(i).getTypeRepr()
}
override ImplTraitType resolveType() {
result.getNumberOfBounds() = super.getTypeBoundList().getNumberOfBounds()
}
}
private TypeParameter pathGetTypeParameter(TypeAlias alias, int i) {

View File

@@ -1664,9 +1664,9 @@ mod async_ {
}
pub async fn f() {
f1().await.f(); // $ MISSING: method=S1f
f2().await.f(); // $ MISSING: method=S1f
f3().await.f(); // $ MISSING: method=S1f
f1().await.f(); // $ method=S1f
f2().await.f(); // $ method=S1f
f3().await.f(); // $ method=S1f
}
}
@@ -1696,8 +1696,8 @@ mod impl_trait {
pub fn f() {
let x = f1();
x.f1(); // $ MISSING: method=Trait1f1
x.f2(); // $ MISSING: method=Trait2f2
x.f1(); // $ method=Trait1f1
x.f2(); // $ method=Trait2f2
}
}

View File

@@ -2377,8 +2377,12 @@ inferType
| main.rs:1639:18:1639:21 | SelfParam | | main.rs:1636:5:1636:14 | S1 |
| main.rs:1642:25:1644:5 | { ... } | | main.rs:1636:5:1636:14 | S1 |
| main.rs:1643:9:1643:10 | S1 | | main.rs:1636:5:1636:14 | S1 |
| main.rs:1646:41:1650:5 | { ... } | | main.rs:1636:5:1636:14 | S1 |
| main.rs:1647:9:1649:9 | { ... } | | main.rs:1636:5:1636:14 | S1 |
| main.rs:1646:41:1650:5 | { ... } | | file://:0:0:0:0 | impl Trait ... |
| main.rs:1646:41:1650:5 | { ... } | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
| main.rs:1646:41:1650:5 | { ... } | impl Trait<0>.Output | main.rs:1636:5:1636:14 | S1 |
| main.rs:1647:9:1649:9 | { ... } | | file://:0:0:0:0 | impl Trait ... |
| main.rs:1647:9:1649:9 | { ... } | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
| main.rs:1647:9:1649:9 | { ... } | impl Trait<0>.Output | main.rs:1636:5:1636:14 | S1 |
| main.rs:1648:13:1648:14 | S1 | | main.rs:1636:5:1636:14 | S1 |
| main.rs:1657:17:1657:46 | SelfParam | | {EXTERNAL LOCATION} | Pin |
| main.rs:1657:17:1657:46 | SelfParam | Ptr | file://:0:0:0:0 | & |
@@ -2390,9 +2394,26 @@ inferType
| main.rs:1658:13:1658:38 | ...::Ready(...) | | {EXTERNAL LOCATION} | Poll |
| main.rs:1658:13:1658:38 | ...::Ready(...) | T | main.rs:1636:5:1636:14 | S1 |
| main.rs:1658:36:1658:37 | S1 | | main.rs:1636:5:1636:14 | S1 |
| main.rs:1662:41:1664:5 | { ... } | | file://:0:0:0:0 | impl Trait ... |
| main.rs:1662:41:1664:5 | { ... } | | main.rs:1652:5:1652:14 | S2 |
| main.rs:1662:41:1664:5 | { ... } | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
| main.rs:1662:41:1664:5 | { ... } | impl Trait<0>.Output | main.rs:1636:5:1636:14 | S1 |
| main.rs:1663:9:1663:10 | S2 | | file://:0:0:0:0 | impl Trait ... |
| main.rs:1663:9:1663:10 | S2 | | main.rs:1652:5:1652:14 | S2 |
| main.rs:1663:9:1663:10 | S2 | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
| main.rs:1663:9:1663:10 | S2 | impl Trait<0>.Output | main.rs:1636:5:1636:14 | S1 |
| main.rs:1667:9:1667:12 | f1(...) | | main.rs:1636:5:1636:14 | S1 |
| main.rs:1667:9:1667:18 | await ... | | main.rs:1636:5:1636:14 | S1 |
| main.rs:1668:9:1668:12 | f2(...) | | file://:0:0:0:0 | impl Trait ... |
| main.rs:1668:9:1668:12 | f2(...) | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
| main.rs:1668:9:1668:12 | f2(...) | impl Trait<0>.Output | main.rs:1636:5:1636:14 | S1 |
| main.rs:1668:9:1668:18 | await ... | | file://:0:0:0:0 | impl Trait ... |
| main.rs:1668:9:1668:18 | await ... | | main.rs:1636:5:1636:14 | S1 |
| main.rs:1669:9:1669:12 | f3(...) | | file://:0:0:0:0 | impl Trait ... |
| main.rs:1669:9:1669:12 | f3(...) | impl Trait<0> | file:///RUSTUP_HOME/toolchain/lib/rustlib/src/rust/library/core/src/future/future.rs:7:1:105:1 | trait Future |
| main.rs:1669:9:1669:12 | f3(...) | impl Trait<0>.Output | main.rs:1636:5:1636:14 | S1 |
| main.rs:1669:9:1669:18 | await ... | | file://:0:0:0:0 | impl Trait ... |
| main.rs:1669:9:1669:18 | await ... | | main.rs:1636:5:1636:14 | S1 |
| main.rs:1678:15:1678:19 | SelfParam | | file://:0:0:0:0 | & |
| main.rs:1678:15:1678:19 | SelfParam | &T | main.rs:1677:5:1679:5 | Self [trait Trait1] |
| main.rs:1682:15:1682:19 | SelfParam | | file://:0:0:0:0 | & |
@@ -2401,8 +2422,26 @@ inferType
| main.rs:1686:15:1686:19 | SelfParam | &T | main.rs:1675:5:1675:14 | S1 |
| main.rs:1690:15:1690:19 | SelfParam | | file://:0:0:0:0 | & |
| main.rs:1690:15:1690:19 | SelfParam | &T | main.rs:1675:5:1675:14 | S1 |
| main.rs:1693:37:1695:5 | { ... } | | file://:0:0:0:0 | impl Trait ... |
| main.rs:1693:37:1695:5 | { ... } | | main.rs:1675:5:1675:14 | S1 |
| main.rs:1693:37:1695:5 | { ... } | impl Trait<0> | main.rs:1677:5:1679:5 | trait Trait1 |
| main.rs:1693:37:1695:5 | { ... } | impl Trait<1> | main.rs:1681:5:1683:5 | trait Trait2 |
| main.rs:1694:9:1694:10 | S1 | | file://:0:0:0:0 | impl Trait ... |
| main.rs:1694:9:1694:10 | S1 | | main.rs:1675:5:1675:14 | S1 |
| main.rs:1694:9:1694:10 | S1 | impl Trait<0> | main.rs:1677:5:1679:5 | trait Trait1 |
| main.rs:1694:9:1694:10 | S1 | impl Trait<1> | main.rs:1681:5:1683:5 | trait Trait2 |
| main.rs:1698:13:1698:13 | x | | file://:0:0:0:0 | impl Trait ... |
| main.rs:1698:13:1698:13 | x | impl Trait<0> | main.rs:1677:5:1679:5 | trait Trait1 |
| main.rs:1698:13:1698:13 | x | impl Trait<1> | main.rs:1681:5:1683:5 | trait Trait2 |
| main.rs:1698:17:1698:20 | f1(...) | | file://:0:0:0:0 | impl Trait ... |
| main.rs:1698:17:1698:20 | f1(...) | impl Trait<0> | main.rs:1677:5:1679:5 | trait Trait1 |
| main.rs:1698:17:1698:20 | f1(...) | impl Trait<1> | main.rs:1681:5:1683:5 | trait Trait2 |
| main.rs:1699:9:1699:9 | x | | file://:0:0:0:0 | impl Trait ... |
| main.rs:1699:9:1699:9 | x | impl Trait<0> | main.rs:1677:5:1679:5 | trait Trait1 |
| main.rs:1699:9:1699:9 | x | impl Trait<1> | main.rs:1681:5:1683:5 | trait Trait2 |
| main.rs:1700:9:1700:9 | x | | file://:0:0:0:0 | impl Trait ... |
| main.rs:1700:9:1700:9 | x | impl Trait<0> | main.rs:1677:5:1679:5 | trait Trait1 |
| main.rs:1700:9:1700:9 | x | impl Trait<1> | main.rs:1681:5:1683:5 | trait Trait2 |
| main.rs:1706:5:1706:20 | ...::f(...) | | main.rs:67:5:67:21 | Foo |
| main.rs:1707:5:1707:60 | ...::g(...) | | main.rs:67:5:67:21 | Foo |
| main.rs:1707:20:1707:38 | ...::Foo {...} | | main.rs:67:5:67:21 | Foo |