Rework SAM conversion extraction (handle arbitrary expression that's being converted)

This commit is contained in:
Tamas Vajk
2022-02-15 16:43:24 +01:00
committed by Ian Lynagh
parent 34ae00fa62
commit a598c7fc0c
6 changed files with 543 additions and 159 deletions

View File

@@ -9,6 +9,7 @@ import com.github.codeql.utils.toRawType
import com.github.codeql.utils.versions.getIrStubFromDescriptor
import com.semmle.extractor.java.OdasaOutput
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.backend.common.ir.simpleFunctions
import org.jetbrains.kotlin.backend.common.pop
import org.jetbrains.kotlin.builtins.functions.BuiltInFunctionArity
import org.jetbrains.kotlin.descriptors.*
@@ -22,6 +23,7 @@ import org.jetbrains.kotlin.ir.symbols.IrConstructorSymbol
import org.jetbrains.kotlin.ir.symbols.IrSymbol
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.metadata.ProtoBuf
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.util.OperatorNameConventions
import org.jetbrains.kotlin.types.Variance
@@ -545,7 +547,7 @@ open class KotlinFileExtractor(
}
}
fun extractFunction(f: IrFunction, parentId: Label<out DbReftype>, extractBody: Boolean, typeSubstitution: TypeSubstitution?, classTypeArgsIncludingOuterClasses: List<IrTypeArgument>?, memberName: String? = null): Label<out DbCallable> {
fun extractFunction(f: IrFunction, parentId: Label<out DbReftype>, extractBody: Boolean, typeSubstitution: TypeSubstitution?, classTypeArgsIncludingOuterClasses: List<IrTypeArgument>?, memberName: String? = null, idOverride: Label<DbMethod>? = null): Label<out DbCallable> {
with("function", f) {
DeclarationStackAdjuster(f).use {
@@ -554,7 +556,9 @@ open class KotlinFileExtractor(
val locId = tw.getLocation(f)
val id =
if (f.isLocalFunction())
if (idOverride != null)
idOverride
else if (f.isLocalFunction())
getLocallyVisibleFunctionLabels(f).function
else
useFunction<DbCallable>(f, parentId, classTypeArgsIncludingOuterClasses)
@@ -2297,11 +2301,64 @@ open class KotlinFileExtractor(
}
}
private inner class FunctionReferenceHelper(private val locId: Label<DbLocation>, private val ids: LocallyVisibleFunctionLabels) {
fun writeExpressionMetadataToTrapFile(id: Label<out DbExpr>, callable: Label<out DbCallable>, stmt: Label<out DbStmt>) {
tw.writeHasLocation(id, locId)
tw.writeCallableEnclosingExpr(id, callable)
tw.writeStatementEnclosingExpr(id, stmt)
}
/**
* Extract a parameter to field assignment, such as `this.field = paramName` below:
* ```
* constructor(paramName: type) {
* this.field = paramName
* }
* ```
*/
fun extractParameterToFieldAssignmentInConstructor(
paramName: String,
type: IrType,
fieldId: Label<DbField>,
paramIdx: Int,
stmtIdx: Int
) {
val paramId = tw.getFreshIdLabel<DbParam>()
val paramType = extractValueParameter(paramId, type, paramName, locId, ids.constructor, paramIdx, null, paramId, false)
val assignmentStmtId = tw.getFreshIdLabel<DbExprstmt>()
tw.writeStmts_exprstmt(assignmentStmtId, ids.constructorBlock, stmtIdx, ids.constructor)
tw.writeHasLocation(assignmentStmtId, locId)
val assignmentId = tw.getFreshIdLabel<DbAssignexpr>()
tw.writeExprs_assignexpr(assignmentId, paramType.javaResult.id, assignmentStmtId, 0)
tw.writeExprsKotlinType(assignmentId, paramType.kotlinResult.id)
writeExpressionMetadataToTrapFile(assignmentId, ids.constructor, assignmentStmtId)
val lhsId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(lhsId, paramType.javaResult.id, assignmentId, 0)
tw.writeExprsKotlinType(lhsId, paramType.kotlinResult.id)
tw.writeVariableBinding(lhsId, fieldId)
writeExpressionMetadataToTrapFile(lhsId, ids.constructor, assignmentStmtId)
val thisId = tw.getFreshIdLabel<DbThisaccess>()
tw.writeExprs_thisaccess(thisId, ids.type.javaResult.id, lhsId, -1)
tw.writeExprsKotlinType(thisId, ids.type.kotlinResult.id)
writeExpressionMetadataToTrapFile(thisId, ids.constructor, assignmentStmtId)
val rhsId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(rhsId, paramType.javaResult.id, assignmentId, 1)
tw.writeExprsKotlinType(rhsId, paramType.kotlinResult.id)
tw.writeVariableBinding(rhsId, paramId)
writeExpressionMetadataToTrapFile(rhsId, ids.constructor, assignmentStmtId)
}
}
private fun extractFunctionReference(
functionReferenceExpr: IrFunctionReference,
parent: StmtExprParent,
callable: Label<out DbCallable>
) {
) : Label<out DbClassinstancexpr> {
with("function reference", functionReferenceExpr) {
val target = functionReferenceExpr.reflectionTarget ?: run {
logger.errorElement("Expected to find reflection target for function reference. Using underlying symbol instead.", functionReferenceExpr)
@@ -2373,56 +2430,7 @@ open class KotlinFileExtractor(
val currentDeclaration = declarationStack.peek()
val id = extractGeneratedClass(ids, listOf(pluginContext.irBuiltIns.anyType, fnInterfaceType), locId, currentDeclaration)
fun writeExpressionMetadataToTrapFile(id: Label<out DbExpr>, callable: Label<out DbCallable>, stmt: Label<out DbStmt>) {
tw.writeHasLocation(id, locId)
tw.writeCallableEnclosingExpr(id, callable)
tw.writeStatementEnclosingExpr(id, stmt)
}
/**
* Extract a parameter to field assignment, such as `this.field = paramName` below:
* ```
* constructor(paramName: type) {
* this.field = paramName
* }
* ```
*/
fun extractParameterToFieldAssignmentInConstructor(
paramName: String,
type: IrType,
fieldId: Label<DbField>,
paramIdx: Int,
stmtIdx: Int
) {
val paramId = tw.getFreshIdLabel<DbParam>()
val paramType = extractValueParameter(paramId, type, paramName, locId, ids.constructor, paramIdx, null, paramId, false)
val assignmentStmtId = tw.getFreshIdLabel<DbExprstmt>()
tw.writeStmts_exprstmt(assignmentStmtId, ids.constructorBlock, stmtIdx, ids.constructor)
tw.writeHasLocation(assignmentStmtId, locId)
val assignmentId = tw.getFreshIdLabel<DbAssignexpr>()
tw.writeExprs_assignexpr(assignmentId, paramType.javaResult.id, assignmentStmtId, 0)
tw.writeExprsKotlinType(assignmentId, paramType.kotlinResult.id)
writeExpressionMetadataToTrapFile(assignmentId, ids.constructor, assignmentStmtId)
val lhsId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(lhsId, paramType.javaResult.id, assignmentId, 0)
tw.writeExprsKotlinType(lhsId, paramType.kotlinResult.id)
tw.writeVariableBinding(lhsId, fieldId)
writeExpressionMetadataToTrapFile(lhsId, ids.constructor, assignmentStmtId)
val thisId = tw.getFreshIdLabel<DbThisaccess>()
tw.writeExprs_thisaccess(thisId, ids.type.javaResult.id, lhsId, -1)
tw.writeExprsKotlinType(thisId, ids.type.kotlinResult.id)
writeExpressionMetadataToTrapFile(thisId, ids.constructor, assignmentStmtId)
val rhsId = tw.getFreshIdLabel<DbVaraccess>()
tw.writeExprs_varaccess(rhsId, paramType.javaResult.id, assignmentId, 1)
tw.writeExprsKotlinType(rhsId, paramType.kotlinResult.id)
tw.writeVariableBinding(rhsId, paramId)
writeExpressionMetadataToTrapFile(rhsId, ids.constructor, assignmentStmtId)
}
val helper = FunctionReferenceHelper(locId, ids)
val firstAssignmentStmtIdx = 1
val extensionParameterIndex: Int
@@ -2433,7 +2441,7 @@ open class KotlinFileExtractor(
extensionParameterIndex = 1
extractField(dispatchFieldId, "<dispatchReceiver>", dispatchReceiver.type, id, locId, DescriptorVisibilities.PRIVATE, functionReferenceExpr, false)
extractParameterToFieldAssignmentInConstructor("<dispatchReceiver>", dispatchReceiver.type, dispatchFieldId, 0, firstAssignmentStmtIdx)
helper.extractParameterToFieldAssignmentInConstructor("<dispatchReceiver>", dispatchReceiver.type, dispatchFieldId, 0, firstAssignmentStmtIdx)
} else {
dispatchFieldId = null
extensionParameterIndex = 0
@@ -2445,7 +2453,7 @@ open class KotlinFileExtractor(
extensionFieldId = tw.getFreshIdLabel()
extractField(extensionFieldId, "<extensionReceiver>", extensionReceiver.type, id, locId, DescriptorVisibilities.PRIVATE, functionReferenceExpr, false)
extractParameterToFieldAssignmentInConstructor( "<extensionReceiver>", extensionReceiver.type, extensionFieldId, 0 + extensionParameterIndex, firstAssignmentStmtIdx + extensionParameterIndex)
helper.extractParameterToFieldAssignmentInConstructor( "<extensionReceiver>", extensionReceiver.type, extensionFieldId, 0 + extensionParameterIndex, firstAssignmentStmtIdx + extensionParameterIndex)
} else {
extensionFieldId = null
}
@@ -2483,7 +2491,7 @@ open class KotlinFileExtractor(
dispatchReceiverIdx = -1
}
writeExpressionMetadataToTrapFile(callId, funLabels.methodId, retId)
helper.writeExpressionMetadataToTrapFile(callId, funLabels.methodId, retId)
@Suppress("UNCHECKED_CAST")
tw.writeCallableBinding(callId as Label<out DbCaller>, targetCallableId)
@@ -2496,7 +2504,7 @@ open class KotlinFileExtractor(
tw.writeExprs_varaccess(pId, pType.javaResult.id, callId, idx)
tw.writeExprsKotlinType(pId, pType.kotlinResult.id)
tw.writeVariableBinding(pId, variable)
writeExpressionMetadataToTrapFile(pId, funLabels.methodId, retId)
helper.writeExpressionMetadataToTrapFile(pId, funLabels.methodId, retId)
return pId
}
@@ -2505,7 +2513,7 @@ open class KotlinFileExtractor(
val thisId = tw.getFreshIdLabel<DbThisaccess>()
tw.writeExprs_thisaccess(thisId, ids.type.javaResult.id, accessId, -1)
tw.writeExprsKotlinType(thisId, ids.type.kotlinResult.id)
writeExpressionMetadataToTrapFile(thisId, funLabels.methodId, retId)
helper.writeExpressionMetadataToTrapFile(thisId, funLabels.methodId, retId)
}
val useFirstArgAsDispatch: Boolean
@@ -2526,7 +2534,7 @@ open class KotlinFileExtractor(
}
if (functionNTypeArguments.size > BuiltInFunctionArity.BIG_ARITY) {
addArgumentsToInvocationInInvokeNBody(parameters, funLabels, retId, callId, locId, { exp -> writeExpressionMetadataToTrapFile(exp, funLabels.methodId, retId) }, extensionIdxOffset, useFirstArgAsDispatch, dispatchReceiverIdx)
addArgumentsToInvocationInInvokeNBody(parameters, funLabels, retId, callId, locId, { exp -> helper.writeExpressionMetadataToTrapFile(exp, funLabels.methodId, retId) }, extensionIdxOffset, useFirstArgAsDispatch, dispatchReceiverIdx)
} else {
val dispatchIdxOffset = if (useFirstArgAsDispatch) 1 else 0
for ((pIdx, p) in funLabels.parameters.withIndex()) {
@@ -2563,6 +2571,8 @@ open class KotlinFileExtractor(
}
tw.writeIsAnonymClass(id, idMemberRef)
return idMemberRef
}
}
@@ -2892,43 +2902,130 @@ open class KotlinFileExtractor(
Boolean accept(Integer i);
}
class <Anon> extends Object implements IntPredicate {
public Boolean accept(Integer i) { return i % 2 == 0; }
Function1<Integer, Boolean> <fn>;
public <Anon>(Function1<Integer, Boolean> <fn>) { this.<fn> = <fn>; }
public Boolean accept(Integer i) { return <fn>.invoke(i); }
}
IntPredicate x = (IntPredicate)new <Anon>();
IntPredicate x = (IntPredicate)new <Anon>(...);
```
*/
val fnExpr = e.argument
if (fnExpr !is IrFunctionExpression) {
logger.errorElement("Expected to find function expression in SAM conversion. Found '${e.argument.javaClass}' instead", e)
if (!e.argument.type.isFunctionOrKFunction()) {
logger.errorElement("Expected to find expression with function type in SAM conversion.", e)
return
}
val ids = getLocallyVisibleFunctionLabels(fnExpr.function)
val locId = tw.getLocation(e)
val functionType = if (e.argument.type.isKFunction()) {
// todo: add error handling to the below
getFunctionalInterfaceType((e.argument.type as IrSimpleType).arguments.filterIsInstance<IrTypeProjection>().map { it.type })
} else {
e.argument.type
}
val invokeMethod = functionType.classOrNull?.owner?.declarations?.filterIsInstance<IrFunction>()?.find { it.name.asString() == "invoke"}
if (invokeMethod == null) {
logger.errorElement("Couldn't find `invoke` method on functional interface.", e)
return
}
val typeOwner = e.typeOperandClassifier.owner
val samMemberName = if (typeOwner !is IrClass) {
logger.errorElement("Expected to find SAM conversion to IrClass. Found '${typeOwner.javaClass}' instead. Can't rename lambda function to match SAM interface member name.", e)
null
val samMember = if (typeOwner !is IrClass) {
logger.errorElement("Expected to find SAM conversion to IrClass. Found '${typeOwner.javaClass}' instead. Can't implement SAM interface.", e)
return
} else {
val samMember = typeOwner.declarations.filterIsInstance<IrFunction>().firstOrNull { it is IrOverridableMember && it.modality == Modality.ABSTRACT }
val samMember = typeOwner.declarations.filterIsInstance<IrFunction>().find { it is IrOverridableMember && it.modality == Modality.ABSTRACT }
if (samMember == null) {
logger.errorElement("Couldn't find SAM member in type '${typeOwner.kotlinFqName.asString()}'. Can't rename lambda function to match SAM interface member name.", e)
null
logger.errorElement("Couldn't find SAM member in type '${typeOwner.kotlinFqName.asString()}'. Can't implement SAM interface.", e)
return
} else {
samMember.name.asString()
samMember
}
}
extractGeneratedClass(
fnExpr.function,
listOf(pluginContext.irBuiltIns.anyType, e.typeOperand),
samMemberName)
val javaResult = TypeResult(tw.getFreshIdLabel<DbClass>(), "", "")
val kotlinResult = TypeResult(tw.getFreshIdLabel<DbKt_notnull_type>(), "", "")
tw.writeKt_notnull_types(kotlinResult.id, javaResult.id)
val ids = LocallyVisibleFunctionLabels(
TypeResults(javaResult, kotlinResult),
tw.getFreshIdLabel(),
tw.getFreshIdLabel(),
tw.getFreshIdLabel()
)
val locId = tw.getLocation(e)
val helper = FunctionReferenceHelper(locId, ids)
val currentDeclaration = declarationStack.peek()
val classId = extractGeneratedClass(ids, listOf(pluginContext.irBuiltIns.anyType, e.typeOperand), locId, currentDeclaration)
// add field
val fieldId = tw.getFreshIdLabel<DbField>()
extractField(fieldId, "<fn>", functionType, classId, locId, DescriptorVisibilities.PRIVATE, e, false)
// adjust constructor
helper.extractParameterToFieldAssignmentInConstructor("<fn>", functionType, fieldId, 0, 1)
// add implementation function
val functionId = tw.getFreshIdLabel<DbMethod>()
extractFunction(samMember, classId, false, null, null, null, functionId)
//body
val blockId = tw.getFreshIdLabel<DbBlock>()
tw.writeStmts_block(blockId, functionId, 0, functionId)
tw.writeHasLocation(blockId, locId)
//return stmt
val returnId = tw.getFreshIdLabel<DbReturnstmt>()
tw.writeStmts_returnstmt(returnId, blockId, 0, functionId)
tw.writeHasLocation(returnId, locId)
//extractExpressionExpr(b.expression, callable, returnId, 0, returnId)
//<fn>.invoke(vp0, cp1, vp2, vp3, ...) or
//<fn>.invoke(new Object[x]{vp0, vp1, vp2, ...}) // TODO: handle big arity functions. We'd need an arraycreationexpr with an initializer
fun extractCommonExpr(id: Label<out DbExpr>) {
tw.writeHasLocation(id, locId)
tw.writeCallableEnclosingExpr(id, functionId)
tw.writeStatementEnclosingExpr(id, returnId)
}
// Call to original `invoke`:
val callId = tw.getFreshIdLabel<DbMethodaccess>()
val callType = useType(samMember.returnType)
tw.writeExprs_methodaccess(callId, callType.javaResult.id, returnId, 0)
tw.writeExprsKotlinType(callId, callType.kotlinResult.id)
extractCommonExpr(callId)
val calledMethodId = useFunction<DbMethod>(invokeMethod, (functionType as IrSimpleType).arguments)
tw.writeCallableBinding(callId, calledMethodId)
// <fn> access
val lhsId = tw.getFreshIdLabel<DbVaraccess>()
val lhsType = useType(functionType)
tw.writeExprs_varaccess(lhsId, lhsType.javaResult.id, callId, -1)
tw.writeExprsKotlinType(lhsId, lhsType.kotlinResult.id)
extractCommonExpr(lhsId)
tw.writeVariableBinding(lhsId, fieldId)
fun extractArgument(p: IrValueParameter, idx: Int) {
val argsAccessId = tw.getFreshIdLabel<DbVaraccess>()
val paramType = useType(p.type)
tw.writeExprs_varaccess(argsAccessId, paramType.javaResult.id, callId, idx)
tw.writeExprsKotlinType(argsAccessId, paramType.kotlinResult.id)
extractCommonExpr(argsAccessId)
tw.writeVariableBinding(argsAccessId, useValueParameter(p, functionId))
}
var idx = 0
val extParam = samMember.extensionReceiverParameter
if (extParam != null) {
extractArgument(extParam, idx++)
}
for (vp in samMember.valueParameters) {
extractArgument(vp, idx++)
}
val id = tw.getFreshIdLabel<DbCastexpr>()
val type = useType(e.type)
val type = useType(e.typeOperand)
tw.writeExprs_castexpr(id, type.javaResult.id, parent, idx)
tw.writeExprsKotlinType(id, type.kotlinResult.id)
tw.writeHasLocation(id, locId)
@@ -2947,7 +3044,8 @@ open class KotlinFileExtractor(
@Suppress("UNCHECKED_CAST")
tw.writeIsAnonymClass(ids.type.javaResult.id as Label<DbClass>, idNewexpr)
extractTypeAccess(pluginContext.irBuiltIns.anyType, callable, idNewexpr, -3, e, enclosingStmt)
extractTypeAccess(e.typeOperand, callable, idNewexpr, -3, e, enclosingStmt)
extractExpressionExpr(e.argument, callable, idNewexpr, 0, enclosingStmt)
}
else -> {
logger.errorElement("Unrecognised IrTypeOperatorCall for ${e.operator}: " + e.render(), e)