diff --git a/change-notes/2020-07-07-missing-error-check.md b/change-notes/2020-07-07-missing-error-check.md new file mode 100644 index 00000000000..9202e3bc9ec --- /dev/null +++ b/change-notes/2020-07-07-missing-error-check.md @@ -0,0 +1,2 @@ +lgtm,codescanning +* New query "Missing error check" (`go/missing-error-check`) added. This checks for dangerous pointer dereferences when an accompanying error value returned from a call has not been checked. diff --git a/ql/src/InconsistentCode/MissingErrorCheck.go b/ql/src/InconsistentCode/MissingErrorCheck.go new file mode 100644 index 00000000000..d44117920ea --- /dev/null +++ b/ql/src/InconsistentCode/MissingErrorCheck.go @@ -0,0 +1,17 @@ +package main + +import ( + "fmt" + "os" +) + +func user(input string) { + + ptr, err := os.Open(input) + // BAD: ptr is dereferenced before either it or `err` has been checked. + fmt.Printf("Opened %v\n", *ptr) + if err != nil { + fmt.Printf("Bad input: %s\n", input) + } + +} diff --git a/ql/src/InconsistentCode/MissingErrorCheck.qhelp b/ql/src/InconsistentCode/MissingErrorCheck.qhelp new file mode 100644 index 00000000000..d542a728cfe --- /dev/null +++ b/ql/src/InconsistentCode/MissingErrorCheck.qhelp @@ -0,0 +1,38 @@ + + + + +

When a function call returns two values, a pointer and a (subtype of) error, it is conventional to assume that the pointer +might be nil until either the pointer or error value has been checked.

+ +

If the pointer is dereferenced without a check, an unexpected nil pointer dereference panic may occur.

+
+ + +

Ensure that the returned pointer is either directly checked against nil, or the error value is checked before using +the returned pointer.

+ +
+ + +

In the example below, user dereferences ptr without checking either +ptr or err. This might lead to a panic.

+ + + +

The corrected version of user checks err before using ptr.

+ + + +
+ + +
  • + The Go Blog: + Error handling and Go. +
  • + +
    +
    diff --git a/ql/src/InconsistentCode/MissingErrorCheck.ql b/ql/src/InconsistentCode/MissingErrorCheck.ql new file mode 100644 index 00000000000..3b47e827752 --- /dev/null +++ b/ql/src/InconsistentCode/MissingErrorCheck.ql @@ -0,0 +1,120 @@ +/** + * @name Missing error check + * @description When a function returns a pointer alongside an error value, one should normally + * assume that the pointer may be nil until either the pointer or error has been checked. + * @kind problem + * @problem.severity warning + * @id go/missing-error-check + * @tags reliability + * correctness + * logic + * @precision high + */ + +import go + +/** + * Holds if `node` is a reference to the `nil` builtin constant. + */ +predicate isNil(DataFlow::Node node) { node = Builtin::nil().getARead() } + +/** + * Matches if `call` may return a nil pointer alongside an error value. + * + * This is both an over- and under-estimate: over in that we assume opaque functions may use this + * convention, and under in that functions with bodies are only recognized if they use a literal + * `nil` for the pointer return value at some return site. + */ +predicate calleeMayReturnNilWithError(DataFlow::CallNode call) { + not exists(call.getACallee()) + or + exists(FuncDef callee | callee = call.getACallee() | + not exists(callee.getBody()) + or + exists(IR::ReturnInstruction ret, DataFlow::Node ptrReturn, DataFlow::Node errReturn | + callee = ret.getRoot() and + ptrReturn = DataFlow::instructionNode(ret.getResult(0)) and + errReturn = DataFlow::instructionNode(ret.getResult(1)) and + isNil(ptrReturn) and + not isNil(errReturn) + ) + ) +} + +/** + * Matches if `type` is a pointer, slice or interface type, or an alias for such a type. + */ +predicate isDereferenceableType(Type maybePointer) { + exists(Type t | t = maybePointer.getUnderlyingType() | + t instanceof PointerType or t instanceof SliceType or t instanceof InterfaceType + ) +} + +/** + * Matches if `instruction` checks `value`. + * + * We consider testing value for equality (against anything), passing it as a parameter to + * a function call, switching on either its value or its type or casting it to constitute a + * check. + */ +predicate checksValue(IR::Instruction instruction, DataFlow::SsaNode value) { + exists(DataFlow::InstructionNode instNode | instNode.asInstruction() = instruction | + instNode.(DataFlow::CallNode).getAnArgument() = value.getAUse() or + instNode.(DataFlow::EqualityTestNode).getAnOperand() = value.getAUse() + ) + or + value.getAUse().asInstruction() = instruction and + ( + exists(ExpressionSwitchStmt s | instruction.(IR::EvalInstruction).getExpr() = s.getExpr()) + or + // This case accounts for both a type-switch or cast used to check `value` + exists(TypeAssertExpr e | instruction.(IR::EvalInstruction).getExpr() = e.getExpr()) + ) +} + +/** + * Matches if `call` is a function returning (`ptr`, `err`) where `ptr` may be nil, and neither + * `ptr` not `err` has been checked for validity as of `node`. + * + * This is initially true of any callsite that may call either an opaque function or a user-defined + * function that may return (nil, error), and is true of any downstream control-flow node where a + * check has not certainly been made against either `ptr` or `err`. + */ +predicate returnUncheckedAtNode( + DataFlow::CallNode call, ControlFlow::Node node, DataFlow::SsaNode ptr, DataFlow::SsaNode err +) { + ( + // Base case: check that `ptr` and `err` have appropriate types, and that the callee may return + // a nil pointer with an error. + ptr.getAPredecessor() = call.getResult(0) and + err.getAPredecessor() = call.getResult(1) and + call.asInstruction() = node and + isDereferenceableType(ptr.getType()) and + err.getType().implements(Builtin::error().getType().getUnderlyingType()) and + calleeMayReturnNilWithError(call) + or + // Recursive case: check that some predecessor is missing a check, and `node` does not itself + // check either `ptr` or `err`. + // localFlow is used to permit checks via either an SSA phi node or ordinary assignment. + returnUncheckedAtNode(call, node.getAPredecessor(), ptr, err) and + not exists(DataFlow::SsaNode checked | + DataFlow::localFlow(ptr, checked) or DataFlow::localFlow(err, checked) + | + checksValue(node, checked) + ) + ) +} + +from + DataFlow::CallNode call, DataFlow::SsaNode ptr, DataFlow::SsaNode err, + DataFlow::PointerDereferenceNode deref, ControlFlow::Node derefNode +where + // `derefNode` is a control-flow node corresponding to `deref` + deref.getOperand().asInstruction() = derefNode and + // neither `ptr` nor `err`, the return values of `call`, have been checked as of `derefNode` + returnUncheckedAtNode(call, derefNode, ptr, err) and + // `deref` dereferences `ptr` + deref.getOperand() = ptr.getAUse() +select deref.getOperand(), + ptr.getSourceVariable() + " may be nil here, because $@ may not have been checked.", err, + err.getSourceVariable().toString() diff --git a/ql/src/InconsistentCode/MissingErrorCheckGood.go b/ql/src/InconsistentCode/MissingErrorCheckGood.go new file mode 100644 index 00000000000..a7ffbe8a9ec --- /dev/null +++ b/ql/src/InconsistentCode/MissingErrorCheckGood.go @@ -0,0 +1,18 @@ +package main + +import ( + "fmt" + "os" +) + +func user(input string) { + + ptr, err := os.Open(input) + if err != nil { + fmt.Printf("Bad input: %s\n", input) + return + } + // GOOD: `err` has been checked before `ptr` is used + fmt.Printf("Result was %v\n", *ptr) + +} diff --git a/ql/test/query-tests/InconsistentCode/MissingErrorCheck/MissingErrorCheck.expected b/ql/test/query-tests/InconsistentCode/MissingErrorCheck/MissingErrorCheck.expected new file mode 100644 index 00000000000..ec1545a4a16 --- /dev/null +++ b/ql/test/query-tests/InconsistentCode/MissingErrorCheck/MissingErrorCheck.expected @@ -0,0 +1,2 @@ +| tests.go:61:30:61:35 | result | result may be nil here, because $@ may not have been checked. | tests.go:59:10:59:12 | definition of err | err | +| tests.go:243:27:243:32 | result | result may be nil here, because $@ may not have been checked. | tests.go:241:10:241:12 | definition of err | err | diff --git a/ql/test/query-tests/InconsistentCode/MissingErrorCheck/MissingErrorCheck.qlref b/ql/test/query-tests/InconsistentCode/MissingErrorCheck/MissingErrorCheck.qlref new file mode 100644 index 00000000000..519bdd54e68 --- /dev/null +++ b/ql/test/query-tests/InconsistentCode/MissingErrorCheck/MissingErrorCheck.qlref @@ -0,0 +1 @@ +InconsistentCode/MissingErrorCheck.ql diff --git a/ql/test/query-tests/InconsistentCode/MissingErrorCheck/tests.go b/ql/test/query-tests/InconsistentCode/MissingErrorCheck/tests.go new file mode 100644 index 00000000000..da60b272bbe --- /dev/null +++ b/ql/test/query-tests/InconsistentCode/MissingErrorCheck/tests.go @@ -0,0 +1,246 @@ +package test + +import ( + "errors" + "fmt" + "os" +) + +func returnsNonNil(input int) (*int, error) { + + newp := new(int) + *newp = 5 + + if input%2 == 0 { + return newp, nil + } else { + return newp, errors.New("oh no") + } + +} + +func userDefinedDie() { + + os.Exit(1) + +} + +func makesCheckUsingSwitch(fname string) { + + result, err := os.Open(fname) + + switch { + case len(os.Args) >= 3: + fmt.Println("Too many args") + return + case err != nil: + fmt.Println("Open failed") + return + } + + fmt.Printf("Opened: %v\n", *result) // OK + +} + +func definesValueInIf(fname string) { + + var result *os.File + var err error + if result, err = os.Open(fname); err != nil { + return + } + + fmt.Printf("Opened: %v\n", *result) // OK + +} + +func missingCheckMayFail(fname string) { + + result, err := os.Open(fname) + + fmt.Printf("Opened: %v\n", *result) // NOT OK + fmt.Printf("%v\n", err) // use err + +} + +func missingCheckSafe(input int) { + + result, err := returnsNonNil(input) + + fmt.Printf("Got: %d\n", *result) // OK + fmt.Printf("%v\n", err) // use err + +} + +func usesUserExitFn(fname string) { + + result, err := os.Open(fname) + if err != nil { + userDefinedDie() + } + + fmt.Printf("Opened: %v\n", *result) // OK + fmt.Printf("%v\n", err) // use err + +} + +func userTestFn(e error) bool { + return e != nil +} + +func usesUserTestFn(fname string) { + + result, err := os.Open(fname) + if userTestFn(err) { + return + } + + fmt.Printf("Opened: %v\n", *result) // OK + fmt.Printf("%v\n", err) // use err + +} + +func userRequireFn(e error) { + if e != nil { + os.Exit(1) + } +} + +func usesUserRequireFn(fname string) { + + result, err := os.Open(fname) + userRequireFn(err) + + fmt.Printf("Opened: %v\n", *result) // OK + fmt.Printf("%v\n", err) // use err + +} + +func userPtrTestFn(ptr *os.File) bool { + return ptr != nil +} + +func usesUserPtrTestFn(fname string) { + + result, err := os.Open(fname) + if userPtrTestFn(result) { + return + } + + fmt.Printf("Opened: %v\n", *result) // OK + fmt.Printf("%v\n", err) // use err + +} + +func userPtrRequireFn(ptr *os.File) { + if ptr != nil { + os.Exit(1) + } +} + +func usesUserPtrRequireFn(fname string) { + + result, err := os.Open(fname) + userPtrRequireFn(result) + + fmt.Printf("Opened: %v\n", *result) // OK + fmt.Printf("%v\n", err) // use err + +} + +func reusesErrorVar(fname string) { + + result, err := os.Open(fname) + if err == nil { + _, err = os.Open(fname) + } + if err != nil { + return + } + + fmt.Printf("Opened: %v\n", *result) // OK + fmt.Printf("%v\n", err) // use err + +} + +func neverReallyErrors() (*int, error) { + + newp := new(int) + *newp = 1 + return newp, nil + +} + +func callsNeverReallyErrors() { + + result, err := neverReallyErrors() + + fmt.Printf("Got: %d\n", *result) // OK + fmt.Printf("%v\n", err) // use err + +} + +func checksErrorViaPhiNode(fname string) { + + // Note 'result' must not be forwarded via a phi; + // the deref has to be of exactly the definition + // we're investigating, whereas the error check can + // be of any downstream SSA or ordinary copy. + result, err := os.Open(fname) + if len(fname)%3 == 0 { + _, err = os.Open(fname) + } + if err != nil { + return + } + + fmt.Printf("Opened: %v\n", *result) // OK + fmt.Printf("%v\n", err) // use err + +} + +func checksErrorViaCopy(fname string) { + + var result *os.File + var err error + var err2 error + result, err2 = os.Open(fname) + err = err2 + if err != nil { + return + } + + fmt.Printf("Opened: %v\n", *result) // OK + fmt.Printf("%v\n", err) // use err + +} + +type myError struct { + field int +} + +// Implement error interface: +func (err *myError) Error() string { + return "myError" +} + +func returnsMyError(input int) (*int, *myError) { + + if input%2 == 0 { + newp := new(int) + *newp = 5 + return newp, nil + } else { + return nil, &myError{} + } + +} + +func mishandlesMyError(input int) { + + result, err := returnsMyError(input) + + fmt.Printf("Got: %d\n", *result) // NOT OK + fmt.Printf("%v\n", err) // use err + +}