Move swift/codegen to misc/codegen

This commit is contained in:
Paolo Tranquilli
2023-02-24 13:44:29 +01:00
parent 6d192cdcc1
commit cdd4e8021b
65 changed files with 116 additions and 86 deletions

View File

@@ -0,0 +1,11 @@
load("@codegen_deps//:requirements.bzl", "requirement")
py_library(
name = "lib",
srcs = glob(["*.py"]),
visibility = ["//misc/codegen:__subpackages__"],
deps = [
requirement("pystache"),
requirement("inflection"),
],
)

163
misc/codegen/lib/cpp.py Normal file
View File

@@ -0,0 +1,163 @@
import pathlib
import re
from dataclasses import dataclass, field
from typing import List, ClassVar
# taken from https://en.cppreference.com/w/cpp/keyword
cpp_keywords = {"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", "atomic_commit", "atomic_noexcept",
"auto", "bitand", "bitor", "bool", "break", "case", "catch", "char", "char8_t", "char16_t", "char32_t",
"class", "compl", "concept", "const", "consteval", "constexpr", "constinit", "const_cast", "continue",
"co_await", "co_return", "co_yield", "decltype", "default", "delete", "do", "double", "dynamic_cast",
"else", "enum", "explicit", "export", "extern", "false", "float", "for", "friend", "goto", "if",
"inline", "int", "long", "mutable", "namespace", "new", "noexcept", "not", "not_eq", "nullptr",
"operator", "or", "or_eq", "private", "protected", "public", "reflexpr", "register", "reinterpret_cast",
"requires", "return", "short", "signed", "sizeof", "static", "static_assert", "static_cast", "struct",
"switch", "synchronized", "template", "this", "thread_local", "throw", "true", "try", "typedef",
"typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile", "wchar_t", "while",
"xor", "xor_eq"}
_field_overrides = [
(re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"), {"base_type": "unsigned"}),
(re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}),
]
def get_field_override(field: str):
for r, o in _field_overrides:
m = r.fullmatch(field)
if m:
return o(m) if callable(o) else o
return {}
@dataclass
class Field:
field_name: str
base_type: str
is_optional: bool = False
is_repeated: bool = False
is_predicate: bool = False
trap_name: str = None
first: bool = False
def __post_init__(self):
if self.field_name in cpp_keywords:
self.field_name += "_"
@property
def type(self) -> str:
type = self.base_type
if self.is_optional:
type = f"std::optional<{type}>"
if self.is_repeated:
type = f"std::vector<{type}>"
return type
# using @property breaks pystache internals here
def get_streamer(self):
if self.type == "std::string":
return lambda x: f"trapQuoted({x})"
elif self.type == "bool":
return lambda x: f'({x} ? "true" : "false")'
else:
return lambda x: x
@property
def is_single(self):
return not (self.is_optional or self.is_repeated or self.is_predicate)
@property
def is_label(self):
return self.base_type.startswith("TrapLabel<")
@dataclass
class Trap:
table_name: str
name: str
fields: List[Field]
id: Field = None
def __post_init__(self):
assert self.fields
self.fields[0].first = True
@dataclass
class TagBase:
base: str
first: bool = False
@dataclass
class Tag:
name: str
bases: List[TagBase]
id: str
def __post_init__(self):
if self.bases:
self.bases = [TagBase(b) for b in self.bases]
self.bases[0].first = True
@property
def has_bases(self):
return bool(self.bases)
@dataclass
class TrapList:
template: ClassVar = 'trap_traps'
extensions = ["h", "cpp"]
traps: List[Trap]
source: str
trap_library_dir: pathlib.Path
gen_dir: pathlib.Path
@dataclass
class TagList:
template: ClassVar = 'trap_tags'
extensions = ["h"]
tags: List[Tag]
source: str
@dataclass
class ClassBase:
ref: 'Class'
first: bool = False
@dataclass
class Class:
name: str
bases: List[ClassBase] = field(default_factory=list)
final: bool = False
fields: List[Field] = field(default_factory=list)
trap_name: str = None
def __post_init__(self):
self.bases = [ClassBase(c) for c in sorted(self.bases, key=lambda cls: cls.name)]
if self.bases:
self.bases[0].first = True
@property
def has_bases(self):
return bool(self.bases)
@property
def single_fields(self):
return [f for f in self.fields if f.is_single]
@dataclass
class ClassList:
template: ClassVar = "cpp_classes"
extensions: ClassVar = ["h", "cpp"]
classes: List[Class]
source: str
trap_library: str
include_parent: bool = False

