Files
codeql/misc/codegen/loaders/schemaloader.py
Paolo Tranquilli 6a540d833e Merge pull request #17523 from github/redsun82/rust-break-up-schema
Codegen/Rust: allow breaking up schema file
2024-09-19 16:57:58 +02:00

173 lines
6.4 KiB
Python

""" schema loader """
import sys
import inflection
import typing
import types
import pathlib
import importlib.util
from dataclasses import dataclass
from toposort import toposort_flatten
from misc.codegen.lib import schema, schemadefs
@dataclass
class _PropertyNamer(schema.PropertyModifier):
name: str
def modify(self, prop: schema.Property):
prop.name = self.name.rstrip("_")
def _get_name(x: typing.Optional[typing.Union[str, type]]):
if x is None:
return None
if isinstance(x, str):
return x
return x.__name__
def _get_class(cls: type) -> schema.Class:
if not isinstance(cls, type):
raise schema.Error(f"Only class definitions allowed in schema, found {cls}")
# we must check that going to dbscheme names and back is preserved
# In particular this will not happen if uppercase acronyms are included in the name
to_underscore_and_back = inflection.camelize(inflection.underscore(cls.__name__), uppercase_first_letter=True)
if cls.__name__ != to_underscore_and_back:
raise schema.Error(f"Class name must be upper camel-case, without capitalized acronyms, found {cls.__name__} "
f"instead of {to_underscore_and_back}")
if len({b._group for b in cls.__bases__ if hasattr(b, "_group")}) > 1:
raise schema.Error(f"Bases with mixed groups for {cls.__name__}")
if any(getattr(b, "_null", False) for b in cls.__bases__):
raise schema.Error(f"Null class cannot be derived")
return schema.Class(name=cls.__name__,
bases=[b.__name__ for b in cls.__bases__ if b is not object],
derived={d.__name__ for d in cls.__subclasses__()},
# getattr to inherit from bases
group=getattr(cls, "_group", ""),
hideable=getattr(cls, "_hideable", False),
test_with=_get_name(getattr(cls, "_test_with", None)),
# in the following we don't use `getattr` to avoid inheriting
pragmas=cls.__dict__.get("_pragmas", []),
synth=cls.__dict__.get("_synth", None),
properties=[
a | _PropertyNamer(n)
for n, a in cls.__dict__.get("__annotations__", {}).items()
],
doc=schema.split_doc(cls.__doc__),
default_doc_name=cls.__dict__.get("_doc_name"),
rust_doc_test_function=cls.__dict__.get("_rust_doc_test_function",
schema.Class.rust_doc_test_function)
)
def _toposort_classes_by_group(classes: typing.Dict[str, schema.Class]) -> typing.Dict[str, schema.Class]:
groups = {}
ret = {}
for name, cls in classes.items():
groups.setdefault(cls.group, []).append(name)
for group, grouped in sorted(groups.items()):
inheritance = {name: classes[name].bases for name in grouped}
for name in toposort_flatten(inheritance):
ret[name] = classes[name]
return ret
def _fill_synth_information(classes: typing.Dict[str, schema.Class]):
""" Take a dictionary where the `synth` field is filled for all explicitly synthesized classes
and update it so that all non-final classes that have only synthesized final descendants
get `True` as` value for the `synth` field
"""
if not classes:
return
is_synth: typing.Dict[str, bool] = {}
def fill_is_synth(name: str):
if name not in is_synth:
cls = classes[name]
for d in cls.derived:
fill_is_synth(d)
if cls.synth is not None:
is_synth[name] = True
elif not cls.derived:
is_synth[name] = False
else:
is_synth[name] = all(is_synth[d] for d in cls.derived)
root = next(iter(classes))
fill_is_synth(root)
for name, cls in classes.items():
if cls.synth is None and is_synth[name]:
cls.synth = True
def _fill_hideable_information(classes: typing.Dict[str, schema.Class]):
""" Update the class map propagating the `hideable` attribute upwards in the hierarchy """
todo = [cls for cls in classes.values() if cls.hideable]
while todo:
cls = todo.pop()
for base in cls.bases:
supercls = classes[base]
if not supercls.hideable:
supercls.hideable = True
todo.append(supercls)
def _check_test_with(classes: typing.Dict[str, schema.Class]):
for cls in classes.values():
if cls.test_with is not None and classes[cls.test_with].test_with is not None:
raise schema.Error(f"{cls.name} has test_with {cls.test_with} which in turn "
f"has test_with {classes[cls.test_with].test_with}, use that directly")
def load(m: types.ModuleType) -> schema.Schema:
includes = []
classes = {}
known = {"int", "string", "boolean"}
known.update(n for n in m.__dict__ if not n.startswith("__"))
import misc.codegen.lib.schemadefs as defs
null = None
for name, data in m.__dict__.items():
if hasattr(defs, name):
continue
if name == "__includes":
includes = data
continue
if name.startswith("__") or name == "_":
continue
if isinstance(data, types.ModuleType):
continue
cls = _get_class(data)
if classes and not cls.bases:
raise schema.Error(
f"Only one root class allowed, found second root {name}")
cls.check_types(known)
classes[name] = cls
if getattr(data, "_null", False):
if null is not None:
raise schema.Error(f"Null class {null} already defined, second null class {name} not allowed")
null = name
cls.is_null_class = True
_fill_synth_information(classes)
_fill_hideable_information(classes)
_check_test_with(classes)
return schema.Schema(includes=includes, classes=_toposort_classes_by_group(classes), null=null)
def load_file(path: pathlib.Path) -> schema.Schema:
assert path.suffix in ("", ".py")
sys.path.insert(0, str(path.parent))
try:
module = importlib.import_module(path.with_suffix("").name)
finally:
sys.path.remove(str(path.parent))
return load(module)