add pyTorch :) code execution sinks, add proper tests

This commit is contained in:
amammad
2024-02-24 15:55:33 +04:00
parent 3d7db0e46b
commit 076faa3a4e
5 changed files with 106 additions and 0 deletions

View File

@@ -68,6 +68,7 @@ private import semmle.python.frameworks.SqlAlchemy
private import semmle.python.frameworks.Starlette
private import semmle.python.frameworks.Stdlib
private import semmle.python.frameworks.Toml
private import semmle.python.frameworks.Torch
private import semmle.python.frameworks.Tornado
private import semmle.python.frameworks.Twisted
private import semmle.python.frameworks.Ujson

View File

@@ -0,0 +1,76 @@
/**
* Provides classes modeling security-relevant aspects of the `torch` PyPI package.
* See https://pypi.org/project/torch/.
*/
private import python
private import semmle.python.Concepts
private import semmle.python.ApiGraphs
/**
* Provides models for the `torch` PyPI package.
* See https://pypi.org/project/torch/.
*/
private module Torch {
/**
* A call to `torch.load`
* See https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
*/
private class TorchLoadCall extends Decoding::Range, API::CallNode {
TorchLoadCall() { this = API::moduleImport("torch").getMember("load").getACall() }
override predicate mayExecuteInput() {
not exists(this.getParameter(2, "pickle_module").asSink()) or
exists(this.getParameter(2, "pickle_module").asSink().asExpr().(None))
}
override DataFlow::Node getAnInput() { result = this.getParameter(0, "f").asSink() }
override DataFlow::Node getOutput() { result = this }
override string getFormat() { result = "pickle" }
}
API::Node test() {
result = API::moduleImport("torch").getMember("package").getMember("PackageImporter")
}
/**
* A call to `torch.package.PackageImporter`
* See https://pytorch.org/docs/stable/package.html#torch.package.PackageImporter
*/
private class TorchPackageImporter extends Decoding::Range, API::CallNode {
TorchPackageImporter() {
this = API::moduleImport("torch").getMember("package").getMember("PackageImporter").getACall() and
exists(this.getAMethodCall("load_pickle"))
}
override predicate mayExecuteInput() { any() }
override DataFlow::Node getAnInput() {
result = this.getParameter(0, "file_or_buffer").asSink()
}
override DataFlow::Node getOutput() { result = this.getAMethodCall("load_pickle") }
override string getFormat() { result = "pickle" }
}
/**
* A call to `torch.jit.load`
* See https://pytorch.org/docs/stable/generated/torch.jit.load.html#torch.jit.load
*/
private class TorchJitLoad extends Decoding::Range, API::CallNode {
TorchJitLoad() {
this = API::moduleImport("torch").getMember("jit").getMember("load").getACall()
}
override predicate mayExecuteInput() { any() }
override DataFlow::Node getAnInput() { result = this.getParameter(0, "f").asSink() }
override DataFlow::Node getOutput() { result = this }
override string getFormat() { result = "pickle" }
}
}

View File

@@ -0,0 +1,2 @@
testFailures
failures

View File

@@ -0,0 +1,2 @@
import python
import experimental.meta.ConceptsTest

View File

@@ -0,0 +1,25 @@
from io import BytesIO
import torch
def someSafeMethod():
pass
PicklePayload = BytesIO(b"payload")
torch.load(PicklePayload) # $ decodeInput=PicklePayload decodeOutput=torch.load(..) decodeFormat=pickle decodeMayExecuteInput
torch.load(PicklePayload, pickle_module=None) # $ decodeInput=PicklePayload decodeOutput=torch.load(..) decodeFormat=pickle decodeMayExecuteInput
torch.load(PicklePayload, pickle_module=someSafeMethod()) # $ decodeInput=PicklePayload decodeOutput=torch.load(..) decodeFormat=pickle
from torch.package import PackageImporter
importer = PackageImporter(PicklePayload) # $ decodeInput=PicklePayload PackageImporter(..) decodeFormat=pickle decodeMayExecuteInput
my_tensor = importer.load_pickle("my_resources", "tensor.pkl") # $ decodeOutput=importer.load_pickle(..)
importer = PackageImporter(PicklePayload)
from torch import jit
jit.load(PicklePayload) # $ decodeInput=PicklePayload decodeOutput=jit.load(..) decodeFormat=pickle decodeMayExecuteInput