mirror of
https://github.com/github/codeql.git
synced 2025-12-17 09:13:20 +01:00
Rework SAM conversion extraction (handle arbitrary expression that's being converted)
This commit is contained in:
@@ -9,6 +9,7 @@ import com.github.codeql.utils.toRawType
|
||||
import com.github.codeql.utils.versions.getIrStubFromDescriptor
|
||||
import com.semmle.extractor.java.OdasaOutput
|
||||
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
|
||||
import org.jetbrains.kotlin.backend.common.ir.simpleFunctions
|
||||
import org.jetbrains.kotlin.backend.common.pop
|
||||
import org.jetbrains.kotlin.builtins.functions.BuiltInFunctionArity
|
||||
import org.jetbrains.kotlin.descriptors.*
|
||||
@@ -22,6 +23,7 @@ import org.jetbrains.kotlin.ir.symbols.IrConstructorSymbol
|
||||
import org.jetbrains.kotlin.ir.symbols.IrSymbol
|
||||
import org.jetbrains.kotlin.ir.types.*
|
||||
import org.jetbrains.kotlin.ir.util.*
|
||||
import org.jetbrains.kotlin.metadata.ProtoBuf
|
||||
import org.jetbrains.kotlin.name.FqName
|
||||
import org.jetbrains.kotlin.util.OperatorNameConventions
|
||||
import org.jetbrains.kotlin.types.Variance
|
||||
@@ -545,7 +547,7 @@ open class KotlinFileExtractor(
|
||||
}
|
||||
}
|
||||
|
||||
fun extractFunction(f: IrFunction, parentId: Label<out DbReftype>, extractBody: Boolean, typeSubstitution: TypeSubstitution?, classTypeArgsIncludingOuterClasses: List<IrTypeArgument>?, memberName: String? = null): Label<out DbCallable> {
|
||||
fun extractFunction(f: IrFunction, parentId: Label<out DbReftype>, extractBody: Boolean, typeSubstitution: TypeSubstitution?, classTypeArgsIncludingOuterClasses: List<IrTypeArgument>?, memberName: String? = null, idOverride: Label<DbMethod>? = null): Label<out DbCallable> {
|
||||
with("function", f) {
|
||||
DeclarationStackAdjuster(f).use {
|
||||
|
||||
@@ -554,7 +556,9 @@ open class KotlinFileExtractor(
|
||||
val locId = tw.getLocation(f)
|
||||
|
||||
val id =
|
||||
if (f.isLocalFunction())
|
||||
if (idOverride != null)
|
||||
idOverride
|
||||
else if (f.isLocalFunction())
|
||||
getLocallyVisibleFunctionLabels(f).function
|
||||
else
|
||||
useFunction<DbCallable>(f, parentId, classTypeArgsIncludingOuterClasses)
|
||||
@@ -2297,11 +2301,64 @@ open class KotlinFileExtractor(
|
||||
}
|
||||
}
|
||||
|
||||
private inner class FunctionReferenceHelper(private val locId: Label<DbLocation>, private val ids: LocallyVisibleFunctionLabels) {
|
||||
fun writeExpressionMetadataToTrapFile(id: Label<out DbExpr>, callable: Label<out DbCallable>, stmt: Label<out DbStmt>) {
|
||||
tw.writeHasLocation(id, locId)
|
||||
tw.writeCallableEnclosingExpr(id, callable)
|
||||
tw.writeStatementEnclosingExpr(id, stmt)
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract a parameter to field assignment, such as `this.field = paramName` below:
|
||||
* ```
|
||||
* constructor(paramName: type) {
|
||||
* this.field = paramName
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
fun extractParameterToFieldAssignmentInConstructor(
|
||||
paramName: String,
|
||||
type: IrType,
|
||||
fieldId: Label<DbField>,
|
||||
paramIdx: Int,
|
||||
stmtIdx: Int
|
||||
) {
|
||||
val paramId = tw.getFreshIdLabel<DbParam>()
|
||||
val paramType = extractValueParameter(paramId, type, paramName, locId, ids.constructor, paramIdx, null, paramId, false)
|
||||
|
||||
val assignmentStmtId = tw.getFreshIdLabel<DbExprstmt>()
|
||||
tw.writeStmts_exprstmt(assignmentStmtId, ids.constructorBlock, stmtIdx, ids.constructor)
|
||||
tw.writeHasLocation(assignmentStmtId, locId)
|
||||
|
||||
val assignmentId = tw.getFreshIdLabel<DbAssignexpr>()
|
||||
tw.writeExprs_assignexpr(assignmentId, paramType.javaResult.id, assignmentStmtId, 0)
|
||||
tw.writeExprsKotlinType(assignmentId, paramType.kotlinResult.id)
|
||||
writeExpressionMetadataToTrapFile(assignmentId, ids.constructor, assignmentStmtId)
|
||||
|
||||
val lhsId = tw.getFreshIdLabel<DbVaraccess>()
|
||||
tw.writeExprs_varaccess(lhsId, paramType.javaResult.id, assignmentId, 0)
|
||||
tw.writeExprsKotlinType(lhsId, paramType.kotlinResult.id)
|
||||
tw.writeVariableBinding(lhsId, fieldId)
|
||||
writeExpressionMetadataToTrapFile(lhsId, ids.constructor, assignmentStmtId)
|
||||
|
||||
val thisId = tw.getFreshIdLabel<DbThisaccess>()
|
||||
tw.writeExprs_thisaccess(thisId, ids.type.javaResult.id, lhsId, -1)
|
||||
tw.writeExprsKotlinType(thisId, ids.type.kotlinResult.id)
|
||||
writeExpressionMetadataToTrapFile(thisId, ids.constructor, assignmentStmtId)
|
||||
|
||||
val rhsId = tw.getFreshIdLabel<DbVaraccess>()
|
||||
tw.writeExprs_varaccess(rhsId, paramType.javaResult.id, assignmentId, 1)
|
||||
tw.writeExprsKotlinType(rhsId, paramType.kotlinResult.id)
|
||||
tw.writeVariableBinding(rhsId, paramId)
|
||||
writeExpressionMetadataToTrapFile(rhsId, ids.constructor, assignmentStmtId)
|
||||
}
|
||||
}
|
||||
|
||||
private fun extractFunctionReference(
|
||||
functionReferenceExpr: IrFunctionReference,
|
||||
parent: StmtExprParent,
|
||||
callable: Label<out DbCallable>
|
||||
) {
|
||||
) : Label<out DbClassinstancexpr> {
|
||||
with("function reference", functionReferenceExpr) {
|
||||
val target = functionReferenceExpr.reflectionTarget ?: run {
|
||||
logger.errorElement("Expected to find reflection target for function reference. Using underlying symbol instead.", functionReferenceExpr)
|
||||
@@ -2373,56 +2430,7 @@ open class KotlinFileExtractor(
|
||||
val currentDeclaration = declarationStack.peek()
|
||||
val id = extractGeneratedClass(ids, listOf(pluginContext.irBuiltIns.anyType, fnInterfaceType), locId, currentDeclaration)
|
||||
|
||||
fun writeExpressionMetadataToTrapFile(id: Label<out DbExpr>, callable: Label<out DbCallable>, stmt: Label<out DbStmt>) {
|
||||
tw.writeHasLocation(id, locId)
|
||||
tw.writeCallableEnclosingExpr(id, callable)
|
||||
tw.writeStatementEnclosingExpr(id, stmt)
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract a parameter to field assignment, such as `this.field = paramName` below:
|
||||
* ```
|
||||
* constructor(paramName: type) {
|
||||
* this.field = paramName
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
fun extractParameterToFieldAssignmentInConstructor(
|
||||
paramName: String,
|
||||
type: IrType,
|
||||
fieldId: Label<DbField>,
|
||||
paramIdx: Int,
|
||||
stmtIdx: Int
|
||||
) {
|
||||
val paramId = tw.getFreshIdLabel<DbParam>()
|
||||
val paramType = extractValueParameter(paramId, type, paramName, locId, ids.constructor, paramIdx, null, paramId, false)
|
||||
|
||||
val assignmentStmtId = tw.getFreshIdLabel<DbExprstmt>()
|
||||
tw.writeStmts_exprstmt(assignmentStmtId, ids.constructorBlock, stmtIdx, ids.constructor)
|
||||
tw.writeHasLocation(assignmentStmtId, locId)
|
||||
|
||||
val assignmentId = tw.getFreshIdLabel<DbAssignexpr>()
|
||||
tw.writeExprs_assignexpr(assignmentId, paramType.javaResult.id, assignmentStmtId, 0)
|
||||
tw.writeExprsKotlinType(assignmentId, paramType.kotlinResult.id)
|
||||
writeExpressionMetadataToTrapFile(assignmentId, ids.constructor, assignmentStmtId)
|
||||
|
||||
val lhsId = tw.getFreshIdLabel<DbVaraccess>()
|
||||
tw.writeExprs_varaccess(lhsId, paramType.javaResult.id, assignmentId, 0)
|
||||
tw.writeExprsKotlinType(lhsId, paramType.kotlinResult.id)
|
||||
tw.writeVariableBinding(lhsId, fieldId)
|
||||
writeExpressionMetadataToTrapFile(lhsId, ids.constructor, assignmentStmtId)
|
||||
|
||||
val thisId = tw.getFreshIdLabel<DbThisaccess>()
|
||||
tw.writeExprs_thisaccess(thisId, ids.type.javaResult.id, lhsId, -1)
|
||||
tw.writeExprsKotlinType(thisId, ids.type.kotlinResult.id)
|
||||
writeExpressionMetadataToTrapFile(thisId, ids.constructor, assignmentStmtId)
|
||||
|
||||
val rhsId = tw.getFreshIdLabel<DbVaraccess>()
|
||||
tw.writeExprs_varaccess(rhsId, paramType.javaResult.id, assignmentId, 1)
|
||||
tw.writeExprsKotlinType(rhsId, paramType.kotlinResult.id)
|
||||
tw.writeVariableBinding(rhsId, paramId)
|
||||
writeExpressionMetadataToTrapFile(rhsId, ids.constructor, assignmentStmtId)
|
||||
}
|
||||
val helper = FunctionReferenceHelper(locId, ids)
|
||||
|
||||
val firstAssignmentStmtIdx = 1
|
||||
val extensionParameterIndex: Int
|
||||
@@ -2433,7 +2441,7 @@ open class KotlinFileExtractor(
|
||||
extensionParameterIndex = 1
|
||||
|
||||
extractField(dispatchFieldId, "<dispatchReceiver>", dispatchReceiver.type, id, locId, DescriptorVisibilities.PRIVATE, functionReferenceExpr, false)
|
||||
extractParameterToFieldAssignmentInConstructor("<dispatchReceiver>", dispatchReceiver.type, dispatchFieldId, 0, firstAssignmentStmtIdx)
|
||||
helper.extractParameterToFieldAssignmentInConstructor("<dispatchReceiver>", dispatchReceiver.type, dispatchFieldId, 0, firstAssignmentStmtIdx)
|
||||
} else {
|
||||
dispatchFieldId = null
|
||||
extensionParameterIndex = 0
|
||||
@@ -2445,7 +2453,7 @@ open class KotlinFileExtractor(
|
||||
extensionFieldId = tw.getFreshIdLabel()
|
||||
|
||||
extractField(extensionFieldId, "<extensionReceiver>", extensionReceiver.type, id, locId, DescriptorVisibilities.PRIVATE, functionReferenceExpr, false)
|
||||
extractParameterToFieldAssignmentInConstructor( "<extensionReceiver>", extensionReceiver.type, extensionFieldId, 0 + extensionParameterIndex, firstAssignmentStmtIdx + extensionParameterIndex)
|
||||
helper.extractParameterToFieldAssignmentInConstructor( "<extensionReceiver>", extensionReceiver.type, extensionFieldId, 0 + extensionParameterIndex, firstAssignmentStmtIdx + extensionParameterIndex)
|
||||
} else {
|
||||
extensionFieldId = null
|
||||
}
|
||||
@@ -2483,7 +2491,7 @@ open class KotlinFileExtractor(
|
||||
dispatchReceiverIdx = -1
|
||||
}
|
||||
|
||||
writeExpressionMetadataToTrapFile(callId, funLabels.methodId, retId)
|
||||
helper.writeExpressionMetadataToTrapFile(callId, funLabels.methodId, retId)
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
tw.writeCallableBinding(callId as Label<out DbCaller>, targetCallableId)
|
||||
|
||||
@@ -2496,7 +2504,7 @@ open class KotlinFileExtractor(
|
||||
tw.writeExprs_varaccess(pId, pType.javaResult.id, callId, idx)
|
||||
tw.writeExprsKotlinType(pId, pType.kotlinResult.id)
|
||||
tw.writeVariableBinding(pId, variable)
|
||||
writeExpressionMetadataToTrapFile(pId, funLabels.methodId, retId)
|
||||
helper.writeExpressionMetadataToTrapFile(pId, funLabels.methodId, retId)
|
||||
return pId
|
||||
}
|
||||
|
||||
@@ -2505,7 +2513,7 @@ open class KotlinFileExtractor(
|
||||
val thisId = tw.getFreshIdLabel<DbThisaccess>()
|
||||
tw.writeExprs_thisaccess(thisId, ids.type.javaResult.id, accessId, -1)
|
||||
tw.writeExprsKotlinType(thisId, ids.type.kotlinResult.id)
|
||||
writeExpressionMetadataToTrapFile(thisId, funLabels.methodId, retId)
|
||||
helper.writeExpressionMetadataToTrapFile(thisId, funLabels.methodId, retId)
|
||||
}
|
||||
|
||||
val useFirstArgAsDispatch: Boolean
|
||||
@@ -2526,7 +2534,7 @@ open class KotlinFileExtractor(
|
||||
}
|
||||
|
||||
if (functionNTypeArguments.size > BuiltInFunctionArity.BIG_ARITY) {
|
||||
addArgumentsToInvocationInInvokeNBody(parameters, funLabels, retId, callId, locId, { exp -> writeExpressionMetadataToTrapFile(exp, funLabels.methodId, retId) }, extensionIdxOffset, useFirstArgAsDispatch, dispatchReceiverIdx)
|
||||
addArgumentsToInvocationInInvokeNBody(parameters, funLabels, retId, callId, locId, { exp -> helper.writeExpressionMetadataToTrapFile(exp, funLabels.methodId, retId) }, extensionIdxOffset, useFirstArgAsDispatch, dispatchReceiverIdx)
|
||||
} else {
|
||||
val dispatchIdxOffset = if (useFirstArgAsDispatch) 1 else 0
|
||||
for ((pIdx, p) in funLabels.parameters.withIndex()) {
|
||||
@@ -2563,6 +2571,8 @@ open class KotlinFileExtractor(
|
||||
}
|
||||
|
||||
tw.writeIsAnonymClass(id, idMemberRef)
|
||||
|
||||
return idMemberRef
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2892,43 +2902,130 @@ open class KotlinFileExtractor(
|
||||
Boolean accept(Integer i);
|
||||
}
|
||||
class <Anon> extends Object implements IntPredicate {
|
||||
public Boolean accept(Integer i) { return i % 2 == 0; }
|
||||
Function1<Integer, Boolean> <fn>;
|
||||
public <Anon>(Function1<Integer, Boolean> <fn>) { this.<fn> = <fn>; }
|
||||
public Boolean accept(Integer i) { return <fn>.invoke(i); }
|
||||
}
|
||||
|
||||
IntPredicate x = (IntPredicate)new <Anon>();
|
||||
IntPredicate x = (IntPredicate)new <Anon>(...);
|
||||
```
|
||||
*/
|
||||
|
||||
val fnExpr = e.argument
|
||||
if (fnExpr !is IrFunctionExpression) {
|
||||
logger.errorElement("Expected to find function expression in SAM conversion. Found '${e.argument.javaClass}' instead", e)
|
||||
if (!e.argument.type.isFunctionOrKFunction()) {
|
||||
logger.errorElement("Expected to find expression with function type in SAM conversion.", e)
|
||||
return
|
||||
}
|
||||
|
||||
val ids = getLocallyVisibleFunctionLabels(fnExpr.function)
|
||||
val locId = tw.getLocation(e)
|
||||
val functionType = if (e.argument.type.isKFunction()) {
|
||||
// todo: add error handling to the below
|
||||
getFunctionalInterfaceType((e.argument.type as IrSimpleType).arguments.filterIsInstance<IrTypeProjection>().map { it.type })
|
||||
} else {
|
||||
e.argument.type
|
||||
}
|
||||
|
||||
val invokeMethod = functionType.classOrNull?.owner?.declarations?.filterIsInstance<IrFunction>()?.find { it.name.asString() == "invoke"}
|
||||
if (invokeMethod == null) {
|
||||
logger.errorElement("Couldn't find `invoke` method on functional interface.", e)
|
||||
return
|
||||
}
|
||||
|
||||
val typeOwner = e.typeOperandClassifier.owner
|
||||
val samMemberName = if (typeOwner !is IrClass) {
|
||||
logger.errorElement("Expected to find SAM conversion to IrClass. Found '${typeOwner.javaClass}' instead. Can't rename lambda function to match SAM interface member name.", e)
|
||||
null
|
||||
val samMember = if (typeOwner !is IrClass) {
|
||||
logger.errorElement("Expected to find SAM conversion to IrClass. Found '${typeOwner.javaClass}' instead. Can't implement SAM interface.", e)
|
||||
return
|
||||
} else {
|
||||
val samMember = typeOwner.declarations.filterIsInstance<IrFunction>().firstOrNull { it is IrOverridableMember && it.modality == Modality.ABSTRACT }
|
||||
val samMember = typeOwner.declarations.filterIsInstance<IrFunction>().find { it is IrOverridableMember && it.modality == Modality.ABSTRACT }
|
||||
if (samMember == null) {
|
||||
logger.errorElement("Couldn't find SAM member in type '${typeOwner.kotlinFqName.asString()}'. Can't rename lambda function to match SAM interface member name.", e)
|
||||
null
|
||||
logger.errorElement("Couldn't find SAM member in type '${typeOwner.kotlinFqName.asString()}'. Can't implement SAM interface.", e)
|
||||
return
|
||||
} else {
|
||||
samMember.name.asString()
|
||||
samMember
|
||||
}
|
||||
}
|
||||
|
||||
extractGeneratedClass(
|
||||
fnExpr.function,
|
||||
listOf(pluginContext.irBuiltIns.anyType, e.typeOperand),
|
||||
samMemberName)
|
||||
val javaResult = TypeResult(tw.getFreshIdLabel<DbClass>(), "", "")
|
||||
val kotlinResult = TypeResult(tw.getFreshIdLabel<DbKt_notnull_type>(), "", "")
|
||||
tw.writeKt_notnull_types(kotlinResult.id, javaResult.id)
|
||||
val ids = LocallyVisibleFunctionLabels(
|
||||
TypeResults(javaResult, kotlinResult),
|
||||
tw.getFreshIdLabel(),
|
||||
tw.getFreshIdLabel(),
|
||||
tw.getFreshIdLabel()
|
||||
)
|
||||
val locId = tw.getLocation(e)
|
||||
val helper = FunctionReferenceHelper(locId, ids)
|
||||
|
||||
val currentDeclaration = declarationStack.peek()
|
||||
val classId = extractGeneratedClass(ids, listOf(pluginContext.irBuiltIns.anyType, e.typeOperand), locId, currentDeclaration)
|
||||
|
||||
// add field
|
||||
val fieldId = tw.getFreshIdLabel<DbField>()
|
||||
extractField(fieldId, "<fn>", functionType, classId, locId, DescriptorVisibilities.PRIVATE, e, false)
|
||||
|
||||
// adjust constructor
|
||||
helper.extractParameterToFieldAssignmentInConstructor("<fn>", functionType, fieldId, 0, 1)
|
||||
|
||||
// add implementation function
|
||||
val functionId = tw.getFreshIdLabel<DbMethod>()
|
||||
extractFunction(samMember, classId, false, null, null, null, functionId)
|
||||
|
||||
//body
|
||||
val blockId = tw.getFreshIdLabel<DbBlock>()
|
||||
tw.writeStmts_block(blockId, functionId, 0, functionId)
|
||||
tw.writeHasLocation(blockId, locId)
|
||||
|
||||
//return stmt
|
||||
val returnId = tw.getFreshIdLabel<DbReturnstmt>()
|
||||
tw.writeStmts_returnstmt(returnId, blockId, 0, functionId)
|
||||
tw.writeHasLocation(returnId, locId)
|
||||
|
||||
//extractExpressionExpr(b.expression, callable, returnId, 0, returnId)
|
||||
//<fn>.invoke(vp0, cp1, vp2, vp3, ...) or
|
||||
//<fn>.invoke(new Object[x]{vp0, vp1, vp2, ...}) // TODO: handle big arity functions. We'd need an arraycreationexpr with an initializer
|
||||
|
||||
fun extractCommonExpr(id: Label<out DbExpr>) {
|
||||
tw.writeHasLocation(id, locId)
|
||||
tw.writeCallableEnclosingExpr(id, functionId)
|
||||
tw.writeStatementEnclosingExpr(id, returnId)
|
||||
}
|
||||
|
||||
// Call to original `invoke`:
|
||||
val callId = tw.getFreshIdLabel<DbMethodaccess>()
|
||||
val callType = useType(samMember.returnType)
|
||||
tw.writeExprs_methodaccess(callId, callType.javaResult.id, returnId, 0)
|
||||
tw.writeExprsKotlinType(callId, callType.kotlinResult.id)
|
||||
extractCommonExpr(callId)
|
||||
val calledMethodId = useFunction<DbMethod>(invokeMethod, (functionType as IrSimpleType).arguments)
|
||||
tw.writeCallableBinding(callId, calledMethodId)
|
||||
|
||||
// <fn> access
|
||||
val lhsId = tw.getFreshIdLabel<DbVaraccess>()
|
||||
val lhsType = useType(functionType)
|
||||
tw.writeExprs_varaccess(lhsId, lhsType.javaResult.id, callId, -1)
|
||||
tw.writeExprsKotlinType(lhsId, lhsType.kotlinResult.id)
|
||||
extractCommonExpr(lhsId)
|
||||
tw.writeVariableBinding(lhsId, fieldId)
|
||||
|
||||
fun extractArgument(p: IrValueParameter, idx: Int) {
|
||||
val argsAccessId = tw.getFreshIdLabel<DbVaraccess>()
|
||||
val paramType = useType(p.type)
|
||||
tw.writeExprs_varaccess(argsAccessId, paramType.javaResult.id, callId, idx)
|
||||
tw.writeExprsKotlinType(argsAccessId, paramType.kotlinResult.id)
|
||||
extractCommonExpr(argsAccessId)
|
||||
tw.writeVariableBinding(argsAccessId, useValueParameter(p, functionId))
|
||||
}
|
||||
|
||||
var idx = 0
|
||||
val extParam = samMember.extensionReceiverParameter
|
||||
if (extParam != null) {
|
||||
extractArgument(extParam, idx++)
|
||||
}
|
||||
for (vp in samMember.valueParameters) {
|
||||
extractArgument(vp, idx++)
|
||||
}
|
||||
|
||||
val id = tw.getFreshIdLabel<DbCastexpr>()
|
||||
val type = useType(e.type)
|
||||
val type = useType(e.typeOperand)
|
||||
tw.writeExprs_castexpr(id, type.javaResult.id, parent, idx)
|
||||
tw.writeExprsKotlinType(id, type.kotlinResult.id)
|
||||
tw.writeHasLocation(id, locId)
|
||||
@@ -2947,7 +3044,8 @@ open class KotlinFileExtractor(
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
tw.writeIsAnonymClass(ids.type.javaResult.id as Label<DbClass>, idNewexpr)
|
||||
|
||||
extractTypeAccess(pluginContext.irBuiltIns.anyType, callable, idNewexpr, -3, e, enclosingStmt)
|
||||
extractTypeAccess(e.typeOperand, callable, idNewexpr, -3, e, enclosingStmt)
|
||||
extractExpressionExpr(e.argument, callable, idNewexpr, 0, enclosingStmt)
|
||||
}
|
||||
else -> {
|
||||
logger.errorElement("Unrecognised IrTypeOperatorCall for ${e.operator}: " + e.render(), e)
|
||||
|
||||
Reference in New Issue
Block a user