Implement websockets models

This commit is contained in:
Joe Farebrother
2025-12-01 16:24:59 +00:00
parent 1c2d8bb70e
commit 384e17a4ef
9 changed files with 193 additions and 0 deletions

View File

@@ -89,6 +89,7 @@ private import semmle.python.frameworks.TRender
private import semmle.python.frameworks.Twisted
private import semmle.python.frameworks.Ujson
private import semmle.python.frameworks.Urllib3
private import semmle.python.frameworks.Websockets
private import semmle.python.frameworks.Xmltodict
private import semmle.python.frameworks.Yaml
private import semmle.python.frameworks.Yarl

View File

@@ -0,0 +1,83 @@
/**
* Provides definitions and modeling for the `websockets` PyPI package.
*
* See https://websockets.readthedocs.io/en/stable/
*/
private import python
private import semmle.python.dataflow.new.RemoteFlowSources
private import semmle.python.Concepts
private import semmle.python.ApiGraphs
private import semmle.python.frameworks.internal.PoorMansFunctionResolution
private import semmle.python.frameworks.internal.InstanceTaintStepsHelper
/**
* Provides models for the `websockets` PyPI package.
* See https://websockets.readthedocs.io/en/stable/
*/
module Websockets {
private class HandlerArg extends DataFlow::Node {
HandlerArg() {
exists(DataFlow::CallCfgNode c |
c =
API::moduleImport("websockets")
.getMember(["asyncio", "sync"])
.getMember("server")
.getMember(["serve", "unix_serve"])
.getACall()
|
(this = c.getArg(0) or this = c.getArgByName("handler"))
)
}
}
/** A websocket handler that is passed to `serve`. */
// TODO: handlers defined via route maps, e.g. through `websockets.asyncio.router.route`, are more complex to handle.
class WebSocketHandler extends Http::Server::RequestHandler::Range {
WebSocketHandler() { poorMansFunctionTracker(this) = any(HandlerArg a) }
override Parameter getARoutedParameter() { result = this.getAnArg() }
override string getFramework() { result = "websockets" }
}
module ServerConnection {
/**
* A source of instances of `websockets.asyncio.ServerConnection` and `websockets.threading.ServerConnection`, extend this class to model new instances.
*
* This can include instantiations of the class, return values from function
* calls, or a special parameter that will be set when functions are called by an external
* library.
*
* Use the predicate `WebSocket::instance()` to get references to instances of `websockets.asyncio.ServerConnection` and `websockets.threading.ServerConnection`.
*/
abstract class InstanceSource extends DataFlow::LocalSourceNode { }
/** Gets a reference to an instance of `websockets.asyncio.ServerConnection` or `websockets.threading.ServerConnection`. */
private DataFlow::TypeTrackingNode instance(DataFlow::TypeTracker t) {
t.start() and
result instanceof InstanceSource
or
exists(DataFlow::TypeTracker t2 | result = instance(t2).track(t2, t))
}
/** Gets a reference to an instance of `websockets.asyncio.ServerConnection` or `websockets.threading.ServerConnection`. */
DataFlow::Node instance() { instance(DataFlow::TypeTracker::end()).flowsTo(result) }
private class HandlerParam extends DataFlow::Node, InstanceSource {
HandlerParam() { exists(WebSocketHandler h | this = DataFlow::parameterNode(h.getArg(0))) }
}
private class InstanceTaintSteps extends InstanceTaintStepsHelper {
InstanceTaintSteps() { this = "websockets.asyncio.ServerConnection" }
override DataFlow::Node getInstance() { result = instance() }
override string getAttributeName() { none() }
override string getAsyncMethodName() { result = ["recv", "recv_streaming"] }
override string getMethodName() { result = ["recv", "recv_streaming"] }
}
}
}

View File

@@ -0,0 +1,2 @@
import python
import experimental.meta.ConceptsTest

View File

@@ -0,0 +1,3 @@
argumentToEnsureNotTaintedNotMarkedAsSpurious
untaintedArgumentToEnsureTaintedNotMarkedAsMissing
testFailures

View File

@@ -0,0 +1,2 @@
import experimental.meta.InlineTaintTest
import MakeInlineTaintTest<TestTaintTrackingConfig>

View File

@@ -0,0 +1,46 @@
import websockets.sync.server
import websockets.sync.router
from werkzeug.routing import Map, Rule
def arg_handler(websocket): # $ requestHandler routedParameter=websocket
websocket.send("arg" + websocket.recv())
s1 = websockets.sync.server.serve(arg_handler, "localhost", 8000)
def kw_handler(websocket): # $ requestHandler routedParameter=websocket
websocket.send("kw" + websocket.recv())
s2 = websockets.sync.server.serve(handler=kw_handler, host="localhost", port=8001)
def route_handler(websocket, x): # $ MISSING: requestHandler routedParameter=websocket routedParameter=x
websocket.send(f"route {x} {websocket.recv()}")
s3 = websockets.sync.router.route(Map([
Rule("/<string:x>", endpoint=route_handler)
]), "localhost", 8002)
def unix_handler(websocket): # $ requestHandler routedParameter=websocket
websocket.send("unix" + websocket.recv())
s4 = websockets.sync.server.unix_serve(unix_handler, path="/tmp/ws.sock")
def unix_route_handler(websocket, x): # $ MISSING: requestHandler routedParameter=websocket routedParameter=x
websocket.send(f"unix route {x} {websocket.recv()}")
s5 = websockets.sync.router.unix_route(Map([
Rule("/<string:x>", endpoint=unix_route_handler)
]), path="/tmp/ws2.sock")
if __name__ == "__main__":
import sys
server = s1
if len(sys.argv) > 1:
if sys.argv[1] == "kw":
server = s2
elif sys.argv[1] == "route":
server = s3
elif sys.argv[1] == "unix":
server = s4
elif sys.argv[1] == "unix_route":
server = s5
server.serve_forever()

View File

@@ -0,0 +1,30 @@
import websockets.asyncio.server
import asyncio
def ensure_tainted(*args):
print("tainted", args)
def ensure_not_tainted(*args):
print("not tainted", args)
async def handler(websocket): # $ requestHandler routedParameter=websocket
ensure_tainted(
websocket, # $ tainted
await websocket.recv() # $ tainted
)
async for msg in websocket:
ensure_tainted(msg) # $ tainted
await websocket.send(msg)
async for msg in websocket.recv_streaming():
ensure_tainted(msg) # $ tainted
await websocket.send(msg)
async def main():
server = await websockets.asyncio.server.serve(handler, "localhost", 8000)
await server.serve_forever()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,26 @@
import websockets.sync.server
def ensure_tainted(*args):
print("tainted", args)
def ensure_not_tainted(*args):
print("not tainted", args)
def handler(websocket): # $ requestHandler routedParameter=websocket
ensure_tainted(
websocket, # $ tainted
websocket.recv() # $ tainted
)
for msg in websocket:
ensure_tainted(msg) # $ tainted
websocket.send(msg)
for msg in websocket.recv_streaming():
ensure_tainted(msg) # $ tainted
websocket.send(msg)
if __name__ == "__main__":
server = websockets.sync.server.serve(handler, "localhost", 8000)
server.serve_forever()