Rust: Unify type inference logic for associated functions

This commit is contained in:
Tom Hvitved
2026-02-23 10:41:43 +01:00
parent ca7017f3d7
commit 1b6f3a43ef
14 changed files with 1273 additions and 1444 deletions

View File

@@ -33,5 +33,11 @@ module Impl {
result = "impl " + trait + this.getSelfTy().toAbbreviatedString() + " { ... }"
)
}
/**
* Holds if this is an inherent `impl` block, that is, one that does not implement a trait.
*/
pragma[nomagic]
predicate isInherent() { not this.hasTrait() }
}
}

View File

@@ -6,7 +6,7 @@ module Impl {
private newtype TArgumentPosition =
TPositionalArgumentPosition(int i) {
i in [0 .. max([any(ParamList l).getNumberOfParams(), any(ArgList l).getNumberOfArgs()]) - 1]
i in [0 .. max([any(ParamList l).getNumberOfParams(), any(ArgList l).getNumberOfArgs()])]
} or
TSelfArgumentPosition() or
TTypeQualifierArgumentPosition()

View File

@@ -41,16 +41,19 @@ private predicate hasFirstNonTrivialTraitBound(TypeParamItemNode tp, Trait trait
*/
pragma[nomagic]
predicate isBlanketLike(ImplItemNode i, TypePath blanketSelfPath, TypeParam blanketTypeParam) {
blanketTypeParam = i.getBlanketImplementationTypeParam() and
blanketSelfPath.isEmpty()
or
exists(TypeMention tm, Type root, TypeParameter tp |
tm = i.(Impl).getSelfTy() and
complexSelfRoot(root, tp) and
tm.getType() = root and
tm.getTypeAt(blanketSelfPath) = TTypeParamTypeParameter(blanketTypeParam) and
blanketSelfPath = TypePath::singleton(tp) and
hasFirstNonTrivialTraitBound(blanketTypeParam, _)
i.(Impl).hasTrait() and
(
blanketTypeParam = i.getBlanketImplementationTypeParam() and
blanketSelfPath.isEmpty()
or
exists(TypeMention tm, Type root, TypeParameter tp |
tm = i.(Impl).getSelfTy() and
complexSelfRoot(root, tp) and
tm.getType() = root and
tm.getTypeAt(blanketSelfPath) = TTypeParamTypeParameter(blanketTypeParam) and
blanketSelfPath = TypePath::singleton(tp) and
hasFirstNonTrivialTraitBound(blanketTypeParam, _)
)
)
}

View File

@@ -5,60 +5,112 @@ private import TypeAbstraction
private import TypeMention
private import TypeInference
private newtype TFunctionPosition =
TArgumentFunctionPosition(ArgumentPosition pos) or
TReturnFunctionPosition()
private signature predicate includeSelfSig();
// We construct `FunctionPosition` and `FunctionPositionAdj` using two different underlying
// `newtype`s in order to prevent unintended mixing of the two
private module MkFunctionPosition<includeSelfSig/0 includeSelf> {
private newtype TFunctionPosition =
TArgumentFunctionPosition(ArgumentPosition pos) {
if pos.isSelf() then includeSelf() else any()
} or
TReturnFunctionPosition()
class FunctionPosition extends TFunctionPosition {
int asPosition() { result = this.asArgumentPosition().asPosition() }
predicate isPosition() { exists(this.asPosition()) }
ArgumentPosition asArgumentPosition() { this = TArgumentFunctionPosition(result) }
predicate isTypeQualifier() { this.asArgumentPosition().isTypeQualifier() }
predicate isReturn() { this = TReturnFunctionPosition() }
TypeMention getTypeMention(Function f) {
result = f.getParam(this.asPosition()).getTypeRepr()
or
this.isReturn() and
result = getReturnTypeMention(f)
}
string toString() {
result = this.asArgumentPosition().toString()
or
this.isReturn() and
result = "(return)"
}
}
}
private predicate any_() { any() }
/**
* A position of a type related to a function.
*
* Either `self`, `return`, or a positional parameter index.
*/
class FunctionPosition extends TFunctionPosition {
final class FunctionPosition extends MkFunctionPosition<any_/0>::FunctionPosition {
predicate isSelf() { this.asArgumentPosition().isSelf() }
int asPosition() { result = this.asArgumentPosition().asPosition() }
predicate isPosition() { exists(this.asPosition()) }
ArgumentPosition asArgumentPosition() { this = TArgumentFunctionPosition(result) }
predicate isTypeQualifier() { this.asArgumentPosition().isTypeQualifier() }
predicate isSelfOrTypeQualifier() { this.isSelf() or this.isTypeQualifier() }
predicate isReturn() { this = TReturnFunctionPosition() }
/** Gets the corresponding position when `f` is invoked via a function call. */
bindingset[f]
FunctionPosition getFunctionCallAdjusted(Function f) {
this.isReturn() and
result = this
override TypeMention getTypeMention(Function f) {
result = super.getTypeMention(f)
or
if f.hasSelfParam()
then
this.isSelf() and result.asPosition() = 0
or
result.asPosition() = this.asPosition() + 1
else result = this
}
TypeMention getTypeMention(Function f) {
this.isSelf() and
result = getSelfParamTypeMention(f.getSelfParam())
or
result = f.getParam(this.asPosition()).getTypeRepr()
or
this.isReturn() and
result = getReturnTypeMention(f)
}
string toString() {
result = this.asArgumentPosition().toString()
/**
* Gets the corresponding position when function call syntax is used, assuming
* this position is for a method.
*/
pragma[nomagic]
FunctionPositionAdj getFunctionCallAdjusted() {
this.isReturn() and result.isReturn()
or
this.isReturn() and
result = "(return)"
this.isTypeQualifier() and
result.isTypeQualifier()
or
this.isSelf() and result.asPosition() = 0
or
result.asPosition() = this.asPosition() + 1
}
/**
* Gets the corresponding position when function call syntax is used, assuming
* this position is _not_ for a method.
*/
pragma[nomagic]
FunctionPositionAdj asAdjusted() {
this.isReturn() and result.isReturn()
or
this.isTypeQualifier() and
result.isTypeQualifier()
or
result.asPosition() = this.asPosition()
}
/**
* Gets the corresponding position when `f` is invoked via function call
* syntax.
*/
bindingset[f]
FunctionPositionAdj getFunctionCallAdjusted(Function f) {
if f.hasSelfParam() then result = this.getFunctionCallAdjusted() else result = this.asAdjusted()
}
}
private predicate none_() { none() }
/**
* A function-call adjust position of a type related to a function.
*
* Either `return` or a positional parameter index.
*/
final class FunctionPositionAdj extends MkFunctionPosition<none_/0>::FunctionPosition {
FunctionPosition asNonAdjusted() { this = result.asAdjusted() }
}
/**
@@ -75,6 +127,20 @@ module FunctionPositionMatchingInput {
}
}
/**
* A helper module for implementing `Matching(WithEnvironment)InputSig` with
* `DeclarationPosition = AccessPosition = FunctionPositionAdj`.
*/
module FunctionPositionAdjMatchingInput {
class DeclarationPosition = FunctionPositionAdj;
class AccessPosition = DeclarationPosition;
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
apos = dpos
}
}
private newtype TAssocFunctionType =
/** An associated function `f` in `parent` should be specialized for `i` at `pos`. */
MkAssocFunctionType(
@@ -197,8 +263,7 @@ class AssocFunctionType extends MkAssocFunctionType {
exists(Function f, ImplOrTraitItemNode i, FunctionPosition pos | this.appliesTo(f, i, pos) |
result = pos.getTypeMention(f)
or
pos.isSelf() and
not f.hasSelfParam() and
pos.isTypeQualifier() and
result = [i.(Impl).getSelfTy().(AstNode), i.(Trait).getName()]
)
}
@@ -209,7 +274,7 @@ class AssocFunctionType extends MkAssocFunctionType {
}
pragma[nomagic]
private Trait getALookupTrait(Type t) {
Trait getALookupTrait(Type t) {
result = t.(TypeParamTypeParameter).getTypeParam().(TypeParamItemNode).resolveABound()
or
result = t.(SelfTypeParameter).getTrait()
@@ -310,12 +375,13 @@ signature module ArgsAreInstantiationsOfInputSig {
* Holds if `f` inside `i` needs to have the type corresponding to type parameter
* `tp` checked.
*
* If `i` is an inherent implementation, `tp` is a type parameter of the type being
* implemented, otherwise `tp` is a type parameter of the trait (being implemented).
* `tp` is a type parameter of the trait being implemented by `f` or the trait to which
* `f` belongs.
*
* `pos` is one of the positions in `f` in which the relevant type occours.
* `posAdj` is one of the function-call adjusted positions in `f` in which the relevant
* type occurs.
*/
predicate toCheck(ImplOrTraitItemNode i, Function f, TypeParameter tp, FunctionPosition pos);
predicate toCheck(ImplOrTraitItemNode i, Function f, TypeParameter tp, FunctionPositionAdj posAdj);
/** A call whose argument types are to be checked. */
class Call {
@@ -323,7 +389,7 @@ signature module ArgsAreInstantiationsOfInputSig {
Location getLocation();
Type getArgType(FunctionPosition pos, TypePath path);
Type getArgType(FunctionPositionAdj posAdj, TypePath path);
predicate hasTargetCand(ImplOrTraitItemNode i, Function f);
}
@@ -337,9 +403,9 @@ signature module ArgsAreInstantiationsOfInputSig {
module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
pragma[nomagic]
private predicate toCheckRanked(
ImplOrTraitItemNode i, Function f, TypeParameter tp, FunctionPosition pos, int rnk
ImplOrTraitItemNode i, Function f, TypeParameter tp, FunctionPositionAdj posAdj, int rnk
) {
Input::toCheck(i, f, tp, pos) and
Input::toCheck(i, f, tp, posAdj) and
tp =
rank[rnk + 1](TypeParameter tp0, int j |
Input::toCheck(i, f, tp0, _) and
@@ -351,53 +417,59 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
pragma[nomagic]
private predicate toCheck(
ImplOrTraitItemNode i, Function f, TypeParameter tp, FunctionPosition pos, AssocFunctionType t
ImplOrTraitItemNode i, Function f, TypeParameter tp, FunctionPositionAdj posAdj,
AssocFunctionType t
) {
Input::toCheck(i, f, tp, pos) and
t.appliesTo(f, i, pos)
exists(FunctionPosition pos |
Input::toCheck(i, f, tp, posAdj) and
t.appliesTo(f, i, pos) and
posAdj = pos.getFunctionCallAdjusted(f)
)
}
private newtype TCallAndPos =
MkCallAndPos(Input::Call call, FunctionPosition pos) { exists(call.getArgType(pos, _)) }
private newtype TCallAndPosAdj =
MkCallAndPosAdj(Input::Call call, FunctionPositionAdj posAdj) {
exists(call.getArgType(posAdj, _))
}
/** A call tagged with a position. */
private class CallAndPos extends MkCallAndPos {
/** A call tagged with a function-call adjusted position. */
private class CallAndPosAdj extends MkCallAndPosAdj {
Input::Call call;
FunctionPosition pos;
FunctionPositionAdj posAdj;
CallAndPos() { this = MkCallAndPos(call, pos) }
CallAndPosAdj() { this = MkCallAndPosAdj(call, posAdj) }
Input::Call getCall() { result = call }
FunctionPosition getPos() { result = pos }
FunctionPositionAdj getPosAdj() { result = posAdj }
Location getLocation() { result = call.getLocation() }
Type getTypeAt(TypePath path) { result = call.getArgType(pos, path) }
Type getTypeAt(TypePath path) { result = call.getArgType(posAdj, path) }
string toString() { result = call.toString() + " [arg " + pos + "]" }
string toString() { result = call.toString() + " [arg " + posAdj + "]" }
}
pragma[nomagic]
private predicate potentialInstantiationOf0(
CallAndPos cp, Input::Call call, TypeParameter tp, FunctionPosition pos, Function f,
CallAndPosAdj cp, Input::Call call, TypeParameter tp, FunctionPositionAdj posAdj, Function f,
TypeAbstraction abs, AssocFunctionType constraint
) {
cp = MkCallAndPos(call, pragma[only_bind_into](pos)) and
cp = MkCallAndPosAdj(call, pragma[only_bind_into](posAdj)) and
call.hasTargetCand(abs, f) and
toCheck(abs, f, tp, pragma[only_bind_into](pos), constraint)
toCheck(abs, f, tp, pragma[only_bind_into](posAdj), constraint)
}
private module ArgIsInstantiationOfToIndexInput implements
IsInstantiationOfInputSig<CallAndPos, AssocFunctionType>
IsInstantiationOfInputSig<CallAndPosAdj, AssocFunctionType>
{
pragma[nomagic]
predicate potentialInstantiationOf(
CallAndPos cp, TypeAbstraction abs, AssocFunctionType constraint
CallAndPosAdj cp, TypeAbstraction abs, AssocFunctionType constraint
) {
exists(Input::Call call, TypeParameter tp, FunctionPosition pos, int rnk, Function f |
potentialInstantiationOf0(cp, call, tp, pos, f, abs, constraint) and
toCheckRanked(abs, f, tp, pos, rnk)
exists(Input::Call call, TypeParameter tp, FunctionPositionAdj posAdj, int rnk, Function f |
potentialInstantiationOf0(cp, call, tp, posAdj, f, abs, constraint) and
toCheckRanked(abs, f, tp, posAdj, rnk)
|
rnk = 0
or
@@ -409,24 +481,25 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
}
private module ArgIsInstantiationOfToIndex =
ArgIsInstantiationOf<CallAndPos, ArgIsInstantiationOfToIndexInput>;
ArgIsInstantiationOf<CallAndPosAdj, ArgIsInstantiationOfToIndexInput>;
pragma[nomagic]
private predicate argIsInstantiationOf(
Input::Call call, FunctionPosition pos, ImplOrTraitItemNode i, Function f, int rnk
Input::Call call, ImplOrTraitItemNode i, Function f, int rnk
) {
ArgIsInstantiationOfToIndex::argIsInstantiationOf(MkCallAndPos(call, pos), i, _) and
toCheckRanked(i, f, _, pos, rnk)
exists(FunctionPositionAdj posAdj |
ArgIsInstantiationOfToIndex::argIsInstantiationOf(MkCallAndPosAdj(call, posAdj), i, _) and
toCheckRanked(i, f, _, posAdj, rnk)
)
}
pragma[nomagic]
private predicate argsAreInstantiationsOfToIndex(
Input::Call call, ImplOrTraitItemNode i, Function f, int rnk
) {
exists(FunctionPosition pos |
argIsInstantiationOf(call, pos, i, f, rnk) and
call.hasTargetCand(i, f)
|
argIsInstantiationOf(call, i, f, rnk) and
call.hasTargetCand(i, f) and
(
rnk = 0
or
argsAreInstantiationsOfToIndex(call, i, f, rnk - 1)
@@ -448,11 +521,11 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
}
private module ArgsAreNotInstantiationOfInput implements
IsInstantiationOfInputSig<CallAndPos, AssocFunctionType>
IsInstantiationOfInputSig<CallAndPosAdj, AssocFunctionType>
{
pragma[nomagic]
predicate potentialInstantiationOf(
CallAndPos cp, TypeAbstraction abs, AssocFunctionType constraint
CallAndPosAdj cp, TypeAbstraction abs, AssocFunctionType constraint
) {
potentialInstantiationOf0(cp, _, _, _, _, abs, constraint)
}
@@ -461,13 +534,13 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
}
private module ArgsAreNotInstantiationOf =
ArgIsInstantiationOf<CallAndPos, ArgsAreNotInstantiationOfInput>;
ArgIsInstantiationOf<CallAndPosAdj, ArgsAreNotInstantiationOfInput>;
pragma[nomagic]
private predicate argsAreNotInstantiationsOf0(
Input::Call call, FunctionPosition pos, ImplOrTraitItemNode i
Input::Call call, FunctionPositionAdj posAdj, ImplOrTraitItemNode i
) {
ArgsAreNotInstantiationOf::argIsNotInstantiationOf(MkCallAndPos(call, pos), i, _, _)
ArgsAreNotInstantiationOf::argIsNotInstantiationOf(MkCallAndPosAdj(call, posAdj), i, _, _)
}
/**
@@ -478,10 +551,10 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
*/
pragma[nomagic]
predicate argsAreNotInstantiationsOf(Input::Call call, ImplOrTraitItemNode i, Function f) {
exists(FunctionPosition pos |
argsAreNotInstantiationsOf0(call, pos, i) and
exists(FunctionPositionAdj posAdj |
argsAreNotInstantiationsOf0(call, posAdj, i) and
call.hasTargetCand(i, f) and
Input::toCheck(i, f, _, pos)
Input::toCheck(i, f, _, posAdj)
)
}
}