diff --git a/go/ql/lib/semmle/go/controlflow/ControlFlowGraphShared.qll b/go/ql/lib/semmle/go/controlflow/ControlFlowGraphShared.qll index 7ef3f5034f7..9dd4326dcd7 100644 --- a/go/ql/lib/semmle/go/controlflow/ControlFlowGraphShared.qll +++ b/go/ql/lib/semmle/go/controlflow/ControlFlowGraphShared.qll @@ -1280,19 +1280,67 @@ module GoCfg { ) } + /** + * Holds if `sw` is a tagless expression switch, in which case the case + * expressions are themselves the boolean conditions being tested and are + * therefore in a boolean conditional context. + */ + private predicate isBooleanSwitch(Go::SwitchStmt sw) { + sw instanceof Go::ExpressionSwitchStmt and + not exists(sw.(Go::ExpressionSwitchStmt).getExpr()) + } + + /** + * Holds if `n` is the control-flow node immediately after evaluating case + * expression `caseExpr` of switch `sw` on the branch where `caseExpr` + * matches. + */ + private predicate afterCaseExprMatch(Go::SwitchStmt sw, Go::Expr caseExpr, PreControlFlowNode n) { + caseExpr = sw.getACase().(Go::CaseClause).getAnExpr() and + ( + isBooleanSwitch(sw) and n.isAfterTrue(caseExpr) + or + not isBooleanSwitch(sw) and n.isAfter(caseExpr) + ) + } + + /** + * Holds if `n` is the control-flow node immediately after evaluating case + * expression `caseExpr` of switch `sw` on the branch where `caseExpr` + * does not match. + */ + private predicate afterCaseExprNoMatch( + Go::SwitchStmt sw, Go::Expr caseExpr, PreControlFlowNode n + ) { + caseExpr = sw.getACase().(Go::CaseClause).getAnExpr() and + ( + isBooleanSwitch(sw) and n.isAfterFalse(caseExpr) + or + not isBooleanSwitch(sw) and n.isAfter(caseExpr) + ) + } + private predicate caseClause(PreControlFlowNode n1, PreControlFlowNode n2) { exists(Go::SwitchStmt sw, Go::CaseClause cc, int i | cc = sw.getNonDefaultCase(i) | n1.isBefore(cc) and n2.isBefore(cc.getExpr(0)) or - exists(int j | n1.isAfter(cc.getExpr(j)) and n2.isBefore(cc.getExpr(j + 1))) + // For a tagless expression switch the case expressions are themselves + // booleans in a conditional context, so we only fall through to the + // next case expression on the false branch. + exists(int j | + afterCaseExprNoMatch(sw, cc.getExpr(j), n1) and n2.isBefore(cc.getExpr(j + 1)) + ) or exists(int last | last = max(int j | exists(cc.getExpr(j))) | - n1.isAfter(cc.getExpr(last)) and + afterCaseExprMatch(sw, cc.getExpr(last), n1) and ( n2.isBefore(cc.getStmt(0)) or not exists(cc.getStmt(0)) and n2.isAfter(sw) - or + ) + or + afterCaseExprNoMatch(sw, cc.getExpr(last), n1) and + ( n2.isBefore(sw.getNonDefaultCase(i + 1)) or not exists(sw.getNonDefaultCase(i + 1)) and n2.isBefore(sw.getDefault())