Merge pull request #21430 from aschackmull/csharp/switch-ast-simplify

C#: Disentangle SwitchStmt AST and CFG.
This commit is contained in:
Anders Schack-Mulligen
2026-03-10 10:23:37 +01:00
committed by GitHub
2 changed files with 119 additions and 88 deletions

View File

@@ -183,9 +183,10 @@ class SwitchStmt extends SelectionStmt, Switch, @switch_stmt {
* return 3;
* }
* ```
* Note that this reorders the `default` case to always be at the end.
*/
override CaseStmt getCase(int i) { result = SwithStmtInternal::getCase(this, i) }
override CaseStmt getCase(int i) {
result = rank[i + 1](CaseStmt cs, int idx | cs = this.getChildStmt(idx) | cs order by idx)
}
/** Gets a case of this `switch` statement. */
override CaseStmt getACase() { result = this.getCase(_) }
@@ -208,87 +209,29 @@ class SwitchStmt extends SelectionStmt, Switch, @switch_stmt {
* ```csharp
* switch (x) {
* case "abc": // i = 0
* return 0;
* case int i when i > 0: // i = 1
* return 1;
* case string s: // i = 2
* Console.WriteLine(s);
* return 2; // i = 3
* default: // i = 4
* return 3; // i = 5
* return 0; // i = 1
* case int i when i > 0: // i = 2
* return 1; // i = 3
* case string s: // i = 4
* Console.WriteLine(s); // i = 5
* return 2; // i = 6
* default: // i = 7
* return 3; // i = 8
* }
* ```
*
* Note that each non-`default` case is a labeled statement, so the statement
* that follows is a child of the labeled statement, and not the `switch` block.
*/
Stmt getStmt(int i) { result = SwithStmtInternal::getStmt(this, i) }
Stmt getStmt(int i) { result = this.getChildStmt(i) }
/** Gets a statement in the body of this `switch` statement. */
Stmt getAStmt() { result = this.getStmt(_) }
}
cached
private module SwithStmtInternal {
cached
CaseStmt getCase(SwitchStmt ss, int i) {
exists(int index, int rankIndex |
caseIndex(ss, result, index) and
rankIndex = i + 1 and
index = rank[rankIndex](int j, CaseStmt cs | caseIndex(ss, cs, j) | j)
)
}
/** Implicitly reorder case statements to put the default case last if needed. */
private predicate caseIndex(SwitchStmt ss, CaseStmt case, int index) {
exists(int i | case = ss.getChildStmt(i) |
if case instanceof DefaultCase
then index = max(int j | exists(ss.getChildStmt(j))) + 1
else index = i
)
}
cached
Stmt getStmt(SwitchStmt ss, int i) {
exists(int index, int rankIndex |
result = ss.getChildStmt(index) and
rankIndex = i + 1 and
index =
rank[rankIndex](int j, Stmt s |
// `getChild` includes both labeled statements and the targeted
// statements of labeled statement as separate children, but we
// only want the labeled statement
s = getLabeledStmt(ss, j)
|
j
)
)
}
private Stmt getLabeledStmt(SwitchStmt ss, int i) {
result = ss.getChildStmt(i) and
not result = any(CaseStmt cs).getBody()
}
}
/** A `case` statement. */
class CaseStmt extends Case, @case_stmt {
override Expr getExpr() { result = any(SwitchStmt ss | ss.getACase() = this).getExpr() }
override PatternExpr getPattern() { result = this.getChild(0) }
override Stmt getBody() {
exists(int i, Stmt next |
this = this.getParent().getChild(i) and
next = this.getParent().getChild(i + 1)
|
result = next and
not result instanceof CaseStmt
or
result = next.(CaseStmt).getBody()
)
}
/**
* Gets the condition on this case, if any. For example, the type case on line 3
* has no condition, and the type case on line 4 has condition `s.Length > 0`, in

View File

@@ -308,6 +308,93 @@ private class ConstructorTree extends ControlFlowTree instanceof Constructor {
}
}
cached
private module SwithStmtInternal {
// Reorders default to be last if needed
cached
CaseStmt getCase(SwitchStmt ss, int i) {
exists(int index, int rankIndex |
caseIndex(ss, result, index) and
rankIndex = i + 1 and
index = rank[rankIndex](int j, CaseStmt cs | caseIndex(ss, cs, j) | j)
)
}
/** Implicitly reorder case statements to put the default case last if needed. */
private predicate caseIndex(SwitchStmt ss, CaseStmt case, int index) {
exists(int i | case = ss.getChildStmt(i) |
if case instanceof DefaultCase
then index = max(int j | exists(ss.getChildStmt(j))) + 1
else index = i
)
}
/**
* Gets the `i`th statement in the body of this `switch` statement.
*
* Example:
*
* ```csharp
* switch (x) {
* case "abc": // i = 0
* return 0;
* case int i when i > 0: // i = 1
* return 1;
* case string s: // i = 2
* Console.WriteLine(s);
* return 2; // i = 3
* default: // i = 4
* return 3; // i = 5
* }
* ```
*
* Note that each non-`default` case is a labeled statement, so the statement
* that follows is a child of the labeled statement, and not the `switch` block.
*/
cached
Stmt getStmt(SwitchStmt ss, int i) {
exists(int index, int rankIndex |
result = ss.getChildStmt(index) and
rankIndex = i + 1 and
index =
rank[rankIndex](int j, Stmt s |
// `getChild` includes both labeled statements and the targeted
// statements of labeled statement as separate children, but we
// only want the labeled statement
s = getLabeledStmt(ss, j)
|
j
)
)
}
private Stmt getLabeledStmt(SwitchStmt ss, int i) {
result = ss.getChildStmt(i) and
not result = caseStmtGetBody(_)
}
}
private ControlFlowElement caseGetBody(Case c) {
result = c.getBody() or result = caseStmtGetBody(c)
}
private ControlFlowElement caseStmtGetBody(CaseStmt c) {
exists(int i, Stmt next |
c = c.getParent().getChild(i) and
next = c.getParent().getChild(i + 1)
|
result = next and
not result instanceof CaseStmt
or
result = caseStmtGetBody(next)
)
}
// Reorders default to be last if needed
private Case switchGetCase(Switch s, int i) {
result = s.(SwitchExpr).getCase(i) or result = SwithStmtInternal::getCase(s, i)
}
abstract private class SwitchTree extends ControlFlowTree instanceof Switch {
override predicate propagatesAbnormal(AstNode child) { child = super.getExpr() }
@@ -315,27 +402,27 @@ abstract private class SwitchTree extends ControlFlowTree instanceof Switch {
// Flow from last element of switch expression to first element of first case
last(super.getExpr(), pred, c) and
c instanceof NormalCompletion and
first(super.getCase(0), succ)
first(switchGetCase(this, 0), succ)
or
// Flow from last element of case pattern to next case
exists(Case case, int i | case = super.getCase(i) |
exists(Case case, int i | case = switchGetCase(this, i) |
last(case.getPattern(), pred, c) and
c.(MatchingCompletion).isNonMatch() and
first(super.getCase(i + 1), succ)
first(switchGetCase(this, i + 1), succ)
)
or
// Flow from last element of condition to next case
exists(Case case, int i | case = super.getCase(i) |
exists(Case case, int i | case = switchGetCase(this, i) |
last(case.getCondition(), pred, c) and
c instanceof FalseCompletion and
first(super.getCase(i + 1), succ)
first(switchGetCase(this, i + 1), succ)
)
}
}
abstract private class CaseTree extends ControlFlowTree instanceof Case {
final override predicate propagatesAbnormal(AstNode child) {
child in [super.getPattern().(ControlFlowElement), super.getCondition(), super.getBody()]
child in [super.getPattern().(ControlFlowElement), super.getCondition(), caseGetBody(this)]
}
override predicate succ(AstNode pred, AstNode succ, Completion c) {
@@ -348,13 +435,13 @@ abstract private class CaseTree extends ControlFlowTree instanceof Case {
first(super.getCondition(), succ)
else
// Flow from last element of pattern to first element of body
first(super.getBody(), succ)
first(caseGetBody(this), succ)
)
or
// Flow from last element of condition to first element of body
last(super.getCondition(), pred, c) and
c instanceof TrueCompletion and
first(super.getBody(), succ)
first(caseGetBody(this), succ)
}
}
@@ -1226,10 +1313,11 @@ module Statements {
c instanceof NormalCompletion
or
// A statement exits with a `break` completion
last(super.getStmt(_), last, c.(NestedBreakCompletion).getAnInnerCompatibleCompletion())
last(SwithStmtInternal::getStmt(this, _), last,
c.(NestedBreakCompletion).getAnInnerCompatibleCompletion())
or
// A statement exits abnormally
last(super.getStmt(_), last, c) and
last(SwithStmtInternal::getStmt(this, _), last, c) and
not c instanceof BreakCompletion and
not c instanceof NormalCompletion and
not any(LabeledStmtTree t |
@@ -1238,8 +1326,8 @@ module Statements {
or
// Last case exits with a non-match
exists(CaseStmt cs, int last_ |
last_ = max(int i | exists(super.getCase(i))) and
cs = super.getCase(last_)
last_ = max(int i | exists(SwithStmtInternal::getCase(this, i))) and
cs = SwithStmtInternal::getCase(this, last_)
|
last(cs.getPattern(), last, c) and
not c.(MatchingCompletion).isMatch()
@@ -1258,22 +1346,22 @@ module Statements {
c instanceof SimpleCompletion
or
// Flow from last element of non-`case` statement `i` to first element of statement `i+1`
exists(int i | last(super.getStmt(i), pred, c) |
not super.getStmt(i) instanceof CaseStmt and
exists(int i | last(SwithStmtInternal::getStmt(this, i), pred, c) |
not SwithStmtInternal::getStmt(this, i) instanceof CaseStmt and
c instanceof NormalCompletion and
first(super.getStmt(i + 1), succ)
first(SwithStmtInternal::getStmt(this, i + 1), succ)
)
or
// Flow from last element of `case` statement `i` to first element of statement `i+1`
exists(int i, Stmt body |
body = super.getStmt(i).(CaseStmt).getBody() and
body = caseStmtGetBody(SwithStmtInternal::getStmt(this, i)) and
// in case of fall-through cases, make sure to not jump from their shared body back
// to one of the fall-through cases
not body = super.getStmt(i + 1).(CaseStmt).getBody() and
not body = caseStmtGetBody(SwithStmtInternal::getStmt(this, i + 1)) and
last(body, pred, c)
|
c instanceof NormalCompletion and
first(super.getStmt(i + 1), succ)
first(SwithStmtInternal::getStmt(this, i + 1), succ)
)
}
}
@@ -1289,7 +1377,7 @@ module Statements {
not c.(MatchingCompletion).isMatch()
or
// Case body exits with any completion
last(super.getBody(), last, c)
last(caseStmtGetBody(this), last, c)
}
final override predicate succ(AstNode pred, AstNode succ, Completion c) {