Rust: Use 'infer' instead of 'resolve' in type inference library

This commit is contained in:
Tom Hvitved
2025-03-13 13:34:43 +01:00
parent 2394f2fab8
commit 78280af570
5 changed files with 72 additions and 72 deletions

View File

@@ -119,7 +119,7 @@ private TypeMention getTypeAnnotation(AstNode n) {
/** Gets the type of `n`, which has an explicit type annotation. */
pragma[nomagic]
private Type resolveAnnotatedType(AstNode n, TypePath path) {
private Type inferAnnotatedType(AstNode n, TypePath path) {
result = getTypeAnnotation(n).resolveTypeAt(path)
}
@@ -159,8 +159,8 @@ private predicate typeSymmetry(AstNode n1, TypePath path1, AstNode n2, TypePath
}
pragma[nomagic]
private Type resolveTypeSymmetry(AstNode n, TypePath path) {
exists(AstNode n2, TypePath path2 | result = resolveType(n2, path2) |
private Type inferTypeSymmetry(AstNode n, TypePath path) {
exists(AstNode n2, TypePath path2 | result = inferType(n2, path2) |
typeSymmetry(n, path, n2, path2)
or
typeSymmetry(n2, path2, n, path)
@@ -193,12 +193,12 @@ private Type getRefAdjustImplicitSelfType(SelfParam self, TypePath suffix, Type
}
pragma[nomagic]
private Type resolveImplSelfType(Impl i, TypePath path) {
private Type inferImplSelfType(Impl i, TypePath path) {
result = i.getSelfTy().(TypeReprMention).resolveTypeAt(path)
}
pragma[nomagic]
private Type resolveTraitSelfType(Trait t, TypePath path) {
private Type inferTraitSelfType(Trait t, TypePath path) {
result = TTrait(t) and
path.isEmpty()
or
@@ -208,15 +208,15 @@ private Type resolveTraitSelfType(Trait t, TypePath path) {
/** Gets the type at `path` of the implicitly typed `self` parameter. */
pragma[nomagic]
private Type resolveImplicitSelfType(SelfParam self, TypePath path) {
private Type inferImplicitSelfType(SelfParam self, TypePath path) {
exists(ImplOrTraitItemNode i, Function f, TypePath suffix, Type t |
f = i.getAnAssocItem() and
self = f.getParamList().getSelfParam() and
result = getRefAdjustImplicitSelfType(self, suffix, t, path)
|
t = resolveImplSelfType(i, suffix)
t = inferImplSelfType(i, suffix)
or
t = resolveTraitSelfType(i, suffix)
t = inferTraitSelfType(i, suffix)
)
}
@@ -327,8 +327,8 @@ private module RecordExprMatchingInput implements MatchingInputSig {
apos.isRecordPos()
}
Type getResolvedType(AccessPosition apos, TypePath path) {
result = resolveType(this.getNodeAt(apos), path)
Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
}
Declaration getTarget() { result = resolvePath(this.getPath()) }
@@ -346,15 +346,15 @@ private module RecordExprMatching = Matching<RecordExprMatchingInput>;
* a field expression of a record expression.
*/
pragma[nomagic]
private Type resolveRecordExprType(AstNode n, TypePath path) {
private Type inferRecordExprType(AstNode n, TypePath path) {
exists(RecordExprMatchingInput::Access a, RecordExprMatchingInput::AccessPosition apos |
n = a.getNodeAt(apos) and
result = RecordExprMatching::resolveAccessType(a, apos, path)
result = RecordExprMatching::inferAccessType(a, apos, path)
)
}
pragma[nomagic]
private Type resolvePathExprType(PathExpr pe, TypePath path) {
private Type inferPathExprType(PathExpr pe, TypePath path) {
// nullary struct/variant constructors
not exists(CallExpr ce | pe = ce.getFunction()) and
path.isEmpty() and
@@ -466,7 +466,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}
pragma[nomagic]
private Type resolveAnnotatedTypeInclSelf(AstNode n, TypePath path) {
private Type inferAnnotatedTypeInclSelf(AstNode n, TypePath path) {
result = getTypeAnnotation(n).resolveTypeAtInclSelf(path)
}
@@ -477,7 +477,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
exists(Param p, int i, boolean inMethod |
paramPos(this.getParamList(), p, i, inMethod) and
dpos = TPositionalDeclarationPosition(i, inMethod) and
result = resolveAnnotatedTypeInclSelf(p.getPat(), path)
result = inferAnnotatedTypeInclSelf(p.getPat(), path)
)
or
exists(SelfParam self |
@@ -485,10 +485,10 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
dpos.isSelf()
|
// `self` parameter with type annotation
result = resolveAnnotatedTypeInclSelf(self, path)
result = inferAnnotatedTypeInclSelf(self, path)
or
// `self` parameter without type annotation
result = resolveImplicitSelfType(self, path)
result = inferImplicitSelfType(self, path)
or
// `self` parameter without type annotation should also have the special `Self` type
result = getRefAdjustImplicitSelfType(self, TypePath::nil(), TSelfTypeParameter(), path)
@@ -559,8 +559,8 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
apos = TReturnAccessPosition()
}
Type getResolvedType(AccessPosition apos, TypePath path) {
result = resolveType(this.getNodeAt(apos), path)
Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
}
Declaration getTarget() {
@@ -644,9 +644,9 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}
pragma[nomagic]
additional Type resolveReceiverType(AstNode n) {
additional Type inferReceiverType(AstNode n) {
exists(Access a, AccessPosition apos |
result = resolveType(n) and
result = inferType(n) and
n = a.getNodeAt(apos) and
apos.isSelf()
)
@@ -660,17 +660,17 @@ private module CallExprBaseMatching = Matching<CallExprBaseMatchingInput>;
* argument/receiver of a call.
*/
pragma[nomagic]
private Type resolveCallExprBaseType(AstNode n, TypePath path) {
private Type inferCallExprBaseType(AstNode n, TypePath path) {
exists(
CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos,
TypePath path0
|
n = a.getNodeAt(apos) and
result = CallExprBaseMatching::resolveAccessType(a, apos, path0)
result = CallExprBaseMatching::inferAccessType(a, apos, path0)
|
if apos.isSelf()
then
exists(Type receiverType | receiverType = CallExprBaseMatchingInput::resolveReceiverType(n) |
exists(Type receiverType | receiverType = CallExprBaseMatchingInput::inferReceiverType(n) |
if receiverType = TRefType()
then
path = path0 and
@@ -758,8 +758,8 @@ private module FieldExprMatchingInput implements MatchingInputSig {
apos.isField()
}
Type getResolvedType(AccessPosition apos, TypePath path) {
result = resolveType(this.getNodeAt(apos), path)
Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
}
Declaration getTarget() {
@@ -795,9 +795,9 @@ private module FieldExprMatchingInput implements MatchingInputSig {
}
pragma[nomagic]
additional Type resolveReceiverType(AstNode n) {
additional Type inferReceiverType(AstNode n) {
exists(Access a, AccessPosition apos |
result = resolveType(n) and
result = inferType(n) and
n = a.getNodeAt(apos) and
apos.isSelf()
)
@@ -811,16 +811,16 @@ private module FieldExprMatching = Matching<FieldExprMatchingInput>;
* the receiver of field expression call.
*/
pragma[nomagic]
private Type resolveFieldExprType(AstNode n, TypePath path) {
private Type inferFieldExprType(AstNode n, TypePath path) {
exists(
FieldExprMatchingInput::Access a, FieldExprMatchingInput::AccessPosition apos, TypePath path0
|
n = a.getNodeAt(apos) and
result = FieldExprMatching::resolveAccessType(a, apos, path0)
result = FieldExprMatching::inferAccessType(a, apos, path0)
|
if apos.isSelf()
then
exists(Type receiverType | receiverType = FieldExprMatchingInput::resolveReceiverType(n) |
exists(Type receiverType | receiverType = FieldExprMatchingInput::inferReceiverType(n) |
if receiverType = TRefType()
then
// adjust for implicit deref
@@ -838,14 +838,14 @@ private Type resolveFieldExprType(AstNode n, TypePath path) {
* `& x` or an expression `x` inside a reference expression `& x`.
*/
pragma[nomagic]
private Type resolveRefExprType(Expr e, TypePath path) {
private Type inferRefExprType(Expr e, TypePath path) {
exists(RefExpr re |
e = re and
path.isEmpty() and
result = TRefType()
or
e = re and
exists(TypePath exprPath | result = resolveType(re.getExpr(), exprPath) |
exists(TypePath exprPath | result = inferType(re.getExpr(), exprPath) |
if exprPath.startsWith(TRefTypeParameter(), _)
then
// `&x` simply means `x` when `x` already has reference type
@@ -858,9 +858,9 @@ private Type resolveRefExprType(Expr e, TypePath path) {
or
e = re.getExpr() and
exists(TypePath exprPath, TypePath refPath, Type exprType |
result = resolveType(re, exprPath) and
result = inferType(re, exprPath) and
exprPath = TypePath::cons(TRefTypeParameter(), refPath) and
exprType = resolveType(e)
exprType = inferType(e)
|
if exprType = TRefType()
then
@@ -878,11 +878,11 @@ private module Cached {
pragma[inline]
private Type getLookupType(AstNode n) {
exists(Type t |
t = resolveType(n) and
t = inferType(n) and
if t = TRefType()
then
// for reference types, lookup members in the type being referenced
result = resolveType(n, TypePath::singleton(TRefTypeParameter()))
result = inferType(n, TypePath::singleton(TRefTypeParameter()))
else result = t
)
}
@@ -894,7 +894,7 @@ private module Cached {
}
/**
* Gets a method that the method call `mce` resolves to, if any.
* Gets a method that the method call `mce` infers to, if any.
*/
cached
Function resolveMethodCallExpr(MethodCallExpr mce) {
@@ -908,7 +908,7 @@ private module Cached {
}
/**
* Gets the record field that the field expression `fe` resolves to, if any.
* Gets the record field that the field expression `fe` infers to, if any.
*/
cached
RecordField resolveRecordFieldExpr(FieldExpr fe) {
@@ -924,7 +924,7 @@ private module Cached {
}
/**
* Gets the tuple field that the field expression `fe` resolves to, if any.
* Gets the tuple field that the field expression `fe` infers to, if any.
*/
cached
TupleField resolveTupleFieldExpr(FieldExpr fe) {
@@ -932,7 +932,7 @@ private module Cached {
}
/**
* Gets a type at `path` that `n` resolves to, if any.
* Gets a type at `path` that `n` infers to, if any.
*
* The type inference implementation works by computing all possible types, so
* the result is not necessarily unique. For example, in
@@ -971,29 +971,29 @@ private module Cached {
* 5. `x.bar()` has type `&MyTrait` (via 2 and 4).
*/
cached
Type resolveType(AstNode n, TypePath path) {
Type inferType(AstNode n, TypePath path) {
Stages::TypeInference::backref() and
result = resolveAnnotatedType(n, path)
result = inferAnnotatedType(n, path)
or
result = resolveTypeSymmetry(n, path)
result = inferTypeSymmetry(n, path)
or
result = resolveImplicitSelfType(n, path)
result = inferImplicitSelfType(n, path)
or
result = resolveRecordExprType(n, path)
result = inferRecordExprType(n, path)
or
result = resolvePathExprType(n, path)
result = inferPathExprType(n, path)
or
result = resolveCallExprBaseType(n, path)
result = inferCallExprBaseType(n, path)
or
result = resolveFieldExprType(n, path)
result = inferFieldExprType(n, path)
or
result = resolveRefExprType(n, path)
result = inferRefExprType(n, path)
}
}
import Cached
/**
* Gets a type that `n` resolves to, if any.
* Gets a type that `n` infers to, if any.
*/
Type resolveType(AstNode n) { result = resolveType(n, TypePath::nil()) }
Type inferType(AstNode n) { result = inferType(n, TypePath::nil()) }

View File

@@ -121,7 +121,7 @@ module Stages {
or
exists(Type t)
or
exists(resolveType(_))
exists(inferType(_))
}
}

View File

@@ -1,4 +1,4 @@
resolveType
inferType
| main.rs:5:19:5:22 | SelfParam | | main.rs:2:5:2:21 | struct Foo |
| main.rs:5:33:7:9 | { ... } | | main.rs:2:5:2:21 | struct Foo |
| main.rs:6:13:6:16 | self | | main.rs:2:5:2:21 | struct Foo |

View File

@@ -3,8 +3,8 @@ import codeql.rust.elements.internal.TypeInference as TypeInference
import TypeInference
import utils.test.InlineExpectationsTest
query predicate resolveType(AstNode n, TypePath path, Type t) {
t = TypeInference::resolveType(n, path)
query predicate inferType(AstNode n, TypePath path, Type t) {
t = TypeInference::inferType(n, path)
}
query predicate resolveMethodCallExpr(MethodCallExpr mce, Function f) {

View File

@@ -465,12 +465,12 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
Type getExplicitTypeArgument(TypeArgumentPosition tapos, TypePath path);
/**
* Gets the resolved type at `path` for the position `apos` of this access.
* Gets the inferred type at `path` for the position `apos` of this access.
*
* For example, if this access is the method call `M(42)`, then the resolved
* For example, if this access is the method call `M(42)`, then the inferred
* type at argument position `0` is `int`.
*/
Type getResolvedType(AccessPosition apos, TypePath path);
Type getInferredType(AccessPosition apos, TypePath path);
/** Gets the declaration that this access targets. */
Declaration getTarget();
@@ -482,7 +482,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos);
/**
* Holds if matching a resolved type `t` at `path` inside an access at `apos`
* Holds if matching an inferred type `t` at `path` inside an access at `apos`
* against the declaration `target` means that the type should be adjusted to
* `tAdj` at `pathAdj`.
*
@@ -493,7 +493,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
* M(42);
* ```
*
* the resolved type of `42` is `int`, but it should be adjusted to `int?`
* the inferred type of `42` is `int`, but it should be adjusted to `int?`
* when matching against `M`.
*/
bindingset[apos, target, path, t]
@@ -520,7 +520,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
) {
target = a.getTarget() and
exists(TypePath path0, Type t0 |
t0 = a.getResolvedType(apos, path0) and
t0 = a.getInferredType(apos, path0) and
adjustAccessType(apos, target, path0, t0, path, t)
)
}
@@ -562,16 +562,16 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
}
pragma[nomagic]
private Type resolveRootType(Access a, AccessPosition apos) {
private Type inferRootType(Access a, AccessPosition apos) {
relevantAccess(a, apos) and
result = a.getResolvedType(apos, TypePath::nil())
result = a.getInferredType(apos, TypePath::nil())
}
pragma[nomagic]
private Type resolveTypeAt(Access a, AccessPosition apos, TypeParameter tp, TypePath suffix) {
private Type inferTypeAt(Access a, AccessPosition apos, TypeParameter tp, TypePath suffix) {
relevantAccess(a, apos) and
exists(TypePath path0 |
result = a.getResolvedType(apos, path0) and
result = a.getInferredType(apos, path0) and
path0.startsWith(tp, suffix)
)
}
@@ -608,12 +608,12 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
predicate hasBaseTypeMention(
Access a, AccessPosition apos, TypeMention baseMention, TypePath path, Type t
) {
exists(Type sub | sub = resolveRootType(a, apos) |
exists(Type sub | sub = inferRootType(a, apos) |
baseTypeMentionHasNonTypeParameterAt(sub, baseMention, path, t)
or
exists(TypePath prefix, TypePath suffix, TypeParameter i |
baseTypeMentionHasTypeParameterAt(sub, baseMention, prefix, i) and
t = resolveTypeAt(a, apos, i, suffix) and
t = inferTypeAt(a, apos, i, suffix) and
path = prefix.append(suffix)
)
)
@@ -709,7 +709,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
}
/**
* Gets the resolved type of `a` at `path` for position `apos`.
* Gets the inferred type of `a` at `path` for position `apos`.
*
* For example, in
*
@@ -728,7 +728,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
* // ^^^^^^^^^^^^^^^^^^^^^^^ `a`
* ```
*
* we resolve the following types for the return position:
* we infer the following types for the return position:
*
* `path` | `t`
* ----------- | -------
@@ -737,7 +737,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
* `"0.0.0"` | ``C`1``
* `"0.0.0.1"` | `int`
*
* We also resolve the following types for the receiver position:
* We also infer the following types for the receiver position:
*
* `path` | `t`
* ----------- | -------
@@ -747,7 +747,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
* `"0.0.0.1"` | `int`
*/
pragma[nomagic]
Type resolveAccessType(Access a, AccessPosition apos, TypePath path) {
Type inferAccessType(Access a, AccessPosition apos, TypePath path) {
exists(DeclarationPosition dpos | accessDeclarationPositionMatch(apos, dpos) |
exists(Declaration target, TypePath prefix, TypeParameter tp, TypePath suffix |
tp = target.getDeclaredType(pragma[only_bind_into](dpos), prefix) and