diff --git a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll index 504d8979c52..3229b3ee0bb 100644 --- a/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll @@ -127,17 +127,15 @@ private module Input implements InputSig1, InputSig2 { PreTypeMention getABaseTypeMention(Type t) { none() } - Type getATypeParameterConstraint(TypeParameter tp, TypePath path) { - exists(TypeMention tm | result = tm.getTypeAt(path) | - tm = tp.(TypeParamTypeParameter).getTypeParam().getATypeBound().getTypeRepr() or - tm = tp.(SelfTypeParameter).getTrait() or - tm = - tp.(ImplTraitTypeTypeParameter) - .getImplTraitTypeRepr() - .getTypeBoundList() - .getABound() - .getTypeRepr() - ) + PreTypeMention getATypeParameterConstraint(TypeParameter tp) { + result = tp.(TypeParamTypeParameter).getTypeParam().getATypeBound().getTypeRepr() or + result = tp.(SelfTypeParameter).getTrait() or + result = + tp.(ImplTraitTypeTypeParameter) + .getImplTraitTypeRepr() + .getTypeBoundList() + .getABound() + .getTypeRepr() } /** @@ -988,7 +986,7 @@ private module ContextTyping { or exists(TypeParameter mid | assocFunctionMentionsTypeParameterAtNonRetPos(i, f, mid) and - tp = getATypeParameterConstraint(mid, _) + tp = getATypeParameterConstraint(mid).getTypeAt(_) ) } diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index d42aef05a40..cbc1f608813 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -336,7 +336,7 @@ module Make1 Input1> { * ``` * the type parameter `T` has the constraint `IComparable`. */ - Type getATypeParameterConstraint(TypeParameter tp, TypePath path); + TypeMention getATypeParameterConstraint(TypeParameter tp); /** * Holds if @@ -1308,7 +1308,7 @@ module Make1 Input1> { private module AccessConstraint { predicate relevantAccessConstraint( Access a, AccessEnvironment e, Declaration target, AccessPosition apos, TypePath path, - Type constraint + TypeMention constraint ) { target = a.getTarget(e) and typeParameterConstraintHasTypeParameter(target, apos, path, constraint, _, _) @@ -1336,7 +1336,7 @@ module Make1 Input1> { } /** Gets the constraint that this relevant access should satisfy. */ - Type getConstraint(Declaration target) { + TypeMention getConstraint(Declaration target) { relevantAccessConstraint(a, e, target, apos, path, result) } @@ -1347,20 +1347,24 @@ module Make1 Input1> { Location getLocation() { result = a.getLocation() } } - private module SatisfiesConstraintInput implements SatisfiesTypeInputSig { - predicate relevantConstraint(RelevantAccess at, Type constraint) { + private module SatisfiesTypeParameterConstraintInput implements + SatisfiesConstraintInputSig + { + predicate relevantConstraint(RelevantAccess at, TypeMention constraint) { constraint = at.getConstraint(_) } } + private module SatisfiesTypeParameterConstraint = + SatisfiesConstraint; + predicate satisfiesConstraintType( Access a, AccessEnvironment e, Declaration target, AccessPosition apos, TypePath prefix, - Type constraint, TypePath path, Type t + TypeMention constraint, TypePath path, Type t ) { exists(RelevantAccess ra | ra = MkRelevantAccess(a, apos, e, prefix) and - SatisfiesType::satisfiesConstraintType(ra, - constraint, path, t) and + SatisfiesTypeParameterConstraint::satisfiesConstraintType(ra, constraint, path, t) and constraint = ra.getConstraint(target) ) } @@ -1469,17 +1473,17 @@ module Make1 Input1> { */ pragma[nomagic] private predicate typeParameterConstraintHasTypeParameter( - Declaration target, AccessPosition apos, TypePath pathToConstrained, Type constraint, + Declaration target, AccessPosition apos, TypePath pathToConstrained, TypeMention constraint, TypePath pathToTp, TypeParameter tp ) { exists(DeclarationPosition dpos, TypeParameter constrainedTp | accessDeclarationPositionMatch(apos, dpos) and constrainedTp = target.getTypeParameter(_) and + constraint = getATypeParameterConstraint(constrainedTp) and tp = target.getTypeParameter(_) and - tp = getATypeParameterConstraint(constrainedTp, pathToTp) and + tp = constraint.getTypeAt(pathToTp) and constrainedTp != tp and - constrainedTp = target.getDeclaredType(dpos, pathToConstrained) and - constraint = getATypeParameterConstraint(constrainedTp, TypePath::nil()) + constrainedTp = target.getDeclaredType(dpos, pathToConstrained) ) } @@ -1488,7 +1492,7 @@ module Make1 Input1> { Access a, AccessEnvironment e, Declaration target, TypePath path, Type t, TypeParameter tp ) { not exists(getTypeArgument(a, target, tp, _)) and - exists(Type constraint, AccessPosition apos, TypePath pathToTp, TypePath pathToTp2 | + exists(TypeMention constraint, AccessPosition apos, TypePath pathToTp, TypePath pathToTp2 | typeParameterConstraintHasTypeParameter(target, apos, pathToTp2, constraint, pathToTp, tp) and AccessConstraint::satisfiesConstraintType(a, e, target, apos, pathToTp2, constraint, pathToTp.appendInverse(path), t)