diff --git a/python/ql/lib/semmle/python/frameworks/Asyncpg.qll b/python/ql/lib/semmle/python/frameworks/Asyncpg.qll index 124ca2e128e..0b3af86abca 100644 --- a/python/ql/lib/semmle/python/frameworks/Asyncpg.qll +++ b/python/ql/lib/semmle/python/frameworks/Asyncpg.qll @@ -18,7 +18,7 @@ private module Asyncpg { /** * A `Connection` is created when * - the result of `asyncpg.connect()` is awaited. - * - the result of calling `aquire` on a c=`ConnectionPool` is awaited. + * - the result of calling `aquire` on a `ConnectionPool` is awaited. */ API::Node connection() { result = API::moduleImport("asyncpg").getMember("connect").getReturn().getAwaited() @@ -112,15 +112,17 @@ private module Asyncpg { * TODO: Rewrite this, once we have `API::CallNode` available. */ module PreparedStatement { + class PreparedStatementConstruction extends SqlConstruction::Range, DataFlow::CallCfgNode { + PreparedStatementConstruction() { this = connection().getMember("prepare").getACall() } + + override DataFlow::Node getSql() { result in [this.getArg(0), this.getArgByName("query")] } + } + private DataFlow::TypeTrackingNode preparedStatementFactory( DataFlow::TypeTracker t, DataFlow::Node sql ) { t.start() and - result = connection().getMember("prepare").getACall() and - sql in [ - result.(DataFlow::CallCfgNode).getArg(0), - result.(DataFlow::CallCfgNode).getArgByName("query") - ] + sql = result.(PreparedStatementConstruction).getSql() or exists(DataFlow::TypeTracker t2 | result = preparedStatementFactory(t2, sql).track(t2, t)) } @@ -163,14 +165,16 @@ private module Asyncpg { * TODO: Rewrite this, once we have `API::CallNode` available. */ module Cursor { + class CursorConstruction extends SqlConstruction::Range, DataFlow::CallCfgNode { + CursorConstruction() { this = connection().getMember("cursor").getACall() } + + override DataFlow::Node getSql() { result in [this.getArg(0), this.getArgByName("query")] } + } + private DataFlow::TypeTrackingNode cursorFactory(DataFlow::TypeTracker t, DataFlow::Node sql) { // cursor created from connection t.start() and - result = connection().getMember("cursor").getACall() and - sql in [ - result.(DataFlow::CallCfgNode).getArg(0), - result.(DataFlow::CallCfgNode).getArgByName("query") - ] + sql = result.(CursorConstruction).getSql() or // cursor created from prepared statement t.start() and diff --git a/python/ql/test/library-tests/frameworks/asyncpg/SqlExecution.py b/python/ql/test/library-tests/frameworks/asyncpg/SqlExecution.py index e9e619afa66..0ffaaacba62 100644 --- a/python/ql/test/library-tests/frameworks/asyncpg/SqlExecution.py +++ b/python/ql/test/library-tests/frameworks/asyncpg/SqlExecution.py @@ -20,7 +20,7 @@ async def test_prepared_statement(): conn = await asyncpg.connect() try: - pstmt = await conn.prepare("psql") + pstmt = await conn.prepare("psql") # $ constructedSql="psql" pstmt.executemany() # $ getSql="psql" pstmt.fetch() # $ getSql="psql" pstmt.fetchrow() # $ getSql="psql" @@ -36,20 +36,20 @@ async def test_cursor(): try: async with conn.transaction(): - cursor = await conn.cursor("sql") # $ getSql="sql" + cursor = await conn.cursor("sql") # $ getSql="sql" constructedSql="sql" await cursor.fetch() - pstmt = await conn.prepare("psql") + pstmt = await conn.prepare("psql") # $ constructedSql="psql" pcursor = await pstmt.cursor() # $ getSql="psql" await pcursor.fetch() - async for record in conn.cursor("sql"): # $ getSql="sql" + async for record in conn.cursor("sql"): # $ getSql="sql" constructedSql="sql" pass async for record in pstmt.cursor(): # $ getSql="psql" pass - cursor_factory = conn.cursor("sql") + cursor_factory = conn.cursor("sql") # $ constructedSql="sql" cursor = await cursor_factory # $ getSql="sql" pcursor_factory = pstmt.cursor()