Extract String?.plus as either an AddExpr or a call to an intrinsic

If it is used by the compiler to implement the infix plus operator, resugar it and extract a `+` as Java would. If it is literally called by the user (e.g. `(if (x) then "not null" else null).plus(something)`), then extract a call to the real method Intrinsics.stringPlus (a two-arg static method).
This commit is contained in:
Chris Smowton
2022-02-03 15:21:08 +00:00
committed by Ian Lynagh
parent 93e8d5a2d6
commit 377bd8f2e9
4 changed files with 152 additions and 105 deletions

View File

@@ -999,9 +999,112 @@ open class KotlinFileExtractor(
}
}
fun extractRawMethodAccess(
syntacticCallTarget: IrFunction,
callsite: IrCall,
enclosingCallable: Label<out DbCallable>,
callsiteParent: Label<out DbExprparent>,
childIdx: Int,
enclosingStmt: Label<out DbStmt>,
valueArguments: List<IrExpression?>,
dispatchReceiver: IrExpression?,
extensionReceiver: IrExpression?,
typeArguments: List<IrType> = listOf(),
extractClassTypeArguments: Boolean = false) {
val callTarget = syntacticCallTarget.target
val id = tw.getFreshIdLabel<DbMethodaccess>()
val type = useType(callsite.type)
val locId = tw.getLocation(callsite)
tw.writeExprs_methodaccess(id, type.javaResult.id, callsiteParent, childIdx)
tw.writeExprsKotlinType(id, type.kotlinResult.id)
tw.writeHasLocation(id, locId)
tw.writeCallableEnclosingExpr(id, enclosingCallable)
tw.writeStatementEnclosingExpr(id, enclosingStmt)
// type arguments at index -2, -3, ...
extractTypeArguments(typeArguments, callsite, id, enclosingCallable, enclosingStmt, -2, true)
if (callTarget.isLocalFunction()) {
val ids = getLocallyVisibleFunctionLabels(callTarget)
val methodId = ids.function
tw.writeCallableBinding(id, methodId)
val idNewexpr = tw.getFreshIdLabel<DbNewexpr>()
tw.writeExprs_newexpr(idNewexpr, ids.type.javaResult.id, id, -1)
tw.writeExprsKotlinType(idNewexpr, ids.type.kotlinResult.id)
tw.writeHasLocation(idNewexpr, locId)
tw.writeCallableEnclosingExpr(idNewexpr, enclosingCallable)
tw.writeStatementEnclosingExpr(idNewexpr, enclosingStmt)
tw.writeCallableBinding(idNewexpr, ids.constructor)
@Suppress("UNCHECKED_CAST")
tw.writeIsAnonymClass(ids.type.javaResult.id as Label<DbClass>, idNewexpr)
extractTypeAccess(pluginContext.irBuiltIns.anyType, enclosingCallable, idNewexpr, -3, callsite, enclosingStmt)
} else {
// Returns true if type is C<T1, T2, ...> where C is declared `class C<T1, T2, ...> { ... }`
fun isUnspecialised(type: IrSimpleType) =
type.classifier.owner is IrClass &&
(type.classifier.owner as IrClass).typeParameters.zip(type.arguments).all { paramAndArg ->
(paramAndArg.second as? IrTypeProjection)?.let {
// Type arg refers to the class' own type parameter?
it.variance == Variance.INVARIANT &&
it.type.classifierOrNull?.owner === paramAndArg.first
} ?: false
}
val drType = dispatchReceiver?.type
val methodId =
if (drType != null && extractClassTypeArguments && drType is IrSimpleType && !isUnspecialised(drType))
useFunction<DbCallable>(callTarget, getDeclaringTypeArguments(callTarget, drType))
else
useFunction<DbCallable>(callTarget)
tw.writeCallableBinding(id, methodId)
if (dispatchReceiver != null) {
extractExpressionExpr(dispatchReceiver, enclosingCallable, id, -1, enclosingStmt)
} else if(callTarget.isStaticMethodOfClass) {
extractTypeAccess(callTarget.parentAsClass.toRawType(), enclosingCallable, id, -1, callsite, enclosingStmt)
}
}
val idxOffset: Int
if (extensionReceiver != null) {
extractExpressionExpr(extensionReceiver, enclosingCallable, id, 0, enclosingStmt)
idxOffset = 1
} else {
idxOffset = 0
}
valueArguments.forEachIndexed { i, arg ->
if(arg != null) {
extractExpressionExpr(arg, enclosingCallable, id, i + idxOffset, enclosingStmt)
}
}
}
fun findFunction(cls: IrClass, name: String): IrFunction? = cls.declarations.find { it is IrFunction && it.name.asString() == name } as IrFunction?
val jvmIntrinsicsClass by lazy {
val result = pluginContext.referenceClass(FqName("kotlin.jvm.internal.Intrinsics"))?.owner
result?.let { extractExternalClassLater(it) }
result
}
fun findJdkIntrinsicOrWarn(name: String, warnAgainstElement: IrElement): IrFunction? {
val result = jvmIntrinsicsClass?.let { findFunction(it, name) }
if(result == null) {
logger.warnElement(Severity.ErrorSevere, "Couldn't find JVM intrinsic function $name", warnAgainstElement)
}
return result
}
fun extractCall(c: IrCall, callable: Label<out DbCallable>, parent: Label<out DbExprparent>, idx: Int, enclosingStmt: Label<out DbStmt>) {
with("call", c) {
fun isFunction(pkgName: String, className: String, fName: String, hasQuestionMark: Boolean = false): Boolean {
fun isFunction(pkgName: String, className: String, fName: String, hasQuestionMark: Boolean? = false): Boolean {
val verbose = false
fun verboseln(s: String) { if(verbose) println(s) }
verboseln("Attempting match for $pkgName $className $fName")
@@ -1012,10 +1115,14 @@ open class KotlinFileExtractor(
}
val extensionReceiverParameter = target.extensionReceiverParameter
val targetClass = if (extensionReceiverParameter == null) {
if (hasQuestionMark == true) {
verboseln("Nullablility of type didn't match (target is not an extension method)")
return false
}
target.parent
} else {
val st = extensionReceiverParameter.type as? IrSimpleType
if (st?.hasQuestionMark != hasQuestionMark) {
if (hasQuestionMark != null && st?.hasQuestionMark != hasQuestionMark) {
verboseln("Nullablility of type didn't match")
return false
}
@@ -1050,86 +1157,15 @@ open class KotlinFileExtractor(
isFunction("kotlin", "Float", fName) ||
isFunction("kotlin", "Double", fName)
}
fun extractMethodAccess(syntacticCallTarget: IrFunction, extractMethodTypeArguments: Boolean = true, extractClassTypeArguments: Boolean = false) {
val callTarget = syntacticCallTarget.target
val id = tw.getFreshIdLabel<DbMethodaccess>()
val type = useType(c.type)
val locId = tw.getLocation(c)
tw.writeExprs_methodaccess(id, type.javaResult.id, parent, idx)
tw.writeExprsKotlinType(id, type.kotlinResult.id)
tw.writeHasLocation(id, locId)
tw.writeCallableEnclosingExpr(id, callable)
tw.writeStatementEnclosingExpr(id, enclosingStmt)
if (extractMethodTypeArguments) {
// type arguments at index -2, -3, ...
extractTypeArguments(c, id, callable, enclosingStmt, -2, true)
}
if (callTarget.isLocalFunction()) {
val ids = getLocallyVisibleFunctionLabels(callTarget)
val methodId = ids.function
tw.writeCallableBinding(id, methodId)
val idNewexpr = tw.getFreshIdLabel<DbNewexpr>()
tw.writeExprs_newexpr(idNewexpr, ids.type.javaResult.id, id, -1)
tw.writeExprsKotlinType(idNewexpr, ids.type.kotlinResult.id)
tw.writeHasLocation(idNewexpr, locId)
tw.writeCallableEnclosingExpr(idNewexpr, callable)
tw.writeStatementEnclosingExpr(idNewexpr, enclosingStmt)
tw.writeCallableBinding(idNewexpr, ids.constructor)
@Suppress("UNCHECKED_CAST")
tw.writeIsAnonymClass(ids.type.javaResult.id as Label<DbClass>, idNewexpr)
extractTypeAccess(pluginContext.irBuiltIns.anyType, callable, idNewexpr, -3, c, enclosingStmt)
} else {
val dr = c.dispatchReceiver
// Returns true if type is C<T1, T2, ...> where C is declared `class C<T1, T2, ...> { ... }`
fun isUnspecialised(type: IrSimpleType) =
type.classifier.owner is IrClass &&
(type.classifier.owner as IrClass).typeParameters.zip(type.arguments).all { paramAndArg ->
(paramAndArg.second as? IrTypeProjection)?.let {
// Type arg refers to the class' own type parameter?
it.variance == Variance.INVARIANT &&
it.type.classifierOrNull?.owner === paramAndArg.first
} ?: false
}
val drType = dr?.type
val methodId =
if (drType != null && extractClassTypeArguments && drType is IrSimpleType && !isUnspecialised(drType))
useFunction<DbCallable>(callTarget, getDeclaringTypeArguments(callTarget, drType))
else
useFunction<DbCallable>(callTarget)
tw.writeCallableBinding(id, methodId)
if (dr != null) {
extractExpressionExpr(dr, callable, id, -1, enclosingStmt)
} else if(callTarget.isStaticMethodOfClass) {
extractTypeAccess(callTarget.parentAsClass.toRawType(), callable, id, -1, c, enclosingStmt)
}
}
val er = c.extensionReceiver
val idxOffset: Int
if (er != null) {
extractExpressionExpr(er, callable, id, 0, enclosingStmt)
idxOffset = 1
} else {
idxOffset = 0
}
for(i in 0 until c.valueArgumentsCount) {
val arg = c.getValueArgument(i)
if(arg != null) {
extractExpressionExpr(arg, callable, id, i + idxOffset, enclosingStmt)
}
}
val typeArgs =
if (extractMethodTypeArguments)
(0 until c.typeArgumentsCount).map { c.getTypeArgument(it)!! }
else
listOf()
extractRawMethodAccess(syntacticCallTarget, c, callable, parent, idx, enclosingStmt, (0 until c.valueArgumentsCount).map { c.getValueArgument(it) }, c.dispatchReceiver, c.extensionReceiver, typeArgs, extractClassTypeArguments)
}
fun extractSpecialEnumFunction(fnName: String){
@@ -1153,11 +1189,11 @@ open class KotlinFileExtractor(
tw.writeCallableEnclosingExpr(id, callable)
tw.writeStatementEnclosingExpr(id, enclosingStmt)
val dr = c.dispatchReceiver
if(dr == null) {
logger.warnElement(Severity.ErrorSevere, "Dispatch receiver not found", c)
val receiver = c.dispatchReceiver ?: c.extensionReceiver
if(receiver == null) {
logger.warnElement(Severity.ErrorSevere, "Receiver not found", c)
} else {
extractExpressionExpr(dr, callable, id, 0, enclosingStmt)
extractExpressionExpr(receiver, callable, id, 0, enclosingStmt)
}
if(c.valueArgumentsCount < 1) {
logger.warnElement(Severity.ErrorSevere, "No RHS found", c)
@@ -1178,7 +1214,7 @@ open class KotlinFileExtractor(
when {
c.origin == IrStatementOrigin.PLUS &&
(isNumericFunction("plus")
|| isFunction("kotlin", "String", "plus")) -> {
|| isFunction("kotlin", "String", "plus", null)) -> {
val id = tw.getFreshIdLabel<DbAddexpr>()
val type = useType(c.type)
tw.writeExprs_addexpr(id, type.javaResult.id, parent, idx)
@@ -1186,13 +1222,9 @@ open class KotlinFileExtractor(
binopDisp(id)
}
isFunction("kotlin", "String", "plus", true) -> {
// TODO: this is not correct. `a + b` becomes `(a?:"\"null\"") + (b?:"\"null\"")`.
val func = pluginContext.irBuiltIns.stringType.classOrNull?.owner?.declarations?.find { it is IrFunction && it.name.asString() == "plus" }
if (func == null) {
logger.warnElement(Severity.ErrorSevere, "Couldn't find plus function on string type", c)
return
findJdkIntrinsicOrWarn("stringPlus", c)?.let { stringPlusFn ->
extractRawMethodAccess(stringPlusFn, c, callable, parent, idx, enclosingStmt, listOf(c.extensionReceiver, c.getValueArgument(0)), null, null)
}
extractMethodAccess(func as IrFunction)
}
c.origin == IrStatementOrigin.MINUS && isNumericFunction("minus") -> {
val id = tw.getFreshIdLabel<DbSubexpr>()
@@ -1475,21 +1507,32 @@ open class KotlinFileExtractor(
}
}
private fun <T : IrSymbol> extractTypeArguments(
c: IrMemberAccessExpression<T>,
id: Label<out DbExprparent>,
callable: Label<out DbCallable>,
private fun extractTypeArguments(
typeArgs: List<IrType>,
elementForLocation: IrElement,
parentExpr: Label<out DbExprparent>,
enclosingCallable: Label<out DbCallable>,
enclosingStmt: Label<out DbStmt>,
startIndex: Int = 0,
reverse: Boolean = false
) {
for (argIdx in 0 until c.typeArgumentsCount) {
val arg = c.getTypeArgument(argIdx)!!
typeArgs.forEachIndexed { argIdx, arg ->
val mul = if (reverse) -1 else 1
extractTypeAccess(arg, callable, id, argIdx * mul + startIndex, c, enclosingStmt, TypeContext.GENERIC_ARGUMENT)
extractTypeAccess(arg, enclosingCallable, parentExpr, argIdx * mul + startIndex, elementForLocation, enclosingStmt, TypeContext.GENERIC_ARGUMENT)
}
}
private fun <T : IrSymbol> extractTypeArguments(
c: IrMemberAccessExpression<T>,
parentExpr: Label<out DbExprparent>,
enclosingCallable: Label<out DbCallable>,
enclosingStmt: Label<out DbStmt>,
startIndex: Int = 0,
reverse: Boolean = false
) {
extractTypeArguments((0 until c.typeArgumentsCount).map { c.getTypeArgument(it)!! }, c, parentExpr, enclosingCallable, enclosingStmt, startIndex, reverse)
}
private fun extractConstructorCall(
e: IrFunctionAccessExpression,
parent: Label<out DbExprparent>,