diff --git a/ql/src/semmle/go/frameworks/Protobuf.qll b/ql/src/semmle/go/frameworks/Protobuf.qll index dc99a00fc62..9447a46b062 100644 --- a/ql/src/semmle/go/frameworks/Protobuf.qll +++ b/ql/src/semmle/go/frameworks/Protobuf.qll @@ -140,12 +140,22 @@ module Protobuf { } /** - * Gets the data-flow node representing the bottom of a stack of zero or more `ComponentReadNode`s. + * Gets the base of `node`, looking through any dereference node found. + */ + private DataFlow::Node getBaseLookingThroughDerefs(DataFlow::ComponentReadNode node) { + result = node.getBase().(DataFlow::PointerDereferenceNode).getOperand() + or + result = node.getBase() and not node.getBase() instanceof DataFlow::PointerDereferenceNode + } + + /** + * Gets the data-flow node representing the bottom of a stack of zero or more `ComponentReadNode`s + * perhaps with interleaved dereferences. * * For example, in the expression a.b[c].d[e], this would return the dataflow node for the read from `a`. */ DataFlow::Node getUnderlyingNode(DataFlow::ReadNode read) { - (result = read or result = read.(DataFlow::ComponentReadNode).getBase+()) and + (result = read or result = getBaseLookingThroughDerefs+(read)) and not result instanceof DataFlow::ComponentReadNode } @@ -155,7 +165,9 @@ module Protobuf { private class WriteMessageFieldStep extends TaintTracking::AdditionalTaintStep { override predicate step(DataFlow::Node pred, DataFlow::Node succ) { [succ.getType(), succ.getType().getPointerType()] instanceof MessageType and - exists(DataFlow::ReadNode base | succ = getUnderlyingNode(base) | + exists(DataFlow::ReadNode base | + succ.(DataFlow::PostUpdateNode).getPreUpdateNode() = getUnderlyingNode(base) + | any(DataFlow::Write w).writesComponent(base, pred) ) } diff --git a/ql/test/library-tests/semmle/go/frameworks/Protobuf/testDeprecatedApi.go b/ql/test/library-tests/semmle/go/frameworks/Protobuf/testDeprecatedApi.go index be1e0aa77e3..ee998768a0b 100644 --- a/ql/test/library-tests/semmle/go/frameworks/Protobuf/testDeprecatedApi.go +++ b/ql/test/library-tests/semmle/go/frameworks/Protobuf/testDeprecatedApi.go @@ -167,3 +167,15 @@ func testTaintedMapFieldReadViaAlias() { sinkString((*alias)[123]) // BAD } + +func testTaintedSubmessageInPlaceNonPointerBase() { + alert := query.Query_Alert{} + + query := query.Query{} + query.Alerts = append(query.Alerts, &alert) + query.Alerts[0].Msg = getUntrustedString() + + serialized, _ := proto.Marshal(query) + + sinkBytes(serialized) // BAD (but not detected by our current analysis) +} diff --git a/ql/test/library-tests/semmle/go/frameworks/Protobuf/testModernApi.go b/ql/test/library-tests/semmle/go/frameworks/Protobuf/testModernApi.go index 070e6deeebd..6beac40a3ab 100644 --- a/ql/test/library-tests/semmle/go/frameworks/Protobuf/testModernApi.go +++ b/ql/test/library-tests/semmle/go/frameworks/Protobuf/testModernApi.go @@ -224,3 +224,15 @@ func testTaintedMapFieldReadViaAliasModern() { sinkString((*alias)[123]) // BAD } + +func testTaintedSubmessageInPlaceNonPointerBaseModern() { + alert := query.Query_Alert{} + + query := query.Query{} + query.Alerts = append(query.Alerts, &alert) + query.Alerts[0].Msg = getUntrustedString() + + serialized, _ := proto.Marshal(query) + + sinkBytes(serialized) // BAD (but not detected by our current implementation) +}