KE2: Extract more constructs for lambda expressions

This commit is contained in:
Tamas Vajk
2024-11-26 13:36:35 +01:00
parent b42fbde130
commit 44e318546f
5 changed files with 317 additions and 267 deletions

View File

@@ -4986,251 +4986,6 @@ OLD: KE1
functionNTypeArguments.map { makeTypeProjection(it, Variance.INVARIANT) }
)
private fun getFunctionalInterfaceTypeWithTypeArgs(
functionNTypeArguments: List<IrTypeArgument>
) =
if (functionNTypeArguments.size > BuiltInFunctionArity.BIG_ARITY)
referenceExternalClass("kotlin.jvm.functions.FunctionN")
?.symbol
?.typeWithArguments(listOf(functionNTypeArguments.last()))
else
functionN(pluginContext)(functionNTypeArguments.size - 1)
.symbol
.typeWithArguments(functionNTypeArguments)
private data class FunctionLabels(
val methodId: Label<DbMethod>,
val blockId: Label<DbBlock>,
val parameters: List<Pair<Label<DbParam>, TypeResults>>
)
/**
* Adds a function `invoke(a: Any[])` with the specified return type to the class identified by
* `parentId`.
*/
private fun addFunctionNInvoke(
methodId: Label<DbMethod>,
returnType: IrType,
parentId: Label<out DbReftype>,
locId: Label<DbLocation>
): FunctionLabels {
return addFunctionInvoke(
methodId,
listOf(pluginContext.irBuiltIns.arrayClass.typeWith(pluginContext.irBuiltIns.anyNType)),
returnType,
parentId,
locId
)
}
/**
* Adds a function named `invoke` with the specified parameter types and return type to the
* class identified by `parentId`.
*/
private fun addFunctionInvoke(
methodId: Label<DbMethod>,
parameterTypes: List<IrType>,
returnType: IrType,
parentId: Label<out DbReftype>,
locId: Label<DbLocation>
): FunctionLabels {
return addFunctionManual(
methodId,
OperatorNameConventions.INVOKE.asString(),
parameterTypes,
returnType,
parentId,
locId
)
}
/**
* Extracts a function with the given name, parameter types, return type, containing type, and
* location.
*/
private fun addFunctionManual(
methodId: Label<DbMethod>,
name: String,
parameterTypes: List<IrType>,
returnType: IrType,
parentId: Label<out DbReftype>,
locId: Label<DbLocation>
): FunctionLabels {
val parameters =
parameterTypes.mapIndexed { idx, p ->
val paramId = tw.getFreshIdLabel<DbParam>()
val paramType =
extractValueParameter(
paramId,
p,
"a$idx",
locId,
methodId,
idx,
paramId,
syntheticParameterNames = false,
isVararg = false,
isNoinline = false,
isCrossinline = false
)
Pair(paramId, paramType)
}
val paramsSignature =
parameters.joinToString(separator = ",", prefix = "(", postfix = ")") {
signatureOrWarn(it.second.javaResult, declarationStack.tryPeek()?.first)
}
val rt = useType(returnType, TypeContext.RETURN)
tw.writeMethods(
methodId,
name,
"$name$paramsSignature",
rt.javaResult.id,
parentId,
methodId
)
tw.writeMethodsKotlinType(methodId, rt.kotlinResult.id)
tw.writeHasLocation(methodId, locId)
addModifiers(methodId, "public")
addModifiers(methodId, "override")
return FunctionLabels(methodId, extractBlockBody(methodId, locId), parameters)
}
/*
* This function generates an implementation for `fun kotlin.FunctionN<R>.invoke(vararg args: Any?): R`
*
* The following body is added:
* ```
* fun invoke(vararg a0: Any?): R {
* return invoke(a0[0] as T0, a0[1] as T1, ..., a0[I] as TI)
* }
* ```
* */
private fun implementFunctionNInvoke(
lambda: IrFunction,
ids: LocallyVisibleFunctionLabels,
locId: Label<DbLocation>,
parameters: List<IrValueParameter>
) {
val funLabels =
addFunctionNInvoke(
tw.getFreshIdLabel(),
lambda.returnType,
ids.type.javaResult.id.cast<DbReftype>(),
locId
)
// Return
val retId = tw.getFreshIdLabel<DbReturnstmt>()
tw.writeStmts_returnstmt(retId, funLabels.blockId, 0, funLabels.methodId)
tw.writeHasLocation(retId, locId)
// Call to original `invoke`:
val callId = tw.getFreshIdLabel<DbMethodaccess>()
val callType = useType(lambda.returnType)
tw.writeExprs_methodaccess(callId, callType.javaResult.id, retId, 0)
tw.writeExprsKotlinType(callId, callType.kotlinResult.id)
extractExprContext(callId, locId, funLabels.methodId, retId)
val calledMethodId = useFunction<DbMethod>(lambda)
if (calledMethodId == null) {
logger.errorElement("Cannot get ID for called lambda", lambda)
} else {
tw.writeCallableBinding(callId, calledMethodId)
}
// this access
extractThisAccess(ids.type, funLabels.methodId, callId, -1, retId, locId)
addArgumentsToInvocationInInvokeNBody(
parameters.map { it.type },
funLabels,
retId,
callId,
locId
)
}
/**
* Adds the arguments to the method call inside `invoke(a0: Any[])`. Each argument is an array
* access with a cast:
* ```
* fun invoke(a0: Any[]) : T {
* return fn(a0[0] as T0, a0[1] as T1, ...)
* }
* ```
*/
private fun addArgumentsToInvocationInInvokeNBody(
parameterTypes: List<IrType>, // list of parameter types
funLabels: FunctionLabels, // already generated labels for the function definition
enclosingStmtId: Label<out DbStmt>, // label for the enclosing statement (return)
exprParentId: Label<out DbExprparent>, // label for the expression parent (call)
locId: Label<DbLocation>, // label for the location of all generated items
firstArgumentOffset: Int =
0, // 0 or 1, the index used for the first argument. 1 in case an extension parameter is
// already accessed at index 0
useFirstArgAsDispatch: Boolean =
false, // true if the first argument should be used as the dispatch receiver
dispatchReceiverIdx: Int =
-1 // index of the dispatch receiver. -1 in case of functions, -2 in case of
// constructors
) {
val argsParamType =
pluginContext.irBuiltIns.arrayClass.typeWith(pluginContext.irBuiltIns.anyNType)
val argsType = useType(argsParamType)
val anyNType = useType(pluginContext.irBuiltIns.anyNType)
val dispatchIdxOffset = if (useFirstArgAsDispatch) 1 else 0
for ((pIdx, pType) in parameterTypes.withIndex()) {
// `a0[i] as Ti` is generated below for each parameter
val childIdx =
if (pIdx == 0 && useFirstArgAsDispatch) {
dispatchReceiverIdx
} else {
pIdx + firstArgumentOffset - dispatchIdxOffset
}
// cast: `(Ti)a0[i]`
val castId = tw.getFreshIdLabel<DbCastexpr>()
val type = useType(pType)
tw.writeExprs_castexpr(castId, type.javaResult.id, exprParentId, childIdx)
tw.writeExprsKotlinType(castId, type.kotlinResult.id)
extractExprContext(castId, locId, funLabels.methodId, enclosingStmtId)
// type access `Ti`
extractTypeAccessRecursive(pType, locId, castId, 0, funLabels.methodId, enclosingStmtId)
// element access: `a0[i]`
val arrayAccessId = tw.getFreshIdLabel<DbArrayaccess>()
tw.writeExprs_arrayaccess(arrayAccessId, anyNType.javaResult.id, castId, 1)
tw.writeExprsKotlinType(arrayAccessId, anyNType.kotlinResult.id)
extractExprContext(arrayAccessId, locId, funLabels.methodId, enclosingStmtId)
// parameter access: `a0`
val argsAccessId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(argsAccessId, argsType.javaResult.id, arrayAccessId, 0)
tw.writeExprsKotlinType(argsAccessId, argsType.kotlinResult.id)
extractExprContext(argsAccessId, locId, funLabels.methodId, enclosingStmtId)
tw.writeVariableBinding(argsAccessId, funLabels.parameters.first().first)
// index access: `i`
extractConstantInteger(
pIdx,
locId,
arrayAccessId,
1,
funLabels.methodId,
enclosingStmtId
)
}
}
private fun extractVarargElement(
e: IrVarargElement,
callable: Label<out DbCallable>,

View File

@@ -57,7 +57,7 @@ fun KotlinFileExtractor.extractBody(b: KtExpression, callable: Label<out DbCalla
}
// TODO: Can this be inlined?
private fun KotlinFileExtractor.extractBlockBody(callable: Label<out DbCallable>, locId: Label<DbLocation>) =
fun KotlinFileExtractor.extractBlockBody(callable: Label<out DbCallable>, locId: Label<DbLocation>) =
tw.getFreshIdLabel<DbBlock>().also {
tw.writeStmts_block(it, callable, 0, callable)
tw.writeHasLocation(it, locId)
@@ -166,7 +166,7 @@ OLD: KE1
}
*/
private fun KotlinFileExtractor.extractConstantInteger(
fun KotlinFileExtractor.extractConstantInteger(
text: String,
t: KaType,
v: Number,

View File

@@ -1,5 +1,6 @@
package com.github.codeql
import com.github.codeql.useType
import com.github.codeql.utils.getJvmName
import org.jetbrains.kotlin.analysis.api.KaSession
import org.jetbrains.kotlin.analysis.api.symbols.*
@@ -639,6 +640,38 @@ OLD: KE1
}
}
fun KotlinFileExtractor.extractValueParameter(
id: Label<out DbParam>,
t: KaType,
name: String,
locId: Label<DbLocation>,
parent: Label<out DbCallable>,
idx: Int,
paramSourceDeclaration: Label<out DbParam>,
syntheticParameterNames: Boolean,
isVararg: Boolean,
isNoinline: Boolean,
isCrossinline: Boolean
): TypeResults {
val type = useType(t)
tw.writeParams(id, type.javaResult.id, idx, parent, paramSourceDeclaration)
tw.writeParamsKotlinType(id, type.kotlinResult.id)
tw.writeHasLocation(id, locId)
if (!syntheticParameterNames) {
tw.writeParamName(id, name)
}
if (isVararg) {
tw.writeIsVarargsParam(id)
}
if (isNoinline) {
addModifiers(id, "noinline")
}
if (isCrossinline) {
addModifiers(id, "crossinline")
}
return type
}
// TODO: Can this be inlined?
private fun KotlinFileExtractor.extractMethod(
id: Label<out DbMethod>,

View File

@@ -1,11 +1,14 @@
import com.github.codeql.*
import com.github.codeql.KotlinFileExtractor.StmtExprParent
import com.github.codeql.utils.type
import com.intellij.psi.PsiElement
import org.jetbrains.kotlin.analysis.api.KaExperimentalApi
import org.jetbrains.kotlin.analysis.api.KaSession
import org.jetbrains.kotlin.analysis.api.symbols.*
import org.jetbrains.kotlin.analysis.api.types.KaClassType
import org.jetbrains.kotlin.analysis.api.types.KaFunctionType
import org.jetbrains.kotlin.analysis.api.types.KaType
import org.jetbrains.kotlin.builtins.StandardNames
import org.jetbrains.kotlin.builtins.functions.BuiltInFunctionArity
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
@@ -45,19 +48,15 @@ fun KotlinFileExtractor.extractFunctionLiteral(
val functionSymbol = e.symbol
val ids = getLocallyVisibleFunctionLabels(functionSymbol)
// todo: is it possible that the receiver parameter is a dispatch receiver?
val ext = if (functionSymbol.isExtension) functionSymbol.receiverParameter else null
val parameterTypes = functionSymbol.valueParameters.map { (it as KaVariableSymbol).returnType }.toMutableList()
if (ext != null) {
parameterTypes.add(0, ext.type)
val parameters = if (functionSymbol.isExtension) {
listOf(functionSymbol.receiverParameter!!) + functionSymbol.valueParameters
} else {
functionSymbol.valueParameters
}
parameterTypes += functionSymbol.returnType
val isBigArity = parameterTypes.size > BuiltInFunctionArity.BIG_ARITY
val isBigArity = parameters.size >= BuiltInFunctionArity.BIG_ARITY
if (isBigArity) {
// OLD: KE1
// implementFunctionNInvoke(e.function, ids, locId, parameters)
implementFunctionNInvoke(functionSymbol, ids, locId, parameters)
} else {
addModifiers(ids.function, "override")
}
@@ -77,8 +76,7 @@ fun KotlinFileExtractor.extractFunctionLiteral(
// todo: fix hard coded block body of lambda
tw.writeLambdaKind(idLambdaExpr, 1)
val functionType =
e.functionType // TODO: change this type for BIG_ARITY lambdas, this should be kotlin.FunctionN<R> and not kotlin.Function33<....,R>. The latter doesn't exist.
val functionType = getRealFunctionalInterfaceType(e.functionType as KaFunctionType)
if (!functionType.isFunctionType) {
logger.warnElement(
"Cannot find functional interface type for function expression",
@@ -113,6 +111,264 @@ fun KotlinFileExtractor.extractFunctionLiteral(
return idLambdaExpr
}
context(KaSession)
private fun KotlinFileExtractor.getRealFunctionalInterfaceType(typeFromApi: KaFunctionType): KaType {
if (typeFromApi.arity < BuiltInFunctionArity.BIG_ARITY) {
return typeFromApi
}
// TODO: the below doesn't work, see https://youtrack.jetbrains.com/issue/KT-73421/
return buildClassType(
ClassId(
FqName("kotlin.jvm.functions"),
Name.identifier("FunctionN")
)
) {
argument(typeFromApi.typeArguments.last().type!!)
}
}
/**
* This function generates an implementation for `fun kotlin.FunctionN<R>.invoke(vararg args: Any?): R`
*
* The following body is added:
* ```
* fun invoke(vararg a0: Any?): R {
* return invoke(a0[0] as T0, a0[1] as T1, ..., a0[I] as TI)
* }
* ```
* */
context(KaSession)
private fun KotlinFileExtractor.implementFunctionNInvoke(
lambda: KaFunctionSymbol,
ids: LocallyVisibleFunctionLabels,
locId: Label<DbLocation>,
parameters: List<KaParameterSymbol>
) {
val funLabels =
addFunctionNInvoke(
tw.getFreshIdLabel(),
lambda.returnType,
ids.type.javaResult.id.cast<DbReftype>(),
locId
)
// Return
val retId = tw.getFreshIdLabel<DbReturnstmt>()
tw.writeStmts_returnstmt(retId, funLabels.blockId, 0, funLabels.methodId)
tw.writeHasLocation(retId, locId)
// Call to original `invoke`:
val callId = tw.getFreshIdLabel<DbMethodaccess>()
val callType = useType(lambda.returnType)
tw.writeExprs_methodaccess(callId, callType.javaResult.id, retId, 0)
tw.writeExprsKotlinType(callId, callType.kotlinResult.id)
extractExprContext(callId, locId, funLabels.methodId, retId)
tw.writeCallableBinding(callId, ids.function)
// this access
// OLD: KE1
// extractThisAccess(ids.type, funLabels.methodId, callId, -1, retId, locId)
addArgumentsToInvocationInInvokeNBody(
parameters.map { it.type },
funLabels,
retId,
callId,
locId
)
}
private data class FunctionLabels(
val methodId: Label<DbMethod>,
val blockId: Label<DbBlock>,
val parameters: List<Pair<Label<DbParam>, TypeResults>>
)
/**
* Adds a function `invoke(a: Any[])` with the specified return type to the class identified by
* `parentId`.
*/
context(KaSession)
private fun KotlinFileExtractor.addFunctionNInvoke(
methodId: Label<DbMethod>,
returnType: KaType,
parentId: Label<out DbReftype>,
locId: Label<DbLocation>
): FunctionLabels {
return addFunctionInvoke(
methodId,
listOf(nullableAnyArrayType),
returnType,
parentId,
locId
)
}
context(KaSession)
private val nullableAnyArrayType: KaType
get() = buildClassType(ClassId.topLevel(StandardNames.FqNames.array.toSafe())) {
argument(builtinTypes.nullableAny)
}
/**
* Adds a function named `invoke` with the specified parameter types and return type to the
* class identified by `parentId`.
*/
private fun KotlinFileExtractor.addFunctionInvoke(
methodId: Label<DbMethod>,
parameterTypes: List<KaType>,
returnType: KaType,
parentId: Label<out DbReftype>,
locId: Label<DbLocation>
): FunctionLabels {
return addFunctionManual(
methodId,
"invoke",
parameterTypes,
returnType,
parentId,
locId
)
}
/**
* Extracts a function with the given name, parameter types, return type, containing type, and
* location.
*/
private fun KotlinFileExtractor.addFunctionManual(
methodId: Label<DbMethod>,
name: String,
parameterTypes: List<KaType>,
returnType: KaType,
parentId: Label<out DbReftype>,
locId: Label<DbLocation>
): FunctionLabels {
val parameters =
parameterTypes.mapIndexed { idx, p ->
val paramId = tw.getFreshIdLabel<DbParam>()
val paramType =
extractValueParameter(
paramId,
p,
"a$idx",
locId,
methodId,
idx,
paramId,
syntheticParameterNames = false,
isVararg = false,
isNoinline = false,
isCrossinline = false
)
Pair(paramId, paramType)
}
/* OLD: KE1
val paramsSignature =
parameters.joinToString(separator = ",", prefix = "(", postfix = ")") {
signatureOrWarn(it.second.javaResult, declarationStack.tryPeek()?.first)
}
*/
val paramsSignature = "()" // TODO
val rt = useType(returnType, TypeContext.RETURN)
tw.writeMethods(
methodId,
name,
"$name$paramsSignature",
rt.javaResult.id,
parentId,
methodId
)
tw.writeMethodsKotlinType(methodId, rt.kotlinResult.id)
tw.writeHasLocation(methodId, locId)
addModifiers(methodId, "public")
addModifiers(methodId, "override")
return FunctionLabels(methodId, extractBlockBody(methodId, locId), parameters)
}
/**
* Adds the arguments to the method call inside `invoke(a0: Any[])`. Each argument is an array
* access with a cast:
* ```
* fun invoke(a0: Any[]) : T {
* return fn(a0[0] as T0, a0[1] as T1, ...)
* }
* ```
*/
context(KaSession)
private fun KotlinFileExtractor.addArgumentsToInvocationInInvokeNBody(
parameterTypes: List<KaType>, // list of parameter types
funLabels: FunctionLabels, // already generated labels for the function definition
enclosingStmtId: Label<out DbStmt>, // label for the enclosing statement (return)
exprParentId: Label<out DbExprparent>, // label for the expression parent (call)
locId: Label<DbLocation>, // label for the location of all generated items
firstArgumentOffset: Int =
0, // 0 or 1, the index used for the first argument. 1 in case an extension parameter is already accessed at index 0
useFirstArgAsDispatch: Boolean =
false, // true if the first argument should be used as the dispatch receiver
dispatchReceiverIdx: Int =
-1 // index of the dispatch receiver. -1 in case of functions, -2 in case of constructors
) {
val argsParamType = nullableAnyArrayType
val argsType = useType(argsParamType)
val anyNType = useType(builtinTypes.nullableAny)
val dispatchIdxOffset = if (useFirstArgAsDispatch) 1 else 0
for ((pIdx, pType) in parameterTypes.withIndex()) {
// `a0[i] as Ti` is generated below for each parameter
val childIdx =
if (pIdx == 0 && useFirstArgAsDispatch) {
dispatchReceiverIdx
} else {
pIdx + firstArgumentOffset - dispatchIdxOffset
}
// cast: `(Ti)a0[i]`
val castId = tw.getFreshIdLabel<DbCastexpr>()
val type = useType(pType)
tw.writeExprs_castexpr(castId, type.javaResult.id, exprParentId, childIdx)
tw.writeExprsKotlinType(castId, type.kotlinResult.id)
extractExprContext(castId, locId, funLabels.methodId, enclosingStmtId)
// type access `Ti`
// TODO: extractTypeAccessRecursive(pType, locId, castId, 0, funLabels.methodId, enclosingStmtId)
// element access: `a0[i]`
val arrayAccessId = tw.getFreshIdLabel<DbArrayaccess>()
tw.writeExprs_arrayaccess(arrayAccessId, anyNType.javaResult.id, castId, 1)
tw.writeExprsKotlinType(arrayAccessId, anyNType.kotlinResult.id)
extractExprContext(arrayAccessId, locId, funLabels.methodId, enclosingStmtId)
// parameter access: `a0`
val argsAccessId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(argsAccessId, argsType.javaResult.id, arrayAccessId, 0)
tw.writeExprsKotlinType(argsAccessId, argsType.kotlinResult.id)
extractExprContext(argsAccessId, locId, funLabels.methodId, enclosingStmtId)
tw.writeVariableBinding(argsAccessId, funLabels.parameters.first().first)
// index access: `i`
extractConstantInteger(
pIdx.toString(),
builtinTypes.int,
pIdx,
locId,
arrayAccessId,
1,
funLabels.methodId,
enclosingStmtId
)
}
}
/**
* Gets the labels for functions belonging to
* - local functions, and
@@ -193,11 +449,11 @@ private fun KotlinFileExtractor.extractGeneratedClass(
locId: Label<DbLocation>,
elementToReportOn: PsiElement,
compilerGeneratedKindOverride: CompilerGeneratedKinds? = null,
superConstructorSelector: (KaFunctionSymbol) -> Boolean = { it.valueParameters.isEmpty() },
extractSuperConstructorArgs: (Label<DbSuperconstructorinvocationstmt>) -> Unit = {},
/*
OLD: KE1
declarationParent: IrDeclarationParent,
superConstructorSelector: (IrFunction) -> Boolean = { it.valueParameters.isEmpty() },
extractSuperconstructorArgs: (Label<DbSuperconstructorinvocationstmt>) -> Unit = {},
*/
): Label<out DbClassorinterface> {
// Write class
@@ -223,15 +479,13 @@ private fun KotlinFileExtractor.extractGeneratedClass(
tw.writeHasLocation(constructorBlockId, locId)
// Super call
// TODO: we should check if this is class or not
val baseClass = superTypes.first() as? KaClassType // superTypes.first().classOrNull
if (baseClass == null) {
val baseClass = superTypes.first() as? KaClassType
if ((baseClass?.symbol as? KaClassSymbol)?.classKind != KaClassKind.CLASS) {
logger.warnElement("Cannot find base class", elementToReportOn)
} else {
val baseConstructor =
baseClass.scope?.declarationScope?.constructors?.find {
// TODO: OLD KE1 superConstructorSelector(it)
true
superConstructorSelector(it)
}
if (baseConstructor == null) {
logger.warnElement("Cannot find base constructor", elementToReportOn)
@@ -251,7 +505,7 @@ private fun KotlinFileExtractor.extractGeneratedClass(
tw.writeHasLocation(superCallId, locId)
tw.writeCallableBinding(superCallId.cast<DbCaller>(), baseConstructorId)
// TODO: OLD KE1 extractSuperconstructorArgs(superCallId)
extractSuperConstructorArgs(superCallId)
}
}
}

View File

@@ -1,7 +1,15 @@
package com.github.codeql.utils
import org.jetbrains.kotlin.analysis.api.symbols.*
import org.jetbrains.kotlin.analysis.api.types.KaType
val KaClassSymbol.isInterfaceLike
get() = classKind == KaClassKind.INTERFACE || classKind == KaClassKind.ANNOTATION_CLASS
val KaParameterSymbol.type: KaType
get() {
return when (this) {
is KaValueParameterSymbol -> this.returnType
is KaReceiverParameterSymbol -> this.type
}
}