View File

@@ -0,0 +1,107 @@
""" dbscheme format representation """
import logging
import pathlib
import re
from dataclasses import dataclass
from typing import ClassVar, List
log = logging.getLogger(__name__)
dbscheme_keywords = {"case", "boolean", "int", "string", "type"}
@dataclass
class Column:
schema_name: str
type: str
binding: bool = False
first: bool = False
@property
def name(self):
if self.schema_name in dbscheme_keywords:
return self.schema_name + "_"
return self.schema_name
@property
def lhstype(self):
if self.type[0] == "@":
return "unique int" if self.binding else "int"
return self.type
@property
def rhstype(self):
if self.type[0] == "@" and self.binding:
return self.type
return self.type + " ref"
@dataclass
class KeySetId:
id: str
first: bool = False
@dataclass
class KeySet:
ids: List[KeySetId]
def __post_init__(self):
assert self.ids
self.ids = [KeySetId(x) for x in self.ids]
self.ids[0].first = True
class Decl:
is_table = False
is_union = False
@dataclass
class Table(Decl):
is_table: ClassVar = True
name: str
columns: List[Column]
keyset: KeySet = None
dir: pathlib.Path = None
def __post_init__(self):
if self.columns:
self.columns[0].first = True
@dataclass
class UnionCase:
type: str
first: bool = False
@dataclass
class Union(Decl):
is_union: ClassVar = True
lhs: str
rhs: List[UnionCase]
def __post_init__(self):
assert self.rhs
self.rhs = [UnionCase(x) for x in self.rhs]
self.rhs.sort(key=lambda c: c.type)
self.rhs[0].first = True
@dataclass
class SchemeInclude:
src: str
data: str
@dataclass
class Scheme:
template: ClassVar = 'dbscheme'
src: str
includes: List[SchemeInclude]
declarations: List[Decl]

19
misc/codegen/lib/paths.py Normal file
View File

@@ -0,0 +1,19 @@
""" module providing useful filesystem paths """
import pathlib
import sys
import os
_this_file = pathlib.Path(__file__).resolve()
try:
workspace_dir = pathlib.Path(os.environ['BUILD_WORKSPACE_DIRECTORY']).resolve() # <- means we are using bazel run
root_dir = workspace_dir / 'swift'
except KeyError:
root_dir = _this_file.parents[2]
workspace_dir = root_dir.parent
lib_dir = _this_file.parents[2] / 'codegen' / 'lib'
templates_dir = _this_file.parents[2] / 'codegen' / 'templates'
exe_file = pathlib.Path(sys.argv[0]).resolve()

312
misc/codegen/lib/ql.py Normal file
View File

