Python: adjust logic and add tests

Due to the way paths a re printed, the tests look surprising
This commit is contained in:
Rasmus Lerchedahl Petersen
2021-10-26 09:55:47 +02:00
parent 149b235c7a
commit a8a181a32f
3 changed files with 26 additions and 11 deletions

View File

@@ -502,20 +502,23 @@ module API {
// - `awaitedValue` is `l`
// - `result` is `l` (should perhaps be `x`, but that should really be a read)
exists(AsyncFor asyncFor |
result.asExpr() = asyncFor.getTarget() and
// Morally, we should perhaps use asyncFor.getIter() = awaitedValue.asExpr(),
// but that is actually behind a read step rather than a flow step.
asyncFor.getTarget() = awaitedValue.asExpr()
result.asExpr() = asyncFor.getIter() and
// To consider `x` the result of awaiting, we would use asyncFor.getTarget() = awaitedValue.asExpr(),
// but that is behind a read step rather than a flow step.
asyncFor.getIter() = awaitedValue.asExpr()
)
or
// `async with x as y`
// - `awaitedValue` is `x`
// - `result` is `x` (should probably be `y` but it might not exist)
// - `result` is `x` and `y` if it exists
exists(AsyncWith asyncWith |
result.asExpr() = asyncWith.getContextExpr() and
// Morally, we should perhaps use asyncWith.getOptionalVars() = awaitedValue.asExpr(),
// but that might not exist.
asyncWith.getContextExpr() = awaitedValue.asExpr()
awaitedValue.asExpr() in [
// `x`
asyncWith.getContextExpr(),
// `y`, if it exists
asyncWith.getOptionalVars()
]
)
}

View File

@@ -83,17 +83,17 @@ private module Asyncpg {
// - `awaitedValue` is local source of `l`
// - `result` is `l`
exists(AsyncFor asyncFor, DataFlow::Node awaited |
result.asExpr() = asyncFor.getTarget() and
result.asExpr() = asyncFor.getIter() and
asyncFor.getIter() = awaited.asExpr() and
awaited.getALocalSource() = awaitedValue
)
or
// `async with x as y`
// - `awaitedValue` is local source of `x`
// - `result` is `x`
// - `result` is `x` and `y`
exists(AsyncWith asyncWith, DataFlow::Node awaited |
result.asExpr() = asyncWith.getContextExpr() and
asyncWith.getOptionalVars() = awaited.asExpr() and
awaited.asExpr() in [asyncWith.getContextExpr(), asyncWith.getOptionalVars()] and
awaited.getALocalSource() = awaitedValue
)
}

View File

@@ -11,6 +11,18 @@ async def bar():
result = await pkg.async_func() # $ use=moduleImport("pkg").getMember("async_func").getReturn().getAwaited()
return result # $ use=moduleImport("pkg").getMember("async_func").getReturn().getAwaited()
async def test_async_with():
async with pkg.async_func() as result: # $ use=moduleImport("pkg").getMember("async_func").getReturn()
return result # $ use=moduleImport("pkg").getMember("async_func").getReturn()
async def test_async_for():
async for _ in pkg.async_func(): # $ use=moduleImport("pkg").getMember("async_func").getReturn()
pass
coro = pkg.async_func() # $ use=moduleImport("pkg").getMember("async_func").getReturn()
async for _ in coro: # $ use=moduleImport("pkg").getMember("async_func").getReturn()
pass
def check_annotations():
# Just to make sure how annotations should look like :)
result = pkg.sync_func() # $ use=moduleImport("pkg").getMember("sync_func").getReturn()