Files
codeql/misc/codegen/loaders/schemaloader.py
2024-09-20 12:47:45 +02:00

175 lines
6.2 KiB
Python

""" schema loader """
import sys
import inflection
import typing
import types
import pathlib
import importlib.util
from dataclasses import dataclass
from toposort import toposort_flatten
from misc.codegen.lib import schema, schemadefs
@dataclass
class _PropertyNamer(schema.PropertyModifier):
name: str
def modify(self, prop: schema.Property):
prop.name = self.name.rstrip("_")
def _get_name(x: typing.Optional[typing.Union[str, type]]):
if x is None:
return None
if isinstance(x, str):
return x
return x.__name__
def _get_class(cls: type) -> schema.Class:
if not isinstance(cls, type):
raise schema.Error(f"Only class definitions allowed in schema, found {cls}")
# we must check that going to dbscheme names and back is preserved
# In particular this will not happen if uppercase acronyms are included in the name
to_underscore_and_back = inflection.camelize(inflection.underscore(cls.__name__), uppercase_first_letter=True)
if cls.__name__ != to_underscore_and_back:
raise schema.Error(f"Class name must be upper camel-case, without capitalized acronyms, found {cls.__name__} "
f"instead of {to_underscore_and_back}")
if len({g for g in (getattr(b, f"{schema.inheritable_pragma_prefix}group", None)
for b in cls.__bases__) if g}) > 1:
raise schema.Error(f"Bases with mixed groups for {cls.__name__}")
pragmas = {
# dir and getattr inherit from bases
a[len(schema.inheritable_pragma_prefix):]: getattr(cls, a)
for a in dir(cls) if a.startswith(schema.inheritable_pragma_prefix)
}
pragmas |= cls.__dict__.get("_pragmas", {})
derived = {d.__name__ for d in cls.__subclasses__()}
if "null" in pragmas and derived:
raise schema.Error(f"Null class cannot be derived")
return schema.Class(name=cls.__name__,
bases=[b.__name__ for b in cls.__bases__ if b is not object],
derived=derived,
pragmas=pragmas,
# in the following we don't use `getattr` to avoid inheriting
properties=[
a | _PropertyNamer(n)
for n, a in cls.__dict__.get("__annotations__", {}).items()
],
doc=schema.split_doc(cls.__doc__),
)
def _toposort_classes_by_group(classes: typing.Dict[str, schema.Class]) -> typing.Dict[str, schema.Class]:
groups = {}
ret = {}
for name, cls in classes.items():
groups.setdefault(cls.group, []).append(name)
for group, grouped in sorted(groups.items()):
inheritance = {name: classes[name].bases for name in grouped}
for name in toposort_flatten(inheritance):
ret[name] = classes[name]
return ret
def _fill_synth_information(classes: typing.Dict[str, schema.Class]):
""" Take a dictionary where the `synth` field is filled for all explicitly synthesized classes
and update it so that all non-final classes that have only synthesized final descendants
get `True` as` value for the `synth` field
"""
if not classes:
return
is_synth: typing.Dict[str, bool] = {}
def fill_is_synth(name: str):
if name not in is_synth:
cls = classes[name]
for d in cls.derived:
fill_is_synth(d)
if cls.synth is not None:
is_synth[name] = True
elif not cls.derived:
is_synth[name] = False
else:
is_synth[name] = all(is_synth[d] for d in cls.derived)
root = next(iter(classes))
fill_is_synth(root)
for name, cls in classes.items():
if is_synth[name]:
cls.mark_synth()
def _fill_hideable_information(classes: typing.Dict[str, schema.Class]):
""" Update the class map propagating the `hideable` attribute upwards in the hierarchy """
todo = [cls for cls in classes.values() if "ql_hideable" in cls.pragmas]
while todo:
cls = todo.pop()
for base in cls.bases:
supercls = classes[base]
if "ql_hideable" not in supercls.pragmas:
supercls.pragmas["ql_hideable"] = None
todo.append(supercls)
def _check_test_with(classes: typing.Dict[str, schema.Class]):
for cls in classes.values():
test_with = typing.cast(str, cls.pragmas.get("qltest_test_with"))
transitive_test_with = test_with and classes[test_with].pragmas.get("qltest_test_with")
if test_with and transitive_test_with:
raise schema.Error(f"{cls.name} has test_with {test_with} which in turn "
f"has test_with {transitive_test_with}, use that directly")
def load(m: types.ModuleType) -> schema.Schema:
includes = set()
classes = {}
known = {"int", "string", "boolean"}
known.update(n for n in m.__dict__ if not n.startswith("__"))
import misc.codegen.lib.schemadefs as defs
null = None
for name, data in m.__dict__.items():
if hasattr(defs, name):
continue
if name == "includes":
includes = data
continue
if name.startswith("__") or name == "_":
continue
if isinstance(data, types.ModuleType):
continue
cls = _get_class(data)
if classes and not cls.bases:
raise schema.Error(
f"Only one root class allowed, found second root {name}")
cls.check_types(known)
classes[name] = cls
if "null" in cls.pragmas:
del cls.pragmas["null"]
if null is not None:
raise schema.Error(f"Null class {null} already defined, second null class {name} not allowed")
null = name
_fill_synth_information(classes)
_fill_hideable_information(classes)
_check_test_with(classes)
return schema.Schema(includes=includes, classes=_toposort_classes_by_group(classes), null=null)
def load_file(path: pathlib.Path) -> schema.Schema:
assert path.suffix in ("", ".py")
sys.path.insert(0, str(path.parent))
try:
module = importlib.import_module(path.with_suffix("").name)
finally:
sys.path.remove(str(path.parent))
return load(module)