@@ -0,0 +1,312 @@
"""
QL files generation
`generate(opts, renderer)` will generate QL classes and manage stub files out of a `yml` schema file.
Each class (for example, `Foo`) in the schema triggers:
* generation of a `FooBase` class implementation translating all properties into appropriate getters
* if not created or already customized, generation of a stub file which defines `Foo` as extending `FooBase`. This can
be used to add hand-written code to `Foo`, which requires removal of the `// generated` header comment in that file.
All generated base classes actually import these customizations when referencing other classes.
Generated files that do not correspond any more to any class in the schema are deleted. Customized stubs are however
left behind and must be dealt with by hand.
"""
import pathlib
from dataclasses import dataclass, field
import itertools
from typing import List, ClassVar, Union, Optional
import inflection
@dataclass
class Param:
param: str
first: bool = False
@dataclass
class Property:
singular: str
type: Optional[str] = None
tablename: Optional[str] = None
tableparams: List[Param] = field(default_factory=list)
plural: Optional[str] = None
first: bool = False
is_optional: bool = False
is_predicate: bool = False
prev_child: Optional[str] = None
qltest_skip: bool = False
description: List[str] = field(default_factory=list)
doc: Optional[str] = None
doc_plural: Optional[str] = None
def __post_init__(self):
if self.tableparams:
self.tableparams = [Param(x) for x in self.tableparams]
self.tableparams[0].first = True
@property
def getter(self):
return f"get{self.singular}" if not self.is_predicate else self.singular
@property
def indefinite_getter(self):
if self.plural:
article = "An" if self.singular[0] in "AEIO" else "A"
return f"get{article}{self.singular}"
@property
def type_is_class(self):
return bool(self.type) and self.type[0].isupper()
@property
def is_repeated(self):
return bool(self.plural)
@property
def is_single(self):
return not (self.is_optional or self.is_repeated or self.is_predicate)
@property
def is_child(self):
return self.prev_child is not None
@property
def has_description(self) -> bool:
return bool(self.description)
@dataclass
class Base:
base: str
prev: str = ""
def __str__(self):
return self.base
@dataclass
class Class:
template: ClassVar = 'ql_class'
name: str
bases: List[Base] = field(default_factory=list)
final: bool = False
properties: List[Property] = field(default_factory=list)
dir: pathlib.Path = pathlib.Path()
imports: List[str] = field(default_factory=list)
import_prefix: Optional[str] = None
qltest_skip: bool = False
qltest_collapse_hierarchy: bool = False
qltest_uncollapse_hierarchy: bool = False
ql_internal: bool = False
ipa: bool = False
doc: List[str] = field(default_factory=list)
def __post_init__(self):
self.bases = [Base(str(b), str(prev)) for b, prev in zip(self.bases, itertools.chain([""], self.bases))]
if self.properties:
self.properties[0].first = True
@property
def root(self) -> bool:
return not self.bases
@property
def path(self) -> pathlib.Path:
return self.dir / self.name
@property
def db_id(self) -> str:
return "@" + inflection.underscore(self.name)
@property
def has_children(self) -> bool:
return any(p.is_child for p in self.properties)
@property
def last_base(self) -> str:
return self.bases[-1].base if self.bases else ""
@property
def has_doc(self) -> bool:
return bool(self.doc) or self.ql_internal
@dataclass
class IpaUnderlyingAccessor:
argument: str
type: str
constructorparams: List[Param]
def __post_init__(self):
if self.constructorparams:
self.constructorparams = [Param(x) for x in self.constructorparams]
self.constructorparams[0].first = True
@dataclass
class Stub:
template: ClassVar = 'ql_stub'
name: str
base_import: str
import_prefix: str
ipa_accessors: List[IpaUnderlyingAccessor] = field(default_factory=list)
@property
def has_ipa_accessors(self) -> bool:
return bool(self.ipa_accessors)
@dataclass
class DbClasses:
template: ClassVar = 'ql_db'
classes: List[Class] = field(default_factory=list)
@dataclass
class ImportList:
template: ClassVar = 'ql_imports'
imports: List[str] = field(default_factory=list)
@dataclass
class GetParentImplementation:
template: ClassVar = 'ql_parent'
classes: List[Class] = field(default_factory=list)
imports: List[str] = field(default_factory=list)
@dataclass
class PropertyForTest:
getter: str
is_total: bool = True
type: Optional[str] = None
is_repeated: bool = False
@dataclass
class TesterBase:
class_name: str
elements_module: str
@dataclass
class ClassTester(TesterBase):
template: ClassVar = 'ql_test_class'
properties: List[PropertyForTest] = field(default_factory=list)
show_ql_class: bool = False
@dataclass
class PropertyTester(TesterBase):
template: ClassVar = 'ql_test_property'
property: PropertyForTest
@dataclass
class MissingTestInstructions:
template: ClassVar = 'ql_test_missing'
class Synth:
@dataclass
class Class:
is_final: ClassVar = False
name: str
first: bool = False
@dataclass
class Param:
param: str
type: str
first: bool = False
@dataclass
class FinalClass(Class):
is_final: ClassVar = True
is_derived_ipa: ClassVar = False
is_fresh_ipa: ClassVar = False
is_db: ClassVar = False
params: List["Synth.Param"] = field(default_factory=list)
def __post_init__(self):
if self.params:
self.params[0].first = True
@property
def is_ipa(self):
return self.is_fresh_ipa or self.is_derived_ipa
@property
def has_params(self) -> bool:
return bool(self.params)
@dataclass
class FinalClassIpa(FinalClass):
pass
@dataclass
class FinalClassDerivedIpa(FinalClassIpa):
is_derived_ipa: ClassVar = True
@dataclass
class FinalClassFreshIpa(FinalClassIpa):
is_fresh_ipa: ClassVar = True
@dataclass
class FinalClassDb(FinalClass):
is_db: ClassVar = True
subtracted_ipa_types: List["Synth.Class"] = field(default_factory=list)
def subtract_type(self, type: str):
self.subtracted_ipa_types.append(Synth.Class(type, first=not self.subtracted_ipa_types))
@property
def has_subtracted_ipa_types(self) -> bool:
return bool(self.subtracted_ipa_types)
@property
def db_id(self) -> str:
return "@" + inflection.underscore(self.name)
@dataclass
class NonFinalClass(Class):
derived: List["Synth.Class"] = field(default_factory=list)
root: bool = False
def __post_init__(self):
self.derived = [Synth.Class(c) for c in self.derived]
if self.derived:
self.derived[0].first = True
@dataclass
class Types:
template: ClassVar = "ql_ipa_types"
root: str
import_prefix: str
final_classes: List["Synth.FinalClass"] = field(default_factory=list)
non_final_classes: List["Synth.NonFinalClass"] = field(default_factory=list)
def __post_init__(self):
if self.final_classes:
self.final_classes[0].first = True
@dataclass
class ConstructorStub:
template: ClassVar = "ql_ipa_constructor_stub"
cls: "Synth.FinalClass"
import_prefix: str

