Avoid quadratic switch case intermediate

This commit is contained in:
Chris Smowton
2023-11-24 17:26:31 +00:00
parent d1e16ada4c
commit 9b5b496462
3 changed files with 64 additions and 37 deletions

View File

@@ -82,6 +82,7 @@
import java
private import Completion
private import controlflow.internal.Preconditions
private import controlflow.internal.SwitchCases
/** A node in the expression-level control-flow graph. */
class ControlFlowNode extends Top, @exprparent {
@@ -436,34 +437,6 @@ private module ControlFlowGraphImpl {
)
}
/**
* Gets the `i`th `SwitchCase` defined on `switch`, if one exists.
*/
private SwitchCase getCase(StmtParent switch, int i) {
result = switch.(SwitchExpr).getCase(i) or result = switch.(SwitchStmt).getCase(i)
}
/**
* Gets the `i`th `PatternCase` defined on `switch`, if one exists.
*/
private PatternCase getPatternCase(StmtParent switch, int i) {
result =
rank[i + 1](PatternCase pc, int caseIdx | pc = getCase(switch, caseIdx) | pc order by caseIdx)
}
/**
* Gets the PatternCase after pc, if one exists.
*/
private PatternCase getNextPatternCase(PatternCase pc) {
exists(int idx, StmtParent switch |
pc = getPatternCase(switch, idx) and result = getPatternCase(switch, idx + 1)
)
}
private int lastCaseIndex(StmtParent switch) {
result = max(int i | any(SwitchCase c).isNthCaseOf(switch, i))
}
// Join order engineering -- first determine the switch block and the case indices required, then retrieve them.
bindingset[switch, i]
pragma[inline_late]

View File

@@ -7,6 +7,7 @@ import java
private import semmle.code.java.controlflow.Dominance
private import semmle.code.java.controlflow.internal.GuardsLogic
private import semmle.code.java.controlflow.internal.Preconditions
private import semmle.code.java.controlflow.internal.SwitchCases
/**
* A basic block that terminates in a condition, splitting the subsequent control flow.
@@ -72,6 +73,35 @@ class ConditionBlock extends BasicBlock {
}
}
// Join order engineering -- first determine the switch block and the case indices required, then retrieve them.
bindingset[switch, i]
pragma[inline_late]
private predicate isNthCaseOf(StmtParent switch, SwitchCase c, int i) { c.isNthCaseOf(switch, i) }
/**
* Gets a switch case >= pred, up to but not including `pred`'s successor pattern case,
* where `pred` is declared on `switch`.
*/
private SwitchCase getACaseUpToNextPattern(PatternCase pred, StmtParent switch) {
// Note we do include `case null, default` (as well as plain old `default`) here.
not result.(ConstCase).getValue(_) instanceof NullLiteral and
exists(int maxCaseIndex |
switch = pred.getParent() and
if exists(getNextPatternCase(pred))
then maxCaseIndex = getNextPatternCase(pred).getCaseIndex() - 1
else maxCaseIndex = lastCaseIndex(switch)
|
isNthCaseOf(switch, result, [pred.getCaseIndex() .. maxCaseIndex])
)
}
/**
* Gets the closest pattern case preceding `case`, including `case` itself, if any.
*/
private PatternCase getClosestPrecedingPatternCase(SwitchCase case) {
case = getACaseUpToNextPattern(result, _)
}
/**
* A condition that can be evaluated to either true or false. This can either
* be an `Expr` of boolean type that isn't a boolean literal, or a case of a
@@ -113,17 +143,10 @@ class Guard extends ExprParent {
result = this.(Expr).getBasicBlock()
or
// Return the closest pattern case statement before this one, including this one.
result =
max(int i, PatternCase c |
c = this.(SwitchCase).getSiblingCase(i) and i <= this.(SwitchCase).getCaseIndex()
|
c order by i
).getBasicBlock()
result = getClosestPrecedingPatternCase(this).getBasicBlock()
or
// Not a pattern case and no preceding pattern case -- return the top of the switch block.
not exists(PatternCase c, int i |
c = this.(SwitchCase).getSiblingCase(i) and i <= this.(SwitchCase).getCaseIndex()
) and
not exists(getClosestPrecedingPatternCase(this)) and
result = this.(SwitchCase).getSelectorExpr().getBasicBlock()
}

View File

@@ -0,0 +1,31 @@
/** Provides utility predicates relating to switch cases. */
import java
/**
* Gets the `i`th `SwitchCase` defined on `switch`, if one exists.
*/
SwitchCase getCase(StmtParent switch, int i) {
result = switch.(SwitchExpr).getCase(i) or result = switch.(SwitchStmt).getCase(i)
}
/**
* Gets the `i`th `PatternCase` defined on `switch`, if one exists.
*/
PatternCase getPatternCase(StmtParent switch, int i) {
result =
rank[i + 1](PatternCase pc, int caseIdx | pc = getCase(switch, caseIdx) | pc order by caseIdx)
}
/**
* Gets the PatternCase after pc, if one exists.
*/
PatternCase getNextPatternCase(PatternCase pc) {
exists(int idx, StmtParent switch |
pc = getPatternCase(switch, idx) and result = getPatternCase(switch, idx + 1)
)
}
int lastCaseIndex(StmtParent switch) {
result = max(int i | any(SwitchCase c).isNthCaseOf(switch, i))
}