Rust: Use Call in type inference

This commit is contained in:
Simon Friis Vindum
2025-06-06 13:44:14 +02:00
parent 47864781c1
commit 7684e01c3a

View File

@@ -8,6 +8,7 @@ private import TypeMention
private import codeql.typeinference.internal.TypeInference
private import codeql.rust.frameworks.stdlib.Stdlib
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
private import codeql.rust.elements.Call
class Type = T::Type;
@@ -496,20 +497,17 @@ private Type inferPathExprType(PathExpr pe, TypePath path) {
* like `foo::bar(baz)` and `foo.bar(baz)`.
*/
private module CallExprBaseMatchingInput implements MatchingInputSig {
private predicate paramPos(ParamList pl, Param p, int pos, boolean inMethod) {
p = pl.getParam(pos) and
if pl.hasSelfParam() then inMethod = true else inMethod = false
}
private predicate paramPos(ParamList pl, Param p, int pos) { p = pl.getParam(pos) }
private newtype TDeclarationPosition =
TSelfDeclarationPosition() or
TPositionalDeclarationPosition(int pos, boolean inMethod) { paramPos(_, _, pos, inMethod) } or
TPositionalDeclarationPosition(int pos) { paramPos(_, _, pos) } or
TReturnDeclarationPosition()
class DeclarationPosition extends TDeclarationPosition {
predicate isSelf() { this = TSelfDeclarationPosition() }
int asPosition(boolean inMethod) { this = TPositionalDeclarationPosition(result, inMethod) }
int asPosition() { this = TPositionalDeclarationPosition(result) }
predicate isReturn() { this = TReturnDeclarationPosition() }
@@ -517,7 +515,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
this.isSelf() and
result = "self"
or
result = this.asPosition(_).toString()
result = this.asPosition().toString()
or
this.isReturn() and
result = "(return)"
@@ -550,7 +548,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
exists(int pos |
result = this.getTupleField(pos).getTypeRepr().(TypeMention).resolveTypeAt(path) and
dpos = TPositionalDeclarationPosition(pos, false)
dpos = TPositionalDeclarationPosition(pos)
)
}
@@ -573,7 +571,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
exists(int p |
result = this.getTupleField(p).getTypeRepr().(TypeMention).resolveTypeAt(path) and
dpos = TPositionalDeclarationPosition(p, false)
dpos = TPositionalDeclarationPosition(p)
)
}
@@ -606,9 +604,9 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
exists(Param p, int i, boolean inMethod |
paramPos(this.getParamList(), p, i, inMethod) and
dpos = TPositionalDeclarationPosition(i, inMethod) and
exists(Param p, int i |
paramPos(this.getParamList(), p, i) and
dpos = TPositionalDeclarationPosition(i) and
result = inferAnnotatedType(p.getPat(), path)
)
or
@@ -640,108 +638,36 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}
}
private predicate argPos(CallExprBase call, Expr e, int pos, boolean isMethodCall) {
exists(ArgList al |
e = al.getArg(pos) and
call.getArgList() = al and
if call instanceof MethodCallExpr then isMethodCall = true else isMethodCall = false
)
}
private newtype TAccessPosition =
TSelfAccessPosition() or
TPositionalAccessPosition(int pos, boolean isMethodCall) { argPos(_, _, pos, isMethodCall) } or
TReturnAccessPosition()
class AccessPosition extends TAccessPosition {
predicate isSelf() { this = TSelfAccessPosition() }
int asPosition(boolean isMethodCall) { this = TPositionalAccessPosition(result, isMethodCall) }
predicate isReturn() { this = TReturnAccessPosition() }
string toString() {
this.isSelf() and
result = "self"
or
result = this.asPosition(_).toString()
or
this.isReturn() and
result = "(return)"
}
}
class AccessPosition = DeclarationPosition;
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl
abstract class Access extends Expr {
abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
abstract AstNode getNodeAt(AccessPosition apos);
abstract Type getInferredType(AccessPosition apos, TypePath path);
abstract Declaration getTarget();
}
private class CallExprBaseAccess extends Access instanceof CallExprBase {
private TypeMention getMethodTypeArg(int i) {
result = this.(MethodCallExpr).getGenericArgList().getTypeArg(i)
}
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
final class Access extends Call {
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
arg = getExplicitTypeArgMention(CallExprImpl::getFunctionPath(this), apos.asTypeParam())
or
arg = this.getMethodTypeArg(apos.asMethodTypeArgumentPosition())
arg =
this.(MethodCallExpr).getGenericArgList().getTypeArg(apos.asMethodTypeArgumentPosition())
)
}
override AstNode getNodeAt(AccessPosition apos) {
exists(int p, boolean isMethodCall |
argPos(this, result, p, isMethodCall) and
apos = TPositionalAccessPosition(p, isMethodCall)
)
AstNode getNodeAt(AccessPosition apos) {
result = this.getArgument(apos.asPosition())
or
result = this.(MethodCallExpr).getReceiver() and
apos = TSelfAccessPosition()
result = this.getReceiver() and apos.isSelf()
or
result = this and
apos = TReturnAccessPosition()
result = this and apos.isReturn()
}
override Type getInferredType(AccessPosition apos, TypePath path) {
Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
}
override Declaration getTarget() {
Declaration getTarget() {
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
or
result = CallExprImpl::getResolvedFunction(this)
or
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
}
}
private class OperationAccess extends Access instanceof Operation {
OperationAccess() { super.isOverloaded(_, _) }
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
// The syntax for operators does not allow type arguments.
none()
}
override AstNode getNodeAt(AccessPosition apos) {
result = super.getOperand(0) and apos = TSelfAccessPosition()
or
result = super.getOperand(1) and apos = TPositionalAccessPosition(0, true)
or
result = this and apos = TReturnAccessPosition()
}
override Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
}
override Declaration getTarget() {
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
}
}
@@ -749,16 +675,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
apos.isSelf() and
dpos.isSelf()
or
exists(int pos, boolean isMethodCall | pos = apos.asPosition(isMethodCall) |
pos = 0 and
isMethodCall = false and
dpos.isSelf()
or
isMethodCall = false and
pos = dpos.asPosition(true) + 1
or
pos = dpos.asPosition(isMethodCall)
)
apos.asPosition() = dpos.asPosition()
or
apos.isReturn() and
dpos.isReturn()
@@ -1180,31 +1097,18 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
)
}
private module MethodCall {
/** An expression that calls a method. */
abstract private class MethodCallImpl extends Expr {
/** Gets the name of the method targeted. */
abstract string getMethodName();
/** Gets the number of arguments _excluding_ the `self` argument. */
abstract int getArity();
/** Gets the trait targeted by this method call, if any. */
Trait getTrait() { none() }
/** Gets the type of the receiver of the method call at `path`. */
abstract Type getTypeAt(TypePath path);
final class MethodCall extends Call {
MethodCall() {
exists(this.getReceiver()) and
// We want the method calls that don't have a path to a concrete method in
// an impl block. We need to exclude calls like `MyType::my_method(..)`.
(this instanceof CallExpr implies exists(this.getTrait()))
}
final class MethodCall = MethodCallImpl;
private class MethodCallExprMethodCall extends MethodCallImpl instanceof MethodCallExpr {
override string getMethodName() { result = super.getIdentifier().getText() }
override int getArity() { result = super.getArgList().getNumberOfArgs() }
pragma[nomagic]
override Type getTypeAt(TypePath path) {
/** Gets the type of the receiver of the method call at `path`. */
Type getTypeAt(TypePath path) {
if this.receiverImplicitlyBorrowed()
then
exists(TypePath path0 | result = inferType(super.getReceiver(), path0) |
path0.isCons(TRefTypeParameter(), path)
or
@@ -1212,59 +1116,10 @@ private module MethodCall {
not (path0.isEmpty() and result = TRefType()) and
path = path0
)
}
}
private class CallExprMethodCall extends MethodCallImpl instanceof CallExpr {
TraitItemNode trait;
string methodName;
Expr receiver;
CallExprMethodCall() {
receiver = this.getArg(0) and
exists(Path path, Function f |
path = this.getFunction().(PathExpr).getPath() and
f = resolvePath(path) and
f.getParamList().hasSelfParam() and
trait = resolvePath(path.getQualifier()) and
trait.getAnAssocItem() = f and
path.getSegment().getIdentifier().getText() = methodName
)
}
override string getMethodName() { result = methodName }
override int getArity() { result = super.getArgList().getNumberOfArgs() - 1 }
override Trait getTrait() { result = trait }
pragma[nomagic]
override Type getTypeAt(TypePath path) { result = inferType(receiver, path) }
}
private class OperationMethodCall extends MethodCallImpl instanceof Operation {
TraitItemNode trait;
string methodName;
OperationMethodCall() { super.isOverloaded(trait, methodName) }
override string getMethodName() { result = methodName }
override int getArity() { result = this.(Operation).getNumberOfOperands() - 1 }
override Trait getTrait() { result = trait }
pragma[nomagic]
override Type getTypeAt(TypePath path) {
result = inferType(this.(BinaryExpr).getLhs(), path)
or
result = inferType(this.(PrefixExpr).getExpr(), path)
}
else result = inferType(super.getReceiver(), path)
}
}
import MethodCall
/**
* Holds if a method for `type` with the name `name` and the arity `arity`
* exists in `impl`.
@@ -1293,7 +1148,7 @@ private module IsInstantiationOfInput implements IsInstantiationOfInputSig<Metho
private predicate isMethodCall(MethodCall mc, Type rootType, string name, int arity) {
rootType = mc.getTypeAt(TypePath::nil()) and
name = mc.getMethodName() and
arity = mc.getArity()
arity = mc.getNumberOfArguments()
}
pragma[nomagic]