198
misc/codegen/lib/render.py Normal file
View File

@@ -0,0 +1,198 @@
""" template renderer module, wrapping around `pystache.Renderer`
`pystache` is a python mustache engine, and mustache is a template language. More information on
https://mustache.github.io/
"""
import logging
import pathlib
import typing
import hashlib
from dataclasses import dataclass
import pystache
from . import paths
log = logging.getLogger(__name__)
class Error(Exception):
pass
class Renderer:
""" Template renderer using mustache templates in the `templates` directory """
def __init__(self, generator: pathlib.Path, root_dir: pathlib.Path):
self._r = pystache.Renderer(search_dirs=str(paths.templates_dir), escape=lambda u: u)
self._root_dir = root_dir
self._generator = generator
def _get_path(self, file: pathlib.Path):
return file.relative_to(self._root_dir)
def render(self, data: object, output: pathlib.Path):
""" Render `data` to `output`.
`data` must have a `template` attribute denoting which template to use from the template directory.
Optionally, `data` can also have an `extensions` attribute denoting list of file extensions: they will all be
appended to the template name with an underscore and be generated in turn.
"""
mnemonic = type(data).__name__
output.parent.mkdir(parents=True, exist_ok=True)
extensions = getattr(data, "extensions", [None])
for ext in extensions:
output_filename = output
template = data.template
if ext:
output_filename = output_filename.with_suffix(f".{ext}")
template += f"_{ext}"
contents = self._r.render_name(template, data, generator=self._generator)
self._do_write(mnemonic, contents, output_filename)
def _do_write(self, mnemonic: str, contents: str, output: pathlib.Path):
with open(output, "w") as out:
out.write(contents)
log.debug(f"{mnemonic}: generated {output.name}")
def manage(self, generated: typing.Iterable[pathlib.Path], stubs: typing.Iterable[pathlib.Path],
registry: pathlib.Path, force: bool = False) -> "RenderManager":
return RenderManager(self._generator, self._root_dir, generated, stubs, registry, force)
class RenderManager(Renderer):
""" A context manager allowing to manage checked in generated files and their cleanup, able
to skip unneeded writes.
This is done by using and updating a checked in list of generated files that assigns two
hashes to each file:
* one is the hash of the mustache rendered contents, that can be used to quickly check whether a
write is needed
* the other is the hash of the actual file after code generation has finished. This will be
different from the above because of post-processing like QL formatting. This hash is used
to detect invalid modification of generated files"""
written: typing.Set[pathlib.Path]
@dataclass
class Hashes:
"""
pre contains the hash of a file as rendered, post is the hash after
postprocessing (for example QL formatting)
"""
pre: str
post: typing.Optional[str] = None
def __init__(self, generator: pathlib.Path, root_dir: pathlib.Path, generated: typing.Iterable[pathlib.Path],
stubs: typing.Iterable[pathlib.Path],
registry: pathlib.Path, force: bool = False):
super().__init__(generator, root_dir)
self._registry_path = registry
self._force = force
self._hashes = {}
self.written = set()
self._existing = set()
self._skipped = set()
self._load_registry()
self._process_generated(generated)
self._process_stubs(stubs)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val is None:
for f in self._existing - self._skipped - self.written:
f.unlink(missing_ok=True)
log.info(f"removed {f.name}")
for f in self.written:
self._hashes[self._get_path(f)].post = self._hash_file(f)
else:
# if an error was encountered, drop already written files from the registry
# so that they get the chance to be regenerated again during the next run
for f in self.written:
self._hashes.pop(self._get_path(f), None)
# clean up the registry from files that do not exist any more
for f in list(self._hashes):
if not (self._root_dir / f).exists():
self._hashes.pop(f)
self._dump_registry()
def _do_write(self, mnemonic: str, contents: str, output: pathlib.Path):
hash = self._hash_string(contents)
rel_output = self._get_path(output)
if rel_output in self._hashes and self._hashes[rel_output].pre == hash:
self._skipped.add(output)
log.debug(f"{mnemonic}: skipped {output.name}")
else:
self.written.add(output)
super()._do_write(mnemonic, contents, output)
self._hashes[rel_output] = self.Hashes(pre=hash)
def _process_generated(self, generated: typing.Iterable[pathlib.Path]):
for f in generated:
self._existing.add(f)
rel_path = self._get_path(f)
if self._force:
pass
elif rel_path not in self._hashes:
log.warning(f"{rel_path} marked as generated but absent from the registry")
elif self._hashes[rel_path].post != self._hash_file(f):
raise Error(f"{rel_path} is generated but was modified, please revert the file "
"or pass --force to overwrite")
def _process_stubs(self, stubs: typing.Iterable[pathlib.Path]):
for f in stubs:
rel_path = self._get_path(f)
if self.is_customized_stub(f):
self._hashes.pop(rel_path, None)
continue
self._existing.add(f)
if self._force:
pass
elif rel_path not in self._hashes:
log.warning(f"{rel_path} marked as stub but absent from the registry")
elif self._hashes[rel_path].post != self._hash_file(f):
raise Error(f"{rel_path} is a stub marked as generated, but it was modified, "
"please remove the `// generated` header, revert the file or pass --force to overwrite it")
@staticmethod
def is_customized_stub(file: pathlib.Path) -> bool:
if not file.is_file():
return False
with open(file) as contents:
for line in contents:
return not line.startswith("// generated")
# no lines
return True
@staticmethod
def _hash_file(filename: pathlib.Path) -> str:
with open(filename) as inp:
return RenderManager._hash_string(inp.read())
@staticmethod
def _hash_string(data: str) -> str:
h = hashlib.sha256()
h.update(data.encode())
return h.hexdigest()
def _load_registry(self):
if self._force:
return
try:
with open(self._registry_path) as reg:
for line in reg:
filename, prehash, posthash = line.split()
self._hashes[pathlib.Path(filename)] = self.Hashes(prehash, posthash)
except FileNotFoundError:
pass
def _dump_registry(self):
self._registry_path.parent.mkdir(parents=True, exist_ok=True)
with open(self._registry_path, 'w') as out:
for f, hashes in sorted(self._hashes.items()):
print(f, hashes.pre, hashes.post, file=out)

