Extract function references

This commit is contained in:
Tamas Vajk
2021-12-14 15:02:46 +01:00
committed by Ian Lynagh
parent 6950f868fb
commit 10ae157682
5 changed files with 1283 additions and 208 deletions

View File

@@ -16,11 +16,13 @@ import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.symbols.IrConstructorSymbol
import org.jetbrains.kotlin.ir.symbols.IrSymbol
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.util.OperatorNameConventions
import org.jetbrains.kotlin.types.Variance
import java.util.*
open class KotlinFileExtractor(
override val logger: FileLogger,
@@ -421,7 +423,7 @@ open class KotlinFileExtractor(
}
fun extractFunction(f: IrFunction, parentId: Label<out DbReftype>, extractBody: Boolean = true, typeSubstitution: TypeSubstitution? = null): Label<out DbCallable> {
currentFunction = f
declarationStack.push(f)
f.typeParameters.map { extractTypeParameter(it) }
@@ -429,7 +431,7 @@ open class KotlinFileExtractor(
val id =
if (f.isLocalFunction())
getLocalFunctionLabels(f).function
getLocallyVisibleFunctionLabels(f).function
else
// TODO: figure out whether to standardise on naming top-level functions for the file-class
// or (as temporarily done here) for their containing package.
@@ -490,26 +492,28 @@ open class KotlinFileExtractor(
extractVisibility(f, id, f.visibility)
currentFunction = null
declarationStack.pop()
return id
}
fun extractField(f: IrField, parentId: Label<out DbReftype>): Label<out DbField> {
val id = useField(f)
val locId = tw.getLocation(f)
val type = useType(f.type)
tw.writeFields(id, f.name.asString(), type.javaResult.id, type.kotlinResult.id, parentId, id)
return extractField(useField(f), f.name.asString(), f.type, parentId, tw.getLocation(f), f.visibility, f, isExternalDeclaration(f))
}
private fun extractField(id: Label<out DbField>, name: String, type: IrType, parentId: Label<out DbReftype>, locId: Label<DbLocation>, visibility: DescriptorVisibility, errorElement: IrElement, isExternalDeclaration: Boolean): Label<out DbField> {
val t = useType(type)
tw.writeFields(id, name, t.javaResult.id, t.kotlinResult.id, parentId, id)
tw.writeHasLocation(id, locId)
extractVisibility(f, id, f.visibility)
extractVisibility(errorElement, id, visibility)
if (!isExternalDeclaration(f)) {
if (!isExternalDeclaration) {
val fieldDeclarationId = tw.getFreshIdLabel<DbFielddecl>()
tw.writeFielddecls(fieldDeclarationId, parentId)
tw.writeFieldDeclaredIn(id, fieldDeclarationId, 0)
tw.writeHasLocation(fieldDeclarationId, locId)
extractTypeAccess(type, locId, fieldDeclarationId, 0)
extractTypeAccess(t, locId, fieldDeclarationId, 0)
}
return id
@@ -660,7 +664,7 @@ open class KotlinFileExtractor(
}
is IrFunction -> {
if (s.isLocalFunction()) {
val classId = extractGeneratedClass(s, listOf(pluginContext.irBuiltIns.anyType))
val classId = extractGeneratedClass(s, listOf(pluginContext.irBuiltIns.anyType))
extracLocalTypeDeclStmt(classId, s, callable, parent, idx)
} else {
logger.warnElement(Severity.ErrorSevere, "Expected to find local function", s)
@@ -829,7 +833,7 @@ open class KotlinFileExtractor(
}
if (callTarget.isLocalFunction()) {
val ids = getLocalFunctionLabels(callTarget)
val ids = getLocallyVisibleFunctionLabels(callTarget)
val methodId = ids.function
tw.writeCallableBinding(id, methodId)
@@ -1192,8 +1196,8 @@ open class KotlinFileExtractor(
}
}
private fun extractTypeArguments(
c: IrFunctionAccessExpression,
private fun <T : IrSymbol> extractTypeArguments(
c: IrMemberAccessExpression<T>,
id: Label<out DbExprparent>,
callable: Label<out DbCallable>,
enclosingStmt: Label<out DbStmt>,
@@ -1269,7 +1273,9 @@ open class KotlinFileExtractor(
private val loopIdMap: MutableMap<IrLoop, Label<out DbKtloopstmt>> = mutableMapOf()
private var currentFunction: IrFunction? = null
// todo: add all declaration types, not only IrFunctions.
// todo: calculating the enclosing ref type could be done through this, instead of walking up the declaration parent chain
private val declarationStack: Stack<IrDeclaration> = Stack()
abstract inner class StmtExprParent {
abstract fun stmt(e: IrExpression, callable: Label<out DbCallable>): StmtParent
@@ -1317,7 +1323,7 @@ open class KotlinFileExtractor(
is IrDelegatingConstructorCall -> {
val stmtParent = parent.stmt(e, callable)
val irCallable = currentFunction
val irCallable = declarationStack.peek()
if (irCallable == null) {
logger.warnElement(Severity.ErrorSevere, "Current function is not set", e)
return
@@ -1443,7 +1449,7 @@ open class KotlinFileExtractor(
}
is IrInstanceInitializerCall -> {
val exprParent = parent.expr(e, callable)
val irCallable = currentFunction
val irCallable = declarationStack.peek()
if (irCallable == null) {
logger.warnElement(Severity.ErrorSevere, "Current function is not set", e)
return
@@ -1749,9 +1755,31 @@ open class KotlinFileExtractor(
tw.writeVariableBinding(id, instance.id)
}
}
is IrFunctionReference -> {
extractFunctionReference(e, parent, callable)
}
is IrFunctionExpression -> {
/*
* Extract generated class:
* ```
* class C : Any, kotlin.FunctionI<T0,T1, ... TI, R> {
* constructor() { super(); }
* fun invoke(a0:T0, a1:T1, ... aI: TI): R { ... }
* }
* ```
* or in case of big arity lambdas
* ```
* class C : Any, kotlin.FunctionN<R> {
* constructor() { super(); }
* fun invoke(a0:T0, a1:T1, ... aI: TI): R { ... }
* fun invoke(vararg args: Any?): R {
* return invoke(args[0] as T0, args[1] as T1, ..., args[I] as TI)
* }
* }
* ```
**/
val ids = getLocalFunctionLabels(e.function)
val ids = getLocallyVisibleFunctionLabels(e.function)
val locId = tw.getLocation(e)
val ext = e.function.extensionReceiverParameter
@@ -1766,34 +1794,10 @@ open class KotlinFileExtractor(
var types = parameters.map { it.type }
types += e.function.returnType
val fnInterface = if (types.size > BuiltInFunctionArity.BIG_ARITY) {
pluginContext.referenceClass(FqName("kotlin.jvm.functions.FunctionN"))!!.typeWith(e.function.returnType)
} else {
functionN(pluginContext)(parameters.size).typeWith(types)
}
/*
* Extract generated class:
* ```
* class C : Any, kotlin.FunctionI<T0,T1, ... TI, R> {
* constructor() { super(); }
* fun invoke(a0:T0, a1:T1, ... aI: TI): R { ... }
* }
* ```
* or in case of big arity lambdas
* ```
* class C : Any, kotlin.FunctionN<R> {
* constructor() { super(); }
* fun invoke(a0:T0, a1:T1, ... aI: TI): R { ... }
* fun invoke(vararg args: Any?): R {
* return invoke(args[0] as T0, args[1] as T1, ..., args[I] as TI)
* }
* }
* ```
* */
extractGeneratedClass(
val fnInterfaceType = getFunctionalInterfaceType(types)
val id = extractGeneratedClass(
e.function, // We're adding this function as a member, and changing its name to `invoke` to implement `kotlin.FunctionX<,,,>.invoke(,,)`
listOf(pluginContext.irBuiltIns.anyType, fnInterface))
listOf(pluginContext.irBuiltIns.anyType, fnInterfaceType))
if (types.size > BuiltInFunctionArity.BIG_ARITY) {
implementFunctionNInvoke(e.function, ids, locId, parameters)
@@ -1807,13 +1811,12 @@ open class KotlinFileExtractor(
tw.writeStatementEnclosingExpr(idLambdaExpr, exprParent.enclosingStmt)
tw.writeCallableBinding(idLambdaExpr, ids.constructor)
extractTypeAccess(fnInterface, callable, idLambdaExpr, -3, e, exprParent.enclosingStmt)
extractTypeAccess(fnInterfaceType, callable, idLambdaExpr, -3, e, exprParent.enclosingStmt)
// todo: fix hard coded block body of lambda
tw.writeLambdaKind(idLambdaExpr, 1)
@Suppress("UNCHECKED_CAST")
tw.writeIsAnonymClass(ids.type.javaResult.id as Label<DbClass>, idLambdaExpr)
tw.writeIsAnonymClass(id, idLambdaExpr)
}
is IrClassReference -> {
val exprParent = parent.expr(e, callable)
@@ -1833,6 +1836,308 @@ open class KotlinFileExtractor(
}
}
private fun extractFunctionReference(
functionReferenceExpr: IrFunctionReference,
parent: StmtExprParent,
callable: Label<out DbCallable>
) {
val target = functionReferenceExpr.reflectionTarget
if (target == null) {
logger.warnElement(Severity.ErrorSevere, "Expected to find reflection target for function reference", functionReferenceExpr)
return
}
/*
* Extract generated class:
* ```
* class C : Any, kotlin.FunctionI<T0,T1, ... TI, R> {
* private dispatchReceiver: TD
* private extensionReceiver: TE
* constructor(dispatchReceiver: TD, extensionReceiver: TE) {
* super()
* this.dispatchReceiver = dispatchReceiver
* this.extensionReceiver = extensionReceiver
* }
* fun invoke(a0:T0, a1:T1, ... aI: TI): R { return this.dispatchReceiver.FN(a0,a1,...,aI) } OR
* fun invoke( a1:T1, ... aI: TI): R { return this.dispatchReceiver.FN(this.dispatchReceiver,a1,...,aI) } OR
* fun invoke(a0:T0, a1:T1, ... aI: TI): R { return Ctor(a0,a1,...,aI) }
* }
* ```
* or in case of big arity lambdas ????
* ```
* class C : Any, kotlin.FunctionN<R> {
* private receiver: TD
* constructor(receiver: TD) { super(); this.receiver = receiver; }
* fun invoke(vararg args: Any?): R {
* return this.receiver.FN(args[0] as T0, args[1] as T1, ..., args[I] as TI)
* }
* }
* ```
**/
val typeArguments = if (target is IrConstructorSymbol) {
(target.owner.returnType as? IrSimpleType)?.arguments
} else {
(functionReferenceExpr.dispatchReceiver?.type as? IrSimpleType)?.arguments
}
val targetCallableId = useFunction<DbCallable>(target.owner, typeArguments)
val locId = tw.getLocation(functionReferenceExpr)
val extensionParameter = target.owner.extensionReceiverParameter
val parameters =
if (extensionParameter != null &&
functionReferenceExpr.extensionReceiver == null) {
// No extension receiver argument is set, so we're creating a parameter for it in `invoke`
val l = mutableListOf(extensionParameter)
l.addAll(target.owner.valueParameters)
l
} else {
// Either not an extension method or one with extension receiver specified. In the latter case a constructor
// argument is created for the extension receiver expression and then passed as the 0th argument of the call
// to `target`.
target.owner.valueParameters
}
val parameterTypes = parameters.map { it.type }
val functionNTypeArguments = parameterTypes + target.owner.returnType
val fnInterfaceType = getFunctionalInterfaceType(functionNTypeArguments)
val javaResult = TypeResult(tw.getFreshIdLabel<DbClass>(), "", "")
val kotlinResult = TypeResult(tw.getFreshIdLabel<DbKt_notnull_type>(), "", "")
tw.writeKt_notnull_types(kotlinResult.id, javaResult.id)
val ids = LocallyVisibleFunctionLabels(
TypeResults(javaResult, kotlinResult),
tw.getFreshIdLabel(),
tw.getFreshIdLabel(),
tw.getFreshIdLabel()
)
val currentDeclaration = declarationStack.peek()
val id = extractGeneratedClass(ids, listOf(pluginContext.irBuiltIns.anyType, fnInterfaceType), locId, currentDeclaration)
fun writeExpressionMetadataToTrapFile(id: Label<out DbExpr>, callable: Label<out DbCallable>, stmt: Label<out DbStmt>) {
tw.writeHasLocation(id, locId)
tw.writeCallableEnclosingExpr(id, callable)
tw.writeStatementEnclosingExpr(id, stmt)
}
/**
* Extract a parameter to field assignment, such as `this.field = paramName` below:
* ```
* constructor(paramName: type) {
* this.field = paramName
* }
* ```
*/
fun extractParameterToFieldAssignmentInConstructor(
paramName: String,
type: IrType,
fieldId: Label<DbField>,
paramIdx: Int,
stmtIdx: Int
) {
val paramId = tw.getFreshIdLabel<DbParam>()
val paramType = extractValueParameter(paramId, type, paramName, locId, ids.constructor, paramIdx, null)
val assignmentStmtId = tw.getFreshIdLabel<DbExprstmt>()
tw.writeStmts_exprstmt(assignmentStmtId, ids.constructorBlock, stmtIdx, ids.constructor)
tw.writeHasLocation(assignmentStmtId, locId)
val assignmentId = tw.getFreshIdLabel<DbAssignexpr>()
tw.writeExprs_assignexpr(assignmentId, paramType.javaResult.id, paramType.kotlinResult.id, assignmentStmtId, 0)
writeExpressionMetadataToTrapFile(assignmentId, ids.constructor, assignmentStmtId)
val lhsId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(lhsId, paramType.javaResult.id, paramType.kotlinResult.id, assignmentId, 0)
tw.writeVariableBinding(lhsId, fieldId)
writeExpressionMetadataToTrapFile(lhsId, ids.constructor, assignmentStmtId)
val thisId = tw.getFreshIdLabel<DbThisaccess>()
tw.writeExprs_thisaccess(thisId, ids.type.javaResult.id, ids.type.kotlinResult.id, lhsId, -1)
writeExpressionMetadataToTrapFile(thisId, ids.constructor, assignmentStmtId)
val rhsId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(rhsId, paramType.javaResult.id, paramType.kotlinResult.id, assignmentId, 1)
tw.writeVariableBinding(rhsId, paramId)
writeExpressionMetadataToTrapFile(rhsId, ids.constructor, assignmentStmtId)
}
val firstAssignmentStmtIdx = 1
val extensionParameterIndex: Int
val dispatchReceiver = functionReferenceExpr.dispatchReceiver
val dispatchFieldId: Label<DbField>?
if (dispatchReceiver != null) {
dispatchFieldId = tw.getFreshIdLabel()
extensionParameterIndex = 1
extractField(dispatchFieldId, "<dispatchReceiver>", dispatchReceiver.type, id, locId, DescriptorVisibilities.PRIVATE, functionReferenceExpr, false)
extractParameterToFieldAssignmentInConstructor("<dispatchReceiver>", dispatchReceiver.type, dispatchFieldId, 0, firstAssignmentStmtIdx)
} else {
dispatchFieldId = null
extensionParameterIndex = 0
}
val extensionReceiver = functionReferenceExpr.extensionReceiver
val extensionFieldId: Label<out DbField>?
if (extensionReceiver != null) {
extensionFieldId = tw.getFreshIdLabel()
extractField(extensionFieldId, "<extensionReceiver>", extensionReceiver.type, id, locId, DescriptorVisibilities.PRIVATE, functionReferenceExpr, false)
extractParameterToFieldAssignmentInConstructor( "<extensionReceiver>", extensionReceiver.type, extensionFieldId, 0 + extensionParameterIndex, firstAssignmentStmtIdx + extensionParameterIndex)
} else {
extensionFieldId = null
}
val funLabels = if (functionNTypeArguments.size > BuiltInFunctionArity.BIG_ARITY) {
addFunctionNInvoke(target.owner.returnType, id, locId)
} else {
addFunctionInvoke(parameterTypes, target.owner.returnType, id, locId)
}
// Return statement of generated function:
val retId = tw.getFreshIdLabel<DbReturnstmt>()
tw.writeStmts_returnstmt(retId, funLabels.blockId, 0, funLabels.methodId)
tw.writeHasLocation(retId, locId)
// Call to target function:
val dispatchReceiverId: Int
val callId: Label<out DbExpr>
val callType = useType(target.owner.returnType)
if (target is IrConstructorSymbol) {
callId = tw.getFreshIdLabel<DbNewexpr>()
tw.writeExprs_newexpr(callId, callType.javaResult.id, callType.kotlinResult.id, retId, 0)
val typeAccessId = extractTypeAccess(callType, locId, funLabels.methodId, callId, -3, retId)
extractTypeArguments(functionReferenceExpr, typeAccessId, funLabels.methodId, retId)
dispatchReceiverId = -2
} else {
callId = tw.getFreshIdLabel<DbMethodaccess>()
tw.writeExprs_methodaccess(callId, callType.javaResult.id, callType.kotlinResult.id, retId, 0)
extractTypeArguments(functionReferenceExpr, callId, funLabels.methodId, retId, -2, true)
dispatchReceiverId = -1
}
writeExpressionMetadataToTrapFile(callId, funLabels.methodId, retId)
@Suppress("UNCHECKED_CAST")
tw.writeCallableBinding(callId as Label<out DbCaller>, targetCallableId)
fun writeVariableAccessInInvokeBody(
pType: TypeResults,
idx: Int,
variable: Label<out DbVariable>
): Label<DbVaraccess> {
val pId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(pId, pType.javaResult.id, pType.kotlinResult.id, callId, idx)
tw.writeVariableBinding(pId, variable)
writeExpressionMetadataToTrapFile(pId, funLabels.methodId, retId)
return pId
}
fun writeFieldAccessInInvokeBody(pType: IrType, idx: Int, variable: Label<out DbField>) {
val accessId = writeVariableAccessInInvokeBody(useType(pType), idx, variable)
val thisId = tw.getFreshIdLabel<DbThisaccess>()
tw.writeExprs_thisaccess(thisId, ids.type.javaResult.id, ids.type.kotlinResult.id, accessId, -1)
writeExpressionMetadataToTrapFile(thisId, funLabels.methodId, retId)
}
if (dispatchReceiver != null) {
writeFieldAccessInInvokeBody(dispatchReceiver.type, dispatchReceiverId, dispatchFieldId!!)
}
val extensionIdxOffset: Int
if (extensionReceiver != null) {
writeFieldAccessInInvokeBody(extensionReceiver.type, 0, extensionFieldId!!)
extensionIdxOffset = 1
} else {
extensionIdxOffset = 0
}
if (functionNTypeArguments.size > BuiltInFunctionArity.BIG_ARITY) {
addArgumentsToInvocationInInvokeNBody(parameters, funLabels, retId, callId, locId, { exp -> writeExpressionMetadataToTrapFile(exp, funLabels.methodId, retId) }, extensionIdxOffset)
} else {
for ((pIdx, p) in funLabels.parameters.withIndex()) {
writeVariableAccessInInvokeBody(p.second, pIdx + extensionIdxOffset, p.first)
}
}
// Add constructor (member ref) call:
val exprParent = parent.expr(functionReferenceExpr, callable)
val idMemberRef = tw.getFreshIdLabel<DbMemberref>()
tw.writeExprs_memberref(idMemberRef, ids.type.javaResult.id, ids.type.kotlinResult.id, exprParent.parent, exprParent.idx)
tw.writeHasLocation(idMemberRef, locId)
tw.writeCallableEnclosingExpr(idMemberRef, callable)
tw.writeStatementEnclosingExpr(idMemberRef, exprParent.enclosingStmt)
tw.writeCallableBinding(idMemberRef, ids.constructor)
extractTypeAccess(fnInterfaceType, locId, callable, idMemberRef, -3, exprParent.enclosingStmt)
tw.writeMemberRefBinding(idMemberRef, targetCallableId)
// constructor arguments:
if (dispatchReceiver != null) {
extractExpressionExpr(dispatchReceiver, callable, idMemberRef, 0, exprParent.enclosingStmt)
}
if (extensionReceiver != null) {
extractExpressionExpr(extensionReceiver, callable, idMemberRef, 0 + extensionParameterIndex, exprParent.enclosingStmt)
}
tw.writeIsAnonymClass(id, idMemberRef)
}
private fun getFunctionalInterfaceType(functionNTypeArguments: List<IrType>) =
if (functionNTypeArguments.size > BuiltInFunctionArity.BIG_ARITY) {
pluginContext.referenceClass(FqName("kotlin.jvm.functions.FunctionN"))!!
.typeWith(functionNTypeArguments.last())
} else {
functionN(pluginContext)(functionNTypeArguments.size - 1).typeWith(functionNTypeArguments)
}
private data class FunctionLabels(
val methodId: Label<DbMethod>,
val blockId: Label<DbBlock>,
val parameters: List<Pair<Label<DbParam>, TypeResults>>)
/**
* Adds a function `invoke(a: Any[])` with the specified return type to the class identified by parentId.
*/
private fun addFunctionNInvoke(returnType: IrType, parentId: Label<out DbReftype>, locId: Label<DbLocation>): FunctionLabels {
return addFunctionInvoke(listOf(pluginContext.irBuiltIns.arrayClass.typeWith(pluginContext.irBuiltIns.anyNType)), returnType, parentId, locId)
}
/**
* Adds a function named "invoke" with the specified parameter types and return type to the class identified by parentId.
*/
private fun addFunctionInvoke(parameterTypes: List<IrType>, returnType: IrType, parentId: Label<out DbReftype>, locId: Label<DbLocation>): FunctionLabels {
val methodId = tw.getFreshIdLabel<DbMethod>()
val parameters = parameterTypes.mapIndexed { idx, p ->
val paramId = tw.getFreshIdLabel<DbParam>()
val paramType = extractValueParameter(paramId, p, "a$idx", locId, methodId, idx, null)
Pair(paramId, paramType)
}
val paramsSignature = parameters.joinToString(separator = ",", prefix = "(", postfix = ")") { it.second.javaResult.signature!! }
val rt = useType(returnType, TypeContext.RETURN)
val shortName = OperatorNameConventions.INVOKE.asString()
tw.writeMethods(methodId, shortName, "$shortName$paramsSignature", rt.javaResult.id, rt.kotlinResult.id, parentId, methodId)
tw.writeHasLocation(methodId, locId)
// Block
val blockId = tw.getFreshIdLabel<DbBlock>()
tw.writeStmts_block(blockId, methodId, 0, methodId)
tw.writeHasLocation(blockId, locId)
return FunctionLabels(methodId, blockId, parameters)
}
/*
* This function generates an implementation for `fun kotlin.FunctionN<R>.invoke(vararg args: Any?): R`
*
@@ -1845,37 +2150,21 @@ open class KotlinFileExtractor(
* */
private fun implementFunctionNInvoke(
lambda: IrFunction,
ids: LocalFunctionLabels,
ids: LocallyVisibleFunctionLabels,
locId: Label<DbLocation>,
parameters: List<IrValueParameter>
) {
val methodId = tw.getFreshIdLabel<DbMethod>()
val argsParamId = tw.getFreshIdLabel<DbParam>()
val argsParamType = pluginContext.irBuiltIns.arrayClass.typeWith(pluginContext.irBuiltIns.anyNType)
val paramType = extractValueParameter(argsParamId, argsParamType, "args", locId, methodId, 0, null)
val paramsSignature = "(${paramType.javaResult.signature!!})"
val returnType = useType(lambda.returnType, TypeContext.RETURN)
val shortName = OperatorNameConventions.INVOKE.asString()
@Suppress("UNCHECKED_CAST")
tw.writeMethods(methodId, shortName, "$shortName$paramsSignature", returnType.javaResult.id, returnType.kotlinResult.id, ids.type.javaResult.id as Label<out DbReftype>, methodId)
tw.writeHasLocation(methodId, locId)
// Block
val blockId = tw.getFreshIdLabel<DbBlock>()
tw.writeStmts_block(blockId, methodId, 0, methodId)
tw.writeHasLocation(blockId, locId)
val funLabels = addFunctionNInvoke(lambda.returnType, ids.type.javaResult.id as Label<DbReftype>, locId)
// Return
val retId = tw.getFreshIdLabel<DbReturnstmt>()
tw.writeStmts_returnstmt(retId, blockId, 0, methodId)
tw.writeStmts_returnstmt(retId, funLabels.blockId, 0, funLabels.methodId)
tw.writeHasLocation(retId, locId)
fun extractCommonExpr(id: Label<out DbExpr>) {
tw.writeHasLocation(id, locId)
tw.writeCallableEnclosingExpr(id, methodId)
tw.writeCallableEnclosingExpr(id, funLabels.methodId)
tw.writeStatementEnclosingExpr(id, retId)
}
@@ -1892,16 +2181,36 @@ open class KotlinFileExtractor(
tw.writeExprs_thisaccess(thisId, ids.type.javaResult.id, ids.type.kotlinResult.id, callId, -1)
extractCommonExpr(thisId)
// parameters
addArgumentsToInvocationInInvokeNBody(parameters, funLabels, retId, callId, locId, ::extractCommonExpr)
}
/**
* Adds the arguments to the method call inside `invoke(a0: Any[])`. Each argument is an array access with a cast:
*
* ```
* fun invoke(args: Any[]) : T {
* return fn(args[0] as T0, args[1] as T1, ...)
* }
* ```
*/
private fun addArgumentsToInvocationInInvokeNBody(
parameters: List<IrValueParameter>,
funLabels: FunctionLabels,
retId: Label<DbReturnstmt>,
callId: Label<out DbExprparent>,
locId: Label<DbLocation>,
extractCommonExpr: (Label<out DbExpr>) -> Unit,
firstArgumentOffset: Int = 0
) {
val intType = useType(pluginContext.irBuiltIns.intType)
val argsParamType = pluginContext.irBuiltIns.arrayClass.typeWith(pluginContext.irBuiltIns.anyNType)
val argsType = useType(argsParamType)
val anyNType = useType(pluginContext.irBuiltIns.anyNType)
val func =
pluginContext.irBuiltIns.arrayClass.owner.declarations.find { it is IrFunction && it.name.asString() == "get" }
val arrayIndexerFunction = pluginContext.irBuiltIns.arrayClass.owner.declarations.find { it is IrFunction && it.name.asString() == "get" }
@Suppress("UNCHECKED_CAST")
val arrayGetMethodId = useFunction<DbMethod>(func as IrFunction)
val arrayIndexerFunctionId = useFunction<DbMethod>(arrayIndexerFunction as IrFunction)
for ((pIdx, p) in parameters.withIndex()) {
// `args[i] as Ti` is generated below for each parameter
@@ -1909,23 +2218,23 @@ open class KotlinFileExtractor(
// cast
val castId = tw.getFreshIdLabel<DbCastexpr>()
val type = useType(p.type)
tw.writeExprs_castexpr(castId, type.javaResult.id, type.kotlinResult.id, callId, pIdx)
tw.writeExprs_castexpr(castId, type.javaResult.id, type.kotlinResult.id, callId, pIdx + firstArgumentOffset)
extractCommonExpr(castId)
// type access
extractTypeAccess(p.type, locId, methodId, castId, 0, retId)
extractTypeAccess(p.type, locId, funLabels.methodId, castId, 0, retId)
// element access: `args.get(i)`
val getCallId = tw.getFreshIdLabel<DbMethodaccess>()
tw.writeExprs_methodaccess(getCallId, anyNType.javaResult.id, anyNType.kotlinResult.id, castId, 1)
extractCommonExpr(getCallId)
tw.writeCallableBinding(getCallId, arrayGetMethodId)
tw.writeCallableBinding(getCallId, arrayIndexerFunctionId)
// parameter access:
val argsAccessId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(argsAccessId, argsType.javaResult.id, argsType.kotlinResult.id, getCallId, -1)
extractCommonExpr(argsAccessId)
tw.writeVariableBinding(argsAccessId, argsParamId)
tw.writeVariableBinding(argsAccessId, funLabels.parameters.first().first)
// index access:
val indexId = tw.getFreshIdLabel<DbIntegerliteral>()
@@ -2087,27 +2396,29 @@ open class KotlinFileExtractor(
private val IrType.isAnonymous: Boolean
get() = ((this as? IrSimpleType)?.classifier?.owner as? IrClass)?.isAnonymousObject ?: false
private fun extractGeneratedClass(localFunction: IrFunction, superTypes: List<IrType>) : Label<out DbClass> {
val ids = getLocalFunctionLabels(localFunction)
/**
* Extracts the class around a local function, a lambda, or a function reference.
*/
private fun extractGeneratedClass(
ids: LocallyVisibleFunctionLabels,
superTypes: List<IrType>,
locId: Label<DbLocation>,
currentDeclaration: IrDeclaration
): Label<out DbClass> {
// Write class
@Suppress("UNCHECKED_CAST")
val id = ids.type.javaResult.id as Label<out DbClass>
val pkgId = extractPackage("")
tw.writeClasses(id, "", pkgId, id)
val locId = tw.getLocation(localFunction)
tw.writeHasLocation(id, locId)
// Extract local function as a member
extractFunctionIfReal(localFunction, id)
// Extract constructor
val unitType = useType(pluginContext.irBuiltIns.unitType)
tw.writeConstrs(ids.constructor, "", "", unitType.javaResult.id, unitType.kotlinResult.id, id, ids.constructor)
tw.writeHasLocation(ids.constructor, locId)
// Constructor body
val constructorBlockId = tw.getFreshIdLabel<DbBlock>()
val constructorBlockId = ids.constructorBlock
tw.writeStmts_block(constructorBlockId, ids.constructor, 0, ids.constructor)
tw.writeHasLocation(constructorBlockId, locId)
@@ -2127,7 +2438,7 @@ open class KotlinFileExtractor(
addModifiers(id, "public", "static", "final")
extractClassSupertypes(superTypes, listOf(), id)
var parent: IrDeclarationParent? = localFunction.parent
var parent: IrDeclarationParent? = currentDeclaration.parent
while (parent != null) {
// todo: merge this with the implementation in `extractClassSource`
if (parent is IrClass) {
@@ -2144,7 +2455,7 @@ open class KotlinFileExtractor(
}
if (parent is IrFile) {
if (this is KotlinSourceFileExtractor && this.file == localFunction.fileOrNull) {
if (this is KotlinSourceFileExtractor && this.file == parent) {
tw.writeEnclInReftype(id, this.fileClass)
} else {
logger.warn(Severity.ErrorSevere, "Unexpected file parent found")
@@ -2157,4 +2468,18 @@ open class KotlinFileExtractor(
return id
}
/**
* Extracts the class around a local function or a lambda.
*/
private fun extractGeneratedClass(localFunction: IrFunction, superTypes: List<IrType>) : Label<out DbClass> {
val ids = getLocallyVisibleFunctionLabels(localFunction)
val id = extractGeneratedClass(ids, superTypes, tw.getLocation(localFunction), localFunction)
// Extract local function as a member
extractFunctionIfReal(localFunction, id)
return id
}
}

View File

@@ -542,25 +542,38 @@ class X {
return this.visibility == DescriptorVisibilities.LOCAL
}
private val generatedLocalFunctionTypeMapping: MutableMap<IrFunction, LocalFunctionLabels> = mutableMapOf()
private val locallyVisibleFunctionLabelMapping: MutableMap<IrFunction, LocallyVisibleFunctionLabels> = mutableMapOf()
data class LocalFunctionLabels(val type: TypeResults, val constructor: Label<DbConstructor>, val function: Label<DbMethod>)
/**
* Data class to hold labels generated for locally visible functions, such as
* - local functions,
* - lambdas, and
* - wrappers around function references.
*/
data class LocallyVisibleFunctionLabels(val type: TypeResults, val constructor: Label<DbConstructor>, val function: Label<DbMethod>, val constructorBlock: Label<DbBlock>)
fun getLocalFunctionLabels(f: IrFunction): LocalFunctionLabels {
/**
* Gets the labels for functions belonging to
* - local functions, and
* - lambdas.
*/
fun getLocallyVisibleFunctionLabels(f: IrFunction): LocallyVisibleFunctionLabels {
if (!f.isLocalFunction()){
logger.warn(Severity.ErrorSevere, "Extracting a non-local function as a local one")
}
var res = generatedLocalFunctionTypeMapping[f]
var res = locallyVisibleFunctionLabelMapping[f]
if (res == null) {
val javaResult = TypeResult(tw.getFreshIdLabel<DbClass>(), "", "")
val kotlinResult = TypeResult(tw.getFreshIdLabel<DbKt_notnull_type>(), "", "")
tw.writeKt_notnull_types(kotlinResult.id, javaResult.id)
res = LocalFunctionLabels(
res = LocallyVisibleFunctionLabels(
TypeResults(javaResult, kotlinResult),
tw.getFreshIdLabel(),
tw.getFreshIdLabel())
generatedLocalFunctionTypeMapping[f] = res
tw.getFreshIdLabel(),
tw.getFreshIdLabel()
)
locallyVisibleFunctionLabelMapping[f] = res
}
return res
@@ -576,7 +589,7 @@ class X {
fun <T: DbCallable> useFunction(f: IrFunction, classTypeArguments: List<IrTypeArgument>? = null): Label<out T> {
if (f.isLocalFunction()) {
val ids = getLocalFunctionLabels(f)
val ids = getLocallyVisibleFunctionLabels(f)
@Suppress("UNCHECKED_CAST")
return ids.function as Label<out T>
} else {