Extract array update operations

These are of the form arrExpr[indexExpr] op= rhs
This commit is contained in:
Chris Smowton
2022-02-18 16:24:21 +00:00
committed by Ian Lynagh
parent d9c72b1c04
commit 7cb6e19e44
5 changed files with 181 additions and 68 deletions

View File

@@ -1206,24 +1206,24 @@ open class KotlinFileExtractor(
isFunction(target, "kotlin", "Double", fName)
}
fun isArrayType(typeName: String) =
when(typeName) {
"Array" -> true
"IntArray" -> true
"ByteArray" -> true
"ShortArray" -> true
"LongArray" -> true
"FloatArray" -> true
"DoubleArray" -> true
"CharArray" -> true
"BooleanArray" -> true
else -> false
}
fun extractCall(c: IrCall, callable: Label<out DbCallable>, parent: Label<out DbExprparent>, idx: Int, enclosingStmt: Label<out DbStmt>) {
with("call", c) {
val target = c.symbol.owner
fun isArrayType(typeName: String) =
when(typeName) {
"Array" -> true
"IntArray" -> true
"ByteArray" -> true
"ShortArray" -> true
"LongArray" -> true
"FloatArray" -> true
"DoubleArray" -> true
"CharArray" -> true
"BooleanArray" -> true
else -> false
}
fun extractMethodAccess(syntacticCallTarget: IrFunction, extractMethodTypeArguments: Boolean = true, extractClassTypeArguments: Boolean = false) {
val typeArgs =
if (extractMethodTypeArguments)
@@ -1782,6 +1782,113 @@ open class KotlinFileExtractor(
}
}
fun getStatementOriginOperator(origin: IrStatementOrigin?) = when (origin) {
IrStatementOrigin.PLUSEQ -> "plus"
IrStatementOrigin.MINUSEQ -> "minus"
IrStatementOrigin.MULTEQ -> "times"
IrStatementOrigin.DIVEQ -> "div"
IrStatementOrigin.PERCEQ -> "rem"
else -> null
}
fun getUpdateInPlaceRHS(origin: IrStatementOrigin?, isExpectedLhs: (IrExpression?) -> Boolean, updateRhs: IrExpression): IrExpression? {
// Check for a desugared in-place update operator, such as "v += e":
return getStatementOriginOperator(origin)?.let {
if (updateRhs is IrCall &&
isNumericFunction(updateRhs.symbol.owner, it)
) {
// Check for an expression like x = get(x).op(e):
val opReceiver = updateRhs.dispatchReceiver
if (isExpectedLhs(opReceiver)) {
updateRhs.getValueArgument(0)
} else null
} else null
} ?: null
}
fun writeUpdateInPlaceExpr(origin: IrStatementOrigin, tw: TrapWriter, id: Label<DbAssignexpr>, type: TypeResults, exprParent: ExprParent): Boolean {
when(origin) {
IrStatementOrigin.PLUSEQ -> tw.writeExprs_assignaddexpr(id as Label<DbAssignaddexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.MINUSEQ -> tw.writeExprs_assignsubexpr(id as Label<DbAssignsubexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.MULTEQ -> tw.writeExprs_assignmulexpr(id as Label<DbAssignmulexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.DIVEQ -> tw.writeExprs_assigndivexpr(id as Label<DbAssigndivexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.PERCEQ -> tw.writeExprs_assignremexpr(id as Label<DbAssignremexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
else -> return false
}
return true
}
fun tryExtractArrayUpdate(e: IrContainerExpression, callable: Label<out DbCallable>, parent: StmtExprParent): Boolean {
/*
* We're expecting the pattern
* {
* val array = e1
* val idx = e2
* array.set(idx, array.get(idx).op(e3))
* }
*
* If we find it, we'll extract e1[e2] op= e3 (op is +, -, ...)
*/
if(e.statements.size != 3)
return false
(e.statements[0] as? IrVariable)?.let { arrayVarDecl ->
arrayVarDecl.initializer?.let { arrayVarInitializer ->
(e.statements[1] as? IrVariable)?.let { indexVarDecl ->
indexVarDecl.initializer?.let { indexVarInitializer ->
(e.statements[2] as? IrCall)?.let { arraySetCall ->
if (isFunction(arraySetCall.symbol.owner, "kotlin", "(some array type)", { isArrayType(it) }, "set")) {
getUpdateInPlaceRHS(
e.origin, // Using e.origin not arraySetCall.origin here distinguishes a compiler-generated block from a user manually code that looks the same.
{ oldValue ->
oldValue is IrCall &&
isFunction(oldValue.symbol.owner, "kotlin", "(some array type)", { typeName -> isArrayType(typeName) }, "get") &&
(oldValue.dispatchReceiver as? IrGetValue)?.let {
receiverVal -> receiverVal.symbol.owner == arrayVarDecl.symbol.owner
} ?: false
},
arraySetCall.getValueArgument(1)!!
)?.let { updateRhs ->
// Create an assignment skeleton _ op= _
val exprParent = parent.expr(e, callable)
val assignId = tw.getFreshIdLabel<DbAssignexpr>()
val type = useType(arrayVarInitializer.type)
val locId = tw.getLocation(e)
tw.writeExprsKotlinType(assignId, type.kotlinResult.id)
tw.writeHasLocation(assignId, locId)
tw.writeCallableEnclosingExpr(assignId, callable)
tw.writeStatementEnclosingExpr(assignId, exprParent.enclosingStmt)
if (!writeUpdateInPlaceExpr(e.origin!!, tw, assignId, type, exprParent)) {
logger.errorElement("Unexpected origin", e)
return false
}
// Extract e1[e2]
val lhsId = tw.getFreshIdLabel<DbArrayaccess>()
val elementType = useType(updateRhs.type)
tw.writeExprs_arrayaccess(lhsId, elementType.javaResult.id, assignId, 0)
tw.writeExprsKotlinType(lhsId, elementType.kotlinResult.id)
tw.writeHasLocation(lhsId, locId)
tw.writeCallableEnclosingExpr(lhsId, callable)
tw.writeStatementEnclosingExpr(lhsId, exprParent.enclosingStmt)
extractExpressionExpr(arrayVarInitializer, callable, lhsId, 0, exprParent.enclosingStmt)
extractExpressionExpr(indexVarInitializer, callable, lhsId, 1, exprParent.enclosingStmt)
// Extract e3
extractExpressionExpr(updateRhs, callable, assignId, 1, exprParent.enclosingStmt)
return true
}
}
}
}
}
}
}
return false
}
fun extractExpressionStmt(e: IrExpression, callable: Label<out DbCallable>, parent: Label<out DbStmtparent>, idx: Int) {
extractExpression(e, callable, StmtParent(parent, idx))
}
@@ -1879,13 +1986,15 @@ open class KotlinFileExtractor(
}
}
is IrContainerExpression -> {
val stmtParent = parent.stmt(e, callable)
val id = tw.getFreshIdLabel<DbBlock>()
val locId = tw.getLocation(e)
tw.writeStmts_block(id, stmtParent.parent, stmtParent.idx, callable)
tw.writeHasLocation(id, locId)
e.statements.forEachIndexed { i, s ->
extractStatement(s, callable, id, i)
if(!tryExtractArrayUpdate(e, callable, parent)) {
val stmtParent = parent.stmt(e, callable)
val id = tw.getFreshIdLabel<DbBlock>()
val locId = tw.getLocation(e)
tw.writeStmts_block(id, stmtParent.parent, stmtParent.idx, callable)
tw.writeHasLocation(id, locId)
e.statements.forEachIndexed { i, s ->
extractStatement(s, callable, id, i)
}
}
}
is IrWhileLoop -> {
@@ -2170,34 +2279,14 @@ open class KotlinFileExtractor(
val rhsValue = e.value
// Check for a desugared in-place update operator, such as "v += e":
val expectedOperator = when (e.origin) {
IrStatementOrigin.PLUSEQ -> "plus"
IrStatementOrigin.MINUSEQ -> "minus"
IrStatementOrigin.MULTEQ -> "times"
IrStatementOrigin.DIVEQ -> "div"
IrStatementOrigin.PERCEQ -> "rem"
else -> null
}
val inPlaceUpdateRhs = expectedOperator?.let {
if (rhsValue is IrCall &&
isNumericFunction(rhsValue.symbol.owner, expectedOperator)
) {
// Check for an expression like x = get(x).op(e):
val opReceiver = rhsValue.dispatchReceiver
if (opReceiver is IrGetValue && opReceiver.symbol.owner == e.symbol.owner) {
rhsValue.getValueArgument(0)
} else null
} else null
}
val extractOrigin = if (inPlaceUpdateRhs == null) null else e.origin
when(extractOrigin) {
IrStatementOrigin.PLUSEQ -> tw.writeExprs_assignaddexpr(id as Label<DbAssignaddexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.MINUSEQ -> tw.writeExprs_assignsubexpr(id as Label<DbAssignsubexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.MULTEQ -> tw.writeExprs_assignmulexpr(id as Label<DbAssignmulexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.DIVEQ -> tw.writeExprs_assigndivexpr(id as Label<DbAssigndivexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
IrStatementOrigin.PERCEQ -> tw.writeExprs_assignremexpr(id as Label<DbAssignremexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
else -> tw.writeExprs_assignexpr(id, type.javaResult.id, exprParent.parent, exprParent.idx)
val inPlaceUpdateRhs = getUpdateInPlaceRHS(e.origin, { it is IrGetValue && it.symbol.owner == e.symbol.owner }, rhsValue)
if (inPlaceUpdateRhs != null) {
if (!writeUpdateInPlaceExpr(e.origin!!, tw, id, type, exprParent)) {
logger.errorElement("Unexpected origin", e)
return
}
} else {
tw.writeExprs_assignexpr(id, type.javaResult.id, exprParent.parent, exprParent.idx)
}
val lhsType = useType(e.symbol.owner.type)

View File

@@ -1,18 +1,22 @@
| arrayGetsSets.kt:12:3:12:7 | ...[...] | arrayGetsSets.kt:12:3:12:7 | ...=... |
| arrayGetsSets.kt:12:11:12:15 | ...[...] | arrayGetsSets.kt:12:3:12:7 | ...=... |
| arrayGetsSets.kt:13:3:13:7 | ...[...] | arrayGetsSets.kt:13:3:13:7 | ...=... |
| arrayGetsSets.kt:13:11:13:15 | ...[...] | arrayGetsSets.kt:13:3:13:7 | ...=... |
| arrayGetsSets.kt:14:3:14:7 | ...[...] | arrayGetsSets.kt:14:3:14:7 | ...=... |
| arrayGetsSets.kt:14:11:14:15 | ...[...] | arrayGetsSets.kt:14:3:14:7 | ...=... |
| arrayGetsSets.kt:15:3:15:7 | ...[...] | arrayGetsSets.kt:15:3:15:7 | ...=... |
| arrayGetsSets.kt:15:11:15:15 | ...[...] | arrayGetsSets.kt:15:3:15:7 | ...=... |
| arrayGetsSets.kt:16:3:16:7 | ...[...] | arrayGetsSets.kt:16:3:16:7 | ...=... |
| arrayGetsSets.kt:16:11:16:15 | ...[...] | arrayGetsSets.kt:16:3:16:7 | ...=... |
| arrayGetsSets.kt:17:3:17:7 | ...[...] | arrayGetsSets.kt:17:3:17:7 | ...=... |
| arrayGetsSets.kt:17:11:17:15 | ...[...] | arrayGetsSets.kt:17:3:17:7 | ...=... |
| arrayGetsSets.kt:18:3:18:7 | ...[...] | arrayGetsSets.kt:18:3:18:7 | ...=... |
| arrayGetsSets.kt:18:11:18:15 | ...[...] | arrayGetsSets.kt:18:3:18:7 | ...=... |
| arrayGetsSets.kt:19:3:19:7 | ...[...] | arrayGetsSets.kt:19:3:19:7 | ...=... |
| arrayGetsSets.kt:19:11:19:15 | ...[...] | arrayGetsSets.kt:19:3:19:7 | ...=... |
| arrayGetsSets.kt:20:3:20:7 | ...[...] | arrayGetsSets.kt:20:3:20:7 | ...=... |
| arrayGetsSets.kt:20:11:20:15 | ...[...] | arrayGetsSets.kt:20:3:20:7 | ...=... |
| arrayGetsSets.kt:12:3:12:7 | ...[...] | arrayGetsSets.kt:12:3:12:7 | ...=... | int[] | arrayGetsSets.kt:12:3:12:4 | a1 | arrayGetsSets.kt:12:6:12:6 | 0 |
| arrayGetsSets.kt:12:11:12:15 | ...[...] | arrayGetsSets.kt:12:3:12:7 | ...=... | int | arrayGetsSets.kt:12:11:12:12 | a1 | arrayGetsSets.kt:12:14:12:14 | 0 |
| arrayGetsSets.kt:13:3:13:7 | ...[...] | arrayGetsSets.kt:13:3:13:7 | ...=... | short[] | arrayGetsSets.kt:13:3:13:4 | a2 | arrayGetsSets.kt:13:6:13:6 | 0 |
| arrayGetsSets.kt:13:11:13:15 | ...[...] | arrayGetsSets.kt:13:3:13:7 | ...=... | short | arrayGetsSets.kt:13:11:13:12 | a2 | arrayGetsSets.kt:13:14:13:14 | 0 |
| arrayGetsSets.kt:14:3:14:7 | ...[...] | arrayGetsSets.kt:14:3:14:7 | ...=... | byte[] | arrayGetsSets.kt:14:3:14:4 | a3 | arrayGetsSets.kt:14:6:14:6 | 0 |
| arrayGetsSets.kt:14:11:14:15 | ...[...] | arrayGetsSets.kt:14:3:14:7 | ...=... | byte | arrayGetsSets.kt:14:11:14:12 | a3 | arrayGetsSets.kt:14:14:14:14 | 0 |
| arrayGetsSets.kt:15:3:15:7 | ...[...] | arrayGetsSets.kt:15:3:15:7 | ...=... | long[] | arrayGetsSets.kt:15:3:15:4 | a4 | arrayGetsSets.kt:15:6:15:6 | 0 |
| arrayGetsSets.kt:15:11:15:15 | ...[...] | arrayGetsSets.kt:15:3:15:7 | ...=... | long | arrayGetsSets.kt:15:11:15:12 | a4 | arrayGetsSets.kt:15:14:15:14 | 0 |
| arrayGetsSets.kt:16:3:16:7 | ...[...] | arrayGetsSets.kt:16:3:16:7 | ...=... | float[] | arrayGetsSets.kt:16:3:16:4 | a5 | arrayGetsSets.kt:16:6:16:6 | 0 |
| arrayGetsSets.kt:16:11:16:15 | ...[...] | arrayGetsSets.kt:16:3:16:7 | ...=... | float | arrayGetsSets.kt:16:11:16:12 | a5 | arrayGetsSets.kt:16:14:16:14 | 0 |
| arrayGetsSets.kt:17:3:17:7 | ...[...] | arrayGetsSets.kt:17:3:17:7 | ...=... | double[] | arrayGetsSets.kt:17:3:17:4 | a6 | arrayGetsSets.kt:17:6:17:6 | 0 |
| arrayGetsSets.kt:17:11:17:15 | ...[...] | arrayGetsSets.kt:17:3:17:7 | ...=... | double | arrayGetsSets.kt:17:11:17:12 | a6 | arrayGetsSets.kt:17:14:17:14 | 0 |
| arrayGetsSets.kt:18:3:18:7 | ...[...] | arrayGetsSets.kt:18:3:18:7 | ...=... | boolean[] | arrayGetsSets.kt:18:3:18:4 | a7 | arrayGetsSets.kt:18:6:18:6 | 0 |
| arrayGetsSets.kt:18:11:18:15 | ...[...] | arrayGetsSets.kt:18:3:18:7 | ...=... | boolean | arrayGetsSets.kt:18:11:18:12 | a7 | arrayGetsSets.kt:18:14:18:14 | 0 |
| arrayGetsSets.kt:19:3:19:7 | ...[...] | arrayGetsSets.kt:19:3:19:7 | ...=... | char[] | arrayGetsSets.kt:19:3:19:4 | a8 | arrayGetsSets.kt:19:6:19:6 | 0 |
| arrayGetsSets.kt:19:11:19:15 | ...[...] | arrayGetsSets.kt:19:3:19:7 | ...=... | char | arrayGetsSets.kt:19:11:19:12 | a8 | arrayGetsSets.kt:19:14:19:14 | 0 |
| arrayGetsSets.kt:20:3:20:7 | ...[...] | arrayGetsSets.kt:20:3:20:7 | ...=... | Object[] | arrayGetsSets.kt:20:3:20:4 | a9 | arrayGetsSets.kt:20:6:20:6 | 0 |
| arrayGetsSets.kt:20:11:20:15 | ...[...] | arrayGetsSets.kt:20:3:20:7 | ...=... | Object | arrayGetsSets.kt:20:11:20:12 | a9 | arrayGetsSets.kt:20:14:20:14 | 0 |
| arrayGetsSets.kt:32:3:32:7 | ...[...] | arrayGetsSets.kt:32:3:32:7 | ...+=... | int | arrayGetsSets.kt:32:3:32:4 | a1 | arrayGetsSets.kt:32:6:32:6 | 0 |
| arrayGetsSets.kt:38:3:38:7 | ...[...] | arrayGetsSets.kt:38:3:38:7 | .../=... | long | arrayGetsSets.kt:38:3:38:4 | a4 | arrayGetsSets.kt:38:6:38:6 | 0 |
| arrayGetsSets.kt:39:3:39:7 | ...[...] | arrayGetsSets.kt:39:3:39:7 | ...-=... | float | arrayGetsSets.kt:39:3:39:4 | a5 | arrayGetsSets.kt:39:6:39:6 | 0 |
| arrayGetsSets.kt:40:3:40:7 | ...[...] | arrayGetsSets.kt:40:3:40:7 | ...*=... | double | arrayGetsSets.kt:40:3:40:4 | a6 | arrayGetsSets.kt:40:6:40:6 | 0 |

View File

@@ -1,4 +1,4 @@
import java
from ArrayAccess aa
select aa, aa.getParent()
select aa, aa.getParent(), aa.getType().toString(), aa.getArray(), aa.getIndexExpr()

View File

@@ -20,3 +20,22 @@ fun arrayGetSet(
a9[0] = a9[0]
}
fun arrayGetSetInPlace(
a1: IntArray,
//a2: ShortArray,
//a3: ByteArray,
a4: LongArray,
a5: FloatArray,
a6: DoubleArray) {
a1[0] += 1
// Short and Byte's arithmetic operators yield an Int,
// so we don't have syntax to convert the result of the arithmetic op
// back to the right type.
//a2[0] %= 1.toShort()
//a3[0] *= 1.toByte()
a4[0] /= 1L
a5[0] -= 1f
a6[0] *= 1.0
}

View File

@@ -18,6 +18,7 @@ sourceSignatures
| arrayCreations.kt:27:24:27:38 | | |
| arrayCreations.kt:27:24:27:38 | invoke | invoke(int) |
| arrayGetsSets.kt:1:1:22:1 | arrayGetSet | arrayGetSet(int[],short[],byte[],long[],float[],double[],boolean[],char[],java.lang.Object[]) |
| arrayGetsSets.kt:24:1:41:1 | arrayGetSetInPlace | arrayGetSetInPlace(int[],long[],float[],double[]) |
| primitiveArrays.kt:3:1:7:1 | <obinit> | <obinit>() |
| primitiveArrays.kt:3:1:7:1 | Test | Test() |
| primitiveArrays.kt:5:3:5:123 | test | test(java.lang.Integer[],java.lang.Integer[],int[],java.lang.Integer[][],java.lang.Integer[][],int[][]) |