diff --git a/java/kotlin-extractor/src/main/kotlin/KotlinFileExtractor.kt b/java/kotlin-extractor/src/main/kotlin/KotlinFileExtractor.kt index 6c3eb2d3c9c..a32e7938d51 100644 --- a/java/kotlin-extractor/src/main/kotlin/KotlinFileExtractor.kt +++ b/java/kotlin-extractor/src/main/kotlin/KotlinFileExtractor.kt @@ -497,7 +497,9 @@ open class KotlinFileExtractor( else null } ?: vp.type - val substitutedType = typeSubstitution?.let { it(maybeErasedType, TypeContext.OTHER, pluginContext) } ?: maybeErasedType + val typeWithWildcards = addJavaLoweringWildcards(maybeErasedType, true) + val substitutedType = typeSubstitution?.let { it(typeWithWildcards, TypeContext.OTHER, pluginContext) } ?: typeWithWildcards + val id = useValueParameter(vp, parent) if (extractTypeAccess) { extractTypeAccessRecursive(substitutedType, location, id, -1) @@ -704,7 +706,7 @@ open class KotlinFileExtractor( val paramsSignature = allParamTypes.joinToString(separator = ",", prefix = "(", postfix = ")") { it.javaResult.signature!! } - val adjustedReturnType = getAdjustedReturnType(f) + val adjustedReturnType = addJavaLoweringWildcards(getAdjustedReturnType(f), false) val substReturnType = typeSubstitution?.let { it(adjustedReturnType, TypeContext.RETURN, pluginContext) } ?: adjustedReturnType val locId = locOverride ?: getLocation(f, classTypeArgsIncludingOuterClasses) diff --git a/java/kotlin-extractor/src/main/kotlin/KotlinUsesExtractor.kt b/java/kotlin-extractor/src/main/kotlin/KotlinUsesExtractor.kt index fef02884df1..951ac4aecf0 100644 --- a/java/kotlin-extractor/src/main/kotlin/KotlinUsesExtractor.kt +++ b/java/kotlin-extractor/src/main/kotlin/KotlinUsesExtractor.kt @@ -6,6 +6,7 @@ import com.github.codeql.utils.versions.isRawType import com.semmle.extractor.java.OdasaOutput import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext import org.jetbrains.kotlin.backend.common.ir.allOverridden +import org.jetbrains.kotlin.backend.common.ir.isFinalClass import org.jetbrains.kotlin.backend.common.lower.parentsWithSelf import org.jetbrains.kotlin.backend.jvm.ir.getJvmNameFromAnnotation import org.jetbrains.kotlin.backend.jvm.ir.propertyIfAccessor @@ -850,6 +851,49 @@ open class KotlinUsesExtractor( (f.name.asString() == "addAll" && overridesFunctionDefinedOn(f, "kotlin.collections", "MutableCollection")) || (f.name.asString() == "addAll" && overridesFunctionDefinedOn(f, "kotlin.collections", "MutableList")) + + private fun wildcardAdditionAllowed(v: Variance, t: IrType, inParameterContext: Boolean) = + when { + t.hasAnnotation(FqName("kotlin.jvm.JvmWildcard")) -> true + !inParameterContext -> false // By default, wildcards are only automatically added for method parameters. + t.hasAnnotation(FqName("kotlin.jvm.JvmSuppressWildcards")) -> false + v == Variance.IN_VARIANCE -> !(t.isNullableAny() || t.isAny()) + v == Variance.OUT_VARIANCE -> ((t as? IrSimpleType)?.classOrNull?.owner?.isFinalClass) != true + else -> false + } + + private fun addJavaLoweringArgumentWildcards(p: IrTypeParameter, t: IrTypeArgument, inParameterContext: Boolean): IrTypeArgument = + (t as? IrTypeProjection)?.let { + val newBase = addJavaLoweringWildcards(it.type, inParameterContext) + val newVariance = + if (it.variance == Variance.INVARIANT && + p.variance != Variance.INVARIANT && + wildcardAdditionAllowed(p.variance, it.type, inParameterContext)) + p.variance + else + it.variance + if (newBase !== it.type || newVariance != it.variance) + makeTypeProjection(newBase, newVariance) + else + null + } ?: t + + fun addJavaLoweringWildcards(t: IrType, inParameterContext: Boolean): IrType = + (t as? IrSimpleType)?.let { + val typeParams = it.classOrNull?.owner?.typeParameters ?: return t + val newArgs = typeParams.zip(it.arguments).map { pair -> + addJavaLoweringArgumentWildcards( + pair.first, + pair.second, + inParameterContext + ) + } + return if (newArgs.zip(it.arguments).all { pair -> pair.first === pair.second }) + t + else + it.toBuilder().also { builder -> builder.arguments = newArgs }.buildSimpleType() + } ?: t + /* * This is the normal getFunctionLabel function to use. If you want * to refer to the function in its source class then @@ -956,8 +1000,10 @@ open class KotlinUsesExtractor( // Collection.remove(Object) because Collection.remove(Collection::E) in the Kotlin universe. // If this has happened, erase the type again to get the correct Java signature. val maybeAmendedForCollections = if (overridesCollectionsMethod) eraseCollectionsMethodParameterType(it.value.type, name, it.index) else it.value.type + // Add any wildcard types that the Kotlin compiler would add in the Java lowering of this function: + val withAddedWildcards = addJavaLoweringWildcards(maybeAmendedForCollections, true) // Now substitute any class type parameters in: - val maybeSubbed = maybeAmendedForCollections.substituteTypeAndArguments(substitutionMap, TypeContext.OTHER, pluginContext) + val maybeSubbed = withAddedWildcards.substituteTypeAndArguments(substitutionMap, TypeContext.OTHER, pluginContext) // Finally, mimic the Java extractor's behaviour by naming functions with type parameters for their erased types; // those without type parameters are named for the generic type. val maybeErased = if (functionTypeParameters.isEmpty()) maybeSubbed else erase(maybeSubbed) @@ -969,6 +1015,8 @@ open class KotlinUsesExtractor( pluginContext.irBuiltIns.unitType else erase(returnType.substituteTypeAndArguments(substitutionMap, TypeContext.RETURN, pluginContext)) + // Note that `addJavaLoweringWildcards` is not required here because the return type used to form the function + // label is always erased. val returnTypeId = useType(labelReturnType, TypeContext.RETURN).javaResult.id // This suffix is added to generic methods (and constructors) to match the Java extractor's behaviour. // Comments in that extractor indicates it didn't want the label of the callable to clash with the raw