Extract function expressions

This commit is contained in:
Tamas Vajk
2021-12-03 14:39:11 +01:00
committed by Ian Lynagh
parent b32ac935f6
commit f4c87cb79d
17 changed files with 371 additions and 13 deletions

View File

@@ -1,16 +1,21 @@
package com.github.codeql
import com.github.codeql.utils.versions.functionN
import com.semmle.extractor.java.OdasaOutput
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.builtins.functions.BuiltInFunctionArity
import org.jetbrains.kotlin.descriptors.ClassKind
import org.jetbrains.kotlin.descriptors.Modality
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.interpreter.toIrConst
import org.jetbrains.kotlin.ir.symbols.IrConstructorSymbol
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.util.OperatorNameConventions
open class KotlinFileExtractor(
override val logger: FileLogger,
@@ -254,13 +259,16 @@ open class KotlinFileExtractor(
return FieldResult(instanceId, instanceName)
}
fun extractValueParameter(vp: IrValueParameter, parent: Label<out DbCallable>, idx: Int): TypeResults {
val id = useValueParameter(vp)
val type = useType(vp.type)
val locId = tw.getLocation(vp)
private fun extractValueParameter(vp: IrValueParameter, parent: Label<out DbCallable>, idx: Int): TypeResults {
return extractValueParameter(useValueParameter(vp), vp.type, vp.name.asString(), vp, parent, idx)
}
private fun extractValueParameter(id: Label<out DbParam>, t: IrType, name: String, loc: IrElement, parent: Label<out DbCallable>, idx: Int): TypeResults {
val type = useType(t)
val locId = tw.getLocation(loc)
tw.writeParams(id, type.javaResult.id, type.kotlinResult.id, idx, parent, id)
tw.writeHasLocation(id, locId)
tw.writeParamName(id, vp.name.asString())
tw.writeParamName(id, name)
return type
}
@@ -371,9 +379,10 @@ open class KotlinFileExtractor(
tw.writeConstrs(id as Label<DbConstructor>, shortName, "$shortName$paramsSignature", returnType.javaResult.id, returnType.kotlinResult.id, parentId, id)
} else {
val returnType = useType(f.returnType, TypeContext.RETURN)
val shortName = f.name.asString()
val shortName = getFunctionShortName(f)
@Suppress("UNCHECKED_CAST")
tw.writeMethods(id as Label<DbMethod>, shortName, "$shortName$paramsSignature", returnType.javaResult.id, returnType.kotlinResult.id, parentId, id)
// TODO: fix `sourceId`. It doesn't always match the method ID.
}
tw.writeHasLocation(id, locId)
@@ -1612,6 +1621,85 @@ open class KotlinFileExtractor(
tw.writeVariableBinding(id, instance.id)
}
}
is IrFunctionExpression -> {
val ids = getLocalFunctionLabels(e.function)
val locId = tw.getLocation(e)
val ext = e.function.extensionReceiverParameter
val parameters = if (ext != null) {
val l = mutableListOf(ext)
l.addAll(e.function.valueParameters)
l
} else {
e.function.valueParameters
}
var types = parameters.map { it.type }
types += e.function.returnType
val fnInterface = if (types.size > BuiltInFunctionArity.BIG_ARITY) {
pluginContext.referenceClass(FqName("kotlin.jvm.functions.FunctionN"))!!.typeWith(e.function.returnType)
} else {
functionN(pluginContext)(parameters.size).typeWith(types)
}
extractGeneratedClass(
e.function, // We're adding this function as a member, and changing its name to `invoke` to implement `kotlin.FunctionX<,,,>.invoke(,,)`
listOf(
pluginContext.referenceClass(FqName("kotlin.jvm.internal.Lambda"))!!.typeWith(),
fnInterface),
listOf(e.function.valueParameters.size.toIrConst(pluginContext.irBuiltIns.intType, e.startOffset, e.endOffset)))
val objectType = useType(pluginContext.irBuiltIns.anyNType).javaResult.id
// Only add bridge method if its signature is different from the lambda function
if (!types.all { useType(it).javaResult.id == objectType } ||
types.size > BuiltInFunctionArity.BIG_ARITY) {
val methodId = tw.getFreshIdLabel<DbMethod>()
val paramTypes =
if (types.size > BuiltInFunctionArity.BIG_ARITY) {
// signature is `Object invoke(Object[] p)`
listOf(extractValueParameter(tw.getFreshIdLabel(), pluginContext.irBuiltIns.arrayClass.typeWith(pluginContext.irBuiltIns.anyNType), "p", e, methodId, 0))
} else {
// signature is `Object invoke(Object p0, Object p1, ..., Object pN)`
parameters.mapIndexed { i, _ ->
extractValueParameter(tw.getFreshIdLabel(), pluginContext.irBuiltIns.anyNType, "p$i", e, methodId, i)
}
}
val paramsSignature = paramTypes.joinToString(separator = ",", prefix = "(", postfix = ")") { it.javaResult.signature!! }
val returnType = useType(pluginContext.irBuiltIns.anyNType, TypeContext.RETURN)
val shortName = OperatorNameConventions.INVOKE.asString()
@Suppress("UNCHECKED_CAST")
tw.writeMethods(methodId, shortName, "$shortName$paramsSignature", returnType.javaResult.id, returnType.kotlinResult.id, ids.type.javaResult.id as Label<out DbReftype>, methodId)
tw.writeHasLocation(methodId, locId)
// TODO:
// - Add body of bridge method, which calls `e.function`:
// ```
// public int invoke(int i, Object j, String k) { return 5; }
// public Object invoke(Object p0, Object p1, Object p2) {
// return invoke((int)p0, (Object)p1, (String)p2);
// or
// invoke((int)p0, (Object)p1, (String)p2);
// return kotlin.Unit.INSTANCE
// }
// ```
}
val exprParent = parent.expr(e, callable)
val idNewexpr = tw.getFreshIdLabel<DbNewexpr>()
tw.writeExprs_newexpr(idNewexpr, ids.type.javaResult.id, ids.type.kotlinResult.id, exprParent.parent, exprParent.idx)
tw.writeHasLocation(idNewexpr, locId)
tw.writeCallableEnclosingExpr(idNewexpr, callable)
tw.writeStatementEnclosingExpr(idNewexpr, exprParent.enclosingStmt)
tw.writeCallableBinding(idNewexpr, ids.constructor)
}
else -> {
logger.warnElement(Severity.ErrorSevere, "Unrecognised IrExpression: " + e.javaClass, e)
}
@@ -1754,7 +1842,7 @@ open class KotlinFileExtractor(
private val IrType.isAnonymous: Boolean
get() = ((this as? IrSimpleType)?.classifier?.owner as? IrClass)?.isAnonymousObject ?: false
fun extractGeneratedClass(localFunction: IrFunction, superTypes: List<IrType>) : Label<out DbClass> {
private fun extractGeneratedClass(localFunction: IrFunction, superTypes: List<IrType>, superConstructorArgs: List<IrExpression> = listOf()) : Label<out DbClass> {
val ids = getLocalFunctionLabels(localFunction)
// Write class
@@ -1780,6 +1868,10 @@ open class KotlinFileExtractor(
// Super call
val superCallId = tw.getFreshIdLabel<DbSuperconstructorinvocationstmt>()
tw.writeStmts_superconstructorinvocationstmt(superCallId, constructorBlockId, 0, ids.function)
for (i in 0 until superConstructorArgs.size) {
val arg = superConstructorArgs[i]
extractExpressionExpr(arg, ids.constructor, superCallId, i, superCallId)
}
val baseConstructor = superTypes.first().classOrNull!!.owner.declarations.find { it is IrFunction && it.symbol is IrConstructorSymbol }
val baseConstructorId = useFunction<DbConstructor>(baseConstructor as IrFunction)

View File

@@ -13,7 +13,9 @@ import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.types.impl.IrSimpleTypeImpl
import org.jetbrains.kotlin.ir.types.impl.makeTypeProjection
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.name.SpecialNames
import org.jetbrains.kotlin.types.Variance
import org.jetbrains.kotlin.util.OperatorNameConventions
open class KotlinUsesExtractor(
open val logger: Logger,
@@ -452,8 +454,17 @@ class X {
}
}
private val IrDeclaration.isAnonymousFunction get() = this is IrSimpleFunction && name == SpecialNames.NO_NAME_PROVIDED
fun getFunctionShortName(f: IrFunction) : String {
if (f.origin == IrDeclarationOrigin.LOCAL_FUNCTION_FOR_LAMBDA || f.isAnonymousFunction)
return OperatorNameConventions.INVOKE.asString()
else
return f.name.asString()
}
fun getFunctionLabel(f: IrFunction) : String {
return getFunctionLabel(f.parent, f.name.asString(), f.valueParameters, f.returnType, f.extensionReceiverParameter)
return getFunctionLabel(f.parent, getFunctionShortName(f), f.valueParameters, f.returnType, f.extensionReceiverParameter)
}
fun getFunctionLabel(
@@ -477,8 +488,7 @@ class X {
}
protected fun IrFunction.isLocalFunction(): Boolean {
return this.visibility == DescriptorVisibilities.LOCAL &&
this.origin != IrDeclarationOrigin.LOCAL_FUNCTION_FOR_LAMBDA
return this.visibility == DescriptorVisibilities.LOCAL
}
private val generatedLocalFunctionTypeMapping: MutableMap<IrFunction, LocalFunctionLabels> = mutableMapOf()

View File

@@ -0,0 +1,8 @@
package com.github.codeql.utils.versions
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.ir.declarations.IrClass
fun functionN(pluginContext: IrPluginContext): (Int) -> IrClass {
return { i -> pluginContext.irBuiltIns.functionFactory.functionN(i) }
}

View File

@@ -0,0 +1,8 @@
package com.github.codeql.utils.versions
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.ir.declarations.IrClass
fun functionN(pluginContext: IrPluginContext): (Int) -> IrClass {
return { i -> pluginContext.irBuiltIns.functionFactory.functionN(i) }
}

View File

@@ -0,0 +1,5 @@
package com.github.codeql.utils.versions
import org.jetbrains.kotlin.ir.IrFileEntry
typealias FileEntry = IrFileEntry

View File

@@ -0,0 +1,5 @@
package com.github.codeql.utils.versions
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
fun functionN(pluginContext: IrPluginContext) = pluginContext.irBuiltIns::functionN

View File

@@ -0,0 +1,18 @@
package com.github.codeql.utils.versions
import com.intellij.psi.PsiElement
import org.jetbrains.kotlin.backend.common.psi.PsiSourceManager
import org.jetbrains.kotlin.backend.jvm.ir.getKtFile
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.declarations.IrFile
import org.jetbrains.kotlin.psi.KtFile
class Psi2Ir: Psi2IrFacade {
override fun getKtFile(irFile: IrFile): KtFile? {
return irFile.getKtFile()
}
override fun findPsiElement(irElement: IrElement, irFile: IrFile): PsiElement? {
return PsiSourceManager.findPsiElement(irElement, irFile)
}
}