194
misc/codegen/lib/schema.py Normal file
View File

@@ -0,0 +1,194 @@
""" schema format representation """
import typing
from dataclasses import dataclass, field
from typing import List, Set, Union, Dict, Optional
from enum import Enum, auto
import functools
class Error(Exception):
def __str__(self):
return self.args[0]
def _check_type(t: Optional[str], known: typing.Iterable[str]):
if t is not None and t not in known:
raise Error(f"Unknown type {t}")
@dataclass
class Property:
class Kind(Enum):
SINGLE = auto()
REPEATED = auto()
OPTIONAL = auto()
REPEATED_OPTIONAL = auto()
PREDICATE = auto()
kind: Kind
name: Optional[str] = None
type: Optional[str] = None
is_child: bool = False
pragmas: List[str] = field(default_factory=list)
doc: Optional[str] = None
description: List[str] = field(default_factory=list)
@property
def is_single(self) -> bool:
return self.kind == self.Kind.SINGLE
@property
def is_optional(self) -> bool:
return self.kind in (self.Kind.OPTIONAL, self.Kind.REPEATED_OPTIONAL)
@property
def is_repeated(self) -> bool:
return self.kind in (self.Kind.REPEATED, self.Kind.REPEATED_OPTIONAL)
@property
def is_predicate(self) -> bool:
return self.kind == self.Kind.PREDICATE
@property
def has_class_type(self) -> bool:
return bool(self.type) and self.type[0].isupper()
@property
def has_builtin_type(self) -> bool:
return bool(self.type) and self.type[0].islower()
SingleProperty = functools.partial(Property, Property.Kind.SINGLE)
OptionalProperty = functools.partial(Property, Property.Kind.OPTIONAL)
RepeatedProperty = functools.partial(Property, Property.Kind.REPEATED)
RepeatedOptionalProperty = functools.partial(
Property, Property.Kind.REPEATED_OPTIONAL)
PredicateProperty = functools.partial(Property, Property.Kind.PREDICATE)
@dataclass
class IpaInfo:
from_class: Optional[str] = None
on_arguments: Optional[Dict[str, str]] = None
@dataclass
class Class:
name: str
bases: List[str] = field(default_factory=list)
derived: Set[str] = field(default_factory=set)
properties: List[Property] = field(default_factory=list)
group: str = ""
pragmas: List[str] = field(default_factory=list)
ipa: Optional[Union[IpaInfo, bool]] = None
"""^^^ filled with `True` for non-final classes with only synthesized final descendants """
doc: List[str] = field(default_factory=list)
default_doc_name: Optional[str] = None
@property
def final(self):
return not self.derived
def check_types(self, known: typing.Iterable[str]):
for b in self.bases:
_check_type(b, known)
for d in self.derived:
_check_type(d, known)
for p in self.properties:
_check_type(p.type, known)
if self.ipa is not None:
_check_type(self.ipa.from_class, known)
if self.ipa.on_arguments is not None:
for t in self.ipa.on_arguments.values():
_check_type(t, known)
@dataclass
class Schema:
classes: Dict[str, Class] = field(default_factory=dict)
includes: Set[str] = field(default_factory=set)
null: Optional[str] = None
@property
def root_class(self):
# always the first in the dictionary
return next(iter(self.classes.values()))
@property
def null_class(self):
return self.classes[self.null] if self.null else None
predicate_marker = object()
TypeRef = Union[type, str]
@functools.singledispatch
def get_type_name(arg: TypeRef) -> str:
raise Error(f"Not a schema type or string ({arg})")
@get_type_name.register
def _(arg: type):
return arg.__name__
@get_type_name.register
def _(arg: str):
return arg
@functools.singledispatch
def _make_property(arg: object) -> Property:
if arg is predicate_marker:
return PredicateProperty()
raise Error(f"Illegal property specifier {arg}")
@_make_property.register(str)
@_make_property.register(type)
def _(arg: TypeRef):
return SingleProperty(type=get_type_name(arg))
@_make_property.register
def _(arg: Property):
return arg
class PropertyModifier:
""" Modifier of `Property` objects.
Being on the right of `|` it will trigger construction of a `Property` from
the left operand.
"""
def __ror__(self, other: object) -> Property:
ret = _make_property(other)
self.modify(ret)
return ret
def modify(self, prop: Property):
raise NotImplementedError
def split_doc(doc):
# implementation inspired from https://peps.python.org/pep-0257/
if not doc:
return []
lines = doc.splitlines()
# Determine minimum indentation (first line doesn't count):
strippedlines = (line.lstrip() for line in lines[1:])
indents = [len(line) - len(stripped) for line, stripped in zip(lines[1:], strippedlines) if stripped]
# Remove indentation (first line is special):
trimmed = [lines[0].strip()]
if indents:
indent = min(indents)
trimmed.extend(line[indent:].rstrip() for line in lines[1:])
# Strip off trailing and leading blank lines:
while trimmed and not trimmed[-1]:
trimmed.pop()
while trimmed and not trimmed[0]:
trimmed.pop(0)
return trimmed

