Python: Add test for async taint

(which we belive we have just broken)
This commit is contained in:
Rasmus Lerchedahl Petersen
2021-10-27 15:33:54 +02:00
parent 06586a13a3
commit cca675a161
4 changed files with 70 additions and 1 deletions

View File

@@ -13,7 +13,7 @@ async def bar():
async def test_async_with():
async with pkg.async_func() as result: # $ use=moduleImport("pkg").getMember("async_func").getReturn().getAwaited() awaited=moduleImport("pkg").getMember("async_func").getReturn()
return result # $ use=moduleImport("pkg").getMember("async_func").getReturn() awaited=moduleImport("pkg").getMember("async_func").getReturn()
return result # $ awaited=moduleImport("pkg").getMember("async_func").getReturn()
async def test_async_for():
async for _ in pkg.async_func(): # $ use=moduleImport("pkg").getMember("async_func").getReturn() awaited=moduleImport("pkg").getMember("async_func").getReturn()

View File

@@ -7,9 +7,16 @@ class TestTaintTrackingConfiguration extends TaintTracking::Configuration {
TestTaintTrackingConfiguration() { this = "TestTaintTrackingConfiguration" }
override predicate isSource(DataFlow::Node source) {
// Standard sources
source.(DataFlow::CfgNode).getNode().(NameNode).getId() in [
"TAINTED_STRING", "TAINTED_BYTES", "TAINTED_LIST", "TAINTED_DICT"
]
or
// User defined sources
exists(CallNode call |
call.getFunction().(NameNode).getId() = "taint" and
source.(DataFlow::CfgNode).getNode() = call.getAnArg()
)
}
override predicate isSink(DataFlow::Node sink) {

View File

@@ -0,0 +1,57 @@
# Add taintlib to PATH so it can be imported during runtime without any hassle
import sys; import os; sys.path.append(os.path.dirname(os.path.dirname((__file__))))
from taintlib import *
# This has no runtime impact, but allows autocomplete to work
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..taintlib import *
# Actual tests
async def tainted_coro():
return TAINTED_STRING
async def test_await():
coro = tainted_coro()
taint(coro)
s = await coro
ensure_tainted(coro, s) # $ tainted
class AsyncContext:
async def __aenter__(self):
return TAINTED_STRING
async def __aexit__(self, exc_type, exc, tb):
pass
async def test_async_with():
ctx = AsyncContext()
taint(ctx)
async with ctx as tainted:
ensure_tainted(tainted) # $ MISSING: tainted
class AsyncIter:
def __aiter__(self):
return self
async def __anext__(self):
raise StopAsyncIteration
async def test_async_for():
iter = AsyncIter()
taint(iter)
async for tainted in iter:
ensure_tainted(tainted) # $ MISSING: tainted
# Make tests runable
import asyncio
asyncio.run(test_await())
asyncio.run(test_async_with())
asyncio.run(test_async_for())

View File

@@ -5,6 +5,11 @@ TAINTED_DICT = {"name": TAINTED_STRING, "some key": "foo"}
NOT_TAINTED = "NOT_TAINTED"
# Use this to force expressions to be tainted
def taint(*args):
pass
def ensure_tainted(*args):
print("- ensure_tainted")
for i, arg in enumerate(args):