diff --git a/ql/src/semmle/go/dataflow/internal/DataFlowUtil.qll b/ql/src/semmle/go/dataflow/internal/DataFlowUtil.qll index bdb4ab046b5..544abc75479 100644 --- a/ql/src/semmle/go/dataflow/internal/DataFlowUtil.qll +++ b/ql/src/semmle/go/dataflow/internal/DataFlowUtil.qll @@ -1322,46 +1322,69 @@ private predicate onlyPossibleReturnOfNil(FuncDecl fd, FunctionOutput res, DataF } /** - * Gets a predecessor of `succ` without following edges corresponding to - * passing constant case tests in the switch block which is switching on - * `protectedExpr`. + * Holds if data flows from `node` to `switchExprNode`, which is the expression + * of a switch statement. */ -private ControlFlow::Node getANonTestPassingPredecessor(ControlFlow::Node succ, Expr protectedExpr) { +private predicate flowsToSwitchExpression(Node node, Node switchExprNode) { + switchExprNode.asExpr() = any(ExpressionSwitchStmt ess).getExpr() and + localFlow(node, switchExprNode) +} + +/** + * Holds if `inputNode` is the exit node of a parameter to `fd` and data flows + * from `inputNode` to the expression of a switch statement. + */ +private predicate isPossibleInputNode(Node inputNode, FuncDef fd) { + inputNode = any(FunctionInput inp | inp.isParameter(_)).getExitNode(fd) and + flowsToSwitchExpression(inputNode, _) +} + +/** + * Gets a predecessor of `succ` without following edges corresponding to + * passing a constant case test in a switch statement which is switching on + * an expression which data flows to from `inputNode`. + */ +private ControlFlow::Node getANonTestPassingPredecessor(ControlFlow::Node succ, Node inputNode) { + isPossibleInputNode(inputNode, succ.getRoot().(FuncDef)) and result = succ.getAPredecessor() and - not exists(Expr testExpr | - ControlFlow::isSwitchCaseTestPassingEdge(result, succ, protectedExpr, testExpr) and + not exists(Expr testExpr, Node switchExprNode | + flowsToSwitchExpression(inputNode, switchExprNode) and + ControlFlow::isSwitchCaseTestPassingEdge(result, succ, switchExprNode.asExpr(), testExpr) and testExpr.isConst() ) } private ControlFlow::Node getANonTestPassingReachingNodeRecursive( - ControlFlow::Node n, Expr protectedExpr + ControlFlow::Node n, Node inputNode ) { - result = n or - result = - getANonTestPassingReachingNodeRecursive(getANonTestPassingPredecessor(n, protectedExpr), - protectedExpr) + isPossibleInputNode(inputNode, n.getRoot().(FuncDef)) and + ( + result = n or + result = + getANonTestPassingReachingNodeRecursive(getANonTestPassingPredecessor(n, inputNode), inputNode) + ) } /** * Gets a node by following predecessors from `ret` without following edges - * corresponding to passing constant test cases in switch blocks. + * corresponding to passing a constant case test in a switch statement which is + * switching on an expression which data flows to from `inputNode`. */ private ControlFlow::Node getANonTestPassingReachingNodeBase( - IR::ReturnInstruction ret, Expr protectedExpr + IR::ReturnInstruction ret, Node inputNode ) { - protectedExpr.getEnclosingFunction() = ret.getReturnStmt().getEnclosingFunction() and - result = getANonTestPassingReachingNodeRecursive(ret, protectedExpr) + result = getANonTestPassingReachingNodeRecursive(ret, inputNode) } /** * Holds if every way to get from the entry node of the function to `ret` - * involves passing a constant test case in the switch statement switching on - * `protectedExpr`. + * involves passing a constant test case in a switch statement which is + * switching on an expression which data flows to from `inputNode`. */ -private predicate mustPassConstantCaseTestToReach(IR::ReturnInstruction ret, Expr protectedExpr) { +private predicate mustPassConstantCaseTestToReach(IR::ReturnInstruction ret, Node inputNode) { + isPossibleInputNode(inputNode, ret.getRoot().(FuncDef)) and not exists(ControlFlow::Node entry | entry = ret.getRoot().getEntryNode() | - entry = getANonTestPassingReachingNodeBase(ret, protectedExpr) + entry = getANonTestPassingReachingNodeBase(ret, inputNode) ) } @@ -1369,48 +1392,46 @@ private predicate mustPassConstantCaseTestToReach(IR::ReturnInstruction ret, Exp * Holds if whenever `outp` of function `f` satisfies `p`, the input `inp` of * `f` matched a constant in a case clause of a switch statement. * - * We check this by looking for guards on `inp` that collectively dominate a - * `return` statement that is the only `return` in `f` that can return `true`. - * This means that if `f` returns `true`, one of the guards must have been - * satisfied. (Similar reasoning is applied for statements returning `false`, - * `nil` or a non-`nil` value.) + * We check this by looking for guards on `inp` that collectively dominate all + * the `return` statements in `f` that can return `true`. This means that if + * `f` returns `true`, one of the guards must have been satisfied. (Similar + * reasoning is applied for statements returning `false`, `nil` or a non-`nil` + * value.) */ predicate functionEnsuresInputIsConstant( Function f, FunctionInput inp, FunctionOutput outp, DataFlow::Property p ) { - outp.isResult(_) and - exists(FuncDecl fd, ExpressionSwitchStmt ess, Node exprNode | - fd.getFunction() = f and - exprNode.asExpr() = ess.getExpr() and - localFlow(inp.getExitNode(fd), exprNode) - | + exists(FuncDecl fd | fd.getFunction() = f | exists(boolean b | p.isBoolean(b) and forex(DataFlow::Node ret, IR::ReturnInstruction ri | - ret = outp.getEntryNode(f.getFuncDecl()) and + ret = outp.getEntryNode(fd) and ri.getReturnStmt().getAnExpr() = ret.asExpr() and possiblyReturnsBool(fd, outp, ret, b) | - mustPassConstantCaseTestToReach(ri, ess.getExpr()) + mustPassConstantCaseTestToReach(ri, inp.getExitNode(fd)) ) ) or p.isNonNil() and forex(DataFlow::Node ret, IR::ReturnInstruction ri | - ret = outp.getEntryNode(f.getFuncDecl()) and + ret = outp.getEntryNode(fd) and ri.getReturnStmt().getAnExpr() = ret.asExpr() and possiblyReturnsNonNil(fd, outp, ret) | - mustPassConstantCaseTestToReach(ri, ess.getExpr()) + mustPassConstantCaseTestToReach(ri, inp.getExitNode(fd)) ) or p.isNil() and forex(DataFlow::Node ret, IR::ReturnInstruction ri | - ret = outp.getEntryNode(f.getFuncDecl()) and + ret = outp.getEntryNode(fd) and ri.getReturnStmt().getAnExpr() = ret.asExpr() and ret.asExpr() = Builtin::nil().getAReference() | - mustPassConstantCaseTestToReach(ri, ess.getExpr()) + exists(Node exprNode | + localFlow(inp.getExitNode(fd), exprNode) and + mustPassConstantCaseTestToReach(ri, inp.getExitNode(fd)) + ) ) ) }