View File

@@ -0,0 +1,149 @@
from typing import Callable as _Callable
from misc.codegen.lib import schema as _schema
import inspect as _inspect
from dataclasses import dataclass as _dataclass
class _ChildModifier(_schema.PropertyModifier):
def modify(self, prop: _schema.Property):
if prop.type is None or prop.type[0].islower():
raise _schema.Error("Non-class properties cannot be children")
prop.is_child = True
@_dataclass
class _DocModifier(_schema.PropertyModifier):
doc: str
def modify(self, prop: _schema.Property):
if "\n" in self.doc or self.doc[-1] == ".":
raise _schema.Error("No newlines or trailing dots are allowed in doc, did you intend to use desc?")
prop.doc = self.doc
@_dataclass
class _DescModifier(_schema.PropertyModifier):
description: str
def modify(self, prop: _schema.Property):
prop.description = _schema.split_doc(self.description)
def include(source: str):
# add to `includes` variable in calling context
_inspect.currentframe().f_back.f_locals.setdefault(
"__includes", []).append(source)
class _Namespace:
""" simple namespacing mechanism """
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
qltest = _Namespace()
ql = _Namespace()
cpp = _Namespace()
synth = _Namespace()
@_dataclass
class _Pragma(_schema.PropertyModifier):
""" A class or property pragma.
For properties, it functions similarly to a `_PropertyModifier` with `|`, adding the pragma.
For schema classes it acts as a python decorator with `@`.
"""
pragma: str
def __post_init__(self):
namespace, _, name = self.pragma.partition('_')
setattr(globals()[namespace], name, self)
def modify(self, prop: _schema.Property):
prop.pragmas.append(self.pragma)
def __call__(self, cls: type) -> type:
""" use this pragma as a decorator on classes """
if "_pragmas" in cls.__dict__: # not using hasattr as we don't want to land on inherited pragmas
cls._pragmas.append(self.pragma)
else:
cls._pragmas = [self.pragma]
return cls
class _Optionalizer(_schema.PropertyModifier):
def modify(self, prop: _schema.Property):
K = _schema.Property.Kind
if prop.kind != K.SINGLE:
raise _schema.Error(
"Optional should only be applied to simple property types")
prop.kind = K.OPTIONAL
class _Listifier(_schema.PropertyModifier):
def modify(self, prop: _schema.Property):
K = _schema.Property.Kind
if prop.kind == K.SINGLE:
prop.kind = K.REPEATED
elif prop.kind == K.OPTIONAL:
prop.kind = K.REPEATED_OPTIONAL
else:
raise _schema.Error(
"Repeated should only be applied to simple or optional property types")
class _TypeModifier:
""" Modifies types using get item notation """
def __init__(self, modifier: _schema.PropertyModifier):
self.modifier = modifier
def __getitem__(self, item):
return item | self.modifier
_ClassDecorator = _Callable[[type], type]
def _annotate(**kwargs) -> _ClassDecorator:
def f(cls: type) -> type:
for k, v in kwargs.items():
setattr(cls, f"_{k}", v)
return cls
return f
boolean = "boolean"
int = "int"
string = "string"
predicate = _schema.predicate_marker
optional = _TypeModifier(_Optionalizer())
list = _TypeModifier(_Listifier())
child = _ChildModifier()
doc = _DocModifier
desc = _DescModifier
use_for_null = _annotate(null=True)
_Pragma("qltest_skip")
_Pragma("qltest_collapse_hierarchy")
_Pragma("qltest_uncollapse_hierarchy")
ql.default_doc_name = lambda doc: _annotate(doc_name=doc)
_Pragma("ql_internal")
_Pragma("cpp_skip")
def group(name: str = "") -> _ClassDecorator:
return _annotate(group=name)
synth.from_class = lambda ref: _annotate(ipa=_schema.IpaInfo(
from_class=_schema.get_type_name(ref)))
synth.on_arguments = lambda **kwargs: _annotate(
ipa=_schema.IpaInfo(on_arguments={k: _schema.get_type_name(t) for k, t in kwargs.items()}))