C++: Use a rank aggregate for a much better implementation.

This commit is contained in:
Mathias Vorreiter Pedersen
2021-05-31 11:17:09 +02:00
parent b2bdf95a9d
commit b4e4c12d0f

View File

@@ -1136,20 +1136,9 @@ private predicate inForUpdate(Expr forUpdate, Expr child) {
exists(Expr mid | inForUpdate(forUpdate, mid) and child.getParent() = mid)
}
/**
* Holds if `next` is the first `SwitchCase` after `case` that is enclosed by the `BlockStmt` `b`.
*/
pragma[nomagic]
private predicate nextSwitchCaseInSameBlock(BlockStmt b, SwitchCase case, SwitchCase next) {
// If the next switch case is in the block `b` we're done.
case.getNextSwitchCase() = next and
next.getEnclosingBlock() = b
or
// Otherwise, skip past the next switch block when it's not enclosed in the block `b`.
exists(SwitchCase mid | mid = case.getNextSwitchCase() |
not mid.getEnclosingBlock() = b and
nextSwitchCaseInSameBlock(b, mid, next)
)
/** Gets the `rnk`'th `case` statement in `b`. */
private int indexOfSwitchCaseRank(BlockStmt b, int rnk) {
result = rank[rnk](int i | b.getStmt(i) instanceof SwitchCase)
}
/**
@@ -1347,19 +1336,14 @@ class SwitchCase extends Stmt, @stmt_switch_case {
* `default:` has results `{ x = 3; }, `x = 4;` and `break;`.
*/
Stmt getAStmt() {
exists(BlockStmt b, int i | this = b.getStmt(i) |
// This is the most usual case:
// We locate the next `SwitchCase` and pick a statement between this `SwitchCase` and the `SwitchCase`
// in the `j`'th position in the block `b`.
exists(int j |
nextSwitchCaseInSameBlock(b, this,
pragma[only_bind_into](b).getStmt(pragma[only_bind_into](j))) and
result = b.getStmt(any(int k | i < k and k < j))
)
exists(BlockStmt b, int rnk, int i |
b.getStmt(i) = this and
i = indexOfSwitchCaseRank(b, rnk)
|
pragma[only_bind_into](b).getStmt([i + 1 .. indexOfSwitchCaseRank(b, rnk + 1) - 1]) = result
or
// If there is no next switch case we pick any subsequent statement in the block `b`.
not nextSwitchCaseInSameBlock(b, this, _) and
result = b.getStmt(any(int k | i < k))
not exists(indexOfSwitchCaseRank(b, rnk + 1)) and
b.getStmt([i + 1 .. b.getNumStmt() + 1]) = result
)
}