Rust: make File usable in codegen

This commit is contained in:
Paolo Tranquilli
2024-12-02 15:15:46 +01:00
parent 7e0e5a3f4e
commit b57a37479b
40 changed files with 364 additions and 142 deletions

View File

@@ -110,7 +110,8 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a
def get_declarations(data: schema.Schema):
add_or_none_except = data.root_class.name if data.null else None
declarations = [d for cls in data.classes.values() for d in cls_to_dbscheme(cls, data.classes, add_or_none_except)]
declarations = [d for cls in data.classes.values() if not cls.imported for d in cls_to_dbscheme(cls,
data.classes, add_or_none_except)]
if data.null:
property_classes = {
prop.type for cls in data.classes.values() for prop in cls.properties

View File

@@ -104,8 +104,17 @@ def _get_doc(cls: schema.Class, prop: schema.Property, plural=None):
return f"{prop_name} of this {class_name}"
def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dict[str, schema.Class],
def _type_is_hideable(t: str, lookup: typing.Dict[str, schema.ClassBase]) -> bool:
if t in lookup:
match lookup[t]:
case schema.Class() as cls:
return "ql_hideable" in cls.pragmas
return False
def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dict[str, schema.ClassBase],
prev_child: str = "") -> ql.Property:
args = dict(
type=prop.type if not prop.is_predicate else "predicate",
qltest_skip="qltest_skip" in prop.pragmas,
@@ -115,7 +124,8 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
is_unordered=prop.is_unordered,
description=prop.description,
synth=bool(cls.synth) or prop.synth,
type_is_hideable="ql_hideable" in lookup[prop.type].pragmas if prop.type in lookup else False,
type_is_hideable=_type_is_hideable(prop.type, lookup),
type_is_codegen_class=prop.type in lookup and not lookup[prop.type].imported,
internal="ql_internal" in prop.pragmas,
)
ql_name = prop.pragmas.get("ql_name", prop.name)
@@ -154,7 +164,7 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
return ql.Property(**args)
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.Class]) -> ql.Class:
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.ClassBase]) -> ql.Class:
if "ql_name" in cls.pragmas:
raise Error("ql_name is not supported yet for classes, only for properties")
prev_child = ""
@@ -391,14 +401,15 @@ def generate(opts, renderer):
data = schemaloader.load_file(input)
classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items()}
classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items() if not cls.imported}
if not classes:
raise NoClasses
root = next(iter(classes.values()))
if root.has_children:
raise RootElementHasChildren(root)
imports = {}
pre_imports = {n: cls.module for n, cls in data.classes.items() if cls.imported}
imports = dict(pre_imports)
imports_impl = {}
classes_used_by = {}
cfg_classes = []
@@ -410,7 +421,7 @@ def generate(opts, renderer):
force=opts.force) as renderer:
db_classes = [cls for name, cls in classes.items() if not data.classes[name].synth]
renderer.render(ql.DbClasses(db_classes), out / "Raw.qll")
renderer.render(ql.DbClasses(classes=db_classes, imports=sorted(set(pre_imports.values()))), out / "Raw.qll")
classes_by_dir_and_name = sorted(classes.values(), key=lambda cls: (cls.dir, cls.name))
for c in classes_by_dir_and_name:
@@ -439,6 +450,8 @@ def generate(opts, renderer):
renderer.render(cfg_classes_val, cfg_qll)
for c in data.classes.values():
if c.imported:
continue
path = _get_path(c)
path_impl = _get_path_impl(c)
stub_file = stub_out / path_impl
@@ -457,7 +470,7 @@ def generate(opts, renderer):
renderer.render(class_public, class_public_file)
# for example path/to/elements -> path/to/elements.qll
renderer.render(ql.ImportList([i for name, i in imports.items() if not classes[name].internal]),
renderer.render(ql.ImportList([i for name, i in imports.items() if name not in classes or not classes[name].internal]),
include_file)
elements_module = get_import(include_file, opts.root_dir)
@@ -465,12 +478,15 @@ def generate(opts, renderer):
renderer.render(
ql.GetParentImplementation(
classes=list(classes.values()),
imports=[elements_module] + [i for name, i in imports.items() if classes[name].internal],
imports=[elements_module] + [i for name,
i in imports.items() if name in classes and classes[name].internal],
),
out / 'ParentChild.qll')
if test_out:
for c in data.classes.values():
if c.imported:
continue
if should_skip_qltest(c, data.classes):
continue
test_with_name = c.pragmas.get("qltest_test_with")
@@ -500,7 +516,8 @@ def generate(opts, renderer):
constructor_imports = []
synth_constructor_imports = []
stubs = {}
for cls in sorted(data.classes.values(), key=lambda cls: (cls.group, cls.name)):
for cls in sorted((cls for cls in data.classes.values() if not cls.imported),
key=lambda cls: (cls.group, cls.name)):
synth_type = get_ql_synth_class(cls)
if synth_type.is_final:
final_synth_types.append(synth_type)

View File

@@ -49,7 +49,7 @@ def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
def _get_properties(
cls: schema.Class, lookup: dict[str, schema.Class],
cls: schema.Class, lookup: dict[str, schema.ClassBase],
) -> typing.Iterable[tuple[schema.Class, schema.Property]]:
for b in cls.bases:
yield from _get_properties(lookup[b], lookup)
@@ -58,12 +58,14 @@ def _get_properties(
def _get_ancestors(
cls: schema.Class, lookup: dict[str, schema.Class]
cls: schema.Class, lookup: dict[str, schema.ClassBase]
) -> typing.Iterable[schema.Class]:
for b in cls.bases:
base = lookup[b]
yield base
yield from _get_ancestors(base, lookup)
if not base.imported:
base = typing.cast(schema.Class, base)
yield base
yield from _get_ancestors(base, lookup)
class Processor:
@@ -71,7 +73,7 @@ class Processor:
self._classmap = data.classes
def _get_class(self, name: str) -> rust.Class:
cls = self._classmap[name]
cls = typing.cast(schema.Class, self._classmap[name])
properties = [
(c, p)
for c, p in _get_properties(cls, self._classmap)
@@ -101,8 +103,10 @@ class Processor:
def get_classes(self):
ret = {"": []}
for k, cls in self._classmap.items():
if not cls.synth:
if not cls.imported and not cls.synth:
ret.setdefault(cls.group, []).append(self._get_class(cls.name))
elif cls.imported:
ret[""].append(rust.Class(name=cls.name))
return ret

View File

@@ -56,6 +56,8 @@ def generate(opts, renderer):
registry=opts.ql_test_output / ".generated_tests.list",
force=opts.force) as renderer:
for cls in schema.classes.values():
if cls.imported:
continue
if (qlgen.should_skip_qltest(cls, schema.classes) or
"rust_skip_doc_test" in cls.pragmas):
continue

View File

@@ -44,6 +44,7 @@ class Property:
doc_plural: Optional[str] = None
synth: bool = False
type_is_hideable: bool = False
type_is_codegen_class: bool = False
internal: bool = False
cfg: bool = False
@@ -66,10 +67,6 @@ class Property:
article = "An" if self.singular[0] in "AEIO" else "A"
return f"get{article}{self.singular}"
@property
def type_is_class(self):
return bool(self.type) and self.type[0].isupper()
@property
def is_repeated(self):
return bool(self.plural)
@@ -191,6 +188,7 @@ class DbClasses:
template: ClassVar = 'ql_db'
classes: List[Class] = field(default_factory=list)
imports: List[str] = field(default_factory=list)
@dataclass

View File

@@ -3,7 +3,7 @@ import abc
import typing
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import List, Set, Union, Dict, Optional
from typing import List, Set, Union, Dict, Optional, FrozenSet
from enum import Enum, auto
import functools
@@ -87,8 +87,22 @@ class SynthInfo:
@dataclass
class Class:
class ClassBase:
imported: typing.ClassVar[bool]
name: str
@dataclass
class ImportedClass(ClassBase):
imported: typing.ClassVar[bool] = True
module: str
@dataclass
class Class(ClassBase):
imported: typing.ClassVar[bool] = False
bases: List[str] = field(default_factory=list)
derived: Set[str] = field(default_factory=set)
properties: List[Property] = field(default_factory=list)
@@ -133,7 +147,7 @@ class Class:
@dataclass
class Schema:
classes: Dict[str, Class] = field(default_factory=dict)
classes: Dict[str, ClassBase] = field(default_factory=dict)
includes: List[str] = field(default_factory=list)
null: Optional[str] = None
@@ -155,7 +169,7 @@ class Schema:
predicate_marker = object()
TypeRef = Union[type, str]
TypeRef = type | str | ImportedClass
def get_type_name(arg: TypeRef) -> str:
@@ -164,6 +178,8 @@ def get_type_name(arg: TypeRef) -> str:
return arg.__name__
case str():
return arg
case ImportedClass():
return arg.name
case _:
raise Error(f"Not a schema type or string ({arg})")
@@ -172,9 +188,9 @@ def _make_property(arg: object) -> Property:
match arg:
case _ if arg is predicate_marker:
return PredicateProperty()
case str() | type():
case (str() | type() | ImportedClass()) as arg:
return SingleProperty(type=get_type_name(arg))
case Property():
case Property() as arg:
return arg
case _:
raise Error(f"Illegal property specifier {arg}")

View File

@@ -8,8 +8,6 @@ from misc.codegen.lib import schema as _schema
import inspect as _inspect
from dataclasses import dataclass as _dataclass
from misc.codegen.lib.schema import Property
_set = set
@@ -69,6 +67,9 @@ def include(source: str):
_inspect.currentframe().f_back.f_locals.setdefault("includes", []).append(source)
imported = _schema.ImportedClass
@_dataclass
class _Namespace:
""" simple namespacing mechanism """
@@ -264,7 +265,7 @@ class _PropertyModifierList(_schema.PropertyModifier):
def __or__(self, other: _schema.PropertyModifier):
return _PropertyModifierList(self._mods + (other,))
def modify(self, prop: Property):
def modify(self, prop: _schema.Property):
for m in self._mods:
m.modify(prop)

View File

@@ -132,6 +132,7 @@ def _check_test_with(classes: typing.Dict[str, schema.Class]):
def load(m: types.ModuleType) -> schema.Schema:
includes = set()
classes = {}
imported_classes = {}
known = {"int", "string", "boolean"}
known.update(n for n in m.__dict__ if not n.startswith("__"))
import misc.codegen.lib.schemadefs as defs
@@ -146,6 +147,9 @@ def load(m: types.ModuleType) -> schema.Schema:
continue
if isinstance(data, types.ModuleType):
continue
if isinstance(data, schema.ImportedClass):
imported_classes[name] = data
continue
cls = _get_class(data)
if classes and not cls.bases:
raise schema.Error(
@@ -162,7 +166,7 @@ def load(m: types.ModuleType) -> schema.Schema:
_fill_hideable_information(classes)
_check_test_with(classes)
return schema.Schema(includes=includes, classes=_toposort_classes_by_group(classes), null=null)
return schema.Schema(includes=includes, classes=imported_classes | _toposort_classes_by_group(classes), null=null)
def load_file(path: pathlib.Path) -> schema.Schema:

View File

@@ -113,7 +113,7 @@ module Generated {
*/
{{type}} {{getter}}({{#is_indexed}}int index{{/is_indexed}}) {
{{^synth}}
{{^is_predicate}}result = {{/is_predicate}}{{#type_is_class}}Synth::convert{{type}}FromRaw({{/type_is_class}}Synth::convert{{name}}ToRaw(this){{^root}}.(Raw::{{name}}){{/root}}.{{getter}}({{#is_indexed}}index{{/is_indexed}}){{#type_is_class}}){{/type_is_class}}
{{^is_predicate}}result = {{/is_predicate}}{{#type_is_codegen_class}}Synth::convert{{type}}FromRaw({{/type_is_codegen_class}}Synth::convert{{name}}ToRaw(this){{^root}}.(Raw::{{name}}){{/root}}.{{getter}}({{#is_indexed}}index{{/is_indexed}}){{#type_is_codegen_class}}){{/type_is_codegen_class}}
{{/synth}}
{{#synth}}
none()

View File

@@ -3,6 +3,10 @@
* This module holds thin fully generated class definitions around DB entities.
*/
module Raw {
{{#imports}}
private import {{.}}
{{/imports}}
{{#classes}}
/**
* INTERNAL: Do not use.

View File

@@ -12,21 +12,6 @@ def test_property_has_first_table_param_marked():
assert [p.param for p in prop.tableparams] == tableparams
@pytest.mark.parametrize("type,expected", [
("Foo", True),
("Bar", True),
("foo", False),
("bar", False),
(None, False),
])
def test_property_is_a_class(type, expected):
tableparams = ["a", "result", "b"]
expected_tableparams = ["a", "result" if expected else "result", "b"]
prop = ql.Property("Prop", type, tableparams=tableparams)
assert prop.type_is_class is expected
assert [p.param for p in prop.tableparams] == expected_tableparams
indefinite_getters = [
("Argument", "getAnArgument"),
("Element", "getAnElement"),

View File

@@ -448,7 +448,8 @@ def test_single_class_property(generate_classes, is_child, prev_child):
ql.Property(singular="Foo", type="Bar", tablename="my_objects",
tableparams=[
"this", "result"],
prev_child=prev_child, doc="foo of this my object"),
prev_child=prev_child, doc="foo of this my object",
type_is_codegen_class=True),
],
)),
"Bar.qll": (a_ql_class_public(name="Bar"), a_ql_stub(name="Bar"), a_ql_class(name="Bar", final=True, imports=[stub_import_prefix + "Bar"])),
@@ -1006,6 +1007,7 @@ def test_hideable_property(generate_classes):
final=True, properties=[
ql.Property(singular="X", type="MyObject", tablename="others",
type_is_hideable=True,
type_is_codegen_class=True,
tableparams=["this", "result"], doc="x of this other"),
])),
}