Codegen: allow inheritable pragmas

This commit is contained in:
Paolo Tranquilli
2024-09-20 11:46:24 +02:00
parent 1bffc2a7d7
commit 3e2f886595
3 changed files with 35 additions and 20 deletions

View File

@@ -211,3 +211,6 @@ def split_doc(doc):
while trimmed and not trimmed[0]:
trimmed.pop(0)
return trimmed
inheritable_pragma_prefix = "_inheritable_pragma_"

View File

@@ -5,6 +5,8 @@ from dataclasses import dataclass as _dataclass
from misc.codegen.lib.schema import Property
_set = set
@_dataclass
class _ChildModifier(_schema.PropertyModifier):
@@ -79,7 +81,7 @@ class _SynthModifier(_schema.PropertyModifier, _Namespace):
def modify(self, prop: _schema.Property):
prop.synth = self.synth
def negate(self) -> "PropertyModifier":
def negate(self) -> _schema.PropertyModifier:
return _SynthModifier(self.name, False)
@@ -100,14 +102,18 @@ class _ClassPragma(_PragmaBase):
""" A class pragma.
For schema classes it acts as a python decorator with `@`.
"""
inherited: bool = False
value: object = None
def __call__(self, cls: type) -> type:
""" use this pragma as a decorator on classes """
# not using hasattr as we don't want to land on inherited pragmas
if "_pragmas" not in cls.__dict__:
cls._pragmas = {}
self._apply(cls._pragmas)
if self.inherited:
setattr(cls, f"{_schema.inheritable_pragma_prefix}{self.pragma}", self.value)
else:
# not using hasattr as we don't want to land on inherited pragmas
if "_pragmas" not in cls.__dict__:
cls._pragmas = {}
self._apply(cls._pragmas)
return cls
def _apply(self, pragmas: _Dict[str, object]) -> None:
@@ -125,7 +131,7 @@ class _Pragma(_ClassPragma, _schema.PropertyModifier):
def modify(self, prop: _schema.Property):
self._apply(prop.pragmas)
def negate(self) -> "PropertyModifier":
def negate(self) -> _schema.PropertyModifier:
return _Pragma(self.pragma, remove=True)
def _apply(self, pragmas: _Dict[str, object]) -> None:
@@ -142,13 +148,14 @@ class _ParametrizedClassPragma(_PragmaBase):
"""
_pragma_class: _ClassVar[type] = _ClassPragma
function: _Callable[..., object] = None
inherited: bool = False
factory: _Callable[..., object] = None
def __post_init__(self):
self.__signature__ = _inspect.signature(self.function).replace(return_annotation=self._pragma_class)
self.__signature__ = _inspect.signature(self.factory).replace(return_annotation=self._pragma_class)
def __call__(self, *args, **kwargs) -> _pragma_class:
return self._pragma_class(self.pragma, value=self.function(*args, **kwargs))
return self._pragma_class(self.pragma, self.inherited, value=self.factory(*args, **kwargs))
@_dataclass
@@ -233,7 +240,7 @@ qltest.add(_ClassPragma("collapse_hierarchy"))
qltest.add(_ClassPragma("uncollapse_hierarchy"))
qltest.test_with = lambda cls: _annotate(test_with=cls) # inheritable
ql.add(_ParametrizedClassPragma("default_doc_name", lambda doc: doc))
ql.add(_ParametrizedClassPragma("default_doc_name", factory=lambda doc: doc))
ql.hideable = _annotate(hideable=True) # inheritable
ql.add(_Pragma("internal"))
@@ -241,16 +248,16 @@ cpp.add(_Pragma("skip"))
rust.add(_Pragma("skip_doc_test"))
rust.add(_ParametrizedClassPragma("doc_test_signature", lambda signature: signature))
rust.add(_ParametrizedClassPragma("doc_test_signature", factory=lambda signature: signature))
def group(name: str = "") -> _ClassDecorator:
return _annotate(group=name)
synth.add(_ParametrizedClassPragma("from_class", lambda ref: _schema.SynthInfo(
synth.add(_ParametrizedClassPragma("from_class", factory=lambda ref: _schema.SynthInfo(
from_class=_schema.get_type_name(ref))), key="synth")
synth.add(_ParametrizedClassPragma("on_arguments", lambda **kwargs:
synth.add(_ParametrizedClassPragma("on_arguments", factory=lambda **kwargs:
_schema.SynthInfo(on_arguments={k: _schema.get_type_name(t) for k, t in kwargs.items()})), key="synth")
@@ -288,12 +295,11 @@ def annotate(annotated_cls: type) -> _Callable[[type], _PropertyAnnotation]:
raise _schema.Error("Annotation classes must be named _")
if cls.__doc__ is not None:
annotated_cls.__doc__ = cls.__doc__
old_pragmas = getattr(annotated_cls, "_pragmas", None)
new_pragmas = getattr(cls, "_pragmas", {})
if old_pragmas:
old_pragmas.update(new_pragmas)
else:
annotated_cls._pragmas = new_pragmas
for p, v in cls.__dict__.get("_pragmas", {}).items():
_ClassPragma(p, value=v)(annotated_cls)
for a in dir(cls):
if a.startswith(_schema.inheritable_pragma_prefix):
setattr(annotated_cls, a, getattr(cls, a))
for a, v in cls.__dict__.items():
# transfer annotations
if a.startswith("_") and not a.startswith("__") and a != "_pragmas":

View File

@@ -41,6 +41,12 @@ def _get_class(cls: type) -> schema.Class:
raise schema.Error(f"Bases with mixed groups for {cls.__name__}")
if any(getattr(b, "_null", False) for b in cls.__bases__):
raise schema.Error(f"Null class cannot be derived")
pragmas = {
# dir and getattr inherit from bases
a[len(schema.inheritable_pragma_prefix):]: getattr(cls, a)
for a in dir(cls) if a.startswith(schema.inheritable_pragma_prefix)
}
pragmas |= cls.__dict__.get("_pragmas", {})
return schema.Class(name=cls.__name__,
bases=[b.__name__ for b in cls.__bases__ if b is not object],
derived={d.__name__ for d in cls.__subclasses__()},
@@ -48,8 +54,8 @@ def _get_class(cls: type) -> schema.Class:
group=getattr(cls, "_group", ""),
hideable=getattr(cls, "_hideable", False),
test_with=_get_name(getattr(cls, "_test_with", None)),
pragmas=pragmas,
# in the following we don't use `getattr` to avoid inheriting
pragmas=cls.__dict__.get("_pragmas", {}),
properties=[
a | _PropertyNamer(n)
for n, a in cls.__dict__.get("__annotations__", {}).items()