mirror of
https://github.com/github/codeql.git
synced 2026-04-26 09:15:12 +02:00
Codegen: allow to include .py files in schema.py
This commit is contained in:
@@ -32,9 +32,16 @@ class _DescModifier(_schema.PropertyModifier):
|
||||
|
||||
|
||||
def include(source: str):
|
||||
# add to `includes` variable in calling context
|
||||
_inspect.currentframe().f_back.f_locals.setdefault(
|
||||
"__includes", []).append(source)
|
||||
scope = _inspect.currentframe().f_back.f_locals
|
||||
if source.endswith(".dbscheme"):
|
||||
# add to `includes` variable in calling context
|
||||
scope.setdefault("__includes", []).append(source)
|
||||
elif source.endswith(".py"):
|
||||
# just load the contents
|
||||
with open(source) as input:
|
||||
exec(input.read(), scope)
|
||||
else:
|
||||
raise _schema.Error(f"Unsupported file for inclusion: {source}")
|
||||
|
||||
|
||||
class _Namespace:
|
||||
|
||||
@@ -126,7 +126,7 @@ def _check_test_with(classes: typing.Dict[str, schema.Class]):
|
||||
|
||||
|
||||
def load(m: types.ModuleType) -> schema.Schema:
|
||||
includes = set()
|
||||
includes = []
|
||||
classes = {}
|
||||
known = {"int", "string", "boolean"}
|
||||
known.update(n for n in m.__dict__ if not n.startswith("__"))
|
||||
|
||||
@@ -13,7 +13,7 @@ def test_empty_schema():
|
||||
pass
|
||||
|
||||
assert data.classes == {}
|
||||
assert data.includes == set()
|
||||
assert data.includes == []
|
||||
assert data.null is None
|
||||
assert data.null_class is None
|
||||
|
||||
@@ -805,5 +805,51 @@ def test_test_with_double():
|
||||
pass
|
||||
|
||||
|
||||
def test_include_dbscheme():
|
||||
@load
|
||||
class data:
|
||||
defs.include("foo.dbscheme")
|
||||
defs.include("bar.dbscheme")
|
||||
|
||||
assert data.includes == ["foo.dbscheme", "bar.dbscheme"]
|
||||
|
||||
|
||||
def test_include_source(tmp_path):
|
||||
(tmp_path / "foo.py").write_text("""
|
||||
class A(Root):
|
||||
pass
|
||||
""")
|
||||
(tmp_path / "bar.py").write_text("""
|
||||
class C(Root):
|
||||
pass
|
||||
""")
|
||||
|
||||
@load
|
||||
class data:
|
||||
class Root:
|
||||
pass
|
||||
|
||||
defs.include(str(tmp_path / "foo.py"))
|
||||
|
||||
class B(Root):
|
||||
pass
|
||||
|
||||
defs.include(str(tmp_path / "bar.py"))
|
||||
|
||||
assert data.classes == {
|
||||
"Root": schema.Class("Root", derived=set("ABC")),
|
||||
"A": schema.Class("A", bases=["Root"]),
|
||||
"B": schema.Class("B", bases=["Root"]),
|
||||
"C": schema.Class("C", bases=["Root"]),
|
||||
}
|
||||
|
||||
|
||||
def test_include_not_supported(tmp_path):
|
||||
with pytest.raises(schema.Error):
|
||||
@load
|
||||
class data:
|
||||
defs.include("foo.bar")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(pytest.main([__file__] + sys.argv[1:]))
|
||||
|
||||
Reference in New Issue
Block a user