Handle large arity lambdas, and add missing type access for some constructor calls (needed for anonymous classes)

This commit is contained in:
Tamas Vajk
2021-12-06 14:09:29 +01:00
committed by Ian Lynagh
parent f4c87cb79d
commit 3cd2583ec8
4 changed files with 410 additions and 122 deletions

View File

@@ -178,7 +178,7 @@ open class KotlinFileExtractor(
val parentId =
if (parent.isAnonymousObject) {
@Suppress("UNCHECKED_CAST")
useAnonymousClass(c).javaResult.id as Label<out DbClass>
useAnonymousClass(parent).javaResult.id as Label<out DbClass>
} else {
useClassInstance(parent, listOf()).typeResult.id
}
@@ -260,12 +260,11 @@ open class KotlinFileExtractor(
}
private fun extractValueParameter(vp: IrValueParameter, parent: Label<out DbCallable>, idx: Int): TypeResults {
return extractValueParameter(useValueParameter(vp), vp.type, vp.name.asString(), vp, parent, idx)
return extractValueParameter(useValueParameter(vp), vp.type, vp.name.asString(), tw.getLocation(vp), parent, idx)
}
private fun extractValueParameter(id: Label<out DbParam>, t: IrType, name: String, loc: IrElement, parent: Label<out DbCallable>, idx: Int): TypeResults {
private fun extractValueParameter(id: Label<out DbParam>, t: IrType, name: String, locId: Label<DbLocation>, parent: Label<out DbCallable>, idx: Int): TypeResults {
val type = useType(t)
val locId = tw.getLocation(loc)
tw.writeParams(id, type.javaResult.id, type.kotlinResult.id, idx, parent, id)
tw.writeHasLocation(id, locId)
tw.writeParamName(id, name)
@@ -721,6 +720,15 @@ open class KotlinFileExtractor(
tw.writeStatementEnclosingExpr(idNewexpr, enclosingStmt)
tw.writeCallableBinding(idNewexpr, ids.constructor)
@Suppress("UNCHECKED_CAST")
tw.writeIsAnonymClass(ids.type.javaResult.id as Label<DbClass>, idNewexpr)
val typeAccessId = tw.getFreshIdLabel<DbUnannotatedtypeaccess>()
val anyType = useType(pluginContext.irBuiltIns.anyType)
tw.writeExprs_unannotatedtypeaccess(typeAccessId, anyType.javaResult.id, anyType.kotlinResult.id, idNewexpr, -3)
tw.writeCallableEnclosingExpr(typeAccessId, callable)
tw.writeStatementEnclosingExpr(typeAccessId, enclosingStmt)
} else {
val methodId = useFunction<DbMethod>(callTarget)
tw.writeCallableBinding(id, methodId)
@@ -1644,61 +1652,57 @@ open class KotlinFileExtractor(
functionN(pluginContext)(parameters.size).typeWith(types)
}
val lambdaType = pluginContext.referenceClass(FqName("kotlin.jvm.internal.Lambda"))!!.typeWith()
/*
* Extract generated class:
* ```
* class C : kotlin.jvm.internal.Lambda, kotlin.FunctionI<T0,T1, ... TI, R> {
* constructor() { super(I); }
* fun invoke(a0:T0, a1:T1, ... aI: TI): R { ... }
* }
* ```
* or in case of big arity lambdas
* ```
* class C : kotlin.jvm.internal.Lambda, kotlin.FunctionN<R> {
* constructor() { super(I); }
* fun invoke(a0:T0, a1:T1, ... aI: TI): R { ... }
* fun invoke(vararg args: Any?): R {
* return invoke(args[0] as T0, args[1] as T1, ..., args[I] as TI)
* }
* }
* ```
* */
extractGeneratedClass(
e.function, // We're adding this function as a member, and changing its name to `invoke` to implement `kotlin.FunctionX<,,,>.invoke(,,)`
listOf(
pluginContext.referenceClass(FqName("kotlin.jvm.internal.Lambda"))!!.typeWith(),
fnInterface),
listOf(lambdaType, fnInterface),
listOf(e.function.valueParameters.size.toIrConst(pluginContext.irBuiltIns.intType, e.startOffset, e.endOffset)))
val objectType = useType(pluginContext.irBuiltIns.anyNType).javaResult.id
if (types.size > BuiltInFunctionArity.BIG_ARITY) {
implementFunctionNInvoke(e.function, ids, locId, parameters)
// Only add bridge method if its signature is different from the lambda function
if (!types.all { useType(it).javaResult.id == objectType } ||
types.size > BuiltInFunctionArity.BIG_ARITY) {
val methodId = tw.getFreshIdLabel<DbMethod>()
val paramTypes =
if (types.size > BuiltInFunctionArity.BIG_ARITY) {
// signature is `Object invoke(Object[] p)`
listOf(extractValueParameter(tw.getFreshIdLabel(), pluginContext.irBuiltIns.arrayClass.typeWith(pluginContext.irBuiltIns.anyNType), "p", e, methodId, 0))
} else {
// signature is `Object invoke(Object p0, Object p1, ..., Object pN)`
parameters.mapIndexed { i, _ ->
extractValueParameter(tw.getFreshIdLabel(), pluginContext.irBuiltIns.anyNType, "p$i", e, methodId, i)
}
}
val paramsSignature = paramTypes.joinToString(separator = ",", prefix = "(", postfix = ")") { it.javaResult.signature!! }
val returnType = useType(pluginContext.irBuiltIns.anyNType, TypeContext.RETURN)
val shortName = OperatorNameConventions.INVOKE.asString()
@Suppress("UNCHECKED_CAST")
tw.writeMethods(methodId, shortName, "$shortName$paramsSignature", returnType.javaResult.id, returnType.kotlinResult.id, ids.type.javaResult.id as Label<out DbReftype>, methodId)
tw.writeHasLocation(methodId, locId)
// TODO:
// - Add body of bridge method, which calls `e.function`:
// ```
// public int invoke(int i, Object j, String k) { return 5; }
// public Object invoke(Object p0, Object p1, Object p2) {
// return invoke((int)p0, (Object)p1, (String)p2);
// or
// invoke((int)p0, (Object)p1, (String)p2);
// return kotlin.Unit.INSTANCE
// }
// ```
// todo: which method should be returned in `LambdaExpr.asMethod()`?
}
val exprParent = parent.expr(e, callable)
val idLambdaExpr = tw.getFreshIdLabel<DbLambdaexpr>()
tw.writeExprs_lambdaexpr(idLambdaExpr, ids.type.javaResult.id, ids.type.kotlinResult.id, exprParent.parent, exprParent.idx)
tw.writeHasLocation(idLambdaExpr, locId)
tw.writeCallableEnclosingExpr(idLambdaExpr, callable)
tw.writeStatementEnclosingExpr(idLambdaExpr, exprParent.enclosingStmt)
tw.writeCallableBinding(idLambdaExpr, ids.constructor)
val idNewexpr = tw.getFreshIdLabel<DbNewexpr>()
tw.writeExprs_newexpr(idNewexpr, ids.type.javaResult.id, ids.type.kotlinResult.id, exprParent.parent, exprParent.idx)
tw.writeHasLocation(idNewexpr, locId)
tw.writeCallableEnclosingExpr(idNewexpr, callable)
tw.writeStatementEnclosingExpr(idNewexpr, exprParent.enclosingStmt)
tw.writeCallableBinding(idNewexpr, ids.constructor)
val typeAccessId = tw.getFreshIdLabel<DbUnannotatedtypeaccess>()
// todo: in Java, we're accessing the base functional interface type.
val typeAccessType = useType(lambdaType)
tw.writeExprs_unannotatedtypeaccess(typeAccessId, typeAccessType.javaResult.id, typeAccessType.kotlinResult.id, idLambdaExpr, -3)
tw.writeCallableEnclosingExpr(typeAccessId, callable)
tw.writeStatementEnclosingExpr(typeAccessId, exprParent.enclosingStmt)
// todo: fix hard coded block body of lambda
tw.writeLambdaKind(idLambdaExpr, 1)
@Suppress("UNCHECKED_CAST")
tw.writeIsAnonymClass(ids.type.javaResult.id as Label<DbClass>, idLambdaExpr)
}
else -> {
logger.warnElement(Severity.ErrorSevere, "Unrecognised IrExpression: " + e.javaClass, e)
@@ -1706,6 +1710,108 @@ open class KotlinFileExtractor(
}
}
/*
* This function generates an implementation for `fun kotlin.FunctionN<R>.invoke(vararg args: Any?): R`
*
* The following body is added:
* ```
* fun invoke(vararg args: Any?): R {
* return invoke(args[0] as T0, args[1] as T1, ..., args[I] as TI)
* }
* ```
* */
private fun implementFunctionNInvoke(
lambda: IrFunction,
ids: LocalFunctionLabels,
locId: Label<DbLocation>,
parameters: List<IrValueParameter>
) {
val methodId = tw.getFreshIdLabel<DbMethod>()
val argsParamId = tw.getFreshIdLabel<DbParam>()
val argsParamType = pluginContext.irBuiltIns.arrayClass.typeWith(pluginContext.irBuiltIns.anyNType)
val paramType = extractValueParameter(argsParamId, argsParamType, "args", locId, methodId, 0)
val paramsSignature = "(${paramType.javaResult.signature!!})"
val returnType = useType(lambda.returnType, TypeContext.RETURN)
val shortName = OperatorNameConventions.INVOKE.asString()
@Suppress("UNCHECKED_CAST")
tw.writeMethods(methodId, shortName, "$shortName$paramsSignature", returnType.javaResult.id, returnType.kotlinResult.id, ids.type.javaResult.id as Label<out DbReftype>, methodId)
tw.writeHasLocation(methodId, locId)
// Block
val blockId = tw.getFreshIdLabel<DbBlock>()
tw.writeStmts_block(blockId, methodId, 0, methodId)
tw.writeHasLocation(blockId, locId)
// Return
val retId = tw.getFreshIdLabel<DbReturnstmt>()
tw.writeStmts_returnstmt(retId, blockId, 0, methodId)
tw.writeHasLocation(retId, locId)
fun extractCommonExpr(id: Label<out DbExpr>) {
tw.writeHasLocation(id, locId)
tw.writeCallableEnclosingExpr(id, methodId)
tw.writeStatementEnclosingExpr(id, retId)
}
// Call to original `invoke`:
val callId = tw.getFreshIdLabel<DbMethodaccess>()
val callType = useType(lambda.returnType)
tw.writeExprs_methodaccess(callId, callType.javaResult.id, callType.kotlinResult.id, retId, 0)
extractCommonExpr(callId)
val calledMethodId = useFunction<DbMethod>(lambda)
tw.writeCallableBinding(callId, calledMethodId)
// this access
val thisId = tw.getFreshIdLabel<DbThisaccess>()
tw.writeExprs_thisaccess(thisId, ids.type.javaResult.id, ids.type.kotlinResult.id, callId, -1)
extractCommonExpr(thisId)
// parameters
val intType = useType(pluginContext.irBuiltIns.intType)
val argsType = useType(argsParamType)
val anyNType = useType(pluginContext.irBuiltIns.anyNType)
val func =
pluginContext.irBuiltIns.arrayClass.owner.declarations.find { it is IrFunction && it.name.asString() == "get" }
@Suppress("UNCHECKED_CAST")
val arrayGetMethodId = useFunction<DbMethod>(func as IrFunction)
for ((pIdx, p) in parameters.withIndex()) {
// `args[i] as Ti` is generated below for each parameter
// cast
val castId = tw.getFreshIdLabel<DbCastexpr>()
val type = useType(p.type)
tw.writeExprs_castexpr(castId, type.javaResult.id, type.kotlinResult.id, callId, pIdx)
extractCommonExpr(castId)
// type access
extractTypeAccess(p.type, locId, methodId, castId, 0, retId)
// element access: `args.get(i)`
val getCallId = tw.getFreshIdLabel<DbMethodaccess>()
tw.writeExprs_methodaccess(getCallId, anyNType.javaResult.id, anyNType.kotlinResult.id, castId, 1)
extractCommonExpr(getCallId)
tw.writeCallableBinding(getCallId, arrayGetMethodId)
// parameter access:
val argsAccessId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(argsAccessId, argsType.javaResult.id, argsType.kotlinResult.id, getCallId, -1)
extractCommonExpr(argsAccessId)
tw.writeVariableBinding(argsAccessId, argsParamId)
// index access:
val indexId = tw.getFreshIdLabel<DbIntegerliteral>()
tw.writeExprs_integerliteral(indexId, intType.javaResult.id, intType.kotlinResult.id, getCallId, pIdx)
extractCommonExpr(indexId)
tw.writeNamestrings(pIdx.toString(), pIdx.toString(), indexId)
}
}
fun extractVarargElement(e: IrVarargElement, callable: Label<out DbCallable>, parent: Label<out DbExprparent>, idx: Int, enclosingStmt: Label<out DbStmt>) {
when(e) {
is IrExpression -> {
@@ -1717,19 +1823,22 @@ open class KotlinFileExtractor(
}
}
fun extractTypeAccess(t: IrType, callable: Label<out DbCallable>, parent: Label<out DbExprparent>, idx: Int, elementForLocation: IrElement, enclosingStmt: Label<out DbStmt>) {
private fun extractTypeAccess(t: IrType, location: Label<DbLocation>, callable: Label<out DbCallable>, parent: Label<out DbExprparent>, idx: Int, enclosingStmt: Label<out DbStmt>) {
// TODO: elementForLocation allows us to give some sort of
// location, but a proper location for the type access will
// require upstream changes
val type = useType(t)
val id = tw.getFreshIdLabel<DbUnannotatedtypeaccess>()
tw.writeExprs_unannotatedtypeaccess(id, type.javaResult.id, type.kotlinResult.id, parent, idx)
val locId = tw.getLocation(elementForLocation)
tw.writeHasLocation(id, locId)
tw.writeHasLocation(id, location)
tw.writeCallableEnclosingExpr(id, callable)
tw.writeStatementEnclosingExpr(id, enclosingStmt)
}
private fun extractTypeAccess(t: IrType, callable: Label<out DbCallable>, parent: Label<out DbExprparent>, idx: Int, elementForLocation: IrElement, enclosingStmt: Label<out DbStmt>) {
extractTypeAccess(t, tw.getLocation(elementForLocation), callable, parent, idx, enclosingStmt)
}
fun extractTypeOperatorCall(e: IrTypeOperatorCall, callable: Label<out DbCallable>, parent: Label<out DbExprparent>, idx: Int, enclosingStmt: Label<out DbStmt>) {
when(e.operator) {
IrTypeOperator.CAST -> {
@@ -1867,7 +1976,7 @@ open class KotlinFileExtractor(
// Super call
val superCallId = tw.getFreshIdLabel<DbSuperconstructorinvocationstmt>()
tw.writeStmts_superconstructorinvocationstmt(superCallId, constructorBlockId, 0, ids.function)
tw.writeStmts_superconstructorinvocationstmt(superCallId, constructorBlockId, 0, ids.constructor)
for (i in 0 until superConstructorArgs.size) {
val arg = superConstructorArgs[i]
extractExpressionExpr(arg, ids.constructor, superCallId, i, superCallId)
@@ -1885,6 +1994,34 @@ open class KotlinFileExtractor(
addModifiers(id, "public", "static", "final")
extractClassSupertypes(superTypes, listOf(), id)
var parent: IrDeclarationParent? = localFunction.parent
while (parent != null) {
// todo: merge this with the implementation in `extractClassSource`
if (parent is IrClass) {
val parentId =
if (parent.isAnonymousObject) {
@Suppress("UNCHECKED_CAST")
useAnonymousClass(parent).javaResult.id as Label<out DbClass>
} else {
useClassInstance(parent, listOf()).typeResult.id
}
tw.writeEnclInReftype(id, parentId)
break
}
if (parent is IrFile) {
if (this is KotlinSourceFileExtractor && this.file == localFunction.fileOrNull) {
tw.writeEnclInReftype(id, this.fileClass)
} else {
logger.warn(Severity.ErrorSevere, "Unexpected file parent found")
}
break
}
parent = (parent as? IrDeclaration)?.parent
}
return id
}
}