From 7fc1d53edec3da5e57b2a33fa3706e6d65a3cbf4 Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Tue, 17 Mar 2026 20:02:01 +0100 Subject: [PATCH] Rust: Disambiguate types inferred from trait bounds --- .../typeinference/FunctionOverloading.qll | 149 +++++++----- .../internal/typeinference/TypeInference.qll | 181 ++++++++++----- .../internal/typeinference/TypeMention.qll | 72 +++--- .../type-inference/overloading.rs | 6 +- .../type-inference/type-inference.expected | 9 - .../typeinference/internal/TypeInference.qll | 215 ++++++++++++++++-- shared/util/codeql/util/UnboundList.qll | 8 + 7 files changed, 469 insertions(+), 171 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/FunctionOverloading.qll b/rust/ql/lib/codeql/rust/internal/typeinference/FunctionOverloading.qll index 6e4cc6e2c2e..d217fc3760a 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/FunctionOverloading.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/FunctionOverloading.qll @@ -13,68 +13,105 @@ private import TypeMention private import TypeInference private import FunctionType -pragma[nomagic] -private Type resolveNonTypeParameterTypeAt(TypeMention tm, TypePath path) { - result = tm.getTypeAt(path) and - not result instanceof TypeParameter -} - -bindingset[t1, t2] -private predicate typeMentionEqual(TypeMention t1, TypeMention t2) { - forex(TypePath path, Type type | resolveNonTypeParameterTypeAt(t1, path) = type | - resolveNonTypeParameterTypeAt(t2, path) = type - ) -} - -pragma[nomagic] -private predicate implSiblingCandidate( - Impl impl, TraitItemNode trait, Type rootType, TypeMention selfTy -) { - trait = impl.(ImplItemNode).resolveTraitTy() and - selfTy = impl.getSelfTy() and - rootType = selfTy.getType() -} - -pragma[nomagic] -private predicate blanketImplSiblingCandidate(ImplItemNode impl, Trait trait) { - impl.isBlanketImplementation() and - trait = impl.resolveTraitTy() -} +private signature Type resolveTypeMentionAtSig(AstNode tm, TypePath path); /** - * Holds if `impl1` and `impl2` are a sibling implementations of `trait`. We - * consider implementations to be siblings if they implement the same trait for - * the same type. In that case `Self` is the same type in both implementations, - * and method calls to the implementations cannot be resolved unambiguously - * based only on the receiver type. + * Provides logic for identifying sibling implementations, parameterized over + * how to resolve type mentions (`PreTypeMention` vs. `TypeMention`). */ -pragma[inline] -private predicate implSiblings(TraitItemNode trait, Impl impl1, Impl impl2) { - impl1 != impl2 and - ( - exists(Type rootType, TypeMention selfTy1, TypeMention selfTy2 | - implSiblingCandidate(impl1, trait, rootType, selfTy1) and - implSiblingCandidate(impl2, trait, rootType, selfTy2) and - // In principle the second conjunct below should be superflous, but we still - // have ill-formed type mentions for types that we don't understand. For - // those checking both directions restricts further. Note also that we check - // syntactic equality, whereas equality up to renaming would be more - // correct. - typeMentionEqual(selfTy1, selfTy2) and - typeMentionEqual(selfTy2, selfTy1) +private module MkSiblingImpls { + pragma[nomagic] + private Type resolveNonTypeParameterTypeAt(AstNode tm, TypePath path) { + result = resolveTypeMentionAt(tm, path) and + not result instanceof TypeParameter + } + + bindingset[t1, t2] + private predicate typeMentionEqual(AstNode t1, AstNode t2) { + forex(TypePath path, Type type | resolveNonTypeParameterTypeAt(t1, path) = type | + resolveNonTypeParameterTypeAt(t2, path) = type ) - or - blanketImplSiblingCandidate(impl1, trait) and - blanketImplSiblingCandidate(impl2, trait) - ) + } + + pragma[nomagic] + private predicate implSiblingCandidate( + Impl impl, TraitItemNode trait, Type rootType, AstNode selfTy + ) { + trait = impl.(ImplItemNode).resolveTraitTy() and + selfTy = impl.getSelfTy() and + rootType = resolveTypeMentionAt(selfTy, TypePath::nil()) + } + + pragma[nomagic] + private predicate blanketImplSiblingCandidate(ImplItemNode impl, Trait trait) { + impl.isBlanketImplementation() and + trait = impl.resolveTraitTy() + } + + /** + * Holds if `impl1` and `impl2` are sibling implementations of `trait`. We + * consider implementations to be siblings if they implement the same trait for + * the same type. In that case `Self` is the same type in both implementations, + * and method calls to the implementations cannot be resolved unambiguously + * based only on the receiver type. + */ + pragma[inline] + predicate implSiblings(TraitItemNode trait, Impl impl1, Impl impl2) { + impl1 != impl2 and + ( + exists(Type rootType, AstNode selfTy1, AstNode selfTy2 | + implSiblingCandidate(impl1, trait, rootType, selfTy1) and + implSiblingCandidate(impl2, trait, rootType, selfTy2) and + // In principle the second conjunct below should be superfluous, but we still + // have ill-formed type mentions for types that we don't understand. For + // those checking both directions restricts further. Note also that we check + // syntactic equality, whereas equality up to renaming would be more + // correct. + typeMentionEqual(selfTy1, selfTy2) and + typeMentionEqual(selfTy2, selfTy1) + ) + or + blanketImplSiblingCandidate(impl1, trait) and + blanketImplSiblingCandidate(impl2, trait) + ) + } + + /** + * Holds if `impl` is an implementation of `trait` and if another implementation + * exists for the same type. + */ + pragma[nomagic] + predicate implHasSibling(ImplItemNode impl, Trait trait) { implSiblings(trait, impl, _) } + + pragma[nomagic] + predicate implHasAmbiguousSiblingAt(ImplItemNode impl, Trait trait, TypePath path) { + exists(ImplItemNode impl2, Type t1, Type t2 | + implSiblings(trait, impl, impl2) and + t1 = resolveTypeMentionAt(impl.getTraitPath(), path) and + t2 = resolveTypeMentionAt(impl2.getTraitPath(), path) and + t1 != t2 + | + not t1 instanceof TypeParameter or + not t2 instanceof TypeParameter + ) + } } -/** - * Holds if `impl` is an implementation of `trait` and if another implementation - * exists for the same type. - */ -pragma[nomagic] -private predicate implHasSibling(ImplItemNode impl, Trait trait) { implSiblings(trait, impl, _) } +private Type resolvePreTypeMention(AstNode tm, TypePath path) { + result = tm.(PreTypeMention).getTypeAt(path) +} + +private module PreSiblingImpls = MkSiblingImpls; + +predicate preImplHasAmbiguousSiblingAt = PreSiblingImpls::implHasAmbiguousSiblingAt/3; + +private Type resolveTypeMention(AstNode tm, TypePath path) { + result = tm.(TypeMention).getTypeAt(path) +} + +private module SiblingImpls = MkSiblingImpls; + +import SiblingImpls /** * Holds if `f` is a function declared inside `trait`, and the type of `f` at diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index 3229b3ee0bb..6d990040f52 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -30,7 +30,7 @@ private newtype TTypeArgumentPosition = } or TTypeParamTypeArgumentPosition(TypeParam tp) -private module Input implements InputSig1, InputSig2 { +private module Input1 implements InputSig1 { private import Type as T private import codeql.rust.elements.internal.generated.Raw private import codeql.rust.elements.internal.generated.Synth @@ -122,12 +122,28 @@ private module Input implements InputSig1, InputSig2 { tp0 order by kind, id1, id2 ) } +} - int getTypePathLimit() { result = 10 } +private import Input1 - PreTypeMention getABaseTypeMention(Type t) { none() } +private module M1 = Make1; - PreTypeMention getATypeParameterConstraint(TypeParameter tp) { +import M1 + +predicate getTypePathLimit = Input1::getTypePathLimit/0; + +predicate getTypeParameterId = Input1::getTypeParameterId/1; + +class TypePath = M1::TypePath; + +module TypePath = M1::TypePath; + +/** + * Provides shared logic for implementing `InputSig2` and + * `InputSig2`. + */ +private module Input2Common { + AstNode getATypeParameterConstraint(TypeParameter tp) { result = tp.(TypeParamTypeParameter).getTypeParam().getATypeBound().getTypeRepr() or result = tp.(SelfTypeParameter).getTrait() or result = @@ -146,7 +162,7 @@ private module Input implements InputSig1, InputSig2 { * inference module for more information. */ predicate conditionSatisfiesConstraint( - TypeAbstraction abs, PreTypeMention condition, PreTypeMention constraint, boolean transitive + TypeAbstraction abs, AstNode condition, AstNode constraint, boolean transitive ) { // `impl` blocks implementing traits transitive = false and @@ -194,23 +210,64 @@ private module Input implements InputSig1, InputSig2 { ) ) } + + predicate typeParameterIsFunctionallyDetermined(TypeParameter tp) { + tp instanceof AssociatedTypeTypeParameter + } } -private import Input +private module PreInput2 implements InputSig2 { + PreTypeMention getABaseTypeMention(Type t) { none() } -private module M1 = Make1; + PreTypeMention getATypeParameterConstraint(TypeParameter tp) { + result = Input2Common::getATypeParameterConstraint(tp) + } -import M1 + predicate conditionSatisfiesConstraint( + TypeAbstraction abs, PreTypeMention condition, PreTypeMention constraint, boolean transitive + ) { + Input2Common::conditionSatisfiesConstraint(abs, condition, constraint, transitive) + } -predicate getTypePathLimit = Input::getTypePathLimit/0; + predicate typeAbstractionHasAmbiguousConstraintAt( + TypeAbstraction abs, Type constraint, TypePath path + ) { + FunctionOverloading::preImplHasAmbiguousSiblingAt(abs, constraint.(TraitType).getTrait(), path) + } -predicate getTypeParameterId = Input::getTypeParameterId/1; + predicate typeParameterIsFunctionallyDetermined = + Input2Common::typeParameterIsFunctionallyDetermined/1; +} -class TypePath = M1::TypePath; +/** Provides an instantiation of the shared type inference library for `PreTypeMention`s. */ +module PreM2 = Make2; -module TypePath = M1::TypePath; +private module Input2 implements InputSig2 { + TypeMention getABaseTypeMention(Type t) { none() } -private module M2 = Make2; + TypeMention getATypeParameterConstraint(TypeParameter tp) { + result = Input2Common::getATypeParameterConstraint(tp) + } + + predicate conditionSatisfiesConstraint( + TypeAbstraction abs, TypeMention condition, TypeMention constraint, boolean transitive + ) { + Input2Common::conditionSatisfiesConstraint(abs, condition, constraint, transitive) + } + + predicate typeAbstractionHasAmbiguousConstraintAt( + TypeAbstraction abs, Type constraint, TypePath path + ) { + FunctionOverloading::implHasAmbiguousSiblingAt(abs, constraint.(TraitType).getTrait(), path) + } + + predicate typeParameterIsFunctionallyDetermined = + Input2Common::typeParameterIsFunctionallyDetermined/1; +} + +private import Input2 + +private module M2 = Make2; import M2 @@ -596,17 +653,18 @@ module CertainTypeInference { } /** - * Holds if `n` has complete and certain type information at _some_ type path. + * Holds if `n` has complete and certain type information at `path`. */ pragma[nomagic] - predicate hasInferredCertainType(AstNode n) { exists(inferCertainType(n, _)) } + predicate hasInferredCertainType(AstNode n, TypePath path) { exists(inferCertainType(n, path)) } /** - * Holds if `n` having type `t` at `path` conflicts with certain type information. + * Holds if `n` having type `t` at `path` conflicts with certain type information + * at `prefix`. */ - bindingset[n, path, t] + bindingset[n, prefix, path, t] pragma[inline_late] - predicate certainTypeConflict(AstNode n, TypePath path, Type t) { + predicate certainTypeConflict(AstNode n, TypePath prefix, TypePath path, Type t) { inferCertainType(n, path) != t or // If we infer that `n` has _some_ type at `T1.T2....Tn`, and we also @@ -615,7 +673,7 @@ module CertainTypeInference { // otherwise there is a conflict. // // Below, `prefix` is `T1.T2...Ti` and `tp` is `T(i+1)`. - exists(TypePath prefix, TypePath suffix, TypeParameter tp, Type certainType | + exists(TypePath suffix, TypeParameter tp, Type certainType | path = prefix.appendInverse(suffix) and tp = suffix.getHead() and inferCertainType(n, prefix) = certainType and @@ -1054,9 +1112,12 @@ private module ContextTyping { pragma[nomagic] private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) } - signature Type inferCallTypeSig( - AstNode n, FunctionPosition pos, boolean hasReceiver, TypePath path - ); + newtype FunctionPositionKind = + SelfKind() or + ReturnKind() or + PositionalKind() + + signature Type inferCallTypeSig(AstNode n, FunctionPositionKind kind, TypePath path); /** * Given a predicate `inferCallType` for inferring the type of a call at a given @@ -1064,35 +1125,28 @@ private module ContextTyping { * predicate and checks that types are only propagated into arguments when they * are context-typed. */ - module CheckContextTyping { + module CheckContextTyping { pragma[nomagic] private Type inferCallNonReturnType( - AstNode n, FunctionPosition pos, boolean hasReceiver, TypePath path + AstNode n, FunctionPositionKind kind, TypePath prefix, TypePath path ) { - result = inferCallType(n, pos, hasReceiver, path) and - not pos.isReturn() - } - - pragma[nomagic] - private Type inferCallNonReturnType( - AstNode n, FunctionPosition pos, boolean hasReceiver, TypePath prefix, TypePath path - ) { - result = inferCallNonReturnType(n, pos, hasReceiver, path) and + result = inferCallType(n, kind, path) and hasUnknownType(n) and + kind != ReturnKind() and prefix = path.getAPrefix() } pragma[nomagic] Type check(AstNode n, TypePath path) { - result = inferCallType(n, any(FunctionPosition pos | pos.isReturn()), _, path) + result = inferCallType(n, ReturnKind(), path) or - exists(FunctionPosition pos, boolean hasReceiver, TypePath prefix | - result = inferCallNonReturnType(n, pos, hasReceiver, prefix, path) and + exists(FunctionPositionKind kind, TypePath prefix | + result = inferCallNonReturnType(n, kind, prefix, path) and hasUnknownTypeAt(n, prefix) | // Never propagate type information directly into the receiver, since its type // must already have been known in order to resolve the call - if pos.asPosition() = 0 and hasReceiver = true then not prefix.isEmpty() else any() + if kind = SelfKind() then not prefix.isEmpty() else any() ) } } @@ -2877,17 +2931,20 @@ private Type inferFunctionCallTypeSelf( } private Type inferFunctionCallTypePreCheck( - AstNode n, FunctionPosition pos, boolean hasReceiver, TypePath path + AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path ) { - result = inferFunctionCallTypeNonSelf(n, pos, path) and - hasReceiver = false + exists(FunctionPosition pos | + result = inferFunctionCallTypeNonSelf(n, pos, path) and + if pos.isPosition() + then kind = ContextTyping::PositionalKind() + else kind = ContextTyping::ReturnKind() + ) or exists(FunctionCallMatchingInput::Access a | result = inferFunctionCallTypeSelf(a, n, DerefChain::nil(), path) and - pos.asPosition() = 0 and if a.(AssocFunctionResolution::AssocFunctionCall).hasReceiver() - then hasReceiver = true - else hasReceiver = false + then kind = ContextTyping::SelfKind() + else kind = ContextTyping::PositionalKind() ) } @@ -2896,7 +2953,7 @@ private Type inferFunctionCallTypePreCheck( * argument/receiver of a function call. */ private predicate inferFunctionCallType = - ContextTyping::CheckContextTyping::check/2; + ContextTyping::CheckContextTyping::check/2; abstract private class Constructor extends Addressable { final TypeParameter getTypeParameter(TypeParameterPosition ppos) { @@ -3055,10 +3112,14 @@ private module ConstructionMatching = Matching; pragma[nomagic] private Type inferConstructionTypePreCheck( - AstNode n, FunctionPosition pos, boolean hasReceiver, TypePath path + AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path ) { - hasReceiver = false and - exists(ConstructionMatchingInput::Access a | n = a.getNodeAt(pos) | + exists(ConstructionMatchingInput::Access a, FunctionPosition pos | + n = a.getNodeAt(pos) and + if pos.isPosition() + then kind = ContextTyping::PositionalKind() + else kind = ContextTyping::ReturnKind() + | result = ConstructionMatching::inferAccessType(a, pos, path) or a.hasUnknownTypeAt(pos, path) and @@ -3067,7 +3128,7 @@ private Type inferConstructionTypePreCheck( } private predicate inferConstructionType = - ContextTyping::CheckContextTyping::check/2; + ContextTyping::CheckContextTyping::check/2; /** * A matching configuration for resolving types of operations like `a + b`. @@ -3133,17 +3194,22 @@ private module OperationMatching = Matching; pragma[nomagic] private Type inferOperationTypePreCheck( - AstNode n, FunctionPosition pos, boolean hasReceiver, TypePath path + AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path ) { - exists(OperationMatchingInput::Access a | + exists(OperationMatchingInput::Access a, FunctionPosition pos | n = a.getNodeAt(pos) and result = OperationMatching::inferAccessType(a, pos, path) and - hasReceiver = true + if pos.asPosition() = 0 + then kind = ContextTyping::SelfKind() + else + if pos.isPosition() + then kind = ContextTyping::PositionalKind() + else kind = ContextTyping::ReturnKind() ) } private predicate inferOperationType = - ContextTyping::CheckContextTyping::check/2; + ContextTyping::CheckContextTyping::check/2; pragma[nomagic] private Type getFieldExprLookupType(FieldExpr fe, string name, DerefChain derefChain) { @@ -3900,10 +3966,11 @@ private module Cached { or // Don't propagate type information into a node which conflicts with certain // type information. - ( - if CertainTypeInference::hasInferredCertainType(n) - then not CertainTypeInference::certainTypeConflict(n, path, result) - else any() + forall(TypePath prefix | + CertainTypeInference::hasInferredCertainType(n, prefix) and + prefix.isPrefixOf(path) + | + not CertainTypeInference::certainTypeConflict(n, prefix, path, result) ) and ( result = inferAssignmentOperationType(n, path) @@ -3970,7 +4037,7 @@ private module Debug { TypeAbstraction abs, TypeMention condition, TypeMention constraint, boolean transitive ) { abs = getRelevantLocatable() and - Input::conditionSatisfiesConstraint(abs, condition, constraint, transitive) + Input2::conditionSatisfiesConstraint(abs, condition, constraint, transitive) } predicate debugInferShorthandSelfType(ShorthandSelfParameterMention self, TypePath path, Type t) { diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeMention.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeMention.qll index d9a00f33940..70dfbeda848 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeMention.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeMention.qll @@ -205,6 +205,13 @@ private module MkTypeMention::...` + exists(PathTypeRepr typeRepr, PathTypeRepr traitRepr | + pathTypeAsTraitAssoc(_, typeRepr, traitRepr, _, _) and + this = traitRepr.getPath() and + result = typeRepr.getPath() + ) } pragma[nomagic] @@ -696,16 +703,26 @@ private module PreTypeMention = MkTypeMention; class PreTypeMention = PreTypeMention::TypeMention; +private class TraitOrTmTrait extends AstNode { + Type getTypeAt(TypePath path) { + pathTypeAsTraitAssoc(_, _, this, _, _) and + result = this.(PreTypeMention).getTypeAt(path) + or + result = TTrait(this) and + path.isEmpty() + } +} + /** * Holds if `path` accesses an associated type `alias` from `trait` on a * concrete type given by `tm`. * - * `implOrTmTrait` is either the mention that resolves to `trait` when `path` - * is of the form `::AssocType`, or the enclosing `impl` block - * when `path` is of the form `Self::AssocType`. + * `traitOrTmTrait` is either the mention that resolves to `trait` when `path` + * is of the form `::AssocType`, or the trait being implemented + * when `path` is of the form `Self::AssocType` within an `impl` block. */ private predicate pathConcreteTypeAssocType( - Path path, PreTypeMention tm, TraitItemNode trait, AstNode implOrTmTrait, TypeAlias alias + Path path, PreTypeMention tm, TraitItemNode trait, TraitOrTmTrait traitOrTmTrait, TypeAlias alias ) { exists(Path qualifier | qualifier = path.getQualifier() and @@ -713,31 +730,34 @@ private predicate pathConcreteTypeAssocType( | // path of the form `::AssocType` // ^^^ tm ^^^^^^^^^ name + // ^^^^^ traitOrTmTrait exists(string name | - pathTypeAsTraitAssoc(path, tm, implOrTmTrait, trait, name) and + pathTypeAsTraitAssoc(path, tm, traitOrTmTrait, trait, name) and getTraitAssocType(trait, name) = alias ) or // path of the form `Self::AssocType` within an `impl` block // tm ^^^^ ^^^^^^^^^ name - implOrTmTrait = - any(ImplItemNode impl | - alias = resolvePath(path) and - qualifier = impl.getASelfPath() and - tm = impl.(Impl).getSelfTy() and - trait.getAnAssocItem() = alias - ) + exists(ImplItemNode impl | + alias = resolvePath(path) and + qualifier = impl.getASelfPath() and + tm = impl.(Impl).getSelfTy() and + trait.getAnAssocItem() = alias and + traitOrTmTrait = trait + ) ) } -private module PathSatisfiesConstraintInput implements SatisfiesTypeInputSig { - predicate relevantConstraint(PreTypeMention tm, Type constraint) { - pathConcreteTypeAssocType(_, tm, constraint.(TraitType).getTrait(), _, _) +private module PathSatisfiesConstraintInput implements + PreM2::SatisfiesConstraintInputSig +{ + predicate relevantConstraint(PreTypeMention tm, TraitOrTmTrait constraint) { + pathConcreteTypeAssocType(_, tm, _, constraint, _) } } private module PathSatisfiesConstraint = - SatisfiesType; + PreM2::SatisfiesConstraint; /** * Gets the type of `path` at `typePath` when `path` accesses an associated type @@ -745,26 +765,12 @@ private module PathSatisfiesConstraint = */ private Type getPathConcreteAssocTypeAt(Path path, TypePath typePath) { exists( - PreTypeMention tm, ImplItemNode impl, TraitItemNode trait, TraitType t, AstNode implOrTmTrait, + PreTypeMention tm, ImplItemNode impl, TraitItemNode trait, TraitOrTmTrait traitOrTmTrait, TypeAlias alias, TypePath path0 | - pathConcreteTypeAssocType(path, tm, trait, implOrTmTrait, alias) and - t = TTrait(trait) and - PathSatisfiesConstraint::satisfiesConstraintTypeThrough(tm, impl, t, path0, result) and + pathConcreteTypeAssocType(path, tm, trait, traitOrTmTrait, alias) and + PathSatisfiesConstraint::satisfiesConstraintTypeThrough(tm, impl, traitOrTmTrait, path0, result) and path0.isCons(TAssociatedTypeTypeParameter(trait, alias), typePath) - | - implOrTmTrait instanceof Impl - or - // When `path` is of the form `::AssocType` we need to check - // that `impl` is not more specific than the mentioned trait - implOrTmTrait = - any(PreTypeMention tmTrait | - not exists(TypePath path1, Type t1 | - t1 = impl.getTraitPath().(PreTypeMention).getTypeAt(path1) and - not t1 instanceof TypeParameter and - t1 != tmTrait.getTypeAt(path1) - ) - ) ) } diff --git a/rust/ql/test/library-tests/type-inference/overloading.rs b/rust/ql/test/library-tests/type-inference/overloading.rs index e0f3dbf6954..06353a12c8f 100644 --- a/rust/ql/test/library-tests/type-inference/overloading.rs +++ b/rust/ql/test/library-tests/type-inference/overloading.rs @@ -509,15 +509,15 @@ mod trait_bound_impl_overlap { fn test() { let x = S(0); - let y = call_f(x); // $ target=call_f type=y:i32 $ SPURIOUS: type=y:i64 + let y = call_f(x); // $ target=call_f type=y:i32 let z: i32 = y; let x = S(0); let y = call_f::(x); // $ target=call_f type=y:i32 let x = S(0); - let y = call_f2(S(0i32), x); // $ target=call_f2 type=y:i32 $ SPURIOUS: type=y:i64 + let y = call_f2(S(0i32), x); // $ target=call_f2 type=y:i32 let x = S(0); - let y = call_f2(S(0i64), x); // $ target=call_f2 type=y:i64 $ SPURIOUS: type=y:i32 + let y = call_f2(S(0i64), x); // $ target=call_f2 type=y:i64 } } diff --git a/rust/ql/test/library-tests/type-inference/type-inference.expected b/rust/ql/test/library-tests/type-inference/type-inference.expected index 54b338bd5c4..a25a9daf003 100644 --- a/rust/ql/test/library-tests/type-inference/type-inference.expected +++ b/rust/ql/test/library-tests/type-inference/type-inference.expected @@ -10775,7 +10775,6 @@ inferType | main.rs:2032:56:2034:9 | { ... } | | {EXTERNAL LOCATION} | & | | main.rs:2032:56:2034:9 | { ... } | TRef | main.rs:2028:10:2028:10 | T | | main.rs:2033:13:2033:29 | &... | | {EXTERNAL LOCATION} | & | -| main.rs:2033:13:2033:29 | &... | TRef | {EXTERNAL LOCATION} | u8 | | main.rs:2033:13:2033:29 | &... | TRef | main.rs:2028:10:2028:10 | T | | main.rs:2033:14:2033:17 | self | | {EXTERNAL LOCATION} | & | | main.rs:2033:14:2033:17 | self | TRef | main.rs:2013:5:2016:5 | MyVec | @@ -10783,7 +10782,6 @@ inferType | main.rs:2033:14:2033:22 | self.data | | {EXTERNAL LOCATION} | Vec | | main.rs:2033:14:2033:22 | self.data | A | {EXTERNAL LOCATION} | Global | | main.rs:2033:14:2033:22 | self.data | T | main.rs:2028:10:2028:10 | T | -| main.rs:2033:14:2033:29 | ...[index] | | {EXTERNAL LOCATION} | u8 | | main.rs:2033:14:2033:29 | ...[index] | | main.rs:2028:10:2028:10 | T | | main.rs:2033:24:2033:28 | index | | {EXTERNAL LOCATION} | usize | | main.rs:2037:22:2037:26 | slice | | {EXTERNAL LOCATION} | & | @@ -12931,14 +12929,11 @@ inferType | overloading.rs:511:17:511:20 | S(...) | T | {EXTERNAL LOCATION} | i32 | | overloading.rs:511:19:511:19 | 0 | | {EXTERNAL LOCATION} | i32 | | overloading.rs:512:13:512:13 | y | | {EXTERNAL LOCATION} | i32 | -| overloading.rs:512:13:512:13 | y | | {EXTERNAL LOCATION} | i64 | | overloading.rs:512:17:512:25 | call_f(...) | | {EXTERNAL LOCATION} | i32 | -| overloading.rs:512:17:512:25 | call_f(...) | | {EXTERNAL LOCATION} | i64 | | overloading.rs:512:24:512:24 | x | | overloading.rs:464:5:464:19 | S | | overloading.rs:512:24:512:24 | x | T | {EXTERNAL LOCATION} | i32 | | overloading.rs:513:13:513:13 | z | | {EXTERNAL LOCATION} | i32 | | overloading.rs:513:22:513:22 | y | | {EXTERNAL LOCATION} | i32 | -| overloading.rs:513:22:513:22 | y | | {EXTERNAL LOCATION} | i64 | | overloading.rs:515:13:515:13 | x | | overloading.rs:464:5:464:19 | S | | overloading.rs:515:13:515:13 | x | T | {EXTERNAL LOCATION} | i32 | | overloading.rs:515:17:515:20 | S(...) | | overloading.rs:464:5:464:19 | S | @@ -12954,9 +12949,7 @@ inferType | overloading.rs:518:17:518:20 | S(...) | T | {EXTERNAL LOCATION} | i32 | | overloading.rs:518:19:518:19 | 0 | | {EXTERNAL LOCATION} | i32 | | overloading.rs:519:13:519:13 | y | | {EXTERNAL LOCATION} | i32 | -| overloading.rs:519:13:519:13 | y | | {EXTERNAL LOCATION} | i64 | | overloading.rs:519:17:519:35 | call_f2(...) | | {EXTERNAL LOCATION} | i32 | -| overloading.rs:519:17:519:35 | call_f2(...) | | {EXTERNAL LOCATION} | i64 | | overloading.rs:519:25:519:31 | S(...) | | overloading.rs:464:5:464:19 | S | | overloading.rs:519:25:519:31 | S(...) | T | {EXTERNAL LOCATION} | i32 | | overloading.rs:519:27:519:30 | 0i32 | | {EXTERNAL LOCATION} | i32 | @@ -12967,9 +12960,7 @@ inferType | overloading.rs:520:17:520:20 | S(...) | | overloading.rs:464:5:464:19 | S | | overloading.rs:520:17:520:20 | S(...) | T | {EXTERNAL LOCATION} | i32 | | overloading.rs:520:19:520:19 | 0 | | {EXTERNAL LOCATION} | i32 | -| overloading.rs:521:13:521:13 | y | | {EXTERNAL LOCATION} | i32 | | overloading.rs:521:13:521:13 | y | | {EXTERNAL LOCATION} | i64 | -| overloading.rs:521:17:521:35 | call_f2(...) | | {EXTERNAL LOCATION} | i32 | | overloading.rs:521:17:521:35 | call_f2(...) | | {EXTERNAL LOCATION} | i64 | | overloading.rs:521:25:521:31 | S(...) | | overloading.rs:464:5:464:19 | S | | overloading.rs:521:25:521:31 | S(...) | T | {EXTERNAL LOCATION} | i64 | diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index cbc1f608813..7cd4dab479d 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -385,6 +385,45 @@ module Make1 Input1> { predicate conditionSatisfiesConstraint( TypeAbstraction abs, TypeMention condition, TypeMention constraint, boolean transitive ); + + /** + * Holds if the constraint belonging to `abs` with root type `constraint` is + * ambiguous at `path`, meaning that there is _some_ other abstraction `abs2` + * with a structurally identical condition and same root constraint type + * `constraint`, and where the constraints differ at `path`. + * + * Example: + * + * ```rust + * trait Trait { } + * + * impl Trait for Foo { ... } + * // ^^^ `abs` + * // ^^^^^ `constraint` + * // ^^^^^^ `condition` + * + * impl Trait for Foo { } + * // ^^^ `abs2` + * // ^^^^^ `constraint` + * // ^^^^^^ `condition2` + * ``` + * + * In the above, `abs` and `abs2` have structurally identical conditions, + * `condition` and `condition2`, and they differ at the path `"T1"`, but + * not at the path `"T2"`. + */ + predicate typeAbstractionHasAmbiguousConstraintAt( + TypeAbstraction abs, Type constraint, TypePath path + ); + + /** + * Holds if all instantiations of `tp` are functionally determined by the + * instantiations of the other type parameters in the same abstraction. + * + * For example, in Rust all associated types act as functionally determined + * type parameters. + */ + predicate typeParameterIsFunctionallyDetermined(TypeParameter tp); } module Make2 Input2> { @@ -661,6 +700,7 @@ module Make1 Input1> { * Holds if the type mention `condition` satisfies `constraint` with the * type `t` at the path `path`. */ + pragma[nomagic] predicate conditionSatisfiesConstraintTypeAt( TypeAbstraction abs, TypeMention condition, TypeMention constraint, TypePath path, Type t ) { @@ -820,15 +860,30 @@ module Make1 Input1> { private import BaseTypes - /** Provides the input to `SatisfiesConstraint`. */ - signature module SatisfiesConstraintInputSig { + /** Provides the input to `SatisfiesConstraintWithTypeMatching`. */ + signature module SatisfiesConstraintWithTypeMatchingInputSig< + HasTypeTreeSig Term, HasTypeTreeSig Constraint> + { /** Holds if it is relevant to know if `term` satisfies `constraint`. */ predicate relevantConstraint(Term term, Constraint constraint); + + /** A context in which a type parameter can be matched with an instantiation. */ + class TypeMatchingContext; + + /** Gets the type matching context for `t`. */ + TypeMatchingContext getTypeMatchingContext(Term t); + + /** + * Holds if `tp` can be matched with the type `t` at `path` in the context `ctx`. + * + * This may be used to disambiguate between multiple constraints that a term may satisfy. + */ + predicate typeMatch(TypeMatchingContext ctx, TypeParameter tp, TypePath path, Type t); } - module SatisfiesConstraint< + module SatisfiesConstraintWithTypeMatching< HasTypeTreeSig Term, HasTypeTreeSig Constraint, - SatisfiesConstraintInputSig Input> + SatisfiesConstraintWithTypeMatchingInputSig Input> { private import Input @@ -944,12 +999,103 @@ module Make1 Input1> { pragma[nomagic] private predicate satisfiesConstraintTypeMention0( + Term term, Constraint constraint, TypeMention constraintMention, TypeAbstraction abs, + TypeMention sub, TypePath path, Type t, boolean ambiguous + ) { + exists(Type constraintRoot | + hasConstraintMention(term, abs, sub, constraint, constraintRoot, constraintMention) and + conditionSatisfiesConstraintTypeAt(abs, sub, constraintMention, path, t) and + if + exists(TypePath prefix | + typeAbstractionHasAmbiguousConstraintAt(abs, constraintRoot, prefix) and + prefix.isPrefixOf(path) + ) + then ambiguous = true + else ambiguous = false + ) + } + + pragma[nomagic] + private predicate conditionSatisfiesConstraintTypeAtForDisambiguation( + TypeAbstraction abs, TypeMention condition, TypeMention constraint, TypePath path, Type t + ) { + conditionSatisfiesConstraintTypeAt(abs, condition, constraint, path, t) and + not t instanceof TypeParameter and + not typeParameterIsFunctionallyDetermined(path.getHead()) + } + + pragma[nomagic] + private predicate constraintTypeMatchForDisambiguation0( + Term term, Constraint constraint, TypePath path, TypePath suffix, TypeParameter tp + ) { + exists( + TypeMention constraintMention, TypeAbstraction abs, TypeMention sub, TypePath prefix + | + satisfiesConstraintTypeMention0(term, constraint, constraintMention, abs, sub, _, _, true) and + conditionSatisfiesConstraintTypeAtForDisambiguation(abs, sub, constraintMention, path, _) and + tp = constraint.getTypeAt(prefix) and + path = prefix.appendInverse(suffix) + ) + } + + pragma[nomagic] + private predicate constraintTypeMatchForDisambiguation1( + Term term, Constraint constraint, TypePath path, TypeMatchingContext ctx, TypePath suffix, + TypeParameter tp + ) { + constraintTypeMatchForDisambiguation0(term, constraint, path, suffix, tp) and + ctx = getTypeMatchingContext(term) + } + + /** + * Holds if the type of `constraint` at `path` is `t` because it is possible + * to match some type parameter that occurs in `constraint` at a prefix of + * `path` in the context of `term`. + * + * For example, if we have + * + * ```rust + * fn f>(x: T1, y: T2) -> T2::Output { ... } + * ``` + * + * then at a call like `f(true, ...)` the constraint `SomeTrait` has the + * type `bool` substituted for `T1`. + */ + pragma[nomagic] + private predicate constraintTypeMatchForDisambiguation( + Term term, Constraint constraint, TypePath path, Type t + ) { + exists(TypeMatchingContext ctx, TypeParameter tp, TypePath suffix | + constraintTypeMatchForDisambiguation1(term, constraint, path, ctx, suffix, tp) and + typeMatch(ctx, tp, suffix, t) + ) + } + + pragma[nomagic] + private predicate satisfiesConstraintTypeMention1( Term term, Constraint constraint, TypeAbstraction abs, TypeMention sub, TypePath path, Type t ) { - exists(Type constraintRoot, TypeMention constraintMention | - hasConstraintMention(term, abs, sub, constraint, constraintRoot, constraintMention) and - conditionSatisfiesConstraintTypeAt(abs, sub, constraintMention, path, t) + exists(TypeMention constraintMention, boolean ambiguous | + satisfiesConstraintTypeMention0(term, constraint, constraintMention, abs, sub, path, t, + ambiguous) + | + if ambiguous = true + then + // When the constraint is not uniquely satisfied, we check that the satisfying + // abstraction is not more specific than the constraint to be satisfied. For example, + // if the constraint is `MyTrait` and there is both `impl MyTrait for ...` and + // `impl MyTrait for ...`, then the latter will be filtered away + forall(TypePath path1, Type t1 | + conditionSatisfiesConstraintTypeAtForDisambiguation(abs, sub, constraintMention, + path1, t1) + | + t1 = constraint.getTypeAt(path1) + or + // The constraint may contain a type parameter, which we can match to the right type + constraintTypeMatchForDisambiguation(term, constraint, path1, t1) + ) + else any() ) } @@ -959,7 +1105,7 @@ module Make1 Input1> { TypePath pathToTypeParamInSub ) { exists(TypeMention sub, TypeParameter tp | - satisfiesConstraintTypeMention0(term, constraint, abs, sub, path, tp) and + satisfiesConstraintTypeMention1(term, constraint, abs, sub, path, tp) and tp = abs.getATypeParameter() and sub.getTypeAt(pathToTypeParamInSub) = tp ) @@ -984,7 +1130,7 @@ module Make1 Input1> { private predicate satisfiesConstraintTypeNonTypeParamInline( Term term, TypeAbstraction abs, Constraint constraint, TypePath path, Type t ) { - satisfiesConstraintTypeMention0(term, constraint, abs, _, path, t) and + satisfiesConstraintTypeMention1(term, constraint, abs, _, path, t) and not t = abs.getATypeParameter() } @@ -1048,12 +1194,41 @@ module Make1 Input1> { predicate dissatisfiesConstraint(Term term, Constraint constraint) { hasNotConstraintMention(term, constraint, _) and exists(Type t, Type constraintRoot | - hasTypeConstraint(term, t, constraint, constraintRoot) and // todo + hasTypeConstraint(term, t, constraint, constraintRoot) and t != constraintRoot ) } } + /** Provides the input to `SatisfiesConstraint`. */ + signature module SatisfiesConstraintInputSig { + /** Holds if it is relevant to know if `term` satisfies `constraint`. */ + predicate relevantConstraint(Term term, Constraint constraint); + } + + module SatisfiesConstraint< + HasTypeTreeSig Term, HasTypeTreeSig Constraint, + SatisfiesConstraintInputSig Input> + { + private module Inp implements SatisfiesConstraintWithTypeMatchingInputSig { + private import codeql.util.Void + + predicate relevantConstraint(Term term, Constraint constraint) { + Input::relevantConstraint(term, constraint) + } + + class TypeMatchingContext = Void; + + TypeMatchingContext getTypeMatchingContext(Term t) { none() } + + predicate typeMatch(TypeMatchingContext ctx, TypeParameter tp, TypePath path, Type t) { + none() + } + } + + import SatisfiesConstraintWithTypeMatching + } + /** Provides the input to `SatisfiesType`. */ signature module SatisfiesTypeInputSig { /** Holds if it is relevant to know if `term` satisfies `type`. */ @@ -1306,7 +1481,7 @@ module Make1 Input1> { } private module AccessConstraint { - predicate relevantAccessConstraint( + private predicate relevantAccessConstraint( Access a, AccessEnvironment e, Declaration target, AccessPosition apos, TypePath path, TypeMention constraint ) { @@ -1331,6 +1506,7 @@ module Make1 Input1> { RelevantAccess() { this = MkRelevantAccess(a, apos, e, path) } + pragma[nomagic] Type getTypeAt(TypePath suffix) { result = a.getInferredType(e, apos, path.appendInverse(suffix)) } @@ -1348,16 +1524,29 @@ module Make1 Input1> { } private module SatisfiesTypeParameterConstraintInput implements - SatisfiesConstraintInputSig + SatisfiesConstraintWithTypeMatchingInputSig { predicate relevantConstraint(RelevantAccess at, TypeMention constraint) { constraint = at.getConstraint(_) } + + class TypeMatchingContext = Access; + + TypeMatchingContext getTypeMatchingContext(RelevantAccess at) { + at = MkRelevantAccess(result, _, _, _) + } + + pragma[nomagic] + predicate typeMatch(TypeMatchingContext ctx, TypeParameter tp, TypePath path, Type t) { + typeMatch(ctx, _, _, path, t, tp) + } } private module SatisfiesTypeParameterConstraint = - SatisfiesConstraint; + SatisfiesConstraintWithTypeMatching; + pragma[nomagic] predicate satisfiesConstraintType( Access a, AccessEnvironment e, Declaration target, AccessPosition apos, TypePath prefix, TypeMention constraint, TypePath path, Type t diff --git a/shared/util/codeql/util/UnboundList.qll b/shared/util/codeql/util/UnboundList.qll index 79fac6506d6..6f05d6cddfc 100644 --- a/shared/util/codeql/util/UnboundList.qll +++ b/shared/util/codeql/util/UnboundList.qll @@ -167,6 +167,14 @@ module Make Input> { */ bindingset[this] UnboundList getAPrefix() { result = [this, this.getAProperPrefix()] } + + /** + * Holds if this list is a prefix of `other`. + * + * This is equivalent to `this = other.getAPrefix()`, but more performant. + */ + bindingset[this, other] + predicate isPrefixOf(UnboundList other) { this = other.prefix(this.stringLength()) } } /** Provides predicates for constructing `UnboundList`s. */