Files
codeql/misc/codegen/generators/rustgen.py
Paolo Tranquilli 14d48e9d58 Add black pre-commit hook
This switched `codegen` from the `autopep8` formatting to the `black`
one, and applies it to `bulk_mad_generator.py` as well. We can enroll
more python scripts to it in the future.
2025-06-10 12:25:39 +02:00

151 lines
4.7 KiB
Python

"""
Rust trap class generation
"""
import functools
import typing
import inflection
from misc.codegen.lib import rust, schema
from misc.codegen.loaders import schemaloader
def _get_type(t: str) -> str:
match t:
case None: # None means a predicate
return "bool"
case "string":
return "String"
case "int":
return "usize"
case _ if t[0].isupper():
return f"trap::Label<{t}>"
case "boolean":
assert False, "boolean unsupported"
case _:
return t
def _get_table_name(cls: schema.Class, p: schema.Property) -> str:
if p.is_single:
return inflection.tableize(cls.name)
overridden_table_name = p.pragmas.get("ql_db_table_name")
if overridden_table_name:
return overridden_table_name
table_name = f"{cls.name}_{p.name}"
if p.is_predicate:
return inflection.underscore(table_name)
else:
return inflection.tableize(table_name)
def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
args = dict(
field_name=rust.avoid_keywords(p.name),
base_type=_get_type(p.type),
is_optional=p.is_optional,
is_repeated=p.is_repeated,
is_predicate=p.is_predicate,
is_unordered=p.is_unordered,
table_name=_get_table_name(cls, p),
)
args.update(rust.get_field_override(p.name))
return rust.Field(**args)
def _get_properties(
cls: schema.Class,
lookup: dict[str, schema.ClassBase],
) -> typing.Iterable[tuple[schema.Class, schema.Property]]:
for b in cls.bases:
yield from _get_properties(lookup[b], lookup)
for p in cls.properties:
yield cls, p
def _get_ancestors(
cls: schema.Class, lookup: dict[str, schema.ClassBase]
) -> typing.Iterable[schema.Class]:
for b in cls.bases:
base = lookup[b]
if not base.imported:
base = typing.cast(schema.Class, base)
yield base
yield from _get_ancestors(base, lookup)
class Processor:
def __init__(self, data: schema.Schema):
self._classmap = data.classes
def _get_class(self, name: str) -> rust.Class:
cls = typing.cast(schema.Class, self._classmap[name])
properties = [
(c, p)
for c, p in _get_properties(cls, self._classmap)
if "rust_skip" not in p.pragmas and not p.synth
]
fields = []
detached_fields = []
for c, p in properties:
if "rust_detach" in p.pragmas:
# only generate detached fields in the actual class defining them, not the derived ones
if c is cls:
# TODO lift this restriction if required (requires change in dbschemegen as well)
assert (
c.derived or not p.is_single
), f"property {p.name} in concrete class marked as detached but not optional"
detached_fields.append(_get_field(c, p))
elif not cls.derived:
# for non-detached ones, only generate fields in the concrete classes
fields.append(_get_field(c, p))
return rust.Class(
name=name,
fields=fields,
detached_fields=detached_fields,
# remove duplicates but preserve ordering
# (`dict` preserves insertion order while `set` doesn't)
ancestors=[*{a.name: None for a in _get_ancestors(cls, self._classmap)}],
entry_table=inflection.tableize(cls.name) if not cls.derived else None,
)
def get_classes(self):
ret = {"": []}
for k, cls in self._classmap.items():
if not cls.imported and not cls.synth:
ret.setdefault(cls.group, []).append(self._get_class(cls.name))
elif cls.imported:
ret[""].append(rust.Class(name=cls.name))
return ret
def generate(opts, renderer):
assert opts.rust_output
processor = Processor(schemaloader.load_file(opts.schema))
out = opts.rust_output
groups = set()
with renderer.manage(
generated=out.rglob("*.rs"),
stubs=(),
registry=out / ".generated.list",
force=opts.force,
) as renderer:
for group, classes in processor.get_classes().items():
group = group or "top"
groups.add(group)
renderer.render(
rust.ClassList(
classes,
opts.schema,
),
out / f"{group}.rs",
)
renderer.render(
rust.ModuleList(
groups,
opts.schema,
),
out / f"mod.rs",
)