Extract array update operations

These are of the form arrExpr[indexExpr] op= rhs
This commit is contained in:
Chris Smowton
2022-02-18 16:24:21 +00:00
committed by Ian Lynagh
parent d9c72b1c04
commit 7cb6e19e44
5 changed files with 181 additions and 68 deletions

View File

@@ -1206,24 +1206,24 @@ open class KotlinFileExtractor(
isFunction(target, "kotlin", "Double", fName)
}
fun isArrayType(typeName: String) =
when(typeName) {
"Array" -> true
"IntArray" -> true
"ByteArray" -> true
"ShortArray" -> true
"LongArray" -> true
"FloatArray" -> true
"DoubleArray" -> true
"CharArray" -> true
"BooleanArray" -> true
else -> false
}
fun extractCall(c: IrCall, callable: Label<out DbCallable>, parent: Label<out DbExprparent>, idx: Int, enclosingStmt: Label<out DbStmt>) {
with("call", c) {
val target = c.symbol.owner
fun isArrayType(typeName: String) =
when(typeName) {
"Array" -> true
"IntArray" -> true
"ByteArray" -> true
"ShortArray" -> true
"LongArray" -> true
"FloatArray" -> true
"DoubleArray" -> true
"CharArray" -> true
"BooleanArray" -> true
else -> false
}
fun extractMethodAccess(syntacticCallTarget: IrFunction, extractMethodTypeArguments: Boolean = true, extractClassTypeArguments: Boolean = false) {
val typeArgs =
if (extractMethodTypeArguments)
@@ -1782,6 +1782,113 @@ open class KotlinFileExtractor(
}
}
fun getStatementOriginOperator(origin: IrStatementOrigin?) = when (origin) {
IrStatementOrigin.PLUSEQ -> "plus"
IrStatementOrigin.MINUSEQ -> "minus"
IrStatementOrigin.MULTEQ -> "times"
IrStatementOrigin.DIVEQ -> "div"
IrStatementOrigin.PERCEQ -> "rem"
else -> null
}
fun getUpdateInPlaceRHS(origin: IrStatementOrigin?, isExpectedLhs: (IrExpression?) -> Boolean, updateRhs: IrExpression): IrExpression? {
// Check for a desugared in-place update operator, such as "v += e":
return getStatementOriginOperator(origin)?.let {
if (updateRhs is IrCall &&
isNumericFunction(updateRhs.symbol.owner, it)
) {
// Check for an expression like x = get(x).op(e):
val opReceiver = updateRhs.dispatchReceiver
if (isExpectedLhs(opReceiver)) {
updateRhs.getValueArgument(0)
} else null
} else null
} ?: null
}
fun writeUpdateInPlaceExpr(origin: IrStatementOrigin, tw: TrapWriter, id: Label<DbAssignexpr>, type: TypeResults, exprParent: ExprParent): Boolean {
when(origin) {
IrStatementOrigin.PLUSEQ -> tw.writeExprs_assignaddexpr(id as Label<DbAssignaddexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.MINUSEQ -> tw.writeExprs_assignsubexpr(id as Label<DbAssignsubexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.MULTEQ -> tw.writeExprs_assignmulexpr(id as Label<DbAssignmulexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.DIVEQ -> tw.writeExprs_assigndivexpr(id as Label<DbAssigndivexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.PERCEQ -> tw.writeExprs_assignremexpr(id as Label<DbAssignremexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
else -> return false
}
return true
}
fun tryExtractArrayUpdate(e: IrContainerExpression, callable: Label<out DbCallable>, parent: StmtExprParent): Boolean {
/*
* We're expecting the pattern
* {
* val array = e1
* val idx = e2
* array.set(idx, array.get(idx).op(e3))
* }
*
* If we find it, we'll extract e1[e2] op= e3 (op is +, -, ...)
*/
if(e.statements.size != 3)
return false
(e.statements[0] as? IrVariable)?.let { arrayVarDecl ->
arrayVarDecl.initializer?.let { arrayVarInitializer ->
(e.statements[1] as? IrVariable)?.let { indexVarDecl ->
indexVarDecl.initializer?.let { indexVarInitializer ->
(e.statements[2] as? IrCall)?.let { arraySetCall ->
if (isFunction(arraySetCall.symbol.owner, "kotlin", "(some array type)", { isArrayType(it) }, "set")) {
getUpdateInPlaceRHS(
e.origin, // Using e.origin not arraySetCall.origin here distinguishes a compiler-generated block from a user manually code that looks the same.
{ oldValue ->
oldValue is IrCall &&
isFunction(oldValue.symbol.owner, "kotlin", "(some array type)", { typeName -> isArrayType(typeName) }, "get") &&
(oldValue.dispatchReceiver as? IrGetValue)?.let {
receiverVal -> receiverVal.symbol.owner == arrayVarDecl.symbol.owner
} ?: false
},
arraySetCall.getValueArgument(1)!!
)?.let { updateRhs ->
// Create an assignment skeleton _ op= _
val exprParent = parent.expr(e, callable)
val assignId = tw.getFreshIdLabel<DbAssignexpr>()
val type = useType(arrayVarInitializer.type)
val locId = tw.getLocation(e)
tw.writeExprsKotlinType(assignId, type.kotlinResult.id)
tw.writeHasLocation(assignId, locId)
tw.writeCallableEnclosingExpr(assignId, callable)
tw.writeStatementEnclosingExpr(assignId, exprParent.enclosingStmt)
if (!writeUpdateInPlaceExpr(e.origin!!, tw, assignId, type, exprParent)) {
logger.errorElement("Unexpected origin", e)
return false
}
// Extract e1[e2]
val lhsId = tw.getFreshIdLabel<DbArrayaccess>()
val elementType = useType(updateRhs.type)
tw.writeExprs_arrayaccess(lhsId, elementType.javaResult.id, assignId, 0)
tw.writeExprsKotlinType(lhsId, elementType.kotlinResult.id)
tw.writeHasLocation(lhsId, locId)
tw.writeCallableEnclosingExpr(lhsId, callable)
tw.writeStatementEnclosingExpr(lhsId, exprParent.enclosingStmt)
extractExpressionExpr(arrayVarInitializer, callable, lhsId, 0, exprParent.enclosingStmt)
extractExpressionExpr(indexVarInitializer, callable, lhsId, 1, exprParent.enclosingStmt)
// Extract e3
extractExpressionExpr(updateRhs, callable, assignId, 1, exprParent.enclosingStmt)
return true
}
}
}
}
}
}
}
return false
}
fun extractExpressionStmt(e: IrExpression, callable: Label<out DbCallable>, parent: Label<out DbStmtparent>, idx: Int) {
extractExpression(e, callable, StmtParent(parent, idx))
}
@@ -1879,13 +1986,15 @@ open class KotlinFileExtractor(
}
}
is IrContainerExpression -> {
val stmtParent = parent.stmt(e, callable)
val id = tw.getFreshIdLabel<DbBlock>()
val locId = tw.getLocation(e)
tw.writeStmts_block(id, stmtParent.parent, stmtParent.idx, callable)
tw.writeHasLocation(id, locId)
e.statements.forEachIndexed { i, s ->
extractStatement(s, callable, id, i)
if(!tryExtractArrayUpdate(e, callable, parent)) {
val stmtParent = parent.stmt(e, callable)
val id = tw.getFreshIdLabel<DbBlock>()
val locId = tw.getLocation(e)
tw.writeStmts_block(id, stmtParent.parent, stmtParent.idx, callable)
tw.writeHasLocation(id, locId)
e.statements.forEachIndexed { i, s ->
extractStatement(s, callable, id, i)
}
}
}
is IrWhileLoop -> {
@@ -2170,34 +2279,14 @@ open class KotlinFileExtractor(
val rhsValue = e.value
// Check for a desugared in-place update operator, such as "v += e":
val expectedOperator = when (e.origin) {
IrStatementOrigin.PLUSEQ -> "plus"
IrStatementOrigin.MINUSEQ -> "minus"
IrStatementOrigin.MULTEQ -> "times"
IrStatementOrigin.DIVEQ -> "div"
IrStatementOrigin.PERCEQ -> "rem"
else -> null
}
val inPlaceUpdateRhs = expectedOperator?.let {
if (rhsValue is IrCall &&
isNumericFunction(rhsValue.symbol.owner, expectedOperator)
) {
// Check for an expression like x = get(x).op(e):
val opReceiver = rhsValue.dispatchReceiver
if (opReceiver is IrGetValue && opReceiver.symbol.owner == e.symbol.owner) {
rhsValue.getValueArgument(0)
} else null
} else null
}
val extractOrigin = if (inPlaceUpdateRhs == null) null else e.origin
when(extractOrigin) {
IrStatementOrigin.PLUSEQ -> tw.writeExprs_assignaddexpr(id as Label<DbAssignaddexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.MINUSEQ -> tw.writeExprs_assignsubexpr(id as Label<DbAssignsubexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.MULTEQ -> tw.writeExprs_assignmulexpr(id as Label<DbAssignmulexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.DIVEQ -> tw.writeExprs_assigndivexpr(id as Label<DbAssigndivexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.PERCEQ -> tw.writeExprs_assignremexpr(id as Label<DbAssignremexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
else -> tw.writeExprs_assignexpr(id, type.javaResult.id, exprParent.parent, exprParent.idx)
val inPlaceUpdateRhs = getUpdateInPlaceRHS(e.origin, { it is IrGetValue && it.symbol.owner == e.symbol.owner }, rhsValue)
if (inPlaceUpdateRhs != null) {
if (!writeUpdateInPlaceExpr(e.origin!!, tw, id, type, exprParent)) {
logger.errorElement("Unexpected origin", e)
return
}
} else {
tw.writeExprs_assignexpr(id, type.javaResult.id, exprParent.parent, exprParent.idx)
}
val lhsType = useType(e.symbol.owner.type)