Merge pull request #21170 from paldepind/rust/type-inference-fns

Rust: Improve type inference for closures and function traits
This commit is contained in:
Tom Hvitved
2026-01-20 11:52:10 +01:00
committed by GitHub
7 changed files with 426 additions and 112 deletions

View File

@@ -143,23 +143,48 @@ class FutureTrait extends Trait {
TypeAlias getOutputType() { result = this.(TraitItemNode).getAssocItem("Output") }
}
/** A function trait `FnOnce`, `FnMut`, or `Fn`. */
abstract private class AnyFnTraitImpl extends Trait {
/** Gets the `Args` type parameter of this trait. */
TypeParam getTypeParam() { result = this.getGenericParamList().getGenericParam(0) }
}
final class AnyFnTrait = AnyFnTraitImpl;
/**
* The [`FnOnce` trait][1].
*
* [1]: https://doc.rust-lang.org/std/ops/trait.FnOnce.html
*/
class FnOnceTrait extends Trait {
class FnOnceTrait extends AnyFnTraitImpl {
pragma[nomagic]
FnOnceTrait() { this.getCanonicalPath() = "core::ops::function::FnOnce" }
/** Gets the type parameter of this trait. */
TypeParam getTypeParam() { result = this.getGenericParamList().getGenericParam(0) }
/** Gets the `Output` associated type. */
pragma[nomagic]
TypeAlias getOutputType() { result = this.(TraitItemNode).getAssocItem("Output") }
}
/**
* The [`FnMut` trait][1].
*
* [1]: https://doc.rust-lang.org/std/ops/trait.FnMut.html
*/
class FnMutTrait extends AnyFnTraitImpl {
pragma[nomagic]
FnMutTrait() { this.getCanonicalPath() = "core::ops::function::FnMut" }
}
/**
* The [`Fn` trait][1].
*
* [1]: https://doc.rust-lang.org/std/ops/trait.Fn.html
*/
class FnTrait extends AnyFnTraitImpl {
pragma[nomagic]
FnTrait() { this.getCanonicalPath() = "core::ops::function::Fn" }
}
/**
* The [`Iterator` trait][1].
*

View File

@@ -3827,16 +3827,29 @@ private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
_, path, result)
}
/** Gets the path to a closure's return type. */
private TypePath closureReturnPath() {
result = TypePath::singleton(getDynTraitTypeParameter(any(FnOnceTrait t).getOutputType()))
/**
* Gets the root type of a closure.
*
* We model closures as `dyn Fn` trait object types. A closure might implement
* only `Fn`, `FnMut`, or `FnOnce`. But since `Fn` is a subtrait of the others,
* giving closures the type `dyn Fn` works well in practice -- even if not
* entirely accurate.
*/
private DynTraitType closureRootType() {
result = TDynTraitType(any(FnTrait t)) // always exists because of the mention in `builtins/mentions.rs`
}
/** Gets the path to a closure with arity `arity`s `index`th parameter type. */
/** Gets the path to a closure's return type. */
private TypePath closureReturnPath() {
result =
TypePath::singleton(TDynTraitTypeParameter(any(FnTrait t), any(FnOnceTrait t).getOutputType()))
}
/** Gets the path to a closure with arity `arity`'s `index`th parameter type. */
pragma[nomagic]
private TypePath closureParameterPath(int arity, int index) {
result =
TypePath::cons(TDynTraitTypeParameter(_, any(FnOnceTrait t).getTypeParam()),
TypePath::cons(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam()),
TypePath::singleton(getTupleTypeParameter(arity, index)))
}
@@ -3874,9 +3887,7 @@ private Type inferDynamicCallExprType(Expr n, TypePath path) {
or
// _If_ the invoked expression has the type of a closure, then we propagate
// the surrounding types into the closure.
exists(int arity, TypePath path0 |
ce.getTypeAt(TypePath::nil()).(DynTraitType).getTrait() instanceof FnOnceTrait
|
exists(int arity, TypePath path0 | ce.getTypeAt(TypePath::nil()) = closureRootType() |
// Propagate the type of arguments to the parameter types of closure
exists(int index, ArgList args |
n = ce and
@@ -3900,10 +3911,10 @@ private Type inferClosureExprType(AstNode n, TypePath path) {
exists(ClosureExpr ce |
n = ce and
path.isEmpty() and
result = TDynTraitType(any(FnOnceTrait t)) // always exists because of the mention in `builtins/mentions.rs`
result = closureRootType()
or
n = ce and
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnOnceTrait t).getTypeParam())) and
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
result.(TupleType).getArity() = ce.getNumberOfParams()
or
// Propagate return type annotation to body

View File

@@ -213,7 +213,7 @@ class NonAliasPathTypeMention extends PathTypeMention {
// associated types of `Fn` and `FnMut` yet.
//
// [1]: https://doc.rust-lang.org/reference/paths.html#grammar-TypePathFn
exists(FnOnceTrait t, PathSegment s |
exists(AnyFnTrait t, PathSegment s |
t = resolved and
s = this.getSegment() and
s.hasParenthesizedArgList()
@@ -221,7 +221,7 @@ class NonAliasPathTypeMention extends PathTypeMention {
tp = TTypeParamTypeParameter(t.getTypeParam()) and
result = s.getParenthesizedArgList().(TypeMention).resolveTypeAt(path)
or
tp = TAssociatedTypeTypeParameter(t, t.getOutputType()) and
tp = TAssociatedTypeTypeParameter(t, any(FnOnceTrait tr).getOutputType()) and
(
result = s.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path)
or