From 44e318546fc4e14218d6cb40869c2ea721cd3a02 Mon Sep 17 00:00:00 2001 From: Tamas Vajk Date: Tue, 26 Nov 2024 13:36:35 +0100 Subject: [PATCH] KE2: Extract more constructs for lambda expressions --- .../src/main/kotlin/KotlinFileExtractor.kt | 245 --------------- .../src/main/kotlin/entities/Expression.kt | 4 +- .../src/main/kotlin/entities/Function.kt | 33 ++ .../kotlin/entities/FunctionalInterface.kt | 294 ++++++++++++++++-- .../src/main/kotlin/utils/Helpers.kt | 8 + 5 files changed, 317 insertions(+), 267 deletions(-) diff --git a/java/kotlin-extractor2/src/main/kotlin/KotlinFileExtractor.kt b/java/kotlin-extractor2/src/main/kotlin/KotlinFileExtractor.kt index ab6c73f4f5d..cede8d6bec5 100644 --- a/java/kotlin-extractor2/src/main/kotlin/KotlinFileExtractor.kt +++ b/java/kotlin-extractor2/src/main/kotlin/KotlinFileExtractor.kt @@ -4986,251 +4986,6 @@ OLD: KE1 functionNTypeArguments.map { makeTypeProjection(it, Variance.INVARIANT) } ) - private fun getFunctionalInterfaceTypeWithTypeArgs( - functionNTypeArguments: List - ) = - if (functionNTypeArguments.size > BuiltInFunctionArity.BIG_ARITY) - referenceExternalClass("kotlin.jvm.functions.FunctionN") - ?.symbol - ?.typeWithArguments(listOf(functionNTypeArguments.last())) - else - functionN(pluginContext)(functionNTypeArguments.size - 1) - .symbol - .typeWithArguments(functionNTypeArguments) - - private data class FunctionLabels( - val methodId: Label, - val blockId: Label, - val parameters: List, TypeResults>> - ) - - /** - * Adds a function `invoke(a: Any[])` with the specified return type to the class identified by - * `parentId`. - */ - private fun addFunctionNInvoke( - methodId: Label, - returnType: IrType, - parentId: Label, - locId: Label - ): FunctionLabels { - return addFunctionInvoke( - methodId, - listOf(pluginContext.irBuiltIns.arrayClass.typeWith(pluginContext.irBuiltIns.anyNType)), - returnType, - parentId, - locId - ) - } - - /** - * Adds a function named `invoke` with the specified parameter types and return type to the - * class identified by `parentId`. - */ - private fun addFunctionInvoke( - methodId: Label, - parameterTypes: List, - returnType: IrType, - parentId: Label, - locId: Label - ): FunctionLabels { - return addFunctionManual( - methodId, - OperatorNameConventions.INVOKE.asString(), - parameterTypes, - returnType, - parentId, - locId - ) - } - - /** - * Extracts a function with the given name, parameter types, return type, containing type, and - * location. - */ - private fun addFunctionManual( - methodId: Label, - name: String, - parameterTypes: List, - returnType: IrType, - parentId: Label, - locId: Label - ): FunctionLabels { - - val parameters = - parameterTypes.mapIndexed { idx, p -> - val paramId = tw.getFreshIdLabel() - val paramType = - extractValueParameter( - paramId, - p, - "a$idx", - locId, - methodId, - idx, - paramId, - syntheticParameterNames = false, - isVararg = false, - isNoinline = false, - isCrossinline = false - ) - - Pair(paramId, paramType) - } - - val paramsSignature = - parameters.joinToString(separator = ",", prefix = "(", postfix = ")") { - signatureOrWarn(it.second.javaResult, declarationStack.tryPeek()?.first) - } - - val rt = useType(returnType, TypeContext.RETURN) - tw.writeMethods( - methodId, - name, - "$name$paramsSignature", - rt.javaResult.id, - parentId, - methodId - ) - tw.writeMethodsKotlinType(methodId, rt.kotlinResult.id) - tw.writeHasLocation(methodId, locId) - - addModifiers(methodId, "public") - addModifiers(methodId, "override") - - return FunctionLabels(methodId, extractBlockBody(methodId, locId), parameters) - } - - /* - * This function generates an implementation for `fun kotlin.FunctionN.invoke(vararg args: Any?): R` - * - * The following body is added: - * ``` - * fun invoke(vararg a0: Any?): R { - * return invoke(a0[0] as T0, a0[1] as T1, ..., a0[I] as TI) - * } - * ``` - * */ - private fun implementFunctionNInvoke( - lambda: IrFunction, - ids: LocallyVisibleFunctionLabels, - locId: Label, - parameters: List - ) { - val funLabels = - addFunctionNInvoke( - tw.getFreshIdLabel(), - lambda.returnType, - ids.type.javaResult.id.cast(), - locId - ) - - // Return - val retId = tw.getFreshIdLabel() - tw.writeStmts_returnstmt(retId, funLabels.blockId, 0, funLabels.methodId) - tw.writeHasLocation(retId, locId) - - // Call to original `invoke`: - val callId = tw.getFreshIdLabel() - val callType = useType(lambda.returnType) - tw.writeExprs_methodaccess(callId, callType.javaResult.id, retId, 0) - tw.writeExprsKotlinType(callId, callType.kotlinResult.id) - extractExprContext(callId, locId, funLabels.methodId, retId) - val calledMethodId = useFunction(lambda) - if (calledMethodId == null) { - logger.errorElement("Cannot get ID for called lambda", lambda) - } else { - tw.writeCallableBinding(callId, calledMethodId) - } - - // this access - extractThisAccess(ids.type, funLabels.methodId, callId, -1, retId, locId) - - addArgumentsToInvocationInInvokeNBody( - parameters.map { it.type }, - funLabels, - retId, - callId, - locId - ) - } - - /** - * Adds the arguments to the method call inside `invoke(a0: Any[])`. Each argument is an array - * access with a cast: - * ``` - * fun invoke(a0: Any[]) : T { - * return fn(a0[0] as T0, a0[1] as T1, ...) - * } - * ``` - */ - private fun addArgumentsToInvocationInInvokeNBody( - parameterTypes: List, // list of parameter types - funLabels: FunctionLabels, // already generated labels for the function definition - enclosingStmtId: Label, // label for the enclosing statement (return) - exprParentId: Label, // label for the expression parent (call) - locId: Label, // label for the location of all generated items - firstArgumentOffset: Int = - 0, // 0 or 1, the index used for the first argument. 1 in case an extension parameter is - // already accessed at index 0 - useFirstArgAsDispatch: Boolean = - false, // true if the first argument should be used as the dispatch receiver - dispatchReceiverIdx: Int = - -1 // index of the dispatch receiver. -1 in case of functions, -2 in case of - // constructors - ) { - val argsParamType = - pluginContext.irBuiltIns.arrayClass.typeWith(pluginContext.irBuiltIns.anyNType) - val argsType = useType(argsParamType) - val anyNType = useType(pluginContext.irBuiltIns.anyNType) - - val dispatchIdxOffset = if (useFirstArgAsDispatch) 1 else 0 - - for ((pIdx, pType) in parameterTypes.withIndex()) { - // `a0[i] as Ti` is generated below for each parameter - - val childIdx = - if (pIdx == 0 && useFirstArgAsDispatch) { - dispatchReceiverIdx - } else { - pIdx + firstArgumentOffset - dispatchIdxOffset - } - - // cast: `(Ti)a0[i]` - val castId = tw.getFreshIdLabel() - val type = useType(pType) - tw.writeExprs_castexpr(castId, type.javaResult.id, exprParentId, childIdx) - tw.writeExprsKotlinType(castId, type.kotlinResult.id) - extractExprContext(castId, locId, funLabels.methodId, enclosingStmtId) - - // type access `Ti` - extractTypeAccessRecursive(pType, locId, castId, 0, funLabels.methodId, enclosingStmtId) - - // element access: `a0[i]` - val arrayAccessId = tw.getFreshIdLabel() - tw.writeExprs_arrayaccess(arrayAccessId, anyNType.javaResult.id, castId, 1) - tw.writeExprsKotlinType(arrayAccessId, anyNType.kotlinResult.id) - extractExprContext(arrayAccessId, locId, funLabels.methodId, enclosingStmtId) - - // parameter access: `a0` - val argsAccessId = tw.getFreshIdLabel() - tw.writeExprs_varaccess(argsAccessId, argsType.javaResult.id, arrayAccessId, 0) - tw.writeExprsKotlinType(argsAccessId, argsType.kotlinResult.id) - extractExprContext(argsAccessId, locId, funLabels.methodId, enclosingStmtId) - tw.writeVariableBinding(argsAccessId, funLabels.parameters.first().first) - - // index access: `i` - extractConstantInteger( - pIdx, - locId, - arrayAccessId, - 1, - funLabels.methodId, - enclosingStmtId - ) - } - } - private fun extractVarargElement( e: IrVarargElement, callable: Label, diff --git a/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt b/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt index 786296b7ec9..7f085db2b6f 100644 --- a/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt +++ b/java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt @@ -57,7 +57,7 @@ fun KotlinFileExtractor.extractBody(b: KtExpression, callable: Label, locId: Label) = +fun KotlinFileExtractor.extractBlockBody(callable: Label, locId: Label) = tw.getFreshIdLabel().also { tw.writeStmts_block(it, callable, 0, callable) tw.writeHasLocation(it, locId) @@ -166,7 +166,7 @@ OLD: KE1 } */ -private fun KotlinFileExtractor.extractConstantInteger( +fun KotlinFileExtractor.extractConstantInteger( text: String, t: KaType, v: Number, diff --git a/java/kotlin-extractor2/src/main/kotlin/entities/Function.kt b/java/kotlin-extractor2/src/main/kotlin/entities/Function.kt index bd2a11e1de4..823fc7dc31e 100644 --- a/java/kotlin-extractor2/src/main/kotlin/entities/Function.kt +++ b/java/kotlin-extractor2/src/main/kotlin/entities/Function.kt @@ -1,5 +1,6 @@ package com.github.codeql +import com.github.codeql.useType import com.github.codeql.utils.getJvmName import org.jetbrains.kotlin.analysis.api.KaSession import org.jetbrains.kotlin.analysis.api.symbols.* @@ -639,6 +640,38 @@ OLD: KE1 } } +fun KotlinFileExtractor.extractValueParameter( + id: Label, + t: KaType, + name: String, + locId: Label, + parent: Label, + idx: Int, + paramSourceDeclaration: Label, + syntheticParameterNames: Boolean, + isVararg: Boolean, + isNoinline: Boolean, + isCrossinline: Boolean +): TypeResults { + val type = useType(t) + tw.writeParams(id, type.javaResult.id, idx, parent, paramSourceDeclaration) + tw.writeParamsKotlinType(id, type.kotlinResult.id) + tw.writeHasLocation(id, locId) + if (!syntheticParameterNames) { + tw.writeParamName(id, name) + } + if (isVararg) { + tw.writeIsVarargsParam(id) + } + if (isNoinline) { + addModifiers(id, "noinline") + } + if (isCrossinline) { + addModifiers(id, "crossinline") + } + return type +} + // TODO: Can this be inlined? private fun KotlinFileExtractor.extractMethod( id: Label, diff --git a/java/kotlin-extractor2/src/main/kotlin/entities/FunctionalInterface.kt b/java/kotlin-extractor2/src/main/kotlin/entities/FunctionalInterface.kt index 5cc73d1c17a..401d41306db 100644 --- a/java/kotlin-extractor2/src/main/kotlin/entities/FunctionalInterface.kt +++ b/java/kotlin-extractor2/src/main/kotlin/entities/FunctionalInterface.kt @@ -1,11 +1,14 @@ import com.github.codeql.* import com.github.codeql.KotlinFileExtractor.StmtExprParent +import com.github.codeql.utils.type import com.intellij.psi.PsiElement import org.jetbrains.kotlin.analysis.api.KaExperimentalApi import org.jetbrains.kotlin.analysis.api.KaSession import org.jetbrains.kotlin.analysis.api.symbols.* import org.jetbrains.kotlin.analysis.api.types.KaClassType +import org.jetbrains.kotlin.analysis.api.types.KaFunctionType import org.jetbrains.kotlin.analysis.api.types.KaType +import org.jetbrains.kotlin.builtins.StandardNames import org.jetbrains.kotlin.builtins.functions.BuiltInFunctionArity import org.jetbrains.kotlin.name.ClassId import org.jetbrains.kotlin.name.FqName @@ -45,19 +48,15 @@ fun KotlinFileExtractor.extractFunctionLiteral( val functionSymbol = e.symbol val ids = getLocallyVisibleFunctionLabels(functionSymbol) - // todo: is it possible that the receiver parameter is a dispatch receiver? - val ext = if (functionSymbol.isExtension) functionSymbol.receiverParameter else null - val parameterTypes = functionSymbol.valueParameters.map { (it as KaVariableSymbol).returnType }.toMutableList() - if (ext != null) { - parameterTypes.add(0, ext.type) + val parameters = if (functionSymbol.isExtension) { + listOf(functionSymbol.receiverParameter!!) + functionSymbol.valueParameters + } else { + functionSymbol.valueParameters } - parameterTypes += functionSymbol.returnType - - val isBigArity = parameterTypes.size > BuiltInFunctionArity.BIG_ARITY + val isBigArity = parameters.size >= BuiltInFunctionArity.BIG_ARITY if (isBigArity) { - // OLD: KE1 - // implementFunctionNInvoke(e.function, ids, locId, parameters) + implementFunctionNInvoke(functionSymbol, ids, locId, parameters) } else { addModifiers(ids.function, "override") } @@ -77,8 +76,7 @@ fun KotlinFileExtractor.extractFunctionLiteral( // todo: fix hard coded block body of lambda tw.writeLambdaKind(idLambdaExpr, 1) - val functionType = - e.functionType // TODO: change this type for BIG_ARITY lambdas, this should be kotlin.FunctionN and not kotlin.Function33<....,R>. The latter doesn't exist. + val functionType = getRealFunctionalInterfaceType(e.functionType as KaFunctionType) if (!functionType.isFunctionType) { logger.warnElement( "Cannot find functional interface type for function expression", @@ -113,6 +111,264 @@ fun KotlinFileExtractor.extractFunctionLiteral( return idLambdaExpr } +context(KaSession) +private fun KotlinFileExtractor.getRealFunctionalInterfaceType(typeFromApi: KaFunctionType): KaType { + if (typeFromApi.arity < BuiltInFunctionArity.BIG_ARITY) { + return typeFromApi + } + + // TODO: the below doesn't work, see https://youtrack.jetbrains.com/issue/KT-73421/ + return buildClassType( + ClassId( + FqName("kotlin.jvm.functions"), + Name.identifier("FunctionN") + ) + ) { + argument(typeFromApi.typeArguments.last().type!!) + } +} + +/** + * This function generates an implementation for `fun kotlin.FunctionN.invoke(vararg args: Any?): R` + * + * The following body is added: + * ``` + * fun invoke(vararg a0: Any?): R { + * return invoke(a0[0] as T0, a0[1] as T1, ..., a0[I] as TI) + * } + * ``` + * */ +context(KaSession) +private fun KotlinFileExtractor.implementFunctionNInvoke( + lambda: KaFunctionSymbol, + ids: LocallyVisibleFunctionLabels, + locId: Label, + parameters: List +) { + val funLabels = + addFunctionNInvoke( + tw.getFreshIdLabel(), + lambda.returnType, + ids.type.javaResult.id.cast(), + locId + ) + + // Return + val retId = tw.getFreshIdLabel() + tw.writeStmts_returnstmt(retId, funLabels.blockId, 0, funLabels.methodId) + tw.writeHasLocation(retId, locId) + + // Call to original `invoke`: + val callId = tw.getFreshIdLabel() + val callType = useType(lambda.returnType) + tw.writeExprs_methodaccess(callId, callType.javaResult.id, retId, 0) + tw.writeExprsKotlinType(callId, callType.kotlinResult.id) + extractExprContext(callId, locId, funLabels.methodId, retId) + tw.writeCallableBinding(callId, ids.function) + + // this access + // OLD: KE1 + // extractThisAccess(ids.type, funLabels.methodId, callId, -1, retId, locId) + + addArgumentsToInvocationInInvokeNBody( + parameters.map { it.type }, + funLabels, + retId, + callId, + locId + ) +} + +private data class FunctionLabels( + val methodId: Label, + val blockId: Label, + val parameters: List, TypeResults>> +) + +/** + * Adds a function `invoke(a: Any[])` with the specified return type to the class identified by + * `parentId`. + */ +context(KaSession) +private fun KotlinFileExtractor.addFunctionNInvoke( + methodId: Label, + returnType: KaType, + parentId: Label, + locId: Label +): FunctionLabels { + return addFunctionInvoke( + methodId, + listOf(nullableAnyArrayType), + returnType, + parentId, + locId + ) +} + +context(KaSession) +private val nullableAnyArrayType: KaType + get() = buildClassType(ClassId.topLevel(StandardNames.FqNames.array.toSafe())) { + argument(builtinTypes.nullableAny) + } + +/** + * Adds a function named `invoke` with the specified parameter types and return type to the + * class identified by `parentId`. + */ +private fun KotlinFileExtractor.addFunctionInvoke( + methodId: Label, + parameterTypes: List, + returnType: KaType, + parentId: Label, + locId: Label +): FunctionLabels { + return addFunctionManual( + methodId, + "invoke", + parameterTypes, + returnType, + parentId, + locId + ) +} + +/** + * Extracts a function with the given name, parameter types, return type, containing type, and + * location. + */ +private fun KotlinFileExtractor.addFunctionManual( + methodId: Label, + name: String, + parameterTypes: List, + returnType: KaType, + parentId: Label, + locId: Label +): FunctionLabels { + + val parameters = + parameterTypes.mapIndexed { idx, p -> + val paramId = tw.getFreshIdLabel() + val paramType = + extractValueParameter( + paramId, + p, + "a$idx", + locId, + methodId, + idx, + paramId, + syntheticParameterNames = false, + isVararg = false, + isNoinline = false, + isCrossinline = false + ) + + Pair(paramId, paramType) + } + + /* OLD: KE1 + val paramsSignature = + parameters.joinToString(separator = ",", prefix = "(", postfix = ")") { + signatureOrWarn(it.second.javaResult, declarationStack.tryPeek()?.first) + } + */ + val paramsSignature = "()" // TODO + + val rt = useType(returnType, TypeContext.RETURN) + tw.writeMethods( + methodId, + name, + "$name$paramsSignature", + rt.javaResult.id, + parentId, + methodId + ) + tw.writeMethodsKotlinType(methodId, rt.kotlinResult.id) + tw.writeHasLocation(methodId, locId) + + addModifiers(methodId, "public") + addModifiers(methodId, "override") + + return FunctionLabels(methodId, extractBlockBody(methodId, locId), parameters) +} + +/** + * Adds the arguments to the method call inside `invoke(a0: Any[])`. Each argument is an array + * access with a cast: + * ``` + * fun invoke(a0: Any[]) : T { + * return fn(a0[0] as T0, a0[1] as T1, ...) + * } + * ``` + */ +context(KaSession) +private fun KotlinFileExtractor.addArgumentsToInvocationInInvokeNBody( + parameterTypes: List, // list of parameter types + funLabels: FunctionLabels, // already generated labels for the function definition + enclosingStmtId: Label, // label for the enclosing statement (return) + exprParentId: Label, // label for the expression parent (call) + locId: Label, // label for the location of all generated items + firstArgumentOffset: Int = + 0, // 0 or 1, the index used for the first argument. 1 in case an extension parameter is already accessed at index 0 + useFirstArgAsDispatch: Boolean = + false, // true if the first argument should be used as the dispatch receiver + dispatchReceiverIdx: Int = + -1 // index of the dispatch receiver. -1 in case of functions, -2 in case of constructors +) { + val argsParamType = nullableAnyArrayType + val argsType = useType(argsParamType) + val anyNType = useType(builtinTypes.nullableAny) + + val dispatchIdxOffset = if (useFirstArgAsDispatch) 1 else 0 + + for ((pIdx, pType) in parameterTypes.withIndex()) { + // `a0[i] as Ti` is generated below for each parameter + + val childIdx = + if (pIdx == 0 && useFirstArgAsDispatch) { + dispatchReceiverIdx + } else { + pIdx + firstArgumentOffset - dispatchIdxOffset + } + + // cast: `(Ti)a0[i]` + val castId = tw.getFreshIdLabel() + val type = useType(pType) + tw.writeExprs_castexpr(castId, type.javaResult.id, exprParentId, childIdx) + tw.writeExprsKotlinType(castId, type.kotlinResult.id) + extractExprContext(castId, locId, funLabels.methodId, enclosingStmtId) + + // type access `Ti` + // TODO: extractTypeAccessRecursive(pType, locId, castId, 0, funLabels.methodId, enclosingStmtId) + + // element access: `a0[i]` + val arrayAccessId = tw.getFreshIdLabel() + tw.writeExprs_arrayaccess(arrayAccessId, anyNType.javaResult.id, castId, 1) + tw.writeExprsKotlinType(arrayAccessId, anyNType.kotlinResult.id) + extractExprContext(arrayAccessId, locId, funLabels.methodId, enclosingStmtId) + + // parameter access: `a0` + val argsAccessId = tw.getFreshIdLabel() + tw.writeExprs_varaccess(argsAccessId, argsType.javaResult.id, arrayAccessId, 0) + tw.writeExprsKotlinType(argsAccessId, argsType.kotlinResult.id) + extractExprContext(argsAccessId, locId, funLabels.methodId, enclosingStmtId) + tw.writeVariableBinding(argsAccessId, funLabels.parameters.first().first) + + // index access: `i` + extractConstantInteger( + pIdx.toString(), + builtinTypes.int, + pIdx, + locId, + arrayAccessId, + 1, + funLabels.methodId, + enclosingStmtId + ) + } +} + + /** * Gets the labels for functions belonging to * - local functions, and @@ -193,11 +449,11 @@ private fun KotlinFileExtractor.extractGeneratedClass( locId: Label, elementToReportOn: PsiElement, compilerGeneratedKindOverride: CompilerGeneratedKinds? = null, + superConstructorSelector: (KaFunctionSymbol) -> Boolean = { it.valueParameters.isEmpty() }, + extractSuperConstructorArgs: (Label) -> Unit = {}, /* OLD: KE1 declarationParent: IrDeclarationParent, - superConstructorSelector: (IrFunction) -> Boolean = { it.valueParameters.isEmpty() }, - extractSuperconstructorArgs: (Label) -> Unit = {}, */ ): Label { // Write class @@ -223,15 +479,13 @@ private fun KotlinFileExtractor.extractGeneratedClass( tw.writeHasLocation(constructorBlockId, locId) // Super call - // TODO: we should check if this is class or not - val baseClass = superTypes.first() as? KaClassType // superTypes.first().classOrNull - if (baseClass == null) { + val baseClass = superTypes.first() as? KaClassType + if ((baseClass?.symbol as? KaClassSymbol)?.classKind != KaClassKind.CLASS) { logger.warnElement("Cannot find base class", elementToReportOn) } else { val baseConstructor = baseClass.scope?.declarationScope?.constructors?.find { - // TODO: OLD KE1 superConstructorSelector(it) - true + superConstructorSelector(it) } if (baseConstructor == null) { logger.warnElement("Cannot find base constructor", elementToReportOn) @@ -251,7 +505,7 @@ private fun KotlinFileExtractor.extractGeneratedClass( tw.writeHasLocation(superCallId, locId) tw.writeCallableBinding(superCallId.cast(), baseConstructorId) - // TODO: OLD KE1 extractSuperconstructorArgs(superCallId) + extractSuperConstructorArgs(superCallId) } } } diff --git a/java/kotlin-extractor2/src/main/kotlin/utils/Helpers.kt b/java/kotlin-extractor2/src/main/kotlin/utils/Helpers.kt index c60badc360d..2b0c7e13e69 100644 --- a/java/kotlin-extractor2/src/main/kotlin/utils/Helpers.kt +++ b/java/kotlin-extractor2/src/main/kotlin/utils/Helpers.kt @@ -1,7 +1,15 @@ package com.github.codeql.utils import org.jetbrains.kotlin.analysis.api.symbols.* +import org.jetbrains.kotlin.analysis.api.types.KaType val KaClassSymbol.isInterfaceLike get() = classKind == KaClassKind.INTERFACE || classKind == KaClassKind.ANNOTATION_CLASS +val KaParameterSymbol.type: KaType + get() { + return when (this) { + is KaValueParameterSymbol -> this.returnType + is KaReceiverParameterSymbol -> this.type + } + } \ No newline at end of file