Merge function and property reference extraction logic in helper class

This commit is contained in:
Tamas Vajk
2022-02-23 16:20:47 +01:00
committed by Ian Lynagh
parent b4b1976bc4
commit 5fea49a3c9
4 changed files with 299 additions and 390 deletions

View File

@@ -9,7 +9,6 @@ import com.github.codeql.utils.toRawType
import com.github.codeql.utils.versions.getIrStubFromDescriptor
import com.semmle.extractor.java.OdasaOutput
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.backend.common.ir.simpleFunctions
import org.jetbrains.kotlin.backend.common.pop
import org.jetbrains.kotlin.builtins.functions.BuiltInFunctionArity
import org.jetbrains.kotlin.descriptors.*
@@ -21,6 +20,7 @@ import org.jetbrains.kotlin.ir.backend.js.utils.realOverrideTarget
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.symbols.IrConstructorSymbol
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
import org.jetbrains.kotlin.ir.symbols.IrSymbol
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.*
@@ -2538,30 +2538,31 @@ open class KotlinFileExtractor(
}
}
private inner class FunctionReferenceHelper(private val locId: Label<DbLocation>, private val ids: LocallyVisibleFunctionLabels) {
fun writeExpressionMetadataToTrapFile(id: Label<out DbExpr>, callable: Label<out DbCallable>, stmt: Label<out DbStmt>) {
private open inner class GeneratedClassHelper(protected val locId: Label<DbLocation>, protected val ids: GeneratedClassLabels) {
protected fun writeExpressionMetadataToTrapFile(id: Label<out DbExpr>, callable: Label<out DbCallable>, stmt: Label<out DbStmt>) {
tw.writeHasLocation(id, locId)
tw.writeCallableEnclosingExpr(id, callable)
tw.writeStatementEnclosingExpr(id, stmt)
}
/**
* Extract a parameter to field assignment, such as `this.field = paramName` below:
* ```
* constructor(paramName: type) {
* this.field = paramName
* }
* ```
*/
* Extract a parameter to field assignment, such as `this.field = paramName` below:
* ```
* constructor(paramName: type) {
* this.field = paramName
* }
* ```
*/
fun extractParameterToFieldAssignmentInConstructor(
paramName: String,
type: IrType,
paramType: IrType,
fieldId: Label<DbField>,
paramIdx: Int,
stmtIdx: Int
) {
val paramId = tw.getFreshIdLabel<DbParam>()
val paramType = extractValueParameter(paramId, type, paramName, locId, ids.constructor, paramIdx, null, paramId, false)
val paramType = extractValueParameter(paramId, paramType, paramName, locId, ids.constructor, paramIdx, null, paramId, false)
val assignmentStmtId = tw.getFreshIdLabel<DbExprstmt>()
tw.writeStmts_exprstmt(assignmentStmtId, ids.constructorBlock, stmtIdx, ids.constructor)
@@ -2591,6 +2592,156 @@ open class KotlinFileExtractor(
}
}
private inner class CallableReferenceHelper(private val callableReferenceExpr: IrCallableReference<out IrSymbol>, locId: Label<DbLocation>, ids: GeneratedClassLabels)
: GeneratedClassHelper(locId, ids) {
private val dispatchReceiver = callableReferenceExpr.dispatchReceiver
private val extensionReceiver = callableReferenceExpr.extensionReceiver
private val dispatchFieldId: Label<DbField>? = if (dispatchReceiver != null) tw.getFreshIdLabel() else null
private val extensionFieldId: Label<DbField>? = if (extensionReceiver != null) tw.getFreshIdLabel() else null
private val extensionParameterIndex: Int = if (dispatchReceiver != null) 1 else 0
fun extractReceiverFields(classId: Label<out DbClass>) {
val firstAssignmentStmtIdx = 1
if (dispatchReceiver != null) {
extractField(dispatchFieldId!!, "<dispatchReceiver>", dispatchReceiver.type, classId, locId, DescriptorVisibilities.PRIVATE, callableReferenceExpr, false)
extractParameterToFieldAssignmentInConstructor("<dispatchReceiver>", dispatchReceiver.type, dispatchFieldId!!, 0, firstAssignmentStmtIdx)
}
if (extensionReceiver != null) {
extractField(extensionFieldId!!, "<extensionReceiver>", extensionReceiver.type, classId, locId, DescriptorVisibilities.PRIVATE, callableReferenceExpr, false)
extractParameterToFieldAssignmentInConstructor( "<extensionReceiver>", extensionReceiver.type, extensionFieldId!!, 0 + extensionParameterIndex, firstAssignmentStmtIdx + extensionParameterIndex)
}
}
/**
* Extracts a call to `target` inside the function identified by `labels`. Special parameters (`dispatch` and `extension`) are also handled.
*
* Examples are:
* ```
* this.<dispatchReceiver>.fn(this.<extensionReceiver>, param1, param2, param3, ...)
* param1.fn(this.<extensionReceiver>, param2, ...)
* param1.fn(param2, param3, ...)
* fn(this.<extensionReceiver>, param1, param2, ...)
* fn(param1, param2, ...)
* ```
*
* The parameters with default argument values cover special cases:
* - dispatchReceiverIdx is usually -1, except if a constructor is referenced
* - big arity function references need to call `invoke` with arguments received in an object array: `fn(param1[0] as T0, param1[1] as T1, ...)`
*/
fun extractCallToReflectionTarget(
labels: FunctionLabels,
target: IrFunctionSymbol,
extractAccessToTarget: (Label<DbReturnstmt>, TypeResults) -> Label<out DbExpr>,
dispatchReceiverIdx: Int = -1,
isBigArity: Boolean = false,
bigArityParameters: LinkedList<IrValueParameter>? = null
) {
// Return statement of generated function:
val retId = tw.getFreshIdLabel<DbReturnstmt>()
tw.writeStmts_returnstmt(retId, labels.blockId, 0, labels.methodId)
tw.writeHasLocation(retId, locId)
// Call to target function:
val callType = useType(target.owner.returnType)
val callId = extractAccessToTarget(retId, callType)
writeExpressionMetadataToTrapFile(callId, labels.methodId, retId)
// todo: type arguments
val callableId = useFunction<DbCallable>(target.owner, null)
@Suppress("UNCHECKED_CAST")
tw.writeCallableBinding(callId as Label<out DbCaller>, callableId)
fun writeVariableAccessInFunctionBody(
pType: TypeResults,
idx: Int,
variable: Label<out DbVariable>
): Label<DbVaraccess> {
val pId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(pId, pType.javaResult.id, callId, idx)
tw.writeExprsKotlinType(pId, pType.kotlinResult.id)
tw.writeVariableBinding(pId, variable)
writeExpressionMetadataToTrapFile(pId, labels.methodId, retId)
return pId
}
fun writeFieldAccessInFunctionBody(pType: IrType, idx: Int, variable: Label<out DbField>) {
val accessId = writeVariableAccessInFunctionBody(useType(pType), idx, variable)
val thisId = tw.getFreshIdLabel<DbThisaccess>()
tw.writeExprs_thisaccess(thisId, ids.type.javaResult.id, accessId, -1)
tw.writeExprsKotlinType(thisId, ids.type.kotlinResult.id)
writeExpressionMetadataToTrapFile(thisId, labels.methodId, retId)
}
val useFirstArgAsDispatch: Boolean
if (dispatchReceiver != null) {
writeFieldAccessInFunctionBody(dispatchReceiver.type, dispatchReceiverIdx, dispatchFieldId!!)
useFirstArgAsDispatch = false
} else {
useFirstArgAsDispatch = target.owner.dispatchReceiverParameter != null
}
val extensionIdxOffset: Int
if (extensionReceiver != null) {
writeFieldAccessInFunctionBody(extensionReceiver.type, 0, extensionFieldId!!)
extensionIdxOffset = 1
} else {
extensionIdxOffset = 0
}
if (isBigArity) {
// In case we're extracting a big arity function reference:
addArgumentsToInvocationInInvokeNBody(
bigArityParameters!!, labels, retId, callId, locId,
{ exp -> writeExpressionMetadataToTrapFile(exp, labels.methodId, retId) },
extensionIdxOffset, useFirstArgAsDispatch, dispatchReceiverIdx)
} else {
val dispatchIdxOffset = if (useFirstArgAsDispatch) 1 else 0
for ((pIdx, p) in labels.parameters.withIndex()) {
val childIdx = if (pIdx == 0 && useFirstArgAsDispatch) {
dispatchReceiverIdx
} else {
pIdx + extensionIdxOffset - dispatchIdxOffset
}
writeVariableAccessInFunctionBody(p.second, childIdx, p.first)
}
}
}
fun extractConstructorArguments(
callable: Label<out DbCallable>,
idCtorRef: Label<out DbClassinstancexpr>,
enclosingStmt: Label<out DbStmt>
) {
if (dispatchReceiver != null) {
extractExpressionExpr(dispatchReceiver, callable, idCtorRef, 0, enclosingStmt)
}
if (extensionReceiver != null) {
extractExpressionExpr(extensionReceiver, callable, idCtorRef, 0 + extensionParameterIndex, enclosingStmt)
}
}
fun getExtraParameters(target: IrFunctionSymbol): LinkedList<IrValueParameter> {
val extensionParameter = target.owner.extensionReceiverParameter
val dispatchParameter = target.owner.dispatchReceiverParameter
var parameters = LinkedList<IrValueParameter>()
if (extensionParameter != null && callableReferenceExpr.extensionReceiver == null) {
parameters.addFirst(extensionParameter)
}
if (dispatchParameter != null && callableReferenceExpr.dispatchReceiver == null) {
parameters.addFirst(dispatchParameter)
}
return parameters
}
}
private fun extractPropertyReference(
propertyReferenceExpr: IrPropertyReference,
parent: StmtExprParent,
@@ -2633,11 +2784,10 @@ open class KotlinFileExtractor(
val javaResult = TypeResult(tw.getFreshIdLabel<DbClass>(), "", "")
val kotlinResult = TypeResult(tw.getFreshIdLabel<DbKt_notnull_type>(), "", "")
tw.writeKt_notnull_types(kotlinResult.id, javaResult.id)
val ids = LocallyVisibleFunctionLabels(
val ids = GeneratedClassLabels(
TypeResults(javaResult, kotlinResult),
tw.getFreshIdLabel(),
tw.getFreshIdLabel(), // not used
tw.getFreshIdLabel()
constructor = tw.getFreshIdLabel(),
constructorBlock = tw.getFreshIdLabel()
)
val currentDeclaration = declarationStack.peek()
@@ -2645,193 +2795,42 @@ open class KotlinFileExtractor(
val baseClass = pluginContext.referenceClass(FqName("kotlin.jvm.internal.PropertyReference"))?.owner?.typeWith()
?: pluginContext.irBuiltIns.anyType
val id = extractGeneratedClass(ids, listOf(baseClass, kPropertyType), locId, currentDeclaration)
val classId = extractGeneratedClass(ids, listOf(baseClass, kPropertyType), locId, currentDeclaration)
val helper = FunctionReferenceHelper(locId, ids)
val firstAssignmentStmtIdx = 1
val extensionParameterIndex: Int
val dispatchReceiver = propertyReferenceExpr.dispatchReceiver
val dispatchFieldId: Label<DbField>?
if (dispatchReceiver != null) {
dispatchFieldId = tw.getFreshIdLabel()
extensionParameterIndex = 1
val helper = CallableReferenceHelper(propertyReferenceExpr, locId, ids)
helper.extractReceiverFields(classId)
var parameters = helper.getExtraParameters((getter ?: setter)!!)
extractField(dispatchFieldId, "<dispatchReceiver>", dispatchReceiver.type, id, locId, DescriptorVisibilities.PRIVATE, propertyReferenceExpr, false)
helper.extractParameterToFieldAssignmentInConstructor("<dispatchReceiver>", dispatchReceiver.type, dispatchFieldId, 0, firstAssignmentStmtIdx)
} else {
dispatchFieldId = null
extensionParameterIndex = 0
}
fun extractAccessToTarget(targetId: Label<DbMethod>, retId: Label<DbReturnstmt>, callType: TypeResults) : Label<out DbExpr> {
val callId = tw.getFreshIdLabel<DbMethodaccess>()
tw.writeExprs_methodaccess(callId, callType.javaResult.id, retId, 0)
tw.writeExprsKotlinType(callId, callType.kotlinResult.id)
val extensionReceiver = propertyReferenceExpr.extensionReceiver
val extensionFieldId: Label<out DbField>?
if (extensionReceiver != null) {
extensionFieldId = tw.getFreshIdLabel()
extractTypeArguments(propertyReferenceExpr, callId, targetId, retId, -2, true)
extractField(extensionFieldId, "<extensionReceiver>", extensionReceiver.type, id, locId, DescriptorVisibilities.PRIVATE, propertyReferenceExpr, false)
helper.extractParameterToFieldAssignmentInConstructor( "<extensionReceiver>", extensionReceiver.type, extensionFieldId, 0 + extensionParameterIndex, firstAssignmentStmtIdx + extensionParameterIndex)
} else {
extensionFieldId = null
}
val accessor = (getter ?: setter)!!
val extensionParameter = accessor.owner.extensionReceiverParameter
val dispatchParameter = accessor.owner.dispatchReceiverParameter
var parameters = LinkedList<IrValueParameter>()
if (extensionParameter != null && propertyReferenceExpr.extensionReceiver == null) {
parameters.addFirst(extensionParameter)
}
if (dispatchParameter != null && propertyReferenceExpr.dispatchReceiver == null) {
parameters.addFirst(dispatchParameter)
return callId
}
if (getter != null) {
val getterParameterTypes = parameters.map { it.type }
val getLabels = addFunctionManual("get", getterParameterTypes, getter.owner.returnType, id, locId)
val getLabels = addFunctionManual(tw.getFreshIdLabel(), "get", getterParameterTypes, getter.owner.returnType, classId, locId)
// Return statement of generated function:
val retId = tw.getFreshIdLabel<DbReturnstmt>()
tw.writeStmts_returnstmt(retId, getLabels.blockId, 0, getLabels.methodId)
tw.writeHasLocation(retId, locId)
// Call to target function:
val callId: Label<out DbExpr>
val callType = useType(getter.owner.returnType)
callId = tw.getFreshIdLabel<DbMethodaccess>()
tw.writeExprs_methodaccess(callId, callType.javaResult.id, retId, 0)
tw.writeExprsKotlinType(callId, callType.kotlinResult.id)
extractTypeArguments(propertyReferenceExpr, callId, getLabels.methodId, retId, -2, true)
helper.writeExpressionMetadataToTrapFile(callId, getLabels.methodId, retId)
// todo: type arguments
val getterCallableId = useFunction<DbCallable>(getter.owner, null)
@Suppress("UNCHECKED_CAST")
tw.writeCallableBinding(callId as Label<out DbCaller>, getterCallableId)
fun writeVariableAccessInFunctionBody(
pType: TypeResults,
idx: Int,
variable: Label<out DbVariable>
): Label<DbVaraccess> {
val pId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(pId, pType.javaResult.id, callId, idx)
tw.writeExprsKotlinType(pId, pType.kotlinResult.id)
tw.writeVariableBinding(pId, variable)
helper.writeExpressionMetadataToTrapFile(pId, getLabels.methodId, retId)
return pId
}
fun writeFieldAccessInFunctionBody(pType: IrType, idx: Int, variable: Label<out DbField>) {
val accessId = writeVariableAccessInFunctionBody(useType(pType), idx, variable)
val thisId = tw.getFreshIdLabel<DbThisaccess>()
tw.writeExprs_thisaccess(thisId, ids.type.javaResult.id, accessId, -1)
tw.writeExprsKotlinType(thisId, ids.type.kotlinResult.id)
helper.writeExpressionMetadataToTrapFile(thisId, getLabels.methodId, retId)
}
val useFirstArgAsDispatch: Boolean
if (dispatchReceiver != null) {
writeFieldAccessInFunctionBody(dispatchReceiver.type, -1, dispatchFieldId!!)
useFirstArgAsDispatch = false
} else {
useFirstArgAsDispatch = dispatchParameter != null
}
val extensionIdxOffset: Int
if (extensionReceiver != null) {
writeFieldAccessInFunctionBody(extensionReceiver.type, 0, extensionFieldId!!)
extensionIdxOffset = 1
} else {
extensionIdxOffset = 0
}
val dispatchIdxOffset = if (useFirstArgAsDispatch) 1 else 0
for ((pIdx, p) in getLabels.parameters.withIndex()) {
val childIdx = if (pIdx == 0 && useFirstArgAsDispatch) {
-1
} else {
pIdx + extensionIdxOffset - dispatchIdxOffset
}
writeVariableAccessInFunctionBody(p.second, childIdx, p.first)
}
helper.extractCallToReflectionTarget(
getLabels,
getter,
{ r, c -> extractAccessToTarget(getLabels.methodId, r, c) }
)
}
if (setter != null) {
val setterParameterTypes = parameters.map { it.type } + setter.owner.valueParameters.map { it.type }
val setLabels = addFunctionManual("set", setterParameterTypes, setter.owner.returnType, id, locId)
val setterParameterTypes = (parameters + setter.owner.valueParameters).map { it.type }
val setLabels = addFunctionManual(tw.getFreshIdLabel(), "set", setterParameterTypes, setter.owner.returnType, classId, locId)
// Return statement of generated function:
val retId = tw.getFreshIdLabel<DbReturnstmt>()
tw.writeStmts_returnstmt(retId, setLabels.blockId, 0, setLabels.methodId)
tw.writeHasLocation(retId, locId)
// Call to target function:
val callId: Label<out DbExpr>
val callType = useType(setter.owner.returnType)
callId = tw.getFreshIdLabel<DbMethodaccess>()
tw.writeExprs_methodaccess(callId, callType.javaResult.id, retId, 0)
tw.writeExprsKotlinType(callId, callType.kotlinResult.id)
extractTypeArguments(propertyReferenceExpr, callId, setLabels.methodId, retId, -2, true)
helper.writeExpressionMetadataToTrapFile(callId, setLabels.methodId, retId)
// todo: type arguments
val setterCallableId = useFunction<DbCallable>(setter.owner, null)
@Suppress("UNCHECKED_CAST")
tw.writeCallableBinding(callId as Label<out DbCaller>, setterCallableId)
fun writeVariableAccessInFunctionBody(
pType: TypeResults,
idx: Int,
variable: Label<out DbVariable>
): Label<DbVaraccess> {
val pId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(pId, pType.javaResult.id, callId, idx)
tw.writeExprsKotlinType(pId, pType.kotlinResult.id)
tw.writeVariableBinding(pId, variable)
helper.writeExpressionMetadataToTrapFile(pId, setLabels.methodId, retId)
return pId
}
fun writeFieldAccessInFunctionBody(pType: IrType, idx: Int, variable: Label<out DbField>) {
val accessId = writeVariableAccessInFunctionBody(useType(pType), idx, variable)
val thisId = tw.getFreshIdLabel<DbThisaccess>()
tw.writeExprs_thisaccess(thisId, ids.type.javaResult.id, accessId, -1)
tw.writeExprsKotlinType(thisId, ids.type.kotlinResult.id)
helper.writeExpressionMetadataToTrapFile(thisId, setLabels.methodId, retId)
}
val useFirstArgAsDispatch: Boolean
if (dispatchReceiver != null) {
writeFieldAccessInFunctionBody(dispatchReceiver.type, -1, dispatchFieldId!!)
useFirstArgAsDispatch = false
} else {
useFirstArgAsDispatch = dispatchParameter != null
}
val extensionIdxOffset: Int
if (extensionReceiver != null) {
writeFieldAccessInFunctionBody(extensionReceiver.type, 0, extensionFieldId!!)
extensionIdxOffset = 1
} else {
extensionIdxOffset = 0
}
val dispatchIdxOffset = if (useFirstArgAsDispatch) 1 else 0
for ((pIdx, p) in setLabels.parameters.withIndex()) {
val childIdx = if (pIdx == 0 && useFirstArgAsDispatch) {
-1
} else {
pIdx + extensionIdxOffset - dispatchIdxOffset
}
writeVariableAccessInFunctionBody(p.second, childIdx, p.first)
}
helper.extractCallToReflectionTarget(
setLabels,
setter,
{ r, c -> extractAccessToTarget(setLabels.methodId, r, c) }
)
}
// todo: property ref
@@ -2850,16 +2849,9 @@ open class KotlinFileExtractor(
// todo: property ref:
//tw.writeMemberRefBinding(idMemberRef, targetCallableId)
// constructor arguments:
if (dispatchReceiver != null) {
extractExpressionExpr(dispatchReceiver, callable, idCtorRef, 0, exprParent.enclosingStmt)
}
helper.extractConstructorArguments(callable, idCtorRef, exprParent.enclosingStmt)
if (extensionReceiver != null) {
extractExpressionExpr(extensionReceiver, callable, idCtorRef, 0 + extensionParameterIndex, exprParent.enclosingStmt)
}
tw.writeIsAnonymClass(id, idCtorRef)
tw.writeIsAnonymClass(classId, idCtorRef)
}
}
@@ -2912,154 +2904,67 @@ open class KotlinFileExtractor(
val targetCallableId = useFunction<DbCallable>(target.owner, typeArguments)
val locId = tw.getLocation(functionReferenceExpr)
val extensionParameter = target.owner.extensionReceiverParameter
val dispatchParameter = target.owner.dispatchReceiverParameter
var parameters = LinkedList(target.owner.valueParameters)
if (extensionParameter != null && functionReferenceExpr.extensionReceiver == null) {
parameters.addFirst(extensionParameter)
}
if (dispatchParameter != null && functionReferenceExpr.dispatchReceiver == null) {
parameters.addFirst(dispatchParameter)
}
val parameterTypes = parameters.map { it.type }
val functionNTypeArguments = parameterTypes + target.owner.returnType
val fnInterfaceType = getFunctionalInterfaceType(functionNTypeArguments)
val javaResult = TypeResult(tw.getFreshIdLabel<DbClass>(), "", "")
val kotlinResult = TypeResult(tw.getFreshIdLabel<DbKt_notnull_type>(), "", "")
tw.writeKt_notnull_types(kotlinResult.id, javaResult.id)
val ids = LocallyVisibleFunctionLabels(
TypeResults(javaResult, kotlinResult),
tw.getFreshIdLabel(),
tw.getFreshIdLabel(), // not used
tw.getFreshIdLabel()
constructor = tw.getFreshIdLabel(),
function = tw.getFreshIdLabel(),
constructorBlock = tw.getFreshIdLabel()
)
val helper = CallableReferenceHelper(functionReferenceExpr, locId, ids)
var parameters: LinkedList<IrValueParameter> = helper.getExtraParameters(target)
parameters += target.owner.valueParameters
val parameterTypes = parameters.map { it.type }
val functionNTypeArguments = parameterTypes + target.owner.returnType
val fnInterfaceType = getFunctionalInterfaceType(functionNTypeArguments)
val currentDeclaration = declarationStack.peek()
// `FunctionReference` base class is required, because that's implementing `KFunction`.
val baseClass = pluginContext.referenceClass(FqName("kotlin.jvm.internal.FunctionReference"))?.owner?.typeWith()
?: pluginContext.irBuiltIns.anyType
val id = extractGeneratedClass(ids, listOf(baseClass, fnInterfaceType), locId, currentDeclaration)
val classId = extractGeneratedClass(ids, listOf(baseClass, fnInterfaceType), locId, currentDeclaration)
val helper = FunctionReferenceHelper(locId, ids)
helper.extractReceiverFields(classId)
val firstAssignmentStmtIdx = 1
val extensionParameterIndex: Int
val dispatchReceiver = functionReferenceExpr.dispatchReceiver
val dispatchFieldId: Label<DbField>?
if (dispatchReceiver != null) {
dispatchFieldId = tw.getFreshIdLabel()
extensionParameterIndex = 1
extractField(dispatchFieldId, "<dispatchReceiver>", dispatchReceiver.type, id, locId, DescriptorVisibilities.PRIVATE, functionReferenceExpr, false)
helper.extractParameterToFieldAssignmentInConstructor("<dispatchReceiver>", dispatchReceiver.type, dispatchFieldId, 0, firstAssignmentStmtIdx)
val isBigArity = functionNTypeArguments.size > BuiltInFunctionArity.BIG_ARITY
val funLabels = if (isBigArity) {
addFunctionNInvoke(ids.function, target.owner.returnType, classId, locId)
} else {
dispatchFieldId = null
extensionParameterIndex = 0
addFunctionInvoke(ids.function, parameterTypes, target.owner.returnType, classId, locId)
}
val extensionReceiver = functionReferenceExpr.extensionReceiver
val extensionFieldId: Label<out DbField>?
if (extensionReceiver != null) {
extensionFieldId = tw.getFreshIdLabel()
val dispatchReceiverIdx: Int = if (target is IrConstructorSymbol) -2 else -1
fun extractAccessToTarget(retId: Label<DbReturnstmt>, callType: TypeResults) : Label<out DbExpr> {
if (target is IrConstructorSymbol) {
val callId = tw.getFreshIdLabel<DbNewexpr>()
tw.writeExprs_newexpr(callId, callType.javaResult.id, retId, 0)
tw.writeExprsKotlinType(callId, callType.kotlinResult.id)
extractField(extensionFieldId, "<extensionReceiver>", extensionReceiver.type, id, locId, DescriptorVisibilities.PRIVATE, functionReferenceExpr, false)
helper.extractParameterToFieldAssignmentInConstructor( "<extensionReceiver>", extensionReceiver.type, extensionFieldId, 0 + extensionParameterIndex, firstAssignmentStmtIdx + extensionParameterIndex)
} else {
extensionFieldId = null
}
val typeAccessId = extractTypeAccess(callType, locId, funLabels.methodId, callId, -3, retId)
val funLabels = if (functionNTypeArguments.size > BuiltInFunctionArity.BIG_ARITY) {
addFunctionNInvoke(target.owner.returnType, id, locId)
} else {
addFunctionInvoke(parameterTypes, target.owner.returnType, id, locId)
}
extractTypeArguments(functionReferenceExpr, typeAccessId, funLabels.methodId, retId)
return callId
} else {
var callId = tw.getFreshIdLabel<DbMethodaccess>()
tw.writeExprs_methodaccess(callId, callType.javaResult.id, retId, 0)
tw.writeExprsKotlinType(callId, callType.kotlinResult.id)
// Return statement of generated function:
val retId = tw.getFreshIdLabel<DbReturnstmt>()
tw.writeStmts_returnstmt(retId, funLabels.blockId, 0, funLabels.methodId)
tw.writeHasLocation(retId, locId)
// Call to target function:
val dispatchReceiverIdx: Int
val callId: Label<out DbExpr>
val callType = useType(target.owner.returnType)
if (target is IrConstructorSymbol) {
callId = tw.getFreshIdLabel<DbNewexpr>()
tw.writeExprs_newexpr(callId, callType.javaResult.id, retId, 0)
tw.writeExprsKotlinType(callId, callType.kotlinResult.id)
val typeAccessId = extractTypeAccess(callType, locId, funLabels.methodId, callId, -3, retId)
extractTypeArguments(functionReferenceExpr, typeAccessId, funLabels.methodId, retId)
dispatchReceiverIdx = -2
} else {
callId = tw.getFreshIdLabel<DbMethodaccess>()
tw.writeExprs_methodaccess(callId, callType.javaResult.id, retId, 0)
tw.writeExprsKotlinType(callId, callType.kotlinResult.id)
extractTypeArguments(functionReferenceExpr, callId, funLabels.methodId, retId, -2, true)
dispatchReceiverIdx = -1
}
helper.writeExpressionMetadataToTrapFile(callId, funLabels.methodId, retId)
@Suppress("UNCHECKED_CAST")
tw.writeCallableBinding(callId as Label<out DbCaller>, targetCallableId)
fun writeVariableAccessInInvokeBody(
pType: TypeResults,
idx: Int,
variable: Label<out DbVariable>
): Label<DbVaraccess> {
val pId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(pId, pType.javaResult.id, callId, idx)
tw.writeExprsKotlinType(pId, pType.kotlinResult.id)
tw.writeVariableBinding(pId, variable)
helper.writeExpressionMetadataToTrapFile(pId, funLabels.methodId, retId)
return pId
}
fun writeFieldAccessInInvokeBody(pType: IrType, idx: Int, variable: Label<out DbField>) {
val accessId = writeVariableAccessInInvokeBody(useType(pType), idx, variable)
val thisId = tw.getFreshIdLabel<DbThisaccess>()
tw.writeExprs_thisaccess(thisId, ids.type.javaResult.id, accessId, -1)
tw.writeExprsKotlinType(thisId, ids.type.kotlinResult.id)
helper.writeExpressionMetadataToTrapFile(thisId, funLabels.methodId, retId)
}
val useFirstArgAsDispatch: Boolean
if (dispatchReceiver != null) {
writeFieldAccessInInvokeBody(dispatchReceiver.type, dispatchReceiverIdx, dispatchFieldId!!)
useFirstArgAsDispatch = false
} else {
useFirstArgAsDispatch = dispatchParameter != null
}
val extensionIdxOffset: Int
if (extensionReceiver != null) {
writeFieldAccessInInvokeBody(extensionReceiver.type, 0, extensionFieldId!!)
extensionIdxOffset = 1
} else {
extensionIdxOffset = 0
}
if (functionNTypeArguments.size > BuiltInFunctionArity.BIG_ARITY) {
addArgumentsToInvocationInInvokeNBody(parameters, funLabels, retId, callId, locId, { exp -> helper.writeExpressionMetadataToTrapFile(exp, funLabels.methodId, retId) }, extensionIdxOffset, useFirstArgAsDispatch, dispatchReceiverIdx)
} else {
val dispatchIdxOffset = if (useFirstArgAsDispatch) 1 else 0
for ((pIdx, p) in funLabels.parameters.withIndex()) {
val childIdx = if (pIdx == 0 && useFirstArgAsDispatch) {
dispatchReceiverIdx
} else {
pIdx + extensionIdxOffset - dispatchIdxOffset
}
writeVariableAccessInInvokeBody(p.second, childIdx, p.first)
extractTypeArguments(functionReferenceExpr, callId, funLabels.methodId, retId, -2, true)
return callId
}
}
helper.extractCallToReflectionTarget(
funLabels,
target,
::extractAccessToTarget,
dispatchReceiverIdx,
isBigArity,
parameters)
// Add constructor (member ref) call:
val exprParent = parent.expr(functionReferenceExpr, callable)
@@ -3075,16 +2980,9 @@ open class KotlinFileExtractor(
tw.writeMemberRefBinding(idMemberRef, targetCallableId)
// constructor arguments:
if (dispatchReceiver != null) {
extractExpressionExpr(dispatchReceiver, callable, idMemberRef, 0, exprParent.enclosingStmt)
}
helper.extractConstructorArguments(callable, idMemberRef, exprParent.enclosingStmt)
if (extensionReceiver != null) {
extractExpressionExpr(extensionReceiver, callable, idMemberRef, 0 + extensionParameterIndex, exprParent.enclosingStmt)
}
tw.writeIsAnonymClass(id, idMemberRef)
tw.writeIsAnonymClass(classId, idMemberRef)
}
}
@@ -3110,24 +3008,30 @@ open class KotlinFileExtractor(
val parameters: List<Pair<Label<DbParam>, TypeResults>>)
/**
* Adds a function `invoke(a: Any[])` with the specified return type to the class identified by parentId.
* Adds a function `invoke(a: Any[])` with the specified return type to the class identified by `parentId`.
*/
private fun addFunctionNInvoke(returnType: IrType, parentId: Label<out DbReftype>, locId: Label<DbLocation>): FunctionLabels {
return addFunctionInvoke(listOf(pluginContext.irBuiltIns.arrayClass.typeWith(pluginContext.irBuiltIns.anyNType)), returnType, parentId, locId)
private fun addFunctionNInvoke(methodId: Label<DbMethod>, returnType: IrType, parentId: Label<out DbReftype>, locId: Label<DbLocation>): FunctionLabels {
return addFunctionInvoke(methodId, listOf(pluginContext.irBuiltIns.arrayClass.typeWith(pluginContext.irBuiltIns.anyNType)), returnType, parentId, locId)
}
/**
* Adds a function named "invoke" with the specified parameter types and return type to the class identified by `parentId`.
* Adds a function named `invoke` with the specified parameter types and return type to the class identified by `parentId`.
*/
private fun addFunctionInvoke(parameterTypes: List<IrType>, returnType: IrType, parentId: Label<out DbReftype>, locId: Label<DbLocation>): FunctionLabels {
return addFunctionManual(OperatorNameConventions.INVOKE.asString(), parameterTypes, returnType, parentId, locId)
private fun addFunctionInvoke(methodId: Label<DbMethod>, parameterTypes: List<IrType>, returnType: IrType, parentId: Label<out DbReftype>, locId: Label<DbLocation>): FunctionLabels {
return addFunctionManual(methodId, OperatorNameConventions.INVOKE.asString(), parameterTypes, returnType, parentId, locId)
}
/**
* Extracts a function with the given name, parameter types, return type, and containing type.
* Extracts a function with the given name, parameter types, return type, containing type, and location.
*/
private fun addFunctionManual(name: String, parameterTypes: List<IrType>, returnType: IrType, parentId: Label<out DbReftype>, locId: Label<DbLocation>): FunctionLabels {
val methodId = tw.getFreshIdLabel<DbMethod>()
private fun addFunctionManual(
methodId: Label<DbMethod>,
name: String,
parameterTypes: List<IrType>,
returnType: IrType,
parentId: Label<out DbReftype>,
locId: Label<DbLocation>)
: FunctionLabels {
val parameters = parameterTypes.mapIndexed { idx, p ->
val paramId = tw.getFreshIdLabel<DbParam>()
@@ -3156,8 +3060,8 @@ open class KotlinFileExtractor(
*
* The following body is added:
* ```
* fun invoke(vararg args: Any?): R {
* return invoke(args[0] as T0, args[1] as T1, ..., args[I] as TI)
* fun invoke(vararg a0: Any?): R {
* return invoke(a0[0] as T0, a0[1] as T1, ..., a0[I] as TI)
* }
* ```
* */
@@ -3168,7 +3072,7 @@ open class KotlinFileExtractor(
parameters: List<IrValueParameter>
) {
@Suppress("UNCHECKED_CAST")
val funLabels = addFunctionNInvoke(lambda.returnType, ids.type.javaResult.id as Label<DbReftype>, locId)
val funLabels = addFunctionNInvoke(tw.getFreshIdLabel(), lambda.returnType, ids.type.javaResult.id as Label<DbReftype>, locId)
// Return
val retId = tw.getFreshIdLabel<DbReturnstmt>()
@@ -3203,21 +3107,21 @@ open class KotlinFileExtractor(
* Adds the arguments to the method call inside `invoke(a0: Any[])`. Each argument is an array access with a cast:
*
* ```
* fun invoke(args: Any[]) : T {
* return fn(args[0] as T0, args[1] as T1, ...)
* fun invoke(a0: Any[]) : T {
* return fn(a0[0] as T0, a0[1] as T1, ...)
* }
* ```
*/
private fun addArgumentsToInvocationInInvokeNBody(
parameters: List<IrValueParameter>,
funLabels: FunctionLabels,
retId: Label<DbReturnstmt>,
callId: Label<out DbExprparent>,
locId: Label<DbLocation>,
extractCommonExpr: (Label<out DbExpr>) -> Unit,
firstArgumentOffset: Int = 0,
useFirstArgAsDispatch: Boolean = false,
dispatchReceiverIdx: Int = -1
parameters: List<IrValueParameter>, // list of parameters
funLabels: FunctionLabels, // already generated labels for the function definition
enclosingStmtId: Label<out DbStmt>, // label for the enclosing statement (return)
exprParentId: Label<out DbExprparent>, // label for the expression parent (call)
locId: Label<DbLocation>, // label for the location of all generated items
extractCommonExpr: (Label<out DbExpr>) -> Unit, // lambda used for extracting location, enclosing stmt and expr for all new expressions
firstArgumentOffset: Int = 0, // 0 or 1, the index used for the first argument. 1 in case an extension parameter is already accessed at index 0
useFirstArgAsDispatch: Boolean = false, // true if the first argument should be used as the dispatch receiver
dispatchReceiverIdx: Int = -1 // index of the dispatch receiver. -1 in case of functions, -2 in case of constructors
) {
val intType = useType(pluginContext.irBuiltIns.intType)
val argsParamType = pluginContext.irBuiltIns.arrayClass.typeWith(pluginContext.irBuiltIns.anyNType)
@@ -3244,12 +3148,12 @@ open class KotlinFileExtractor(
// cast
val castId = tw.getFreshIdLabel<DbCastexpr>()
val type = useType(p.type)
tw.writeExprs_castexpr(castId, type.javaResult.id, callId, childIdx)
tw.writeExprs_castexpr(castId, type.javaResult.id, exprParentId, childIdx)
tw.writeExprsKotlinType(castId, type.kotlinResult.id)
extractCommonExpr(castId)
// type access
extractTypeAccess(p.type, locId, funLabels.methodId, castId, 0, retId)
extractTypeAccess(p.type, locId, funLabels.methodId, castId, 0, enclosingStmtId)
// element access: `args.get(i)`
val getCallId = tw.getFreshIdLabel<DbMethodaccess>()
@@ -3475,12 +3379,12 @@ open class KotlinFileExtractor(
tw.writeKt_notnull_types(kotlinResult.id, javaResult.id)
val ids = LocallyVisibleFunctionLabels(
TypeResults(javaResult, kotlinResult),
tw.getFreshIdLabel(),
tw.getFreshIdLabel(),
tw.getFreshIdLabel()
)
constructor = tw.getFreshIdLabel(),
constructorBlock = tw.getFreshIdLabel(),
function = tw.getFreshIdLabel())
val locId = tw.getLocation(e)
val helper = FunctionReferenceHelper(locId, ids)
val helper = GeneratedClassHelper(locId, ids)
val currentDeclaration = declarationStack.peek()
val classId = extractGeneratedClass(ids, listOf(pluginContext.irBuiltIns.anyType, e.typeOperand), locId, currentDeclaration)
@@ -3493,17 +3397,16 @@ open class KotlinFileExtractor(
helper.extractParameterToFieldAssignmentInConstructor("<fn>", functionType, fieldId, 0, 1)
// add implementation function
val functionId = tw.getFreshIdLabel<DbMethod>()
extractFunction(samMember, classId, false, null, null, functionId)
extractFunction(samMember, classId, false, null, null, ids.function)
//body
val blockId = tw.getFreshIdLabel<DbBlock>()
tw.writeStmts_block(blockId, functionId, 0, functionId)
tw.writeStmts_block(blockId, ids.function, 0, ids.function)
tw.writeHasLocation(blockId, locId)
//return stmt
val returnId = tw.getFreshIdLabel<DbReturnstmt>()
tw.writeStmts_returnstmt(returnId, blockId, 0, functionId)
tw.writeStmts_returnstmt(returnId, blockId, 0, ids.function)
tw.writeHasLocation(returnId, locId)
//<fn>.invoke(vp0, cp1, vp2, vp3, ...) or
@@ -3511,7 +3414,7 @@ open class KotlinFileExtractor(
fun extractCommonExpr(id: Label<out DbExpr>) {
tw.writeHasLocation(id, locId)
tw.writeCallableEnclosingExpr(id, functionId)
tw.writeCallableEnclosingExpr(id, ids.function)
tw.writeStatementEnclosingExpr(id, returnId)
}
@@ -3545,7 +3448,7 @@ open class KotlinFileExtractor(
tw.writeExprs_varaccess(argsAccessId, paramType.javaResult.id, parent, idx)
tw.writeExprsKotlinType(argsAccessId, paramType.kotlinResult.id)
extractCommonExpr(argsAccessId)
tw.writeVariableBinding(argsAccessId, useValueParameter(p, functionId))
tw.writeVariableBinding(argsAccessId, useValueParameter(p, ids.function))
}
if (st.arguments.size > BuiltInFunctionArity.BIG_ARITY) {
@@ -3557,7 +3460,7 @@ open class KotlinFileExtractor(
tw.writeExprsKotlinType(arrayCreationId, at.kotlinResult.id)
extractCommonExpr(arrayCreationId)
extractTypeAccess(pluginContext.irBuiltIns.anyNType, functionId, arrayCreationId, -1, e, returnId)
extractTypeAccess(pluginContext.irBuiltIns.anyNType, ids.function, arrayCreationId, -1, e, returnId)
val initId = tw.getFreshIdLabel<DbArrayinit>()
tw.writeExprs_arrayinit(initId, at.javaResult.id, arrayCreationId, -2)
@@ -3640,7 +3543,7 @@ open class KotlinFileExtractor(
* Extracts the class around a local function, a lambda, or a function reference.
*/
private fun extractGeneratedClass(
ids: LocallyVisibleFunctionLabels,
ids: GeneratedClassLabels,
superTypes: List<IrType>,
locId: Label<DbLocation>,
currentDeclaration: IrDeclaration