Rust: Refactor and generalize Call

This commit is contained in:
Simon Friis Vindum
2025-06-16 12:51:36 +02:00
parent 87b52cc347
commit ebdffcc4ef
5 changed files with 102 additions and 72 deletions

View File

@@ -182,8 +182,8 @@ final class CallCfgNode extends ExprCfgNode {
}
/** Gets the `i`th argument of this call, if any. */
ExprCfgNode getArgument(int i) {
any(ChildMapping mapping).hasCfgChild(node, node.getArgument(i), this, result)
ExprCfgNode getPositionalArgument(int i) {
any(ChildMapping mapping).hasCfgChild(node, node.getPositionalArgument(i), this, result)
}
}

View File

@@ -133,7 +133,7 @@ final class ParameterPosition extends TParameterPosition {
final class ArgumentPosition extends ParameterPosition {
/** Gets the argument of `call` at this position, if any. */
Expr getArgument(Call call) {
result = call.getArgument(this.getPosition())
result = call.getPositionalArgument(this.getPosition())
or
result = call.getReceiver() and this.isSelf()
}
@@ -146,7 +146,7 @@ final class ArgumentPosition extends ParameterPosition {
* as the synthetic `ReceiverNode` is the argument for the `self` parameter.
*/
predicate isArgumentForCall(ExprCfgNode arg, CallCfgNode call, ParameterPosition pos) {
call.getArgument(pos.getPosition()) = arg
call.getPositionalArgument(pos.getPosition()) = arg
or
call.getReceiver() = arg and pos.isSelf() and not call.getCall().receiverImplicitlyBorrowed()
}

View File

@@ -4,4 +4,6 @@
private import internal.CallImpl
final class ArgumentPosition = Impl::ArgumentPosition;
final class Call = Impl::Call;

View File

@@ -5,20 +5,37 @@ private import codeql.rust.elements.internal.ExprImpl::Impl as ExprImpl
private import codeql.rust.elements.Operation
module Impl {
newtype TArgumentPosition =
TPositionalArgumentPosition(int i) {
i in [0 .. max([any(ParamList l).getNumberOfParams(), any(ArgList l).getNumberOfArgs()]) - 1]
} or
TSelfArgumentPosition()
/** An argument position in a call. */
class ArgumentPosition extends TArgumentPosition {
/** Gets the index of the argument in the call, if this is a positional argument. */
int asPosition() { this = TPositionalArgumentPosition(result) }
/** Holds if this call position is a self argument. */
predicate isSelf() { this instanceof TSelfArgumentPosition }
/** Gets a string representation of this argument position. */
string toString() {
result = this.asPosition().toString()
or
this.isSelf() and result = "self"
}
}
/**
* An expression that calls a function.
*
* This class abstracts over the different ways in which a function can be called in Rust.
* This class abstracts over the different ways in which a function can be
* called in Rust.
*/
abstract class Call extends ExprImpl::Expr {
/** Gets the number of arguments _excluding_ any `self` argument. */
abstract int getNumberOfArguments();
/** Gets the receiver of this call if it is a method call. */
abstract Expr getReceiver();
/** Holds if the call has a receiver that might be implicitly borrowed. */
abstract predicate receiverImplicitlyBorrowed();
/** Holds if the receiver of this call is implicitly borrowed. */
predicate receiverImplicitlyBorrowed() { this.implicitBorrowAt(TSelfArgumentPosition()) }
/** Gets the trait targeted by this call, if any. */
abstract Trait getTrait();
@@ -26,8 +43,20 @@ module Impl {
/** Gets the name of the method called if this call is a method call. */
abstract string getMethodName();
/** Gets the argument at the given position, if any. */
abstract Expr getArgument(ArgumentPosition pos);
/** Holds if the argument at `pos` might be implicitly borrowed. */
abstract predicate implicitBorrowAt(ArgumentPosition pos);
/** Gets the number of arguments _excluding_ any `self` argument. */
int getNumberOfArguments() { result = count(this.getArgument(TPositionalArgumentPosition(_))) }
/** Gets the `i`th argument of this call, if any. */
abstract Expr getArgument(int i);
Expr getPositionalArgument(int i) { result = this.getArgument(TPositionalArgumentPosition(i)) }
/** Gets the receiver of this call if it is a method call. */
Expr getReceiver() { result = this.getArgument(TSelfArgumentPosition()) }
/** Gets the static target of this call, if any. */
Function getStaticTarget() {
@@ -54,15 +83,13 @@ module Impl {
override string getMethodName() { none() }
override Expr getReceiver() { none() }
override Trait getTrait() { none() }
override predicate receiverImplicitlyBorrowed() { none() }
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
override int getNumberOfArguments() { result = super.getArgList().getNumberOfArgs() }
override Expr getArgument(int i) { result = super.getArgList().getArg(i) }
override Expr getArgument(ArgumentPosition pos) {
result = super.getArgList().getArg(pos.asPosition())
}
}
private class CallExprMethodCall extends Call instanceof CallExpr {
@@ -73,8 +100,6 @@ module Impl {
override string getMethodName() { result = methodName }
override Expr getReceiver() { result = super.getArgList().getArg(0) }
override Trait getTrait() {
result = resolvePath(qualifier) and
// When the qualifier is `Self` and resolves to a trait, it's inside a
@@ -84,25 +109,27 @@ module Impl {
qualifier.toString() != "Self"
}
override predicate receiverImplicitlyBorrowed() { none() }
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
override int getNumberOfArguments() { result = super.getArgList().getNumberOfArgs() - 1 }
override Expr getArgument(int i) { result = super.getArgList().getArg(i + 1) }
override Expr getArgument(ArgumentPosition pos) {
pos.isSelf() and result = super.getArgList().getArg(0)
or
result = super.getArgList().getArg(pos.asPosition() + 1)
}
}
private class MethodCallExprCall extends Call instanceof MethodCallExpr {
override string getMethodName() { result = super.getIdentifier().getText() }
override Expr getReceiver() { result = this.(MethodCallExpr).getReceiver() }
override Trait getTrait() { none() }
override predicate receiverImplicitlyBorrowed() { any() }
override predicate implicitBorrowAt(ArgumentPosition pos) { pos.isSelf() }
override int getNumberOfArguments() { result = super.getArgList().getNumberOfArgs() }
override Expr getArgument(int i) { result = super.getArgList().getArg(i) }
override Expr getArgument(ArgumentPosition pos) {
pos.isSelf() and result = this.(MethodCallExpr).getReceiver()
or
result = super.getArgList().getArg(pos.asPosition())
}
}
private class OperatorCall extends Call instanceof Operation {
@@ -113,14 +140,14 @@ module Impl {
override string getMethodName() { result = methodName }
override Expr getReceiver() { result = super.getOperand(0) }
override Trait getTrait() { result = trait }
override predicate receiverImplicitlyBorrowed() { none() }
override predicate implicitBorrowAt(ArgumentPosition pos) { none() }
override int getNumberOfArguments() { result = super.getNumberOfOperands() - 1 }
override Expr getArgument(int i) { result = super.getOperand(1) and i = 0 }
override Expr getArgument(ArgumentPosition pos) {
pos.isSelf() and result = super.getOperand(0)
or
pos.asPosition() = 0 and result = super.getOperand(1)
}
}
}

View File

@@ -503,22 +503,20 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
private predicate paramPos(ParamList pl, Param p, int pos) { p = pl.getParam(pos) }
private newtype TDeclarationPosition =
TSelfDeclarationPosition() or
TPositionalDeclarationPosition(int pos) { paramPos(_, _, pos) } or
TArgumentDeclarationPosition(ArgumentPosition pos) or
TReturnDeclarationPosition()
class DeclarationPosition extends TDeclarationPosition {
predicate isSelf() { this = TSelfDeclarationPosition() }
predicate isSelf() { this.asArgumentPosition().isSelf() }
int asPosition() { this = TPositionalDeclarationPosition(result) }
int asPosition() { result = this.asArgumentPosition().asPosition() }
ArgumentPosition asArgumentPosition() { this = TArgumentDeclarationPosition(result) }
predicate isReturn() { this = TReturnDeclarationPosition() }
string toString() {
this.isSelf() and
result = "self"
or
result = this.asPosition().toString()
result = this.asArgumentPosition().toString()
or
this.isReturn() and
result = "(return)"
@@ -551,7 +549,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
exists(int pos |
result = this.getTupleField(pos).getTypeRepr().(TypeMention).resolveTypeAt(path) and
dpos = TPositionalDeclarationPosition(pos)
pos = dpos.asPosition()
)
}
@@ -572,9 +570,9 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
exists(int p |
result = this.getTupleField(p).getTypeRepr().(TypeMention).resolveTypeAt(path) and
dpos = TPositionalDeclarationPosition(p)
exists(int pos |
result = this.getTupleField(pos).getTypeRepr().(TypeMention).resolveTypeAt(path) and
pos = dpos.asPosition()
)
}
@@ -609,7 +607,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
override Type getParameterType(DeclarationPosition dpos, TypePath path) {
exists(Param p, int i |
paramPos(this.getParamList(), p, i) and
dpos = TPositionalDeclarationPosition(i) and
i = dpos.asPosition() and
result = inferAnnotatedType(p.getPat(), path)
)
or
@@ -642,22 +640,21 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}
private newtype TAccessPosition =
TSelfAccessPosition(Boolean implicitlyBorrowed) or
TPositionalAccessPosition(int pos) { exists(TPositionalDeclarationPosition(pos)) } or
TArgumentAccessPosition(ArgumentPosition pos, Boolean borrowed) or
TReturnAccessPosition()
class AccessPosition extends TAccessPosition {
predicate isSelf(boolean implicitlyBorrowed) { this = TSelfAccessPosition(implicitlyBorrowed) }
ArgumentPosition getArgumentPosition() { this = TArgumentAccessPosition(result, _) }
int asPosition() { this = TPositionalAccessPosition(result) }
predicate isBorrowed() { this = TArgumentAccessPosition(_, true) }
predicate isReturn() { this = TReturnAccessPosition() }
string toString() {
this.isSelf(_) and
result = "self"
or
result = this.asPosition().toString()
exists(ArgumentPosition pos, boolean borrowed |
this = TArgumentAccessPosition(pos, borrowed) and
result = pos + ":" + borrowed
)
or
this.isReturn() and
result = "(return)"
@@ -677,10 +674,11 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}
AstNode getNodeAt(AccessPosition apos) {
result = this.getArgument(apos.asPosition())
or
result = this.getReceiver() and
if this.receiverImplicitlyBorrowed() then apos.isSelf(true) else apos.isSelf(false)
exists(ArgumentPosition pos, boolean borrowed |
apos = TArgumentAccessPosition(pos, borrowed) and
result = this.getArgument(pos) and
if this.implicitBorrowAt(pos) then borrowed = true else borrowed = false
)
or
result = this and apos.isReturn()
}
@@ -697,9 +695,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
apos.isSelf(_) and dpos.isSelf()
or
apos.asPosition() = dpos.asPosition()
apos.getArgumentPosition() = dpos.asArgumentPosition()
or
apos.isReturn() and dpos.isReturn()
}
@@ -709,10 +705,13 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
predicate adjustAccessType(
AccessPosition apos, Declaration target, TypePath path, Type t, TypePath pathAdj, Type tAdj
) {
if apos.isSelf(true)
if apos.getArgumentPosition().isSelf() and apos.isBorrowed()
then
exists(Type selfParamType |
selfParamType = target.getParameterType(TSelfDeclarationPosition(), TypePath::nil())
selfParamType =
target
.getParameterType(TArgumentDeclarationPosition(apos.getArgumentPosition()),
TypePath::nil())
|
if selfParamType = TRefType()
then
@@ -771,7 +770,7 @@ private Type inferCallExprBaseType(AstNode n, TypePath path) {
// temporary workaround until implicit borrows are handled correctly
if a instanceof Operation then apos.isReturn() else any()
|
if apos.isSelf(_)
if apos.getArgumentPosition().isSelf()
then
exists(Type receiverType | receiverType = inferType(n) |
if receiverType = TRefType()
@@ -1356,7 +1355,7 @@ private Function getMethodFromImpl(MethodCall mc) {
or
exists(int pos, TypePath path, Type type |
methodResolutionDependsOnArgument(impl, mc.getMethodName(), result, pos, path, type) and
inferType(mc.getArgument(pos), path) = type
inferType(mc.getPositionalArgument(pos), path) = type
)
)
}
@@ -1391,7 +1390,8 @@ private module Cached {
cached
predicate receiverHasImplicitDeref(AstNode receiver) {
exists(CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos |
apos.isSelf(true) and
apos.getArgumentPosition().isSelf() and
apos.isBorrowed() and
receiver = a.getNodeAt(apos) and
inferType(receiver) = TRefType() and
CallExprBaseMatching::inferAccessType(a, apos, TypePath::nil()) != TRefType()
@@ -1402,7 +1402,8 @@ private module Cached {
cached
predicate receiverHasImplicitBorrow(AstNode receiver) {
exists(CallExprBaseMatchingInput::Access a, CallExprBaseMatchingInput::AccessPosition apos |
apos.isSelf(true) and
apos.getArgumentPosition().isSelf() and
apos.isBorrowed() and
receiver = a.getNodeAt(apos) and
CallExprBaseMatching::inferAccessType(a, apos, TypePath::nil()) = TRefType() and
inferType(receiver) != TRefType()