Files
codeql/misc/codegen/generators/rustgen.py
Paolo Tranquilli 4110636032 Rust: preserve ordering in rust generated code
This is a small devex improvement to the rust code generator.

Usage of `sorted` in `rustgen.py` was causing the generated code to be
completely reshuffled on renames, which made diffs hard to follow. As an
example see [this generated file diff](https://github.com/github/codeql/pull/19059/files#diff-c938ba77a3398dd4c633ada5702a03477705c24740a2f7d1e40d4b270d8c3f86).

This will make the order deterministically based on the order of
definitions in the schema file. This means that renames will find the
same place in the generated file, and the place in the generated file
will generally be more predictable with respect to the schema.

However, that does mean this change is heavily reshuffling the generated
code.
2025-03-20 12:12:52 +01:00

141 lines
4.5 KiB
Python

"""
Rust trap class generation
"""
import functools
import typing
import inflection
from misc.codegen.lib import rust, schema
from misc.codegen.loaders import schemaloader
def _get_type(t: str) -> str:
match t:
case None: # None means a predicate
return "bool"
case "string":
return "String"
case "int":
return "usize"
case _ if t[0].isupper():
return f"trap::Label<{t}>"
case "boolean":
assert False, "boolean unsupported"
case _:
return t
def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
table_name = inflection.tableize(cls.name)
if not p.is_single:
table_name = f"{cls.name}_{p.name}"
if p.is_predicate:
table_name = inflection.underscore(table_name)
else:
table_name = inflection.tableize(table_name)
args = dict(
field_name=rust.avoid_keywords(p.name),
base_type=_get_type(p.type),
is_optional=p.is_optional,
is_repeated=p.is_repeated,
is_predicate=p.is_predicate,
is_unordered=p.is_unordered,
table_name=table_name,
)
args.update(rust.get_field_override(p.name))
return rust.Field(**args)
def _get_properties(
cls: schema.Class, lookup: dict[str, schema.ClassBase],
) -> typing.Iterable[tuple[schema.Class, schema.Property]]:
for b in cls.bases:
yield from _get_properties(lookup[b], lookup)
for p in cls.properties:
yield cls, p
def _get_ancestors(
cls: schema.Class, lookup: dict[str, schema.ClassBase]
) -> typing.Iterable[schema.Class]:
for b in cls.bases:
base = lookup[b]
if not base.imported:
base = typing.cast(schema.Class, base)
yield base
yield from _get_ancestors(base, lookup)
class Processor:
def __init__(self, data: schema.Schema):
self._classmap = data.classes
def _get_class(self, name: str) -> rust.Class:
cls = typing.cast(schema.Class, self._classmap[name])
properties = [
(c, p)
for c, p in _get_properties(cls, self._classmap)
if "rust_skip" not in p.pragmas and not p.synth
]
fields = []
detached_fields = []
for c, p in properties:
if "rust_detach" in p.pragmas:
# only generate detached fields in the actual class defining them, not the derived ones
if c is cls:
# TODO lift this restriction if required (requires change in dbschemegen as well)
assert c.derived or not p.is_single, \
f"property {p.name} in concrete class marked as detached but not optional"
detached_fields.append(_get_field(c, p))
elif not cls.derived:
# for non-detached ones, only generate fields in the concrete classes
fields.append(_get_field(c, p))
return rust.Class(
name=name,
fields=fields,
detached_fields=detached_fields,
# remove duplicates but preserve ordering
# (`dict` preserves insertion order while `set` doesn't)
ancestors=[*{a.name: None for a in _get_ancestors(cls, self._classmap)}],
entry_table=inflection.tableize(cls.name) if not cls.derived else None,
)
def get_classes(self):
ret = {"": []}
for k, cls in self._classmap.items():
if not cls.imported and not cls.synth:
ret.setdefault(cls.group, []).append(self._get_class(cls.name))
elif cls.imported:
ret[""].append(rust.Class(name=cls.name))
return ret
def generate(opts, renderer):
assert opts.rust_output
processor = Processor(schemaloader.load_file(opts.schema))
out = opts.rust_output
groups = set()
with renderer.manage(generated=out.rglob("*.rs"),
stubs=(),
registry=out / ".generated.list",
force=opts.force) as renderer:
for group, classes in processor.get_classes().items():
group = group or "top"
groups.add(group)
renderer.render(
rust.ClassList(
classes,
opts.schema,
),
out / f"{group}.rs",
)
renderer.render(
rust.ModuleList(
groups,
opts.schema,
),
out / f"mod.rs",
)