diff --git a/rust/ql/lib/codeql/rust/dataflow/internal/Node.qll b/rust/ql/lib/codeql/rust/dataflow/internal/Node.qll index cc738d1dc86..f2f6fa1b0d8 100644 --- a/rust/ql/lib/codeql/rust/dataflow/internal/Node.qll +++ b/rust/ql/lib/codeql/rust/dataflow/internal/Node.qll @@ -230,7 +230,7 @@ final class ExprArgumentNode extends ArgumentNode, ExprNode { ExprArgumentNode() { isArgumentForCall(n, call_, pos_) and not TypeInference::implicitDeref(n) and - not TypeInference::implicitBorrow(n) + not TypeInference::implicitBorrow(n, _) } override predicate isArgumentOf(DataFlowCall call, RustDataFlow::ArgumentPosition pos) { @@ -579,7 +579,7 @@ newtype TNode = TypeInference::implicitDeref(n) and borrow = false or - TypeInference::implicitBorrow(n) and + TypeInference::implicitBorrow(n, _) and borrow = true } or TDerefOutNode(DerefExpr de, Boolean isPost) or diff --git a/rust/ql/lib/codeql/rust/internal/Type.qll b/rust/ql/lib/codeql/rust/internal/Type.qll index 83dcfff8c3a..b4907dee172 100644 --- a/rust/ql/lib/codeql/rust/internal/Type.qll +++ b/rust/ql/lib/codeql/rust/internal/Type.qll @@ -226,22 +226,12 @@ TypeParamTypeParameter getArrayTypeParameter() { abstract class RefType extends StructType { } -pragma[nomagic] -TypeParamTypeParameter getRefTypeParameter() { - result = any(RefType t).getPositionalTypeParameter(0) -} - class RefMutType extends RefType { RefMutType() { this.getStruct() instanceof Builtins::RefMutType } override string toString() { result = "&mut" } } -pragma[nomagic] -TypeParamTypeParameter getRefMutTypeParameter() { - result = any(RefMutType t).getPositionalTypeParameter(0) -} - class RefSharedType extends RefType { RefSharedType() { this.getStruct() instanceof Builtins::RefSharedType } @@ -249,8 +239,15 @@ class RefSharedType extends RefType { } pragma[nomagic] -TypeParamTypeParameter getRefSharedTypeParameter() { - result = any(RefSharedType t).getPositionalTypeParameter(0) +RefType getRefType(boolean isMutable) { + isMutable = true and result instanceof RefMutType + or + isMutable = false and result instanceof RefSharedType +} + +pragma[nomagic] +TypeParamTypeParameter getRefTypeParameter(boolean isMutable) { + result = getRefType(isMutable).getPositionalTypeParameter(0) } /** diff --git a/rust/ql/lib/codeql/rust/internal/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/TypeInference.qll index a004b2cbf4f..5b0ed687357 100644 --- a/rust/ql/lib/codeql/rust/internal/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/TypeInference.qll @@ -501,9 +501,9 @@ module CertainTypeInference { prefix1.isEmpty() and if ip.isRef() then - if ip.isMut() - then prefix2 = TypePath::singleton(getRefMutTypeParameter()) - else prefix2 = TypePath::singleton(getRefSharedTypeParameter()) + exists(boolean isMutable | if ip.isMut() then isMutable = true else isMutable = false | + prefix2 = TypePath::singleton(getRefTypeParameter(isMutable)) + ) else prefix2.isEmpty() ) } @@ -726,9 +726,9 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat any(RefPat rp | n1 = rp.getPat() and prefix1.isEmpty() and - if rp.isMut() - then prefix2 = TypePath::singleton(getRefMutTypeParameter()) - else prefix2 = TypePath::singleton(getRefSharedTypeParameter()) + exists(boolean isMutable | if rp.isMut() then isMutable = true else isMutable = false | + prefix2 = TypePath::singleton(getRefTypeParameter(isMutable)) + ) ) or exists(int i, int arity | @@ -1272,29 +1272,31 @@ private predicate isComplexRootStripped(TypePath path, Type type) { private newtype TBorrowKind = TNoBorrowKind() or - TSharedBorrowKind() or - TMutBorrowKind() + TSomeBorrowKind(Boolean isMutable) private class BorrowKind extends TBorrowKind { predicate isNoBorrow() { this = TNoBorrowKind() } + predicate isSharedBorrow() { this = TSomeBorrowKind(false) } + + predicate isMutableBorrow() { this = TSomeBorrowKind(true) } + RefType getRefType() { - this = TSharedBorrowKind() and - result instanceof RefSharedType - or - this = TMutBorrowKind() and - result instanceof RefMutType + exists(boolean isMutable | + this = TSomeBorrowKind(isMutable) and + result = getRefType(isMutable) + ) } string toString() { - this = TNoBorrowKind() and + this.isNoBorrow() and result = "" or - this = TSharedBorrowKind() and - result = "&" - or - this = TMutBorrowKind() and + this.isMutableBorrow() and result = "&mut" + or + this.isSharedBorrow() and + result = "&" } } @@ -1559,7 +1561,7 @@ private module MethodResolution { this.hasNoCompatibleTargetMutBorrow(derefChain0) and t0 = this.getACandidateReceiverTypeAtNoBorrow(derefChain0, path0) | - path0.isCons(getRefTypeParameter(), path) and + path0.isCons(getRefTypeParameter(_), path) and result = t0 and derefChain = derefChain0 + ".ref" or @@ -1733,13 +1735,14 @@ private module MethodResolution { string derefChain, TypePath strippedTypePath, Type strippedType, int n ) { this.hasNoCompatibleTargetNoBorrow(derefChain) and - strippedType = this.getComplexStrippedType(derefChain, TSharedBorrowKind(), strippedTypePath) and + strippedType = + this.getComplexStrippedType(derefChain, TSomeBorrowKind(false), strippedTypePath) and n = -1 or this.hasNoCompatibleTargetSharedBorrowToIndex(derefChain, strippedTypePath, strippedType, n - 1) and exists(Type t | t = getNthLookupType(strippedType, n) | - this.hasNoCompatibleNonBlanketLikeTargetCheck(derefChain, TSharedBorrowKind(), + this.hasNoCompatibleNonBlanketLikeTargetCheck(derefChain, TSomeBorrowKind(false), strippedTypePath, t) ) } @@ -1762,12 +1765,13 @@ private module MethodResolution { string derefChain, TypePath strippedTypePath, Type strippedType, int n ) { this.hasNoCompatibleTargetSharedBorrow(derefChain) and - strippedType = this.getComplexStrippedType(derefChain, TMutBorrowKind(), strippedTypePath) and + strippedType = + this.getComplexStrippedType(derefChain, TSomeBorrowKind(true), strippedTypePath) and n = -1 or this.hasNoCompatibleTargetMutBorrowToIndex(derefChain, strippedTypePath, strippedType, n - 1) and exists(Type t | t = getNthLookupType(strippedType, n) | - this.hasNoCompatibleNonBlanketLikeTargetCheck(derefChain, TMutBorrowKind(), + this.hasNoCompatibleNonBlanketLikeTargetCheck(derefChain, TSomeBorrowKind(true), strippedTypePath, t) ) } @@ -1790,14 +1794,15 @@ private module MethodResolution { string derefChain, TypePath strippedTypePath, Type strippedType, int n ) { this.hasNoCompatibleTargetNoBorrow(derefChain) and - strippedType = this.getComplexStrippedType(derefChain, TSharedBorrowKind(), strippedTypePath) and + strippedType = + this.getComplexStrippedType(derefChain, TSomeBorrowKind(false), strippedTypePath) and n = -1 or this.hasNoCompatibleNonBlanketTargetSharedBorrowToIndex(derefChain, strippedTypePath, strippedType, n - 1) and exists(Type t | t = getNthLookupType(strippedType, n) | - this.hasNoCompatibleNonBlanketTargetCheck(derefChain, TSharedBorrowKind(), strippedTypePath, - t) + this.hasNoCompatibleNonBlanketTargetCheck(derefChain, TSomeBorrowKind(false), + strippedTypePath, t) ) } @@ -1819,13 +1824,15 @@ private module MethodResolution { string derefChain, TypePath strippedTypePath, Type strippedType, int n ) { this.hasNoCompatibleNonBlanketTargetSharedBorrow(derefChain) and - strippedType = this.getComplexStrippedType(derefChain, TMutBorrowKind(), strippedTypePath) and + strippedType = + this.getComplexStrippedType(derefChain, TSomeBorrowKind(true), strippedTypePath) and n = -1 or this.hasNoCompatibleNonBlanketTargetMutBorrowToIndex(derefChain, strippedTypePath, strippedType, n - 1) and exists(Type t | t = getNthLookupType(strippedType, n) | - this.hasNoCompatibleNonBlanketTargetCheck(derefChain, TMutBorrowKind(), strippedTypePath, t) + this.hasNoCompatibleNonBlanketTargetCheck(derefChain, TSomeBorrowKind(true), + strippedTypePath, t) ) } @@ -1856,17 +1863,17 @@ private module MethodResolution { pragma[nomagic] Type getACandidateReceiverTypeAt(string derefChain, BorrowKind borrow, TypePath path) { result = this.getACandidateReceiverTypeAtNoBorrow(derefChain, path) and - borrow = TNoBorrowKind() + borrow.isNoBorrow() or exists(RefType rt | // first try shared borrow this.supportsAutoDerefAndBorrow() and this.hasNoCompatibleTargetNoBorrow(derefChain) and - borrow = TSharedBorrowKind() + borrow.isSharedBorrow() or // then try mutable borrow this.hasNoCompatibleTargetSharedBorrow(derefChain) and - borrow = TMutBorrowKind() + borrow.isMutableBorrow() | rt = borrow.getRefType() and ( @@ -1899,9 +1906,8 @@ private module MethodResolution { receiver = this.getArg(any(ArgumentPosition pos | pos.isSelf())) } - predicate argumentHasImplicitBorrow(AstNode arg, BorrowKind borrow) { - exists(this.resolveCallTarget(_, "", borrow)) and - borrow != TNoBorrowKind() and + predicate argumentHasImplicitBorrow(AstNode arg, boolean isMutable) { + exists(this.resolveCallTarget(_, "", TSomeBorrowKind(isMutable))) and arg = this.getArg(any(ArgumentPosition pos | pos.isSelf())) } } @@ -2010,24 +2016,22 @@ private module MethodResolution { result = super.getOperand(pos.asPosition() + 1) } - private predicate implicitBorrowAt(ArgumentPosition pos, BorrowKind borrow) { + private predicate implicitBorrowAt(ArgumentPosition pos, boolean isMutable) { exists(int borrows | super.isOverloaded(_, _, borrows) | pos.isSelf() and borrows >= 1 and - if this instanceof AssignmentOperation - then borrow = TMutBorrowKind() - else borrow = TSharedBorrowKind() + if this instanceof CompoundAssignmentExpr then isMutable = true else isMutable = false or pos.asPosition() = 0 and borrows = 2 and - borrow = TSharedBorrowKind() + isMutable = false ) } override Type getArgumentTypeAt(ArgumentPosition pos, TypePath path) { - exists(BorrowKind borrow, RefType rt | - this.implicitBorrowAt(pos, borrow) and - rt = borrow.getRefType() + exists(boolean isMutable, RefType rt | + this.implicitBorrowAt(pos, isMutable) and + rt = getRefType(isMutable) | result = rt and path.isEmpty() @@ -2042,9 +2046,9 @@ private module MethodResolution { result = inferType(this.getArg(pos), path) } - override predicate argumentHasImplicitBorrow(AstNode arg, BorrowKind borrow) { + override predicate argumentHasImplicitBorrow(AstNode arg, boolean isMutable) { exists(ArgumentPosition pos | - this.implicitBorrowAt(pos, borrow) and + this.implicitBorrowAt(pos, isMutable) and arg = this.getArg(pos) ) } @@ -2083,13 +2087,13 @@ private module MethodResolution { pragma[nomagic] predicate hasNoCompatibleNonBlanketTarget() { mc_.hasNoCompatibleNonBlanketTargetSharedBorrow(derefChain) and - borrow = TSharedBorrowKind() + borrow.isSharedBorrow() or mc_.hasNoCompatibleNonBlanketTargetMutBorrow(derefChain) and - borrow = TMutBorrowKind() + borrow.isMutableBorrow() or mc_.hasNoCompatibleNonBlanketTargetNoBorrow(derefChain) and - borrow = TNoBorrowKind() + borrow.isNoBorrow() } pragma[nomagic] @@ -2391,7 +2395,7 @@ private module MethodCallMatchingInput implements MatchingWithEnvironmentInputSi class AccessEnvironment = string; bindingset[derefChain, borrow] - private AccessEnvironment encodeDerefChainBorrow(string derefChain, BorrowKind borrow) { + additional AccessEnvironment encodeDerefChainBorrow(string derefChain, BorrowKind borrow) { result = derefChain + ";" + borrow } @@ -2437,7 +2441,7 @@ private module MethodCallMatchingInput implements MatchingWithEnvironmentInputSi or exists(TypePath suffix | result = inferType(this.getNodeAt(apos), suffix) and - path = TypePath::cons(getRefTypeParameter(), suffix) + path = TypePath::cons(getRefTypeParameter(_), suffix) ) else ( not apos.isSelf() and @@ -2500,7 +2504,7 @@ private Type inferMethodCallType0( // the implicit deref apos.isReturn() and a instanceof IndexExpr - then path0.isCons(getRefTypeParameter(), path) + then path0.isCons(getRefTypeParameter(_), path) else path = path0 ) } @@ -2523,13 +2527,13 @@ private Type inferMethodCallType1(AstNode n, boolean isReturn, TypePath path) { or // adjust for implicit deref apos.isSelf() and - derefChainBorrow = ".ref;" and - path = TypePath::cons(getRefTypeParameter(), path0) + derefChainBorrow = MethodCallMatchingInput::encodeDerefChainBorrow(".ref", TNoBorrowKind()) and + path = TypePath::cons(getRefTypeParameter(_), path0) or // adjust for implicit borrow apos.isSelf() and - derefChainBorrow = [";&", ";&mut"] and - path0.isCons(getRefTypeParameter(), path) + derefChainBorrow = MethodCallMatchingInput::encodeDerefChainBorrow("", TSomeBorrowKind(_)) and + path0.isCons(getRefTypeParameter(_), path) ) } @@ -3121,7 +3125,7 @@ private module OperationMatchingInput implements MatchingInputSig { this.borrowsAt(dpos) or dpos.isReturn() and this.derefsReturn() - then path0.isCons(getRefTypeParameter(), path) + then path0.isCons(getRefTypeParameter(_), path) else path0 = path ) } @@ -3274,9 +3278,9 @@ private module FieldExprMatchingInput implements MatchingInputSig { if apos.isSelf() then // adjust for implicit deref - path0.isCons(getRefTypeParameter(), path) + path0.isCons(getRefTypeParameter(_), path) or - not path0.isCons(getRefTypeParameter(), _) and + not path0.isCons(getRefTypeParameter(_), _) and not (result instanceof RefType and path0.isEmpty()) and path = path0 else path = path0 @@ -3318,9 +3322,9 @@ private Type inferFieldExprType(AstNode n, TypePath path) { if receiverType instanceof RefType then // adjust for implicit deref - not path0.isCons(getRefTypeParameter(), _) and + not path0.isCons(getRefTypeParameter(_), _) and not (path0.isEmpty() and result instanceof RefType) and - path = TypePath::cons(getRefTypeParameter(), path0) + path = TypePath::cons(getRefTypeParameter(_), path0) else path = path0 ) else path = path0 @@ -3353,7 +3357,7 @@ private Type inferRefPatType(AstNode ref) { or ref = any(RefPat rp | if rp.isMut() then isMut = true else isMut = false) | - if isMut = true then result instanceof RefMutType else result instanceof RefSharedType + result = getRefType(isMut) ) } @@ -3410,7 +3414,7 @@ private Type inferLiteralType(LiteralExpr le, TypePath path, boolean certain) { ( path.isEmpty() and result instanceof RefSharedType or - path = TypePath::singleton(getRefSharedTypeParameter()) and + path = TypePath::singleton(getRefTypeParameter(false)) and result = getStrStruct() ) and certain = true @@ -3531,7 +3535,7 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) { exprPath.isCons(getArrayTypeParameter(), path) or exists(TypePath path0 | - exprPath.isCons(getRefTypeParameter(), path0) and + exprPath.isCons(getRefTypeParameter(_), path0) and path0.isCons(getSliceTypeParameter(), path) ) ) @@ -3843,8 +3847,8 @@ private module Cached { /** Holds if `n` is implicitly borrowed. */ cached - predicate implicitBorrow(AstNode n) { - any(MethodResolution::MethodCall mc).argumentHasImplicitBorrow(n, _) + predicate implicitBorrow(AstNode n, boolean isMutable) { + any(MethodResolution::MethodCall mc).argumentHasImplicitBorrow(n, isMutable) } /**