mirror of
https://github.com/github/codeql.git
synced 2025-12-20 18:56:32 +01:00
Extract array update operations
These are of the form arrExpr[indexExpr] op= rhs
This commit is contained in:
committed by
Ian Lynagh
parent
d9c72b1c04
commit
7cb6e19e44
@@ -1206,24 +1206,24 @@ open class KotlinFileExtractor(
|
||||
isFunction(target, "kotlin", "Double", fName)
|
||||
}
|
||||
|
||||
fun isArrayType(typeName: String) =
|
||||
when(typeName) {
|
||||
"Array" -> true
|
||||
"IntArray" -> true
|
||||
"ByteArray" -> true
|
||||
"ShortArray" -> true
|
||||
"LongArray" -> true
|
||||
"FloatArray" -> true
|
||||
"DoubleArray" -> true
|
||||
"CharArray" -> true
|
||||
"BooleanArray" -> true
|
||||
else -> false
|
||||
}
|
||||
|
||||
fun extractCall(c: IrCall, callable: Label<out DbCallable>, parent: Label<out DbExprparent>, idx: Int, enclosingStmt: Label<out DbStmt>) {
|
||||
with("call", c) {
|
||||
val target = c.symbol.owner
|
||||
|
||||
fun isArrayType(typeName: String) =
|
||||
when(typeName) {
|
||||
"Array" -> true
|
||||
"IntArray" -> true
|
||||
"ByteArray" -> true
|
||||
"ShortArray" -> true
|
||||
"LongArray" -> true
|
||||
"FloatArray" -> true
|
||||
"DoubleArray" -> true
|
||||
"CharArray" -> true
|
||||
"BooleanArray" -> true
|
||||
else -> false
|
||||
}
|
||||
|
||||
fun extractMethodAccess(syntacticCallTarget: IrFunction, extractMethodTypeArguments: Boolean = true, extractClassTypeArguments: Boolean = false) {
|
||||
val typeArgs =
|
||||
if (extractMethodTypeArguments)
|
||||
@@ -1782,6 +1782,113 @@ open class KotlinFileExtractor(
|
||||
}
|
||||
}
|
||||
|
||||
fun getStatementOriginOperator(origin: IrStatementOrigin?) = when (origin) {
|
||||
IrStatementOrigin.PLUSEQ -> "plus"
|
||||
IrStatementOrigin.MINUSEQ -> "minus"
|
||||
IrStatementOrigin.MULTEQ -> "times"
|
||||
IrStatementOrigin.DIVEQ -> "div"
|
||||
IrStatementOrigin.PERCEQ -> "rem"
|
||||
else -> null
|
||||
}
|
||||
|
||||
fun getUpdateInPlaceRHS(origin: IrStatementOrigin?, isExpectedLhs: (IrExpression?) -> Boolean, updateRhs: IrExpression): IrExpression? {
|
||||
// Check for a desugared in-place update operator, such as "v += e":
|
||||
return getStatementOriginOperator(origin)?.let {
|
||||
if (updateRhs is IrCall &&
|
||||
isNumericFunction(updateRhs.symbol.owner, it)
|
||||
) {
|
||||
// Check for an expression like x = get(x).op(e):
|
||||
val opReceiver = updateRhs.dispatchReceiver
|
||||
if (isExpectedLhs(opReceiver)) {
|
||||
updateRhs.getValueArgument(0)
|
||||
} else null
|
||||
} else null
|
||||
} ?: null
|
||||
}
|
||||
|
||||
fun writeUpdateInPlaceExpr(origin: IrStatementOrigin, tw: TrapWriter, id: Label<DbAssignexpr>, type: TypeResults, exprParent: ExprParent): Boolean {
|
||||
when(origin) {
|
||||
IrStatementOrigin.PLUSEQ -> tw.writeExprs_assignaddexpr(id as Label<DbAssignaddexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
|
||||
IrStatementOrigin.MINUSEQ -> tw.writeExprs_assignsubexpr(id as Label<DbAssignsubexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
|
||||
IrStatementOrigin.MULTEQ -> tw.writeExprs_assignmulexpr(id as Label<DbAssignmulexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
|
||||
IrStatementOrigin.DIVEQ -> tw.writeExprs_assigndivexpr(id as Label<DbAssigndivexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
|
||||
IrStatementOrigin.PERCEQ -> tw.writeExprs_assignremexpr(id as Label<DbAssignremexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
|
||||
else -> return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
fun tryExtractArrayUpdate(e: IrContainerExpression, callable: Label<out DbCallable>, parent: StmtExprParent): Boolean {
|
||||
/*
|
||||
* We're expecting the pattern
|
||||
* {
|
||||
* val array = e1
|
||||
* val idx = e2
|
||||
* array.set(idx, array.get(idx).op(e3))
|
||||
* }
|
||||
*
|
||||
* If we find it, we'll extract e1[e2] op= e3 (op is +, -, ...)
|
||||
*/
|
||||
if(e.statements.size != 3)
|
||||
return false
|
||||
(e.statements[0] as? IrVariable)?.let { arrayVarDecl ->
|
||||
arrayVarDecl.initializer?.let { arrayVarInitializer ->
|
||||
(e.statements[1] as? IrVariable)?.let { indexVarDecl ->
|
||||
indexVarDecl.initializer?.let { indexVarInitializer ->
|
||||
(e.statements[2] as? IrCall)?.let { arraySetCall ->
|
||||
if (isFunction(arraySetCall.symbol.owner, "kotlin", "(some array type)", { isArrayType(it) }, "set")) {
|
||||
getUpdateInPlaceRHS(
|
||||
e.origin, // Using e.origin not arraySetCall.origin here distinguishes a compiler-generated block from a user manually code that looks the same.
|
||||
{ oldValue ->
|
||||
oldValue is IrCall &&
|
||||
isFunction(oldValue.symbol.owner, "kotlin", "(some array type)", { typeName -> isArrayType(typeName) }, "get") &&
|
||||
(oldValue.dispatchReceiver as? IrGetValue)?.let {
|
||||
receiverVal -> receiverVal.symbol.owner == arrayVarDecl.symbol.owner
|
||||
} ?: false
|
||||
},
|
||||
arraySetCall.getValueArgument(1)!!
|
||||
)?.let { updateRhs ->
|
||||
// Create an assignment skeleton _ op= _
|
||||
val exprParent = parent.expr(e, callable)
|
||||
val assignId = tw.getFreshIdLabel<DbAssignexpr>()
|
||||
val type = useType(arrayVarInitializer.type)
|
||||
val locId = tw.getLocation(e)
|
||||
tw.writeExprsKotlinType(assignId, type.kotlinResult.id)
|
||||
tw.writeHasLocation(assignId, locId)
|
||||
tw.writeCallableEnclosingExpr(assignId, callable)
|
||||
tw.writeStatementEnclosingExpr(assignId, exprParent.enclosingStmt)
|
||||
|
||||
if (!writeUpdateInPlaceExpr(e.origin!!, tw, assignId, type, exprParent)) {
|
||||
logger.errorElement("Unexpected origin", e)
|
||||
return false
|
||||
}
|
||||
|
||||
// Extract e1[e2]
|
||||
val lhsId = tw.getFreshIdLabel<DbArrayaccess>()
|
||||
val elementType = useType(updateRhs.type)
|
||||
tw.writeExprs_arrayaccess(lhsId, elementType.javaResult.id, assignId, 0)
|
||||
tw.writeExprsKotlinType(lhsId, elementType.kotlinResult.id)
|
||||
tw.writeHasLocation(lhsId, locId)
|
||||
tw.writeCallableEnclosingExpr(lhsId, callable)
|
||||
tw.writeStatementEnclosingExpr(lhsId, exprParent.enclosingStmt)
|
||||
extractExpressionExpr(arrayVarInitializer, callable, lhsId, 0, exprParent.enclosingStmt)
|
||||
extractExpressionExpr(indexVarInitializer, callable, lhsId, 1, exprParent.enclosingStmt)
|
||||
|
||||
// Extract e3
|
||||
extractExpressionExpr(updateRhs, callable, assignId, 1, exprParent.enclosingStmt)
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
fun extractExpressionStmt(e: IrExpression, callable: Label<out DbCallable>, parent: Label<out DbStmtparent>, idx: Int) {
|
||||
extractExpression(e, callable, StmtParent(parent, idx))
|
||||
}
|
||||
@@ -1879,13 +1986,15 @@ open class KotlinFileExtractor(
|
||||
}
|
||||
}
|
||||
is IrContainerExpression -> {
|
||||
val stmtParent = parent.stmt(e, callable)
|
||||
val id = tw.getFreshIdLabel<DbBlock>()
|
||||
val locId = tw.getLocation(e)
|
||||
tw.writeStmts_block(id, stmtParent.parent, stmtParent.idx, callable)
|
||||
tw.writeHasLocation(id, locId)
|
||||
e.statements.forEachIndexed { i, s ->
|
||||
extractStatement(s, callable, id, i)
|
||||
if(!tryExtractArrayUpdate(e, callable, parent)) {
|
||||
val stmtParent = parent.stmt(e, callable)
|
||||
val id = tw.getFreshIdLabel<DbBlock>()
|
||||
val locId = tw.getLocation(e)
|
||||
tw.writeStmts_block(id, stmtParent.parent, stmtParent.idx, callable)
|
||||
tw.writeHasLocation(id, locId)
|
||||
e.statements.forEachIndexed { i, s ->
|
||||
extractStatement(s, callable, id, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
is IrWhileLoop -> {
|
||||
@@ -2170,34 +2279,14 @@ open class KotlinFileExtractor(
|
||||
val rhsValue = e.value
|
||||
|
||||
// Check for a desugared in-place update operator, such as "v += e":
|
||||
val expectedOperator = when (e.origin) {
|
||||
IrStatementOrigin.PLUSEQ -> "plus"
|
||||
IrStatementOrigin.MINUSEQ -> "minus"
|
||||
IrStatementOrigin.MULTEQ -> "times"
|
||||
IrStatementOrigin.DIVEQ -> "div"
|
||||
IrStatementOrigin.PERCEQ -> "rem"
|
||||
else -> null
|
||||
}
|
||||
val inPlaceUpdateRhs = expectedOperator?.let {
|
||||
if (rhsValue is IrCall &&
|
||||
isNumericFunction(rhsValue.symbol.owner, expectedOperator)
|
||||
) {
|
||||
// Check for an expression like x = get(x).op(e):
|
||||
val opReceiver = rhsValue.dispatchReceiver
|
||||
if (opReceiver is IrGetValue && opReceiver.symbol.owner == e.symbol.owner) {
|
||||
rhsValue.getValueArgument(0)
|
||||
} else null
|
||||
} else null
|
||||
}
|
||||
|
||||
val extractOrigin = if (inPlaceUpdateRhs == null) null else e.origin
|
||||
when(extractOrigin) {
|
||||
IrStatementOrigin.PLUSEQ -> tw.writeExprs_assignaddexpr(id as Label<DbAssignaddexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
|
||||
IrStatementOrigin.MINUSEQ -> tw.writeExprs_assignsubexpr(id as Label<DbAssignsubexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
|
||||
IrStatementOrigin.MULTEQ -> tw.writeExprs_assignmulexpr(id as Label<DbAssignmulexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
|
||||
IrStatementOrigin.DIVEQ -> tw.writeExprs_assigndivexpr(id as Label<DbAssigndivexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
|
||||
IrStatementOrigin.PERCEQ -> tw.writeExprs_assignremexpr(id as Label<DbAssignremexpr>, type.javaResult.id, exprParent.parent, exprParent.idx)
|
||||
else -> tw.writeExprs_assignexpr(id, type.javaResult.id, exprParent.parent, exprParent.idx)
|
||||
val inPlaceUpdateRhs = getUpdateInPlaceRHS(e.origin, { it is IrGetValue && it.symbol.owner == e.symbol.owner }, rhsValue)
|
||||
if (inPlaceUpdateRhs != null) {
|
||||
if (!writeUpdateInPlaceExpr(e.origin!!, tw, id, type, exprParent)) {
|
||||
logger.errorElement("Unexpected origin", e)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
tw.writeExprs_assignexpr(id, type.javaResult.id, exprParent.parent, exprParent.idx)
|
||||
}
|
||||
|
||||
val lhsType = useType(e.symbol.owner.type)
|
||||
|
||||
Reference in New Issue
Block a user