Codegen: move qltest.test_with to parametrized pragmas

This commit is contained in:
Paolo Tranquilli
2024-09-20 12:15:10 +02:00
parent 3e2f886595
commit 8d291ab938
8 changed files with 18 additions and 17 deletions

View File

@@ -154,7 +154,6 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.Class]) -> ql.Class:
pragmas = {k: True for k in cls.pragmas if k.startswith("qltest")}
prev_child = ""
properties = []
for p in cls.properties:
@@ -172,7 +171,6 @@ def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.Class]) -> q
doc=cls.doc,
hideable=cls.hideable,
internal="ql_internal" in cls.pragmas,
**pragmas,
)
@@ -448,7 +446,8 @@ def generate(opts, renderer):
for c in data.classes.values():
if should_skip_qltest(c, data.classes):
continue
test_with = data.classes[c.test_with] if c.test_with else c
test_with_name = c.pragmas.get("qltest_test_with")
test_with = data.classes[test_with_name] if test_with_name else c
test_dir = test_out / test_with.group / test_with.name
test_dir.mkdir(parents=True, exist_ok=True)
if all(f.suffix in (".txt", ".ql", ".actual", ".expected") for f in test_dir.glob("*.*")):

View File

@@ -60,6 +60,7 @@ def generate(opts, renderer):
if fn:
indent = 4 * " "
code = [indent + l for l in code]
test_with = schema.classes[cls.test_with] if cls.test_with else cls
test_with_name = typing.cast(str, cls.pragmas.get("qltest_test_with"))
test_with = schema.classes[test_with_name] if test_with_name else cls
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{test_name}.rs"
renderer.render(TestCode(code="\n".join(code), function=fn), test)

View File

@@ -107,9 +107,6 @@ class Class:
dir: pathlib.Path = pathlib.Path()
imports: List[str] = field(default_factory=list)
import_prefix: Optional[str] = None
qltest_skip: bool = False
qltest_collapse_hierarchy: bool = False
qltest_uncollapse_hierarchy: bool = False
internal: bool = False
doc: List[str] = field(default_factory=list)
hideable: bool = False

View File

@@ -95,7 +95,6 @@ class Class:
pragmas: List[str] | Dict[str, object] = field(default_factory=dict)
doc: List[str] = field(default_factory=list)
hideable: bool = False
test_with: Optional[str] = None
def __post_init__(self):
if not isinstance(self.pragmas, dict):
@@ -118,7 +117,7 @@ class Class:
if synth.on_arguments is not None:
for t in synth.on_arguments.values():
_check_type(t, known)
_check_type(self.test_with, known)
_check_type(self.pragmas.get("qltest_test_with"), known)
@property
def synth(self) -> SynthInfo | bool | None:

View File

@@ -238,7 +238,7 @@ use_for_null = _annotate(null=True)
qltest.add(_Pragma("skip"))
qltest.add(_ClassPragma("collapse_hierarchy"))
qltest.add(_ClassPragma("uncollapse_hierarchy"))
qltest.test_with = lambda cls: _annotate(test_with=cls) # inheritable
qltest.add(_ParametrizedClassPragma("test_with", inherited=True, factory=_schema.get_type_name))
ql.add(_ParametrizedClassPragma("default_doc_name", factory=lambda doc: doc))
ql.hideable = _annotate(hideable=True) # inheritable

View File

@@ -53,7 +53,6 @@ def _get_class(cls: type) -> schema.Class:
# getattr to inherit from bases
group=getattr(cls, "_group", ""),
hideable=getattr(cls, "_hideable", False),
test_with=_get_name(getattr(cls, "_test_with", None)),
pragmas=pragmas,
# in the following we don't use `getattr` to avoid inheriting
properties=[
@@ -123,9 +122,11 @@ def _fill_hideable_information(classes: typing.Dict[str, schema.Class]):
def _check_test_with(classes: typing.Dict[str, schema.Class]):
for cls in classes.values():
if cls.test_with is not None and classes[cls.test_with].test_with is not None:
raise schema.Error(f"{cls.name} has test_with {cls.test_with} which in turn "
f"has test_with {classes[cls.test_with].test_with}, use that directly")
test_with = typing.cast(str, cls.pragmas.get("qltest_test_with"))
transitive_test_with = test_with and classes[test_with].pragmas.get("qltest_test_with")
if test_with and transitive_test_with:
raise schema.Error(f"{cls.name} has test_with {test_with} which in turn "
f"has test_with {transitive_test_with}, use that directly")
def load(m: types.ModuleType) -> schema.Schema:

View File

@@ -749,7 +749,7 @@ def test_test_with(opts, generate_tests):
write(opts.ql_test_output / "B" / "test.swift")
assert generate_tests([
schema.Class("Base", derived={"A", "B"}),
schema.Class("A", bases=["Base"], test_with="B"),
schema.Class("A", bases=["Base"], pragmas={"qltest_test_with": "B"}),
schema.Class("B", bases=["Base"]),
]) == {
"B/A.ql": a_ql_class_tester(class_name="A"),

View File

@@ -754,12 +754,16 @@ def test_test_with():
class D(Root):
pass
class E(B):
pass
assert data.classes == {
"Root": schema.Class("Root", derived=set("ABCD")),
"A": schema.Class("A", bases=["Root"]),
"B": schema.Class("B", bases=["Root"], test_with="A"),
"C": schema.Class("C", bases=["Root"], test_with="D"),
"B": schema.Class("B", bases=["Root"], pragmas={"qltest_test_with": "A"}, derived={'E'}),
"C": schema.Class("C", bases=["Root"], pragmas={"qltest_test_with": "D"}),
"D": schema.Class("D", bases=["Root"]),
"E": schema.Class("E", bases=["B"], pragmas={"qltest_test_with": "A"}),
}