mirror of
https://github.com/github/codeql.git
synced 2025-12-16 16:53:25 +01:00
This switched `codegen` from the `autopep8` formatting to the `black` one, and applies it to `bulk_mad_generator.py` as well. We can enroll more python scripts to it in the future.
151 lines
4.7 KiB
Python
151 lines
4.7 KiB
Python
"""
|
|
Rust trap class generation
|
|
"""
|
|
|
|
import functools
|
|
import typing
|
|
|
|
import inflection
|
|
|
|
from misc.codegen.lib import rust, schema
|
|
from misc.codegen.loaders import schemaloader
|
|
|
|
|
|
def _get_type(t: str) -> str:
|
|
match t:
|
|
case None: # None means a predicate
|
|
return "bool"
|
|
case "string":
|
|
return "String"
|
|
case "int":
|
|
return "usize"
|
|
case _ if t[0].isupper():
|
|
return f"trap::Label<{t}>"
|
|
case "boolean":
|
|
assert False, "boolean unsupported"
|
|
case _:
|
|
return t
|
|
|
|
|
|
def _get_table_name(cls: schema.Class, p: schema.Property) -> str:
|
|
if p.is_single:
|
|
return inflection.tableize(cls.name)
|
|
overridden_table_name = p.pragmas.get("ql_db_table_name")
|
|
if overridden_table_name:
|
|
return overridden_table_name
|
|
table_name = f"{cls.name}_{p.name}"
|
|
if p.is_predicate:
|
|
return inflection.underscore(table_name)
|
|
else:
|
|
return inflection.tableize(table_name)
|
|
|
|
|
|
def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
|
|
args = dict(
|
|
field_name=rust.avoid_keywords(p.name),
|
|
base_type=_get_type(p.type),
|
|
is_optional=p.is_optional,
|
|
is_repeated=p.is_repeated,
|
|
is_predicate=p.is_predicate,
|
|
is_unordered=p.is_unordered,
|
|
table_name=_get_table_name(cls, p),
|
|
)
|
|
args.update(rust.get_field_override(p.name))
|
|
return rust.Field(**args)
|
|
|
|
|
|
def _get_properties(
|
|
cls: schema.Class,
|
|
lookup: dict[str, schema.ClassBase],
|
|
) -> typing.Iterable[tuple[schema.Class, schema.Property]]:
|
|
for b in cls.bases:
|
|
yield from _get_properties(lookup[b], lookup)
|
|
for p in cls.properties:
|
|
yield cls, p
|
|
|
|
|
|
def _get_ancestors(
|
|
cls: schema.Class, lookup: dict[str, schema.ClassBase]
|
|
) -> typing.Iterable[schema.Class]:
|
|
for b in cls.bases:
|
|
base = lookup[b]
|
|
if not base.imported:
|
|
base = typing.cast(schema.Class, base)
|
|
yield base
|
|
yield from _get_ancestors(base, lookup)
|
|
|
|
|
|
class Processor:
|
|
def __init__(self, data: schema.Schema):
|
|
self._classmap = data.classes
|
|
|
|
def _get_class(self, name: str) -> rust.Class:
|
|
cls = typing.cast(schema.Class, self._classmap[name])
|
|
properties = [
|
|
(c, p)
|
|
for c, p in _get_properties(cls, self._classmap)
|
|
if "rust_skip" not in p.pragmas and not p.synth
|
|
]
|
|
fields = []
|
|
detached_fields = []
|
|
for c, p in properties:
|
|
if "rust_detach" in p.pragmas:
|
|
# only generate detached fields in the actual class defining them, not the derived ones
|
|
if c is cls:
|
|
# TODO lift this restriction if required (requires change in dbschemegen as well)
|
|
assert (
|
|
c.derived or not p.is_single
|
|
), f"property {p.name} in concrete class marked as detached but not optional"
|
|
detached_fields.append(_get_field(c, p))
|
|
elif not cls.derived:
|
|
# for non-detached ones, only generate fields in the concrete classes
|
|
fields.append(_get_field(c, p))
|
|
return rust.Class(
|
|
name=name,
|
|
fields=fields,
|
|
detached_fields=detached_fields,
|
|
# remove duplicates but preserve ordering
|
|
# (`dict` preserves insertion order while `set` doesn't)
|
|
ancestors=[*{a.name: None for a in _get_ancestors(cls, self._classmap)}],
|
|
entry_table=inflection.tableize(cls.name) if not cls.derived else None,
|
|
)
|
|
|
|
def get_classes(self):
|
|
ret = {"": []}
|
|
for k, cls in self._classmap.items():
|
|
if not cls.imported and not cls.synth:
|
|
ret.setdefault(cls.group, []).append(self._get_class(cls.name))
|
|
elif cls.imported:
|
|
ret[""].append(rust.Class(name=cls.name))
|
|
return ret
|
|
|
|
|
|
def generate(opts, renderer):
|
|
assert opts.rust_output
|
|
processor = Processor(schemaloader.load_file(opts.schema))
|
|
out = opts.rust_output
|
|
groups = set()
|
|
with renderer.manage(
|
|
generated=out.rglob("*.rs"),
|
|
stubs=(),
|
|
registry=out / ".generated.list",
|
|
force=opts.force,
|
|
) as renderer:
|
|
for group, classes in processor.get_classes().items():
|
|
group = group or "top"
|
|
groups.add(group)
|
|
renderer.render(
|
|
rust.ClassList(
|
|
classes,
|
|
opts.schema,
|
|
),
|
|
out / f"{group}.rs",
|
|
)
|
|
renderer.render(
|
|
rust.ModuleList(
|
|
groups,
|
|
opts.schema,
|
|
),
|
|
out / f"mod.rs",
|
|
)
|