Add black pre-commit hook

This switched `codegen` from the `autopep8` formatting to the `black`
one, and applies it to `bulk_mad_generator.py` as well. We can enroll
more python scripts to it in the future.
This commit is contained in:
Paolo Tranquilli
2025-06-10 12:25:39 +02:00
parent 7a632e8a47
commit 14d48e9d58
28 changed files with 3479 additions and 1641 deletions

View File

@@ -1,5 +1,5 @@
#!/usr/bin/env python3
""" Driver script to run all code generation """
"""Driver script to run all code generation"""
import argparse
import logging
@@ -9,7 +9,7 @@ import pathlib
import typing
import shlex
if 'BUILD_WORKSPACE_DIRECTORY' not in os.environ:
if "BUILD_WORKSPACE_DIRECTORY" not in os.environ:
# we are not running with `bazel run`, set up module search path
_repo_root = pathlib.Path(__file__).resolve().parents[2]
sys.path.append(str(_repo_root))
@@ -29,57 +29,105 @@ def _parse_args() -> argparse.Namespace:
conf = None
p = argparse.ArgumentParser(description="Code generation suite")
p.add_argument("--generate", type=lambda x: x.split(","),
help="specify what targets to generate as a comma separated list, choosing among dbscheme, ql, "
"trap, cpp and rust")
p.add_argument("--verbose", "-v", action="store_true", help="print more information")
p.add_argument(
"--generate",
type=lambda x: x.split(","),
help="specify what targets to generate as a comma separated list, choosing among dbscheme, ql, "
"trap, cpp and rust",
)
p.add_argument(
"--verbose", "-v", action="store_true", help="print more information"
)
p.add_argument("--quiet", "-q", action="store_true", help="only print errors")
p.add_argument("--configuration-file", "-c", type=_abspath, default=conf,
help="A configuration file to load options from. By default, the first codegen.conf file found by "
"going up directories from the current location. If present all paths provided in options are "
"considered relative to its directory")
p.add_argument("--root-dir", type=_abspath,
help="the directory that should be regarded as the root of the language pack codebase. Used to "
"compute QL imports and in some comments and as root for relative paths provided as options. "
"If not provided it defaults to the directory of the configuration file, if any")
p.add_argument(
"--configuration-file",
"-c",
type=_abspath,
default=conf,
help="A configuration file to load options from. By default, the first codegen.conf file found by "
"going up directories from the current location. If present all paths provided in options are "
"considered relative to its directory",
)
p.add_argument(
"--root-dir",
type=_abspath,
help="the directory that should be regarded as the root of the language pack codebase. Used to "
"compute QL imports and in some comments and as root for relative paths provided as options. "
"If not provided it defaults to the directory of the configuration file, if any",
)
path_arguments = [
p.add_argument("--schema",
help="input schema file (default schema.py)"),
p.add_argument("--dbscheme",
help="output file for dbscheme generation, input file for trap generation"),
p.add_argument("--ql-output",
help="output directory for generated QL files"),
p.add_argument("--ql-stub-output",
help="output directory for QL stub/customization files. Defines also the "
"generated qll file importing every class file"),
p.add_argument("--ql-test-output",
help="output directory for QL generated extractor test files"),
p.add_argument("--ql-cfg-output",
help="output directory for QL CFG layer (optional)."),
p.add_argument("--cpp-output",
help="output directory for generated C++ files, required if trap or cpp is provided to "
"--generate"),
p.add_argument("--rust-output",
help="output directory for generated Rust files, required if rust is provided to "
"--generate"),
p.add_argument("--generated-registry",
help="registry file containing information about checked-in generated code. A .gitattributes"
"file is generated besides it to mark those files with linguist-generated=true. Must"
"be in a directory containing all generated code."),
p.add_argument("--schema", help="input schema file (default schema.py)"),
p.add_argument(
"--dbscheme",
help="output file for dbscheme generation, input file for trap generation",
),
p.add_argument("--ql-output", help="output directory for generated QL files"),
p.add_argument(
"--ql-stub-output",
help="output directory for QL stub/customization files. Defines also the "
"generated qll file importing every class file",
),
p.add_argument(
"--ql-test-output",
help="output directory for QL generated extractor test files",
),
p.add_argument(
"--ql-cfg-output", help="output directory for QL CFG layer (optional)."
),
p.add_argument(
"--cpp-output",
help="output directory for generated C++ files, required if trap or cpp is provided to "
"--generate",
),
p.add_argument(
"--rust-output",
help="output directory for generated Rust files, required if rust is provided to "
"--generate",
),
p.add_argument(
"--generated-registry",
help="registry file containing information about checked-in generated code. A .gitattributes"
"file is generated besides it to mark those files with linguist-generated=true. Must"
"be in a directory containing all generated code.",
),
]
p.add_argument("--script-name",
help="script name to put in header comments of generated files. By default, the path of this "
"script relative to the root directory")
p.add_argument("--trap-library",
help="path to the trap library from an include directory, required if generating C++ trap bindings"),
p.add_argument("--ql-format", action="store_true", default=True,
help="use codeql to autoformat QL files (which is the default)")
p.add_argument("--no-ql-format", action="store_false", dest="ql_format", help="do not format QL files")
p.add_argument("--codeql-binary", default="codeql", help="command to use for QL formatting (default %(default)s)")
p.add_argument("--force", "-f", action="store_true",
help="generate all files without skipping unchanged files and overwriting modified ones")
p.add_argument("--use-current-directory", action="store_true",
help="do not consider paths as relative to --root-dir or the configuration directory")
p.add_argument(
"--script-name",
help="script name to put in header comments of generated files. By default, the path of this "
"script relative to the root directory",
)
p.add_argument(
"--trap-library",
help="path to the trap library from an include directory, required if generating C++ trap bindings",
),
p.add_argument(
"--ql-format",
action="store_true",
default=True,
help="use codeql to autoformat QL files (which is the default)",
)
p.add_argument(
"--no-ql-format",
action="store_false",
dest="ql_format",
help="do not format QL files",
)
p.add_argument(
"--codeql-binary",
default="codeql",
help="command to use for QL formatting (default %(default)s)",
)
p.add_argument(
"--force",
"-f",
action="store_true",
help="generate all files without skipping unchanged files and overwriting modified ones",
)
p.add_argument(
"--use-current-directory",
action="store_true",
help="do not consider paths as relative to --root-dir or the configuration directory",
)
opts = p.parse_args()
if opts.configuration_file is not None:
with open(opts.configuration_file) as config:
@@ -97,7 +145,15 @@ def _parse_args() -> argparse.Namespace:
for arg in path_arguments:
path = getattr(opts, arg.dest)
if path is not None:
setattr(opts, arg.dest, _abspath(path) if opts.use_current_directory else (opts.root_dir / path))
setattr(
opts,
arg.dest,
(
_abspath(path)
if opts.use_current_directory
else (opts.root_dir / path)
),
)
if not opts.script_name:
opts.script_name = paths.exe_file.relative_to(opts.root_dir)
return opts
@@ -115,7 +171,7 @@ def run():
log_level = logging.ERROR
else:
log_level = logging.INFO
logging.basicConfig(format="{levelname} {message}", style='{', level=log_level)
logging.basicConfig(format="{levelname} {message}", style="{", level=log_level)
for target in opts.generate:
generate(target, opts, render.Renderer(opts.script_name))

View File

@@ -49,7 +49,11 @@ def _get_trap_name(cls: schema.Class, p: schema.Property) -> str | None:
return inflection.pluralize(trap_name)
def _get_field(cls: schema.Class, p: schema.Property, add_or_none_except: typing.Optional[str] = None) -> cpp.Field:
def _get_field(
cls: schema.Class,
p: schema.Property,
add_or_none_except: typing.Optional[str] = None,
) -> cpp.Field:
args = dict(
field_name=p.name + ("_" if p.name in cpp.cpp_keywords else ""),
base_type=_get_type(p.type, add_or_none_except),
@@ -83,14 +87,15 @@ class Processor:
bases=[self._get_class(b) for b in cls.bases],
fields=[
_get_field(cls, p, self._add_or_none_except)
for p in cls.properties if "cpp_skip" not in p.pragmas and not p.synth
for p in cls.properties
if "cpp_skip" not in p.pragmas and not p.synth
],
final=not cls.derived,
trap_name=trap_name,
)
def get_classes(self):
ret = {'': []}
ret = {"": []}
for k, cls in self._classmap.items():
if not cls.synth:
ret.setdefault(cls.group, []).append(self._get_class(cls.name))
@@ -102,6 +107,12 @@ def generate(opts, renderer):
processor = Processor(schemaloader.load_file(opts.schema))
out = opts.cpp_output
for dir, classes in processor.get_classes().items():
renderer.render(cpp.ClassList(classes, opts.schema,
include_parent=bool(dir),
trap_library=opts.trap_library), out / dir / "TrapClasses")
renderer.render(
cpp.ClassList(
classes,
opts.schema,
include_parent=bool(dir),
trap_library=opts.trap_library,
),
out / dir / "TrapClasses",
)

View File

@@ -13,6 +13,7 @@ Moreover:
as columns
The type hierarchy will be translated to corresponding `union` declarations.
"""
import typing
import inflection
@@ -29,7 +30,7 @@ class Error(Exception):
def dbtype(typename: str, add_or_none_except: typing.Optional[str] = None) -> str:
""" translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes.
"""translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes.
For class types, appends an underscore followed by `null` if provided
"""
if typename[0].isupper():
@@ -42,12 +43,18 @@ def dbtype(typename: str, add_or_none_except: typing.Optional[str] = None) -> st
return typename
def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], add_or_none_except: typing.Optional[str] = None):
""" Yield all dbscheme entities needed to model class `cls` """
def cls_to_dbscheme(
cls: schema.Class,
lookup: typing.Dict[str, schema.Class],
add_or_none_except: typing.Optional[str] = None,
):
"""Yield all dbscheme entities needed to model class `cls`"""
if cls.synth:
return
if cls.derived:
yield Union(dbtype(cls.name), (dbtype(c) for c in cls.derived if not lookup[c].synth))
yield Union(
dbtype(cls.name), (dbtype(c) for c in cls.derived if not lookup[c].synth)
)
dir = pathlib.Path(cls.group) if cls.group else None
# output a table specific to a class only if it is a leaf class or it has 1-to-1 properties
# Leaf classes need a table to bind the `@` ids
@@ -61,9 +68,11 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a
name=inflection.tableize(cls.name),
columns=[
Column("id", type=dbtype(cls.name), binding=binding),
] + [
]
+ [
Column(f.name, dbtype(f.type, add_or_none_except))
for f in cls.properties if f.is_single and not f.synth
for f in cls.properties
if f.is_single and not f.synth
],
dir=dir,
)
@@ -74,28 +83,37 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a
continue
if f.is_unordered:
yield Table(
name=overridden_table_name or inflection.tableize(f"{cls.name}_{f.name}"),
name=overridden_table_name
or inflection.tableize(f"{cls.name}_{f.name}"),
columns=[
Column("id", type=dbtype(cls.name)),
Column(inflection.singularize(f.name), dbtype(f.type, add_or_none_except)),
Column(
inflection.singularize(f.name),
dbtype(f.type, add_or_none_except),
),
],
dir=dir,
)
elif f.is_repeated:
yield Table(
keyset=KeySet(["id", "index"]),
name=overridden_table_name or inflection.tableize(f"{cls.name}_{f.name}"),
name=overridden_table_name
or inflection.tableize(f"{cls.name}_{f.name}"),
columns=[
Column("id", type=dbtype(cls.name)),
Column("index", type="int"),
Column(inflection.singularize(f.name), dbtype(f.type, add_or_none_except)),
Column(
inflection.singularize(f.name),
dbtype(f.type, add_or_none_except),
),
],
dir=dir,
)
elif f.is_optional:
yield Table(
keyset=KeySet(["id"]),
name=overridden_table_name or inflection.tableize(f"{cls.name}_{f.name}"),
name=overridden_table_name
or inflection.tableize(f"{cls.name}_{f.name}"),
columns=[
Column("id", type=dbtype(cls.name)),
Column(f.name, dbtype(f.type, add_or_none_except)),
@@ -105,7 +123,8 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a
elif f.is_predicate:
yield Table(
keyset=KeySet(["id"]),
name=overridden_table_name or inflection.underscore(f"{cls.name}_{f.name}"),
name=overridden_table_name
or inflection.underscore(f"{cls.name}_{f.name}"),
columns=[
Column("id", type=dbtype(cls.name)),
],
@@ -119,33 +138,46 @@ def check_name_conflicts(decls: list[Table | Union]):
match decl:
case Table(name=name):
if name in names:
raise Error(f"Duplicate table name: {
name}, you can use `@ql.db_table_name` on a property to resolve this")
raise Error(
f"Duplicate table name: {
name}, you can use `@ql.db_table_name` on a property to resolve this"
)
names.add(name)
def get_declarations(data: schema.Schema):
add_or_none_except = data.root_class.name if data.null else None
declarations = [d for cls in data.classes.values() if not cls.imported for d in cls_to_dbscheme(cls,
data.classes, add_or_none_except)]
declarations = [
d
for cls in data.classes.values()
if not cls.imported
for d in cls_to_dbscheme(cls, data.classes, add_or_none_except)
]
if data.null:
property_classes = {
prop.type for cls in data.classes.values() for prop in cls.properties
prop.type
for cls in data.classes.values()
for prop in cls.properties
if cls.name != data.null and prop.type and prop.type[0].isupper()
}
declarations += [
Union(dbtype(t, data.null), [dbtype(t), dbtype(data.null)]) for t in sorted(property_classes)
Union(dbtype(t, data.null), [dbtype(t), dbtype(data.null)])
for t in sorted(property_classes)
]
check_name_conflicts(declarations)
return declarations
def get_includes(data: schema.Schema, include_dir: pathlib.Path, root_dir: pathlib.Path):
def get_includes(
data: schema.Schema, include_dir: pathlib.Path, root_dir: pathlib.Path
):
includes = []
for inc in data.includes:
inc = include_dir / inc
with open(inc) as inclusion:
includes.append(SchemeInclude(src=inc.relative_to(root_dir), data=inclusion.read()))
includes.append(
SchemeInclude(src=inc.relative_to(root_dir), data=inclusion.read())
)
return includes
@@ -155,8 +187,10 @@ def generate(opts, renderer):
data = schemaloader.load_file(input)
dbscheme = Scheme(src=input.name,
includes=get_includes(data, include_dir=input.parent, root_dir=input.parent),
declarations=get_declarations(data))
dbscheme = Scheme(
src=input.name,
includes=get_includes(data, include_dir=input.parent, root_dir=input.parent),
declarations=get_declarations(data),
)
renderer.render(dbscheme, out)

View File

@@ -19,6 +19,7 @@ Moreover in the test directory for each <Class> in <group> it will generate bene
* one `<Class>.ql` test query for all single properties and on `<Class>_<property>.ql` test query for each optional or
repeated property
"""
# TODO this should probably be split in different generators now: ql, qltest, maybe qlsynth
import logging
@@ -70,7 +71,7 @@ abbreviations = {
abbreviations.update({f"{k}s": f"{v}s" for k, v in abbreviations.items()})
_abbreviations_re = re.compile("|".join(fr"\b{abbr}\b" for abbr in abbreviations))
_abbreviations_re = re.compile("|".join(rf"\b{abbr}\b" for abbr in abbreviations))
def _humanize(s: str) -> str:
@@ -98,11 +99,17 @@ def _get_doc(cls: schema.Class, prop: schema.Property, plural=None):
return format.format(**{noun: transform(noun) for noun in nouns})
prop_name = _humanize(prop.name)
class_name = cls.pragmas.get("ql_default_doc_name", _humanize(inflection.underscore(cls.name)))
class_name = cls.pragmas.get(
"ql_default_doc_name", _humanize(inflection.underscore(cls.name))
)
if prop.is_predicate:
return f"this {class_name} {prop_name}"
if plural is not None:
prop_name = inflection.pluralize(prop_name) if plural else inflection.singularize(prop_name)
prop_name = (
inflection.pluralize(prop_name)
if plural
else inflection.singularize(prop_name)
)
return f"{prop_name} of this {class_name}"
@@ -114,8 +121,12 @@ def _type_is_hideable(t: str, lookup: typing.Dict[str, schema.ClassBase]) -> boo
return False
def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dict[str, schema.ClassBase],
prev_child: str = "") -> ql.Property:
def get_ql_property(
cls: schema.Class,
prop: schema.Property,
lookup: typing.Dict[str, schema.ClassBase],
prev_child: str = "",
) -> ql.Property:
args = dict(
type=prop.type if not prop.is_predicate else "predicate",
@@ -133,12 +144,15 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
ql_name = prop.pragmas.get("ql_name", prop.name)
db_table_name = prop.pragmas.get("ql_db_table_name")
if db_table_name and prop.is_single:
raise Error(f"`db_table_name` pragma is not supported for single properties, but {cls.name}.{prop.name} has it")
raise Error(
f"`db_table_name` pragma is not supported for single properties, but {cls.name}.{prop.name} has it"
)
if prop.is_single:
args.update(
singular=inflection.camelize(ql_name),
tablename=inflection.tableize(cls.name),
tableparams=["this"] + ["result" if p is prop else "_" for p in cls.properties if p.is_single],
tableparams=["this"]
+ ["result" if p is prop else "_" for p in cls.properties if p.is_single],
doc=_get_doc(cls, prop),
)
elif prop.is_repeated:
@@ -146,7 +160,11 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
singular=inflection.singularize(inflection.camelize(ql_name)),
plural=inflection.pluralize(inflection.camelize(ql_name)),
tablename=db_table_name or inflection.tableize(f"{cls.name}_{prop.name}"),
tableparams=["this", "index", "result"] if not prop.is_unordered else ["this", "result"],
tableparams=(
["this", "index", "result"]
if not prop.is_unordered
else ["this", "result"]
),
doc=_get_doc(cls, prop, plural=False),
doc_plural=_get_doc(cls, prop, plural=True),
)
@@ -169,7 +187,9 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
return ql.Property(**args)
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.ClassBase]) -> ql.Class:
def get_ql_class(
cls: schema.Class, lookup: typing.Dict[str, schema.ClassBase]
) -> ql.Class:
if "ql_name" in cls.pragmas:
raise Error("ql_name is not supported yet for classes, only for properties")
prev_child = ""
@@ -195,12 +215,14 @@ def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.ClassBase])
)
def get_ql_cfg_class(cls: schema.Class, lookup: typing.Dict[str, ql.Class]) -> ql.CfgClass:
def get_ql_cfg_class(
cls: schema.Class, lookup: typing.Dict[str, ql.Class]
) -> ql.CfgClass:
return ql.CfgClass(
name=cls.name,
bases=[base for base in cls.bases if lookup[base.base].cfg],
properties=cls.properties,
doc=cls.doc
doc=cls.doc,
)
@@ -214,24 +236,33 @@ _final_db_class_lookup = {}
def get_ql_synth_class_db(name: str) -> ql.Synth.FinalClassDb:
return _final_db_class_lookup.setdefault(name, ql.Synth.FinalClassDb(name=name,
params=[
ql.Synth.Param("id", _to_db_type(name))]))
return _final_db_class_lookup.setdefault(
name,
ql.Synth.FinalClassDb(
name=name, params=[ql.Synth.Param("id", _to_db_type(name))]
),
)
def get_ql_synth_class(cls: schema.Class):
if cls.derived:
return ql.Synth.NonFinalClass(name=cls.name, derived=sorted(cls.derived),
root=not cls.bases)
return ql.Synth.NonFinalClass(
name=cls.name, derived=sorted(cls.derived), root=not cls.bases
)
if cls.synth and cls.synth.from_class is not None:
source = cls.synth.from_class
get_ql_synth_class_db(source).subtract_type(cls.name)
return ql.Synth.FinalClassDerivedSynth(name=cls.name,
params=[ql.Synth.Param("id", _to_db_type(source))])
return ql.Synth.FinalClassDerivedSynth(
name=cls.name, params=[ql.Synth.Param("id", _to_db_type(source))]
)
if cls.synth and cls.synth.on_arguments is not None:
return ql.Synth.FinalClassFreshSynth(name=cls.name,
params=[ql.Synth.Param(k, _to_db_type(v))
for k, v in cls.synth.on_arguments.items()])
return ql.Synth.FinalClassFreshSynth(
name=cls.name,
params=[
ql.Synth.Param(k, _to_db_type(v))
for k, v in cls.synth.on_arguments.items()
],
)
return get_ql_synth_class_db(cls.name)
@@ -250,7 +281,13 @@ def get_types_used_by(cls: ql.Class, is_impl: bool) -> typing.Iterable[str]:
def get_classes_used_by(cls: ql.Class, is_impl: bool) -> typing.List[str]:
return sorted(set(t for t in get_types_used_by(cls, is_impl) if t[0].isupper() and (is_impl or t != cls.name)))
return sorted(
set(
t
for t in get_types_used_by(cls, is_impl)
if t[0].isupper() and (is_impl or t != cls.name)
)
)
def format(codeql, files):
@@ -265,7 +302,8 @@ def format(codeql, files):
codeql_path = shutil.which(codeql)
if not codeql_path:
raise FormatError(
f"`{codeql}` not found in PATH. Either install it, or pass `-- --codeql-binary` with a full path")
f"`{codeql}` not found in PATH. Either install it, or pass `-- --codeql-binary` with a full path"
)
codeql = codeql_path
res = subprocess.run(format_cmd, stderr=subprocess.PIPE, text=True)
if res.returncode:
@@ -281,16 +319,22 @@ def _get_path(cls: schema.Class) -> pathlib.Path:
def _get_path_impl(cls: schema.Class) -> pathlib.Path:
return pathlib.Path(cls.group or "", "internal", cls.name+"Impl").with_suffix(".qll")
return pathlib.Path(cls.group or "", "internal", cls.name + "Impl").with_suffix(
".qll"
)
def _get_path_public(cls: schema.Class) -> pathlib.Path:
return pathlib.Path(cls.group or "", "internal" if "ql_internal" in cls.pragmas else "", cls.name).with_suffix(".qll")
return pathlib.Path(
cls.group or "", "internal" if "ql_internal" in cls.pragmas else "", cls.name
).with_suffix(".qll")
def _get_all_properties(cls: schema.Class, lookup: typing.Dict[str, schema.Class],
already_seen: typing.Optional[typing.Set[int]] = None) -> \
typing.Iterable[typing.Tuple[schema.Class, schema.Property]]:
def _get_all_properties(
cls: schema.Class,
lookup: typing.Dict[str, schema.Class],
already_seen: typing.Optional[typing.Set[int]] = None,
) -> typing.Iterable[typing.Tuple[schema.Class, schema.Property]]:
# deduplicate using ids
if already_seen is None:
already_seen = set()
@@ -304,14 +348,19 @@ def _get_all_properties(cls: schema.Class, lookup: typing.Dict[str, schema.Class
yield cls, p
def _get_all_properties_to_be_tested(cls: schema.Class, lookup: typing.Dict[str, schema.Class]) -> \
typing.Iterable[ql.PropertyForTest]:
def _get_all_properties_to_be_tested(
cls: schema.Class, lookup: typing.Dict[str, schema.Class]
) -> typing.Iterable[ql.PropertyForTest]:
for c, p in _get_all_properties(cls, lookup):
if not ("qltest_skip" in c.pragmas or "qltest_skip" in p.pragmas):
# TODO here operations are duplicated, but should be better if we split ql and qltest generation
p = get_ql_property(c, p, lookup)
yield ql.PropertyForTest(p.getter, is_total=p.is_single or p.is_predicate,
type=p.type if not p.is_predicate else None, is_indexed=p.is_indexed)
yield ql.PropertyForTest(
p.getter,
is_total=p.is_single or p.is_predicate,
type=p.type if not p.is_predicate else None,
is_indexed=p.is_indexed,
)
if p.is_repeated and not p.is_optional:
yield ql.PropertyForTest(f"getNumberOf{p.plural}", type="int")
elif p.is_optional and not p.is_repeated:
@@ -324,33 +373,45 @@ def _partition_iter(x, pred):
def _partition(l, pred):
""" partitions a list according to boolean predicate """
"""partitions a list according to boolean predicate"""
return map(list, _partition_iter(l, pred))
def _is_in_qltest_collapsed_hierarchy(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
return "qltest_collapse_hierarchy" in cls.pragmas or _is_under_qltest_collapsed_hierarchy(cls, lookup)
def _is_in_qltest_collapsed_hierarchy(
cls: schema.Class, lookup: typing.Dict[str, schema.Class]
):
return (
"qltest_collapse_hierarchy" in cls.pragmas
or _is_under_qltest_collapsed_hierarchy(cls, lookup)
)
def _is_under_qltest_collapsed_hierarchy(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
def _is_under_qltest_collapsed_hierarchy(
cls: schema.Class, lookup: typing.Dict[str, schema.Class]
):
return "qltest_uncollapse_hierarchy" not in cls.pragmas and any(
_is_in_qltest_collapsed_hierarchy(lookup[b], lookup) for b in cls.bases)
_is_in_qltest_collapsed_hierarchy(lookup[b], lookup) for b in cls.bases
)
def should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
return "qltest_skip" in cls.pragmas or not (
cls.final or "qltest_collapse_hierarchy" in cls.pragmas) or _is_under_qltest_collapsed_hierarchy(
cls, lookup)
return (
"qltest_skip" in cls.pragmas
or not (cls.final or "qltest_collapse_hierarchy" in cls.pragmas)
or _is_under_qltest_collapsed_hierarchy(cls, lookup)
)
def _get_stub(cls: schema.Class, base_import: str, generated_import_prefix: str) -> ql.Stub:
def _get_stub(
cls: schema.Class, base_import: str, generated_import_prefix: str
) -> ql.Stub:
if isinstance(cls.synth, schema.SynthInfo):
if cls.synth.from_class is not None:
accessors = [
ql.SynthUnderlyingAccessor(
argument="Entity",
type=_to_db_type(cls.synth.from_class),
constructorparams=["result"]
constructorparams=["result"],
)
]
elif cls.synth.on_arguments is not None:
@@ -358,28 +419,39 @@ def _get_stub(cls: schema.Class, base_import: str, generated_import_prefix: str)
ql.SynthUnderlyingAccessor(
argument=inflection.camelize(arg),
type=_to_db_type(type),
constructorparams=["result" if a == arg else "_" for a in cls.synth.on_arguments]
) for arg, type in cls.synth.on_arguments.items()
constructorparams=[
"result" if a == arg else "_" for a in cls.synth.on_arguments
],
)
for arg, type in cls.synth.on_arguments.items()
]
else:
accessors = []
return ql.Stub(name=cls.name, base_import=base_import, import_prefix=generated_import_prefix,
doc=cls.doc, synth_accessors=accessors)
return ql.Stub(
name=cls.name,
base_import=base_import,
import_prefix=generated_import_prefix,
doc=cls.doc,
synth_accessors=accessors,
)
def _get_class_public(cls: schema.Class) -> ql.ClassPublic:
return ql.ClassPublic(name=cls.name, doc=cls.doc, internal="ql_internal" in cls.pragmas)
return ql.ClassPublic(
name=cls.name, doc=cls.doc, internal="ql_internal" in cls.pragmas
)
_stub_qldoc_header = "// the following QLdoc is generated: if you need to edit it, do it in the schema file\n "
_class_qldoc_re = re.compile(
rf"(?P<qldoc>(?:{re.escape(_stub_qldoc_header)})?/\*\*.*?\*/\s*|^\s*)(?:class\s+(?P<class>\w+))?",
re.MULTILINE | re.DOTALL)
re.MULTILINE | re.DOTALL,
)
def _patch_class_qldoc(cls: str, qldoc: str, stub_file: pathlib.Path):
""" Replace or insert `qldoc` as the QLdoc of class `cls` in `stub_file` """
"""Replace or insert `qldoc` as the QLdoc of class `cls` in `stub_file`"""
if not qldoc or not stub_file.exists():
return
qldoc = "\n ".join(l.rstrip() for l in qldoc.splitlines())
@@ -415,7 +487,11 @@ def generate(opts, renderer):
data = schemaloader.load_file(input)
classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items() if not cls.imported}
classes = {
name: get_ql_class(cls, data.classes)
for name, cls in data.classes.items()
if not cls.imported
}
if not classes:
raise NoClasses
root = next(iter(classes.values()))
@@ -429,28 +505,47 @@ def generate(opts, renderer):
cfg_classes = []
generated_import_prefix = get_import(out, opts.root_dir)
registry = opts.generated_registry or pathlib.Path(
os.path.commonpath((out, stub_out, test_out)), ".generated.list")
os.path.commonpath((out, stub_out, test_out)), ".generated.list"
)
with renderer.manage(generated=generated, stubs=stubs, registry=registry,
force=opts.force) as renderer:
with renderer.manage(
generated=generated, stubs=stubs, registry=registry, force=opts.force
) as renderer:
db_classes = [cls for name, cls in classes.items() if not data.classes[name].synth]
renderer.render(ql.DbClasses(classes=db_classes, imports=sorted(set(pre_imports.values()))), out / "Raw.qll")
db_classes = [
cls for name, cls in classes.items() if not data.classes[name].synth
]
renderer.render(
ql.DbClasses(classes=db_classes, imports=sorted(set(pre_imports.values()))),
out / "Raw.qll",
)
classes_by_dir_and_name = sorted(classes.values(), key=lambda cls: (cls.dir, cls.name))
classes_by_dir_and_name = sorted(
classes.values(), key=lambda cls: (cls.dir, cls.name)
)
for c in classes_by_dir_and_name:
path = get_import(stub_out / c.dir / "internal" /
c.name if c.internal else stub_out / c.path, opts.root_dir)
path = get_import(
(
stub_out / c.dir / "internal" / c.name
if c.internal
else stub_out / c.path
),
opts.root_dir,
)
imports[c.name] = path
path_impl = get_import(stub_out / c.dir / "internal" / c.name, opts.root_dir)
path_impl = get_import(
stub_out / c.dir / "internal" / c.name, opts.root_dir
)
imports_impl[c.name + "Impl"] = path_impl + "Impl"
if c.cfg:
cfg_classes.append(get_ql_cfg_class(c, classes))
for c in classes.values():
qll = out / c.path.with_suffix(".qll")
c.imports = [imports[t] if t in imports else imports_impl[t] +
"::Impl as " + t for t in get_classes_used_by(c, is_impl=True)]
c.imports = [
imports[t] if t in imports else imports_impl[t] + "::Impl as " + t
for t in get_classes_used_by(c, is_impl=True)
]
classes_used_by[c.name] = get_classes_used_by(c, is_impl=False)
c.import_prefix = generated_import_prefix
renderer.render(c, qll)
@@ -458,7 +553,7 @@ def generate(opts, renderer):
if cfg_out:
cfg_classes_val = ql.CfgClasses(
include_file_import=get_import(include_file, opts.root_dir),
classes=cfg_classes
classes=cfg_classes,
)
cfg_qll = cfg_out / "CfgNodes.qll"
renderer.render(cfg_classes_val, cfg_qll)
@@ -475,7 +570,7 @@ def generate(opts, renderer):
if not renderer.is_customized_stub(stub_file):
renderer.render(stub, stub_file)
else:
qldoc = renderer.render_str(stub, template='ql_stub_class_qldoc')
qldoc = renderer.render_str(stub, template="ql_stub_class_qldoc")
_patch_class_qldoc(c.name, qldoc, stub_file)
class_public = _get_class_public(c)
path_public = _get_path_public(c)
@@ -484,18 +579,31 @@ def generate(opts, renderer):
renderer.render(class_public, class_public_file)
# for example path/to/elements -> path/to/elements.qll
renderer.render(ql.ImportList([i for name, i in imports.items() if name not in classes or not classes[name].internal]),
include_file)
renderer.render(
ql.ImportList(
[
i
for name, i in imports.items()
if name not in classes or not classes[name].internal
]
),
include_file,
)
elements_module = get_import(include_file, opts.root_dir)
renderer.render(
ql.GetParentImplementation(
classes=list(classes.values()),
imports=[elements_module] + [i for name,
i in imports.items() if name in classes and classes[name].internal],
imports=[elements_module]
+ [
i
for name, i in imports.items()
if name in classes and classes[name].internal
],
),
out / 'ParentChild.qll')
out / "ParentChild.qll",
)
if test_out:
for c in data.classes.values():
@@ -507,39 +615,61 @@ def generate(opts, renderer):
test_with = data.classes[test_with_name] if test_with_name else c
test_dir = test_out / test_with.group / test_with.name
test_dir.mkdir(parents=True, exist_ok=True)
if all(f.suffix in (".txt", ".ql", ".actual", ".expected") for f in test_dir.glob("*.*")):
if all(
f.suffix in (".txt", ".ql", ".actual", ".expected")
for f in test_dir.glob("*.*")
):
log.warning(f"no test source in {test_dir.relative_to(test_out)}")
renderer.render(ql.MissingTestInstructions(),
test_dir / missing_test_source_filename)
renderer.render(
ql.MissingTestInstructions(),
test_dir / missing_test_source_filename,
)
continue
total_props, partial_props = _partition(_get_all_properties_to_be_tested(c, data.classes),
lambda p: p.is_total)
renderer.render(ql.ClassTester(class_name=c.name,
properties=total_props,
elements_module=elements_module,
# in case of collapsed hierarchies we want to see the actual QL class in results
show_ql_class="qltest_collapse_hierarchy" in c.pragmas),
test_dir / f"{c.name}.ql")
total_props, partial_props = _partition(
_get_all_properties_to_be_tested(c, data.classes),
lambda p: p.is_total,
)
renderer.render(
ql.ClassTester(
class_name=c.name,
properties=total_props,
elements_module=elements_module,
# in case of collapsed hierarchies we want to see the actual QL class in results
show_ql_class="qltest_collapse_hierarchy" in c.pragmas,
),
test_dir / f"{c.name}.ql",
)
for p in partial_props:
renderer.render(ql.PropertyTester(class_name=c.name,
elements_module=elements_module,
property=p), test_dir / f"{c.name}_{p.getter}.ql")
renderer.render(
ql.PropertyTester(
class_name=c.name,
elements_module=elements_module,
property=p,
),
test_dir / f"{c.name}_{p.getter}.ql",
)
final_synth_types = []
non_final_synth_types = []
constructor_imports = []
synth_constructor_imports = []
stubs = {}
for cls in sorted((cls for cls in data.classes.values() if not cls.imported),
key=lambda cls: (cls.group, cls.name)):
for cls in sorted(
(cls for cls in data.classes.values() if not cls.imported),
key=lambda cls: (cls.group, cls.name),
):
synth_type = get_ql_synth_class(cls)
if synth_type.is_final:
final_synth_types.append(synth_type)
if synth_type.has_params:
stub_file = stub_out / cls.group / "internal" / f"{cls.name}Constructor.qll"
stub_file = (
stub_out / cls.group / "internal" / f"{cls.name}Constructor.qll"
)
if not renderer.is_customized_stub(stub_file):
# stub rendering must be postponed as we might not have yet all subtracted synth types in `synth_type`
stubs[stub_file] = ql.Synth.ConstructorStub(synth_type, import_prefix=generated_import_prefix)
stubs[stub_file] = ql.Synth.ConstructorStub(
synth_type, import_prefix=generated_import_prefix
)
constructor_import = get_import(stub_file, opts.root_dir)
constructor_imports.append(constructor_import)
if synth_type.is_synth:
@@ -549,9 +679,20 @@ def generate(opts, renderer):
for stub_file, data in stubs.items():
renderer.render(data, stub_file)
renderer.render(ql.Synth.Types(root.name, generated_import_prefix,
final_synth_types, non_final_synth_types), out / "Synth.qll")
renderer.render(ql.ImportList(constructor_imports), out / "SynthConstructors.qll")
renderer.render(ql.ImportList(synth_constructor_imports), out / "PureSynthConstructors.qll")
renderer.render(
ql.Synth.Types(
root.name,
generated_import_prefix,
final_synth_types,
non_final_synth_types,
),
out / "Synth.qll",
)
renderer.render(
ql.ImportList(constructor_imports), out / "SynthConstructors.qll"
)
renderer.render(
ql.ImportList(synth_constructor_imports), out / "PureSynthConstructors.qll"
)
if opts.ql_format:
format(opts.codeql_binary, renderer.written)

View File

@@ -55,7 +55,8 @@ def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
def _get_properties(
cls: schema.Class, lookup: dict[str, schema.ClassBase],
cls: schema.Class,
lookup: dict[str, schema.ClassBase],
) -> typing.Iterable[tuple[schema.Class, schema.Property]]:
for b in cls.bases:
yield from _get_properties(lookup[b], lookup)
@@ -92,8 +93,9 @@ class Processor:
# only generate detached fields in the actual class defining them, not the derived ones
if c is cls:
# TODO lift this restriction if required (requires change in dbschemegen as well)
assert c.derived or not p.is_single, \
f"property {p.name} in concrete class marked as detached but not optional"
assert (
c.derived or not p.is_single
), f"property {p.name} in concrete class marked as detached but not optional"
detached_fields.append(_get_field(c, p))
elif not cls.derived:
# for non-detached ones, only generate fields in the concrete classes
@@ -123,10 +125,12 @@ def generate(opts, renderer):
processor = Processor(schemaloader.load_file(opts.schema))
out = opts.rust_output
groups = set()
with renderer.manage(generated=out.rglob("*.rs"),
stubs=(),
registry=out / ".generated.list",
force=opts.force) as renderer:
with renderer.manage(
generated=out.rglob("*.rs"),
stubs=(),
registry=out / ".generated.list",
force=opts.force,
) as renderer:
for group, classes in processor.get_classes().items():
group = group or "top"
groups.add(group)

View File

@@ -42,7 +42,9 @@ def _get_code(doc: list[str]) -> list[str]:
code.append(f"// {line}")
case _, True:
code.append(line)
assert not adding_code, "Unterminated code block in docstring:\n " + "\n ".join(doc)
assert not adding_code, "Unterminated code block in docstring:\n " + "\n ".join(
doc
)
if has_code:
return code
return []
@@ -51,15 +53,19 @@ def _get_code(doc: list[str]) -> list[str]:
def generate(opts, renderer):
assert opts.ql_test_output
schema = schemaloader.load_file(opts.schema)
with renderer.manage(generated=opts.ql_test_output.rglob("gen_*.rs"),
stubs=(),
registry=opts.ql_test_output / ".generated_tests.list",
force=opts.force) as renderer:
with renderer.manage(
generated=opts.ql_test_output.rglob("gen_*.rs"),
stubs=(),
registry=opts.ql_test_output / ".generated_tests.list",
force=opts.force,
) as renderer:
for cls in schema.classes.values():
if cls.imported:
continue
if (qlgen.should_skip_qltest(cls, schema.classes) or
"rust_skip_doc_test" in cls.pragmas):
if (
qlgen.should_skip_qltest(cls, schema.classes)
or "rust_skip_doc_test" in cls.pragmas
):
continue
code = _get_code(cls.doc)
for p in schema.iter_properties(cls.name):
@@ -79,5 +85,10 @@ def generate(opts, renderer):
code = [indent + l for l in code]
test_with_name = typing.cast(str, cls.pragmas.get("qltest_test_with"))
test_with = schema.classes[test_with_name] if test_with_name else cls
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{test_name}.rs"
test = (
opts.ql_test_output
/ test_with.group
/ test_with.name
/ f"gen_{test_name}.rs"
)
renderer.render(TestCode(code="\n".join(code), function=fn), test)

View File

@@ -86,13 +86,18 @@ def generate(opts, renderer):
for dir, entries in traps.items():
dir = dir or pathlib.Path()
relative_gen_dir = pathlib.Path(*[".." for _ in dir.parents])
renderer.render(cpp.TrapList(entries, opts.dbscheme, trap_library, relative_gen_dir), out / dir / "TrapEntries")
renderer.render(
cpp.TrapList(entries, opts.dbscheme, trap_library, relative_gen_dir),
out / dir / "TrapEntries",
)
tags = []
for tag in toposort_flatten(tag_graph):
tags.append(cpp.Tag(
name=get_tag_name(tag),
bases=[get_tag_name(b) for b in sorted(tag_graph[tag])],
id=tag,
))
tags.append(
cpp.Tag(
name=get_tag_name(tag),
bases=[get_tag_name(b) for b in sorted(tag_graph[tag])],
id=tag,
)
)
renderer.render(cpp.TagList(tags, opts.dbscheme), out / "TrapTags")

View File

@@ -4,20 +4,111 @@ from dataclasses import dataclass, field
from typing import List, ClassVar
# taken from https://en.cppreference.com/w/cpp/keyword
cpp_keywords = {"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", "atomic_commit", "atomic_noexcept",
"auto", "bitand", "bitor", "bool", "break", "case", "catch", "char", "char8_t", "char16_t", "char32_t",
"class", "compl", "concept", "const", "consteval", "constexpr", "constinit", "const_cast", "continue",
"co_await", "co_return", "co_yield", "decltype", "default", "delete", "do", "double", "dynamic_cast",
"else", "enum", "explicit", "export", "extern", "false", "float", "for", "friend", "goto", "if",
"inline", "int", "long", "mutable", "namespace", "new", "noexcept", "not", "not_eq", "nullptr",
"operator", "or", "or_eq", "private", "protected", "public", "reflexpr", "register", "reinterpret_cast",
"requires", "return", "short", "signed", "sizeof", "static", "static_assert", "static_cast", "struct",
"switch", "synchronized", "template", "this", "thread_local", "throw", "true", "try", "typedef",
"typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile", "wchar_t", "while",
"xor", "xor_eq"}
cpp_keywords = {
"alignas",
"alignof",
"and",
"and_eq",
"asm",
"atomic_cancel",
"atomic_commit",
"atomic_noexcept",
"auto",
"bitand",
"bitor",
"bool",
"break",
"case",
"catch",
"char",
"char8_t",
"char16_t",
"char32_t",
"class",
"compl",
"concept",
"const",
"consteval",
"constexpr",
"constinit",
"const_cast",
"continue",
"co_await",
"co_return",
"co_yield",
"decltype",
"default",
"delete",
"do",
"double",
"dynamic_cast",
"else",
"enum",
"explicit",
"export",
"extern",
"false",
"float",
"for",
"friend",
"goto",
"if",
"inline",
"int",
"long",
"mutable",
"namespace",
"new",
"noexcept",
"not",
"not_eq",
"nullptr",
"operator",
"or",
"or_eq",
"private",
"protected",
"public",
"reflexpr",
"register",
"reinterpret_cast",
"requires",
"return",
"short",
"signed",
"sizeof",
"static",
"static_assert",
"static_cast",
"struct",
"switch",
"synchronized",
"template",
"this",
"thread_local",
"throw",
"true",
"try",
"typedef",
"typeid",
"typename",
"union",
"unsigned",
"using",
"virtual",
"void",
"volatile",
"wchar_t",
"while",
"xor",
"xor_eq",
}
_field_overrides = [
(re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"), {"base_type": "unsigned"}),
(
re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"),
{"base_type": "unsigned"},
),
(re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}),
]
@@ -108,7 +199,7 @@ class Tag:
@dataclass
class TrapList:
template: ClassVar = 'trap_traps'
template: ClassVar = "trap_traps"
extensions = ["h", "cpp"]
traps: List[Trap]
source: str
@@ -118,7 +209,7 @@ class TrapList:
@dataclass
class TagList:
template: ClassVar = 'trap_tags'
template: ClassVar = "trap_tags"
extensions = ["h"]
tags: List[Tag]
@@ -127,7 +218,7 @@ class TagList:
@dataclass
class ClassBase:
ref: 'Class'
ref: "Class"
first: bool = False
@@ -140,7 +231,9 @@ class Class:
trap_name: str = None
def __post_init__(self):
self.bases = [ClassBase(c) for c in sorted(self.bases, key=lambda cls: cls.name)]
self.bases = [
ClassBase(c) for c in sorted(self.bases, key=lambda cls: cls.name)
]
if self.bases:
self.bases[0].first = True

View File

@@ -1,4 +1,4 @@
""" dbscheme format representation """
"""dbscheme format representation"""
import logging
import pathlib
@@ -100,7 +100,7 @@ class SchemeInclude:
@dataclass
class Scheme:
template: ClassVar = 'dbscheme'
template: ClassVar = "dbscheme"
src: str
includes: List[SchemeInclude]

View File

@@ -1,4 +1,4 @@
""" module providing useful filesystem paths """
"""module providing useful filesystem paths"""
import pathlib
import sys
@@ -7,13 +7,15 @@ import os
_this_file = pathlib.Path(__file__).resolve()
try:
workspace_dir = pathlib.Path(os.environ['BUILD_WORKSPACE_DIRECTORY']).resolve() # <- means we are using bazel run
root_dir = workspace_dir / 'swift'
workspace_dir = pathlib.Path(
os.environ["BUILD_WORKSPACE_DIRECTORY"]
).resolve() # <- means we are using bazel run
root_dir = workspace_dir / "swift"
except KeyError:
root_dir = _this_file.parents[2]
workspace_dir = root_dir.parent
lib_dir = _this_file.parents[2] / 'codegen' / 'lib'
templates_dir = _this_file.parents[2] / 'codegen' / 'templates'
lib_dir = _this_file.parents[2] / "codegen" / "lib"
templates_dir = _this_file.parents[2] / "codegen" / "templates"
exe_file = pathlib.Path(sys.argv[0]).resolve()

View File

@@ -100,7 +100,7 @@ class Base:
@dataclass
class Class:
template: ClassVar = 'ql_class'
template: ClassVar = "ql_class"
name: str
bases: List[Base] = field(default_factory=list)
@@ -116,7 +116,12 @@ class Class:
cfg: bool = False
def __post_init__(self):
def get_bases(bases): return [Base(str(b), str(prev)) for b, prev in zip(bases, itertools.chain([""], bases))]
def get_bases(bases):
return [
Base(str(b), str(prev))
for b, prev in zip(bases, itertools.chain([""], bases))
]
self.bases = get_bases(self.bases)
self.bases_impl = get_bases(self.bases_impl)
if self.properties:
@@ -164,7 +169,7 @@ class SynthUnderlyingAccessor:
@dataclass
class Stub:
template: ClassVar = 'ql_stub'
template: ClassVar = "ql_stub"
name: str
base_import: str
@@ -183,7 +188,7 @@ class Stub:
@dataclass
class ClassPublic:
template: ClassVar = 'ql_class_public'
template: ClassVar = "ql_class_public"
name: str
imports: List[str] = field(default_factory=list)
@@ -197,7 +202,7 @@ class ClassPublic:
@dataclass
class DbClasses:
template: ClassVar = 'ql_db'
template: ClassVar = "ql_db"
classes: List[Class] = field(default_factory=list)
imports: List[str] = field(default_factory=list)
@@ -205,14 +210,14 @@ class DbClasses:
@dataclass
class ImportList:
template: ClassVar = 'ql_imports'
template: ClassVar = "ql_imports"
imports: List[str] = field(default_factory=list)
@dataclass
class GetParentImplementation:
template: ClassVar = 'ql_parent'
template: ClassVar = "ql_parent"
classes: List[Class] = field(default_factory=list)
imports: List[str] = field(default_factory=list)
@@ -234,7 +239,7 @@ class TesterBase:
@dataclass
class ClassTester(TesterBase):
template: ClassVar = 'ql_test_class'
template: ClassVar = "ql_test_class"
properties: List[PropertyForTest] = field(default_factory=list)
show_ql_class: bool = False
@@ -242,14 +247,14 @@ class ClassTester(TesterBase):
@dataclass
class PropertyTester(TesterBase):
template: ClassVar = 'ql_test_property'
template: ClassVar = "ql_test_property"
property: PropertyForTest
@dataclass
class MissingTestInstructions:
template: ClassVar = 'ql_test_missing'
template: ClassVar = "ql_test_missing"
class Synth:
@@ -306,7 +311,9 @@ class Synth:
subtracted_synth_types: List["Synth.Class"] = field(default_factory=list)
def subtract_type(self, type: str):
self.subtracted_synth_types.append(Synth.Class(type, first=not self.subtracted_synth_types))
self.subtracted_synth_types.append(
Synth.Class(type, first=not self.subtracted_synth_types)
)
@property
def has_subtracted_synth_types(self) -> bool:
@@ -357,6 +364,6 @@ class CfgClass:
@dataclass
class CfgClasses:
template: ClassVar = 'ql_cfg_nodes'
template: ClassVar = "ql_cfg_nodes"
include_file_import: Optional[str] = None
classes: List[CfgClass] = field(default_factory=list)

View File

@@ -1,4 +1,4 @@
""" template renderer module, wrapping around `pystache.Renderer`
"""template renderer module, wrapping around `pystache.Renderer`
`pystache` is a python mustache engine, and mustache is a template language. More information on
@@ -23,14 +23,21 @@ class Error(Exception):
class Renderer:
""" Template renderer using mustache templates in the `templates` directory """
"""Template renderer using mustache templates in the `templates` directory"""
def __init__(self, generator: pathlib.Path):
self._r = pystache.Renderer(search_dirs=str(paths.templates_dir), escape=lambda u: u)
self._r = pystache.Renderer(
search_dirs=str(paths.templates_dir), escape=lambda u: u
)
self._generator = generator
def render(self, data: object, output: typing.Optional[pathlib.Path], template: typing.Optional[str] = None):
""" Render `data` to `output`.
def render(
self,
data: object,
output: typing.Optional[pathlib.Path],
template: typing.Optional[str] = None,
):
"""Render `data` to `output`.
`data` must have a `template` attribute denoting which template to use from the template directory.
@@ -58,13 +65,18 @@ class Renderer:
out.write(contents)
log.debug(f"{mnemonic}: generated {output.name}")
def manage(self, generated: typing.Iterable[pathlib.Path], stubs: typing.Iterable[pathlib.Path],
registry: pathlib.Path, force: bool = False) -> "RenderManager":
def manage(
self,
generated: typing.Iterable[pathlib.Path],
stubs: typing.Iterable[pathlib.Path],
registry: pathlib.Path,
force: bool = False,
) -> "RenderManager":
return RenderManager(self._generator, generated, stubs, registry, force)
class RenderManager(Renderer):
""" A context manager allowing to manage checked in generated files and their cleanup, able
"""A context manager allowing to manage checked in generated files and their cleanup, able
to skip unneeded writes.
This is done by using and updating a checked in list of generated files that assigns two
@@ -74,6 +86,7 @@ class RenderManager(Renderer):
* the other is the hash of the actual file after code generation has finished. This will be
different from the above because of post-processing like QL formatting. This hash is used
to detect invalid modification of generated files"""
written: typing.Set[pathlib.Path]
@dataclass
@@ -82,12 +95,18 @@ class RenderManager(Renderer):
pre contains the hash of a file as rendered, post is the hash after
postprocessing (for example QL formatting)
"""
pre: str
post: typing.Optional[str] = None
def __init__(self, generator: pathlib.Path, generated: typing.Iterable[pathlib.Path],
stubs: typing.Iterable[pathlib.Path],
registry: pathlib.Path, force: bool = False):
def __init__(
self,
generator: pathlib.Path,
generated: typing.Iterable[pathlib.Path],
stubs: typing.Iterable[pathlib.Path],
registry: pathlib.Path,
force: bool = False,
):
super().__init__(generator)
self._registry_path = registry
self._force = force
@@ -142,10 +161,14 @@ class RenderManager(Renderer):
if self._force:
pass
elif rel_path not in self._hashes:
log.warning(f"{rel_path} marked as generated but absent from the registry")
log.warning(
f"{rel_path} marked as generated but absent from the registry"
)
elif self._hashes[rel_path].post != self._hash_file(f):
raise Error(f"{rel_path} is generated but was modified, please revert the file "
"or pass --force to overwrite")
raise Error(
f"{rel_path} is generated but was modified, please revert the file "
"or pass --force to overwrite"
)
def _process_stubs(self, stubs: typing.Iterable[pathlib.Path]):
for f in stubs:
@@ -159,8 +182,10 @@ class RenderManager(Renderer):
elif rel_path not in self._hashes:
log.warning(f"{rel_path} marked as stub but absent from the registry")
elif self._hashes[rel_path].post != self._hash_file(f):
raise Error(f"{rel_path} is a stub marked as generated, but it was modified, "
"please remove the `// generated` header, revert the file or pass --force to overwrite it")
raise Error(
f"{rel_path} is a stub marked as generated, but it was modified, "
"please remove the `// generated` header, revert the file or pass --force to overwrite it"
)
@staticmethod
def is_customized_stub(file: pathlib.Path) -> bool:
@@ -191,13 +216,17 @@ class RenderManager(Renderer):
for line in reg:
if line.strip():
filename, prehash, posthash = line.split()
self._hashes[pathlib.Path(filename)] = self.Hashes(prehash, posthash)
self._hashes[pathlib.Path(filename)] = self.Hashes(
prehash, posthash
)
except FileNotFoundError:
pass
def _dump_registry(self):
self._registry_path.parent.mkdir(parents=True, exist_ok=True)
with open(self._registry_path, 'w') as out, open(self._registry_path.parent / ".gitattributes", "w") as attrs:
with open(self._registry_path, "w") as out, open(
self._registry_path.parent / ".gitattributes", "w"
) as attrs:
print(f"/{self._registry_path.name}", "linguist-generated", file=attrs)
print("/.gitattributes", "linguist-generated", file=attrs)
for f, hashes in sorted(self._hashes.items()):

View File

@@ -1,4 +1,5 @@
""" schema format representation """
"""schema format representation"""
import abc
import typing
from collections.abc import Iterable
@@ -52,7 +53,11 @@ class Property:
@property
def is_repeated(self) -> bool:
return self.kind in (self.Kind.REPEATED, self.Kind.REPEATED_OPTIONAL, self.Kind.REPEATED_UNORDERED)
return self.kind in (
self.Kind.REPEATED,
self.Kind.REPEATED_OPTIONAL,
self.Kind.REPEATED_UNORDERED,
)
@property
def is_unordered(self) -> bool:
@@ -74,10 +79,11 @@ class Property:
SingleProperty = functools.partial(Property, Property.Kind.SINGLE)
OptionalProperty = functools.partial(Property, Property.Kind.OPTIONAL)
RepeatedProperty = functools.partial(Property, Property.Kind.REPEATED)
RepeatedOptionalProperty = functools.partial(
Property, Property.Kind.REPEATED_OPTIONAL)
RepeatedOptionalProperty = functools.partial(Property, Property.Kind.REPEATED_OPTIONAL)
PredicateProperty = functools.partial(Property, Property.Kind.PREDICATE)
RepeatedUnorderedProperty = functools.partial(Property, Property.Kind.REPEATED_UNORDERED)
RepeatedUnorderedProperty = functools.partial(
Property, Property.Kind.REPEATED_UNORDERED
)
@dataclass
@@ -197,9 +203,9 @@ def _make_property(arg: object) -> Property:
class PropertyModifier(abc.ABC):
""" Modifier of `Property` objects.
Being on the right of `|` it will trigger construction of a `Property` from
the left operand.
"""Modifier of `Property` objects.
Being on the right of `|` it will trigger construction of a `Property` from
the left operand.
"""
def __ror__(self, other: object) -> Property:
@@ -210,11 +216,9 @@ class PropertyModifier(abc.ABC):
def __invert__(self) -> "PropertyModifier":
return self.negate()
def modify(self, prop: Property):
...
def modify(self, prop: Property): ...
def negate(self) -> "PropertyModifier":
...
def negate(self) -> "PropertyModifier": ...
def split_doc(doc):
@@ -224,7 +228,11 @@ def split_doc(doc):
lines = doc.splitlines()
# Determine minimum indentation (first line doesn't count):
strippedlines = (line.lstrip() for line in lines[1:])
indents = [len(line) - len(stripped) for line, stripped in zip(lines[1:], strippedlines) if stripped]
indents = [
len(line) - len(stripped)
for line, stripped in zip(lines[1:], strippedlines)
if stripped
]
# Remove indentation (first line is special):
trimmed = [lines[0].strip()]
if indents:

View File

@@ -39,7 +39,9 @@ class _DocModifier(_schema.PropertyModifier, metaclass=_DocModifierMetaclass):
def modify(self, prop: _schema.Property):
if self.doc and ("\n" in self.doc or self.doc[-1] == "."):
raise _schema.Error("No newlines or trailing dots are allowed in doc, did you intend to use desc?")
raise _schema.Error(
"No newlines or trailing dots are allowed in doc, did you intend to use desc?"
)
prop.doc = self.doc
def negate(self) -> _schema.PropertyModifier:
@@ -73,10 +75,13 @@ imported = _schema.ImportedClass
@_dataclass
class _Namespace:
""" simple namespacing mechanism """
"""simple namespacing mechanism"""
_name: str
def add(self, pragma: _Union["_PragmaBase", "_Parametrized"], key: str | None = None):
def add(
self, pragma: _Union["_PragmaBase", "_Parametrized"], key: str | None = None
):
self.__dict__[pragma.pragma] = pragma
pragma.pragma = key or f"{self._name}_{pragma.pragma}"
@@ -110,15 +115,18 @@ class _PragmaBase:
@_dataclass
class _ClassPragma(_PragmaBase):
""" A class pragma.
"""A class pragma.
For schema classes it acts as a python decorator with `@`.
"""
inherited: bool = False
def __call__(self, cls: type) -> type:
""" use this pragma as a decorator on classes """
"""use this pragma as a decorator on classes"""
if self.inherited:
setattr(cls, f"{_schema.inheritable_pragma_prefix}{self.pragma}", self.value)
setattr(
cls, f"{_schema.inheritable_pragma_prefix}{self.pragma}", self.value
)
else:
# not using hasattr as we don't want to land on inherited pragmas
if "_pragmas" not in cls.__dict__:
@@ -129,9 +137,10 @@ class _ClassPragma(_PragmaBase):
@_dataclass
class _PropertyPragma(_PragmaBase, _schema.PropertyModifier):
""" A property pragma.
"""A property pragma.
It functions similarly to a `_PropertyModifier` with `|`, adding the pragma.
"""
remove: bool = False
def modify(self, prop: _schema.Property):
@@ -149,21 +158,23 @@ class _PropertyPragma(_PragmaBase, _schema.PropertyModifier):
@_dataclass
class _Pragma(_ClassPragma, _PropertyPragma):
""" A class or property pragma.
"""A class or property pragma.
For properties, it functions similarly to a `_PropertyModifier` with `|`, adding the pragma.
For schema classes it acts as a python decorator with `@`.
"""
class _Parametrized[P, **Q, T]:
""" A parametrized pragma.
"""A parametrized pragma.
Needs to be applied to a parameter to give a pragma.
"""
def __init__(self, pragma_instance: P, factory: _Callable[Q, T]):
self.pragma_instance = pragma_instance
self.factory = factory
self.__signature__ = _inspect.signature(self.factory).replace(return_annotation=type(self.pragma_instance))
self.__signature__ = _inspect.signature(self.factory).replace(
return_annotation=type(self.pragma_instance)
)
@property
def pragma(self):
@@ -187,7 +198,8 @@ class _Optionalizer(_schema.PropertyModifier):
K = _schema.Property.Kind
if prop.kind != K.SINGLE:
raise _schema.Error(
"optional should only be applied to simple property types")
"optional should only be applied to simple property types"
)
prop.kind = K.OPTIONAL
@@ -200,7 +212,8 @@ class _Listifier(_schema.PropertyModifier):
prop.kind = K.REPEATED_OPTIONAL
else:
raise _schema.Error(
"list should only be applied to simple or optional property types")
"list should only be applied to simple or optional property types"
)
class _Setifier(_schema.PropertyModifier):
@@ -212,7 +225,7 @@ class _Setifier(_schema.PropertyModifier):
class _TypeModifier:
""" Modifies types using get item notation """
"""Modifies types using get item notation"""
def __init__(self, modifier: _schema.PropertyModifier):
self.modifier = modifier
@@ -242,7 +255,11 @@ use_for_null = _ClassPragma("null")
qltest.add(_ClassPragma("skip"))
qltest.add(_ClassPragma("collapse_hierarchy"))
qltest.add(_ClassPragma("uncollapse_hierarchy"))
qltest.add(_Parametrized(_ClassPragma("test_with", inherited=True), factory=_schema.get_type_name))
qltest.add(
_Parametrized(
_ClassPragma("test_with", inherited=True), factory=_schema.get_type_name
)
)
ql.add(_Parametrized(_ClassPragma("default_doc_name"), factory=lambda doc: doc))
ql.add(_ClassPragma("hideable", inherited=True))
@@ -255,15 +272,33 @@ cpp.add(_Pragma("skip"))
rust.add(_PropertyPragma("detach"))
rust.add(_Pragma("skip_doc_test"))
rust.add(_Parametrized(_ClassPragma("doc_test_signature"), factory=lambda signature: signature))
rust.add(
_Parametrized(
_ClassPragma("doc_test_signature"), factory=lambda signature: signature
)
)
group = _Parametrized(_ClassPragma("group", inherited=True), factory=lambda group: group)
group = _Parametrized(
_ClassPragma("group", inherited=True), factory=lambda group: group
)
synth.add(_Parametrized(_ClassPragma("from_class"), factory=lambda ref: _schema.SynthInfo(
from_class=_schema.get_type_name(ref))), key="synth")
synth.add(_Parametrized(_ClassPragma("on_arguments"), factory=lambda **kwargs:
_schema.SynthInfo(on_arguments={k: _schema.get_type_name(t) for k, t in kwargs.items()})), key="synth")
synth.add(
_Parametrized(
_ClassPragma("from_class"),
factory=lambda ref: _schema.SynthInfo(from_class=_schema.get_type_name(ref)),
),
key="synth",
)
synth.add(
_Parametrized(
_ClassPragma("on_arguments"),
factory=lambda **kwargs: _schema.SynthInfo(
on_arguments={k: _schema.get_type_name(t) for k, t in kwargs.items()}
),
),
key="synth",
)
@_dataclass(frozen=True)
@@ -283,7 +318,12 @@ _ = _PropertyModifierList(())
drop = object()
def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, replace_bases: _Dict[type, type] | None = None, cfg: bool = False) -> _Callable[[type], _PropertyModifierList]:
def annotate(
annotated_cls: type,
add_bases: _Iterable[type] | None = None,
replace_bases: _Dict[type, type] | None = None,
cfg: bool = False,
) -> _Callable[[type], _PropertyModifierList]:
"""
Add or modify schema annotations after a class has been defined previously.
@@ -291,6 +331,7 @@ def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, repl
`replace_bases` can be used to replace bases on the annotated class.
"""
def decorator(cls: type) -> _PropertyModifierList:
if cls.__name__ != "_":
raise _schema.Error("Annotation classes must be named _")
@@ -299,7 +340,9 @@ def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, repl
for p, v in cls.__dict__.get("_pragmas", {}).items():
_ClassPragma(p, value=v)(annotated_cls)
if replace_bases:
annotated_cls.__bases__ = tuple(replace_bases.get(b, b) for b in annotated_cls.__bases__)
annotated_cls.__bases__ = tuple(
replace_bases.get(b, b) for b in annotated_cls.__bases__
)
if add_bases:
annotated_cls.__bases__ += tuple(add_bases)
annotated_cls.__cfg__ = cfg
@@ -312,9 +355,12 @@ def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, repl
elif p in annotated_cls.__annotations__:
annotated_cls.__annotations__[p] |= a
elif isinstance(a, (_PropertyModifierList, _PropertyModifierList)):
raise _schema.Error(f"annotated property {p} not present in annotated class "
f"{annotated_cls.__name__}")
raise _schema.Error(
f"annotated property {p} not present in annotated class "
f"{annotated_cls.__name__}"
)
else:
annotated_cls.__annotations__[p] = a
return _
return decorator

View File

@@ -12,9 +12,13 @@ class _Re:
"|"
r"^(?P<union>@\w+)\s*=\s*(?P<unionbody>@\w+(?:\s*\|\s*@\w+)*)\s*;?"
)
field = re.compile(r"(?m)[\w\s]*\s(?P<field>\w+)\s*:\s*(?P<type>@?\w+)(?P<ref>\s+ref)?")
field = re.compile(
r"(?m)[\w\s]*\s(?P<field>\w+)\s*:\s*(?P<type>@?\w+)(?P<ref>\s+ref)?"
)
key = re.compile(r"@\w+")
comment = re.compile(r"(?m)(?s)/\*.*?\*/|//(?!dir=)[^\n]*$") # lookahead avoid ignoring metadata like //dir=foo
comment = re.compile(
r"(?m)(?s)/\*.*?\*/|//(?!dir=)[^\n]*$"
) # lookahead avoid ignoring metadata like //dir=foo
def _get_column(match):

View File

@@ -1,4 +1,5 @@
""" schema loader """
"""schema loader"""
import sys
import inflection
@@ -33,37 +34,56 @@ def _get_class(cls: type) -> schema.Class:
raise schema.Error(f"Only class definitions allowed in schema, found {cls}")
# we must check that going to dbscheme names and back is preserved
# In particular this will not happen if uppercase acronyms are included in the name
to_underscore_and_back = inflection.camelize(inflection.underscore(cls.__name__), uppercase_first_letter=True)
to_underscore_and_back = inflection.camelize(
inflection.underscore(cls.__name__), uppercase_first_letter=True
)
if cls.__name__ != to_underscore_and_back:
raise schema.Error(f"Class name must be upper camel-case, without capitalized acronyms, found {cls.__name__} "
f"instead of {to_underscore_and_back}")
if len({g for g in (getattr(b, f"{schema.inheritable_pragma_prefix}group", None)
for b in cls.__bases__) if g}) > 1:
raise schema.Error(
f"Class name must be upper camel-case, without capitalized acronyms, found {cls.__name__} "
f"instead of {to_underscore_and_back}"
)
if (
len(
{
g
for g in (
getattr(b, f"{schema.inheritable_pragma_prefix}group", None)
for b in cls.__bases__
)
if g
}
)
> 1
):
raise schema.Error(f"Bases with mixed groups for {cls.__name__}")
pragmas = {
# dir and getattr inherit from bases
a[len(schema.inheritable_pragma_prefix):]: getattr(cls, a)
for a in dir(cls) if a.startswith(schema.inheritable_pragma_prefix)
a[len(schema.inheritable_pragma_prefix) :]: getattr(cls, a)
for a in dir(cls)
if a.startswith(schema.inheritable_pragma_prefix)
}
pragmas |= cls.__dict__.get("_pragmas", {})
derived = {d.__name__ for d in cls.__subclasses__()}
if "null" in pragmas and derived:
raise schema.Error(f"Null class cannot be derived")
return schema.Class(name=cls.__name__,
bases=[b.__name__ for b in cls.__bases__ if b is not object],
derived=derived,
pragmas=pragmas,
cfg=cls.__cfg__ if hasattr(cls, "__cfg__") else False,
# in the following we don't use `getattr` to avoid inheriting
properties=[
a | _PropertyNamer(n)
for n, a in cls.__dict__.get("__annotations__", {}).items()
],
doc=schema.split_doc(cls.__doc__),
)
return schema.Class(
name=cls.__name__,
bases=[b.__name__ for b in cls.__bases__ if b is not object],
derived=derived,
pragmas=pragmas,
cfg=cls.__cfg__ if hasattr(cls, "__cfg__") else False,
# in the following we don't use `getattr` to avoid inheriting
properties=[
a | _PropertyNamer(n)
for n, a in cls.__dict__.get("__annotations__", {}).items()
],
doc=schema.split_doc(cls.__doc__),
)
def _toposort_classes_by_group(classes: typing.Dict[str, schema.Class]) -> typing.Dict[str, schema.Class]:
def _toposort_classes_by_group(
classes: typing.Dict[str, schema.Class],
) -> typing.Dict[str, schema.Class]:
groups = {}
ret = {}
@@ -79,7 +99,7 @@ def _toposort_classes_by_group(classes: typing.Dict[str, schema.Class]) -> typin
def _fill_synth_information(classes: typing.Dict[str, schema.Class]):
""" Take a dictionary where the `synth` field is filled for all explicitly synthesized classes
"""Take a dictionary where the `synth` field is filled for all explicitly synthesized classes
and update it so that all non-final classes that have only synthesized final descendants
get `True` as` value for the `synth` field
"""
@@ -109,7 +129,7 @@ def _fill_synth_information(classes: typing.Dict[str, schema.Class]):
def _fill_hideable_information(classes: typing.Dict[str, schema.Class]):
""" Update the class map propagating the `hideable` attribute upwards in the hierarchy """
"""Update the class map propagating the `hideable` attribute upwards in the hierarchy"""
todo = [cls for cls in classes.values() if "ql_hideable" in cls.pragmas]
while todo:
cls = todo.pop()
@@ -123,10 +143,14 @@ def _fill_hideable_information(classes: typing.Dict[str, schema.Class]):
def _check_test_with(classes: typing.Dict[str, schema.Class]):
for cls in classes.values():
test_with = typing.cast(str, cls.pragmas.get("qltest_test_with"))
transitive_test_with = test_with and classes[test_with].pragmas.get("qltest_test_with")
transitive_test_with = test_with and classes[test_with].pragmas.get(
"qltest_test_with"
)
if test_with and transitive_test_with:
raise schema.Error(f"{cls.name} has test_with {test_with} which in turn "
f"has test_with {transitive_test_with}, use that directly")
raise schema.Error(
f"{cls.name} has test_with {test_with} which in turn "
f"has test_with {transitive_test_with}, use that directly"
)
def load(m: types.ModuleType) -> schema.Schema:
@@ -136,6 +160,7 @@ def load(m: types.ModuleType) -> schema.Schema:
known = {"int", "string", "boolean"}
known.update(n for n in m.__dict__ if not n.startswith("__"))
import misc.codegen.lib.schemadefs as defs
null = None
for name, data in m.__dict__.items():
if hasattr(defs, name):
@@ -152,21 +177,26 @@ def load(m: types.ModuleType) -> schema.Schema:
continue
cls = _get_class(data)
if classes and not cls.bases:
raise schema.Error(
f"Only one root class allowed, found second root {name}")
raise schema.Error(f"Only one root class allowed, found second root {name}")
cls.check_types(known)
classes[name] = cls
if "null" in cls.pragmas:
del cls.pragmas["null"]
if null is not None:
raise schema.Error(f"Null class {null} already defined, second null class {name} not allowed")
raise schema.Error(
f"Null class {null} already defined, second null class {name} not allowed"
)
null = name
_fill_synth_information(classes)
_fill_hideable_information(classes)
_check_test_with(classes)
return schema.Schema(includes=includes, classes=imported_classes | _toposort_classes_by_group(classes), null=null)
return schema.Schema(
includes=includes,
classes=imported_classes | _toposort_classes_by_group(classes),
null=null,
)
def load_file(path: pathlib.Path) -> schema.Schema:

View File

@@ -17,34 +17,49 @@ def test_field_name():
assert f.field_name == "foo"
@pytest.mark.parametrize("type,expected", [
("std::string", "trapQuoted(value)"),
("bool", '(value ? "true" : "false")'),
("something_else", "value"),
])
@pytest.mark.parametrize(
"type,expected",
[
("std::string", "trapQuoted(value)"),
("bool", '(value ? "true" : "false")'),
("something_else", "value"),
],
)
def test_field_get_streamer(type, expected):
f = cpp.Field("name", type)
assert f.get_streamer()("value") == expected
@pytest.mark.parametrize("is_optional,is_repeated,is_predicate,expected", [
(False, False, False, True),
(True, False, False, False),
(False, True, False, False),
(True, True, False, False),
(False, False, True, False),
])
@pytest.mark.parametrize(
"is_optional,is_repeated,is_predicate,expected",
[
(False, False, False, True),
(True, False, False, False),
(False, True, False, False),
(True, True, False, False),
(False, False, True, False),
],
)
def test_field_is_single(is_optional, is_repeated, is_predicate, expected):
f = cpp.Field("name", "type", is_optional=is_optional, is_repeated=is_repeated, is_predicate=is_predicate)
f = cpp.Field(
"name",
"type",
is_optional=is_optional,
is_repeated=is_repeated,
is_predicate=is_predicate,
)
assert f.is_single is expected
@pytest.mark.parametrize("is_optional,is_repeated,expected", [
(False, False, "bar"),
(True, False, "std::optional<bar>"),
(False, True, "std::vector<bar>"),
(True, True, "std::vector<std::optional<bar>>"),
])
@pytest.mark.parametrize(
"is_optional,is_repeated,expected",
[
(False, False, "bar"),
(True, False, "std::optional<bar>"),
(False, True, "std::vector<bar>"),
(True, True, "std::vector<std::optional<bar>>"),
],
)
def test_field_modal_types(is_optional, is_repeated, expected):
f = cpp.Field("name", "bar", is_optional=is_optional, is_repeated=is_repeated)
assert f.type == expected
@@ -69,11 +84,9 @@ def test_tag_has_first_base_marked():
assert t.bases == expected
@pytest.mark.parametrize("bases,expected", [
([], False),
(["a"], True),
(["a", "b"], True)
])
@pytest.mark.parametrize(
"bases,expected", [([], False), (["a"], True), (["a", "b"], True)]
)
def test_tag_has_bases(bases, expected):
t = cpp.Tag("name", bases, "id")
assert t.has_bases is expected
@@ -91,11 +104,9 @@ def test_class_has_first_base_marked():
assert c.bases == expected
@pytest.mark.parametrize("bases,expected", [
([], False),
(["a"], True),
(["a", "b"], True)
])
@pytest.mark.parametrize(
"bases,expected", [([], False), (["a"], True), (["a", "b"], True)]
)
def test_class_has_bases(bases, expected):
t = cpp.Class("name", [cpp.Class(b) for b in bases])
assert t.has_bases is expected
@@ -113,5 +124,5 @@ def test_class_single_fields():
assert c.single_fields == fields[::2]
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -18,7 +18,10 @@ def generate_grouped(opts, renderer, input):
assert isinstance(g, cpp.ClassList), f
assert g.include_parent is (f.parent != output_dir)
assert f.name == "TrapClasses", f
return {str(f.parent.relative_to(output_dir)): g.classes for f, g in generated.items()}
return {
str(f.parent.relative_to(output_dir)): g.classes
for f, g in generated.items()
}
return ret
@@ -38,129 +41,193 @@ def test_empty(generate):
def test_empty_class(generate):
assert generate([
schema.Class(name="MyClass"),
]) == [
cpp.Class(name="MyClass", final=True, trap_name="MyClasses")
]
assert generate(
[
schema.Class(name="MyClass"),
]
) == [cpp.Class(name="MyClass", final=True, trap_name="MyClasses")]
def test_two_class_hierarchy(generate):
base = cpp.Class(name="A")
assert generate([
schema.Class(name="A", derived={"B"}),
schema.Class(name="B", bases=["A"]),
]) == [
assert generate(
[
schema.Class(name="A", derived={"B"}),
schema.Class(name="B", bases=["A"]),
]
) == [
base,
cpp.Class(name="B", bases=[base], final=True, trap_name="Bs"),
]
@pytest.mark.parametrize("type,expected", [
("a", "a"),
("string", "std::string"),
("boolean", "bool"),
("MyClass", "TrapLabel<MyClassTag>"),
])
@pytest.mark.parametrize("property_cls,optional,repeated,unordered,trap_name", [
(schema.SingleProperty, False, False, False, None),
(schema.OptionalProperty, True, False, False, "MyClassProps"),
(schema.RepeatedProperty, False, True, False, "MyClassProps"),
(schema.RepeatedOptionalProperty, True, True, False, "MyClassProps"),
(schema.RepeatedUnorderedProperty, False, True, True, "MyClassProps"),
])
def test_class_with_field(generate, type, expected, property_cls, optional, repeated, unordered, trap_name):
assert generate([
schema.Class(name="MyClass", properties=[property_cls("prop", type)]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field("prop", expected, is_optional=optional,
is_repeated=repeated, is_unordered=unordered, trap_name=trap_name)],
trap_name="MyClasses",
final=True)
@pytest.mark.parametrize(
"type,expected",
[
("a", "a"),
("string", "std::string"),
("boolean", "bool"),
("MyClass", "TrapLabel<MyClassTag>"),
],
)
@pytest.mark.parametrize(
"property_cls,optional,repeated,unordered,trap_name",
[
(schema.SingleProperty, False, False, False, None),
(schema.OptionalProperty, True, False, False, "MyClassProps"),
(schema.RepeatedProperty, False, True, False, "MyClassProps"),
(schema.RepeatedOptionalProperty, True, True, False, "MyClassProps"),
(schema.RepeatedUnorderedProperty, False, True, True, "MyClassProps"),
],
)
def test_class_with_field(
generate, type, expected, property_cls, optional, repeated, unordered, trap_name
):
assert generate(
[
schema.Class(name="MyClass", properties=[property_cls("prop", type)]),
]
) == [
cpp.Class(
name="MyClass",
fields=[
cpp.Field(
"prop",
expected,
is_optional=optional,
is_repeated=repeated,
is_unordered=unordered,
trap_name=trap_name,
)
],
trap_name="MyClasses",
final=True,
)
]
def test_class_field_with_null(generate, input):
input.null = "Null"
a = cpp.Class(name="A")
assert generate([
schema.Class(name="A", derived={"B"}),
schema.Class(name="B", bases=["A"], properties=[
schema.SingleProperty("x", "A"),
schema.SingleProperty("y", "B"),
])
]) == [
assert generate(
[
schema.Class(name="A", derived={"B"}),
schema.Class(
name="B",
bases=["A"],
properties=[
schema.SingleProperty("x", "A"),
schema.SingleProperty("y", "B"),
],
),
]
) == [
a,
cpp.Class(name="B", bases=[a], final=True, trap_name="Bs",
fields=[
cpp.Field("x", "TrapLabel<ATag>"),
cpp.Field("y", "TrapLabel<BOrNoneTag>"),
]),
cpp.Class(
name="B",
bases=[a],
final=True,
trap_name="Bs",
fields=[
cpp.Field("x", "TrapLabel<ATag>"),
cpp.Field("y", "TrapLabel<BOrNoneTag>"),
],
),
]
def test_class_with_predicate(generate):
assert generate([
schema.Class(name="MyClass", properties=[
schema.PredicateProperty("prop")]),
]) == [
cpp.Class(name="MyClass",
fields=[
cpp.Field("prop", "bool", trap_name="MyClassProp", is_predicate=True)],
trap_name="MyClasses",
final=True)
assert generate(
[
schema.Class(name="MyClass", properties=[schema.PredicateProperty("prop")]),
]
) == [
cpp.Class(
name="MyClass",
fields=[
cpp.Field("prop", "bool", trap_name="MyClassProp", is_predicate=True)
],
trap_name="MyClasses",
final=True,
)
]
@pytest.mark.parametrize("name",
["start_line", "start_column", "end_line", "end_column", "index", "num_whatever", "width"])
@pytest.mark.parametrize(
"name",
[
"start_line",
"start_column",
"end_line",
"end_column",
"index",
"num_whatever",
"width",
],
)
def test_class_with_overridden_unsigned_field(generate, name):
assert generate([
schema.Class(name="MyClass", properties=[
schema.SingleProperty(name, "bar")]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field(name, "unsigned")],
trap_name="MyClasses",
final=True)
assert generate(
[
schema.Class(
name="MyClass", properties=[schema.SingleProperty(name, "bar")]
),
]
) == [
cpp.Class(
name="MyClass",
fields=[cpp.Field(name, "unsigned")],
trap_name="MyClasses",
final=True,
)
]
def test_class_with_overridden_underscore_field(generate):
assert generate([
schema.Class(name="MyClass", properties=[
schema.SingleProperty("something_", "bar")]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field("something", "bar")],
trap_name="MyClasses",
final=True)
assert generate(
[
schema.Class(
name="MyClass", properties=[schema.SingleProperty("something_", "bar")]
),
]
) == [
cpp.Class(
name="MyClass",
fields=[cpp.Field("something", "bar")],
trap_name="MyClasses",
final=True,
)
]
@pytest.mark.parametrize("name", cpp.cpp_keywords)
def test_class_with_keyword_field(generate, name):
assert generate([
schema.Class(name="MyClass", properties=[
schema.SingleProperty(name, "bar")]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field(name + "_", "bar")],
trap_name="MyClasses",
final=True)
assert generate(
[
schema.Class(
name="MyClass", properties=[schema.SingleProperty(name, "bar")]
),
]
) == [
cpp.Class(
name="MyClass",
fields=[cpp.Field(name + "_", "bar")],
trap_name="MyClasses",
final=True,
)
]
def test_classes_with_dirs(generate_grouped):
cbase = cpp.Class(name="CBase")
assert generate_grouped([
schema.Class(name="A"),
schema.Class(name="B", pragmas={"group": "foo"}),
schema.Class(name="CBase", derived={"C"}, pragmas={"group": "bar"}),
schema.Class(name="C", bases=["CBase"], pragmas={"group": "bar"}),
schema.Class(name="D", pragmas={"group": "foo/bar/baz"}),
]) == {
assert generate_grouped(
[
schema.Class(name="A"),
schema.Class(name="B", pragmas={"group": "foo"}),
schema.Class(name="CBase", derived={"C"}, pragmas={"group": "bar"}),
schema.Class(name="C", bases=["CBase"], pragmas={"group": "bar"}),
schema.Class(name="D", pragmas={"group": "foo/bar/baz"}),
]
) == {
".": [cpp.Class(name="A", trap_name="As", final=True)],
"foo": [cpp.Class(name="B", trap_name="Bs", final=True)],
"bar": [cbase, cpp.Class(name="C", bases=[cbase], trap_name="Cs", final=True)],
@@ -169,81 +236,126 @@ def test_classes_with_dirs(generate_grouped):
def test_cpp_skip_pragma(generate):
assert generate([
schema.Class(name="A", properties=[
schema.SingleProperty("x", "foo"),
schema.SingleProperty("y", "bar", pragmas=["x", "cpp_skip", "y"]),
])
]) == [
cpp.Class(name="A", final=True, trap_name="As", fields=[
cpp.Field("x", "foo"),
]),
assert generate(
[
schema.Class(
name="A",
properties=[
schema.SingleProperty("x", "foo"),
schema.SingleProperty("y", "bar", pragmas=["x", "cpp_skip", "y"]),
],
)
]
) == [
cpp.Class(
name="A",
final=True,
trap_name="As",
fields=[
cpp.Field("x", "foo"),
],
),
]
def test_synth_classes_ignored(generate):
assert generate([
schema.Class(
name="W",
pragmas={"synth": schema.SynthInfo()},
),
schema.Class(
name="X",
pragmas={"synth": schema.SynthInfo(from_class="A")},
),
schema.Class(
name="Y",
pragmas={"synth": schema.SynthInfo(on_arguments={"a": "A", "b": "int"})},
),
schema.Class(
name="Z",
),
]) == [
assert generate(
[
schema.Class(
name="W",
pragmas={"synth": schema.SynthInfo()},
),
schema.Class(
name="X",
pragmas={"synth": schema.SynthInfo(from_class="A")},
),
schema.Class(
name="Y",
pragmas={
"synth": schema.SynthInfo(on_arguments={"a": "A", "b": "int"})
},
),
schema.Class(
name="Z",
),
]
) == [
cpp.Class(name="Z", final=True, trap_name="Zs"),
]
def test_synth_properties_ignored(generate):
assert generate([
schema.Class(
assert generate(
[
schema.Class(
name="X",
properties=[
schema.SingleProperty("x", "a"),
schema.SingleProperty("y", "b", synth=True),
schema.SingleProperty("z", "c"),
schema.OptionalProperty("foo", "bar", synth=True),
schema.RepeatedProperty("baz", "bazz", synth=True),
schema.RepeatedOptionalProperty("bazzz", "bazzzz", synth=True),
schema.RepeatedUnorderedProperty("bazzzzz", "bazzzzzz", synth=True),
],
),
]
) == [
cpp.Class(
name="X",
properties=[
schema.SingleProperty("x", "a"),
schema.SingleProperty("y", "b", synth=True),
schema.SingleProperty("z", "c"),
schema.OptionalProperty("foo", "bar", synth=True),
schema.RepeatedProperty("baz", "bazz", synth=True),
schema.RepeatedOptionalProperty("bazzz", "bazzzz", synth=True),
schema.RepeatedUnorderedProperty("bazzzzz", "bazzzzzz", synth=True),
final=True,
trap_name="Xes",
fields=[
cpp.Field("x", "a"),
cpp.Field("z", "c"),
],
),
]) == [
cpp.Class(name="X", final=True, trap_name="Xes", fields=[
cpp.Field("x", "a"),
cpp.Field("z", "c"),
]),
]
def test_properties_with_custom_db_table_names(generate):
assert generate([
schema.Class("Obj", properties=[
schema.OptionalProperty("x", "a", pragmas={"ql_db_table_name": "foo"}),
schema.RepeatedProperty("y", "b", pragmas={"ql_db_table_name": "bar"}),
schema.RepeatedOptionalProperty("z", "c", pragmas={"ql_db_table_name": "baz"}),
schema.PredicateProperty("p", pragmas={"ql_db_table_name": "hello"}),
schema.RepeatedUnorderedProperty("q", "d", pragmas={"ql_db_table_name": "world"}),
]),
]) == [
cpp.Class(name="Obj", final=True, trap_name="Objs", fields=[
cpp.Field("x", "a", is_optional=True, trap_name="Foo"),
cpp.Field("y", "b", is_repeated=True, trap_name="Bar"),
cpp.Field("z", "c", is_repeated=True, is_optional=True, trap_name="Baz"),
cpp.Field("p", "bool", is_predicate=True, trap_name="Hello"),
cpp.Field("q", "d", is_repeated=True, is_unordered=True, trap_name="World"),
]),
assert generate(
[
schema.Class(
"Obj",
properties=[
schema.OptionalProperty(
"x", "a", pragmas={"ql_db_table_name": "foo"}
),
schema.RepeatedProperty(
"y", "b", pragmas={"ql_db_table_name": "bar"}
),
schema.RepeatedOptionalProperty(
"z", "c", pragmas={"ql_db_table_name": "baz"}
),
schema.PredicateProperty(
"p", pragmas={"ql_db_table_name": "hello"}
),
schema.RepeatedUnorderedProperty(
"q", "d", pragmas={"ql_db_table_name": "world"}
),
],
),
]
) == [
cpp.Class(
name="Obj",
final=True,
trap_name="Objs",
fields=[
cpp.Field("x", "a", is_optional=True, trap_name="Foo"),
cpp.Field("y", "b", is_repeated=True, trap_name="Bar"),
cpp.Field(
"z", "c", is_repeated=True, is_optional=True, trap_name="Baz"
),
cpp.Field("p", "bool", is_predicate=True, trap_name="Hello"),
cpp.Field(
"q", "d", is_repeated=True, is_unordered=True, trap_name="World"
),
],
),
]
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -14,12 +14,15 @@ def test_dbcolumn_keyword_name(keyword):
assert dbscheme.Column(keyword, "some_type").name == keyword + "_"
@pytest.mark.parametrize("type,binding,lhstype,rhstype", [
("builtin_type", False, "builtin_type", "builtin_type ref"),
("builtin_type", True, "builtin_type", "builtin_type ref"),
("@at_type", False, "int", "@at_type ref"),
("@at_type", True, "unique int", "@at_type"),
])
@pytest.mark.parametrize(
"type,binding,lhstype,rhstype",
[
("builtin_type", False, "builtin_type", "builtin_type ref"),
("builtin_type", True, "builtin_type", "builtin_type ref"),
("@at_type", False, "int", "@at_type ref"),
("@at_type", True, "unique int", "@at_type"),
],
)
def test_dbcolumn_types(type, binding, lhstype, rhstype):
col = dbscheme.Column("foo", type, binding)
assert col.lhstype == lhstype
@@ -34,7 +37,11 @@ def test_keyset_has_first_id_marked():
def test_table_has_first_column_marked():
columns = [dbscheme.Column("a", "x"), dbscheme.Column("b", "y", binding=True), dbscheme.Column("c", "z")]
columns = [
dbscheme.Column("a", "x"),
dbscheme.Column("b", "y", binding=True),
dbscheme.Column("c", "z"),
]
expected = deepcopy(columns)
table = dbscheme.Table("foo", columns)
expected[0].first = True
@@ -48,5 +55,5 @@ def test_union_has_first_case_marked():
assert [c.type for c in u.rhs] == rhs
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -8,10 +8,12 @@ from misc.codegen.test.utils import *
InputExpectedPair = collections.namedtuple("InputExpectedPair", ("input", "expected"))
@pytest.fixture(params=[
InputExpectedPair(None, None),
InputExpectedPair("foodir", pathlib.Path("foodir")),
])
@pytest.fixture(
params=[
InputExpectedPair(None, None),
InputExpectedPair("foodir", pathlib.Path("foodir")),
]
)
def dir_param(request):
return request.param
@@ -21,7 +23,7 @@ def generate(opts, input, renderer):
def func(classes, null=None):
input.classes = {cls.name: cls for cls in classes}
input.null = null
(out, data), = run_generation(dbschemegen.generate, opts, renderer).items()
((out, data),) = run_generation(dbschemegen.generate, opts, renderer).items()
assert out is opts.dbscheme
return data
@@ -48,23 +50,26 @@ def test_includes(input, opts, generate):
dbscheme.SchemeInclude(
src=pathlib.Path(i),
data=i + " data",
) for i in includes
)
for i in includes
],
declarations=[],
)
def test_empty_final_class(generate, dir_param):
assert generate([
schema.Class("Object", pragmas={"group": dir_param.input}),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class("Object", pragmas={"group": dir_param.input}),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
dbscheme.Column("id", "@object", binding=True),
],
dir=dir_param.expected,
)
@@ -73,218 +78,279 @@ def test_empty_final_class(generate, dir_param):
def test_final_class_with_single_scalar_field(generate, dir_param):
assert generate([
schema.Class("Object", pragmas={"group": dir_param.input}, properties=[
schema.SingleProperty("foo", "bar"),
]),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(
"Object",
pragmas={"group": dir_param.input},
properties=[
schema.SingleProperty("foo", "bar"),
],
),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
dbscheme.Column('foo', 'bar'),
], dir=dir_param.expected,
dbscheme.Column("id", "@object", binding=True),
dbscheme.Column("foo", "bar"),
],
dir=dir_param.expected,
)
],
)
def test_final_class_with_single_class_field(generate, dir_param):
assert generate([
schema.Class("Object", pragmas={"group": dir_param.input}, properties=[
schema.SingleProperty("foo", "Bar"),
]),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(
"Object",
pragmas={"group": dir_param.input},
properties=[
schema.SingleProperty("foo", "Bar"),
],
),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
dbscheme.Column('foo', '@bar'),
], dir=dir_param.expected,
dbscheme.Column("id", "@object", binding=True),
dbscheme.Column("foo", "@bar"),
],
dir=dir_param.expected,
)
],
)
def test_final_class_with_optional_field(generate, dir_param):
assert generate([
schema.Class("Object", pragmas={"group": dir_param.input}, properties=[
schema.OptionalProperty("foo", "bar"),
]),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(
"Object",
pragmas={"group": dir_param.input},
properties=[
schema.OptionalProperty("foo", "bar"),
],
),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
], dir=dir_param.expected,
dbscheme.Column("id", "@object", binding=True),
],
dir=dir_param.expected,
),
dbscheme.Table(
name="object_foos",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@object'),
dbscheme.Column('foo', 'bar'),
], dir=dir_param.expected,
dbscheme.Column("id", "@object"),
dbscheme.Column("foo", "bar"),
],
dir=dir_param.expected,
),
],
)
@pytest.mark.parametrize("property_cls", [schema.RepeatedProperty, schema.RepeatedOptionalProperty])
@pytest.mark.parametrize(
"property_cls", [schema.RepeatedProperty, schema.RepeatedOptionalProperty]
)
def test_final_class_with_repeated_field(generate, property_cls, dir_param):
assert generate([
schema.Class("Object", pragmas={"group": dir_param.input}, properties=[
property_cls("foo", "bar"),
]),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(
"Object",
pragmas={"group": dir_param.input},
properties=[
property_cls("foo", "bar"),
],
),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
], dir=dir_param.expected,
dbscheme.Column("id", "@object", binding=True),
],
dir=dir_param.expected,
),
dbscheme.Table(
name="object_foos",
keyset=dbscheme.KeySet(["id", "index"]),
columns=[
dbscheme.Column('id', '@object'),
dbscheme.Column('index', 'int'),
dbscheme.Column('foo', 'bar'),
], dir=dir_param.expected,
dbscheme.Column("id", "@object"),
dbscheme.Column("index", "int"),
dbscheme.Column("foo", "bar"),
],
dir=dir_param.expected,
),
],
)
def test_final_class_with_repeated_unordered_field(generate, dir_param):
assert generate([
schema.Class("Object", pragmas={"group": dir_param.input}, properties=[
schema.RepeatedUnorderedProperty("foo", "bar"),
]),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(
"Object",
pragmas={"group": dir_param.input},
properties=[
schema.RepeatedUnorderedProperty("foo", "bar"),
],
),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
], dir=dir_param.expected,
dbscheme.Column("id", "@object", binding=True),
],
dir=dir_param.expected,
),
dbscheme.Table(
name="object_foos",
columns=[
dbscheme.Column('id', '@object'),
dbscheme.Column('foo', 'bar'),
], dir=dir_param.expected,
dbscheme.Column("id", "@object"),
dbscheme.Column("foo", "bar"),
],
dir=dir_param.expected,
),
],
)
def test_final_class_with_predicate_field(generate, dir_param):
assert generate([
schema.Class("Object", pragmas={"group": dir_param.input}, properties=[
schema.PredicateProperty("foo"),
]),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(
"Object",
pragmas={"group": dir_param.input},
properties=[
schema.PredicateProperty("foo"),
],
),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
], dir=dir_param.expected,
dbscheme.Column("id", "@object", binding=True),
],
dir=dir_param.expected,
),
dbscheme.Table(
name="object_foo",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@object'),
], dir=dir_param.expected,
dbscheme.Column("id", "@object"),
],
dir=dir_param.expected,
),
],
)
def test_final_class_with_more_fields(generate, dir_param):
assert generate([
schema.Class("Object", pragmas={"group": dir_param.input}, properties=[
schema.SingleProperty("one", "x"),
schema.SingleProperty("two", "y"),
schema.OptionalProperty("three", "z"),
schema.RepeatedProperty("four", "u"),
schema.RepeatedOptionalProperty("five", "v"),
schema.PredicateProperty("six"),
]),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(
"Object",
pragmas={"group": dir_param.input},
properties=[
schema.SingleProperty("one", "x"),
schema.SingleProperty("two", "y"),
schema.OptionalProperty("three", "z"),
schema.RepeatedProperty("four", "u"),
schema.RepeatedOptionalProperty("five", "v"),
schema.PredicateProperty("six"),
],
),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
dbscheme.Column('one', 'x'),
dbscheme.Column('two', 'y'),
], dir=dir_param.expected,
dbscheme.Column("id", "@object", binding=True),
dbscheme.Column("one", "x"),
dbscheme.Column("two", "y"),
],
dir=dir_param.expected,
),
dbscheme.Table(
name="object_threes",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@object'),
dbscheme.Column('three', 'z'),
], dir=dir_param.expected,
dbscheme.Column("id", "@object"),
dbscheme.Column("three", "z"),
],
dir=dir_param.expected,
),
dbscheme.Table(
name="object_fours",
keyset=dbscheme.KeySet(["id", "index"]),
columns=[
dbscheme.Column('id', '@object'),
dbscheme.Column('index', 'int'),
dbscheme.Column('four', 'u'),
], dir=dir_param.expected,
dbscheme.Column("id", "@object"),
dbscheme.Column("index", "int"),
dbscheme.Column("four", "u"),
],
dir=dir_param.expected,
),
dbscheme.Table(
name="object_fives",
keyset=dbscheme.KeySet(["id", "index"]),
columns=[
dbscheme.Column('id', '@object'),
dbscheme.Column('index', 'int'),
dbscheme.Column('five', 'v'),
], dir=dir_param.expected,
dbscheme.Column("id", "@object"),
dbscheme.Column("index", "int"),
dbscheme.Column("five", "v"),
],
dir=dir_param.expected,
),
dbscheme.Table(
name="object_six",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@object'),
], dir=dir_param.expected,
dbscheme.Column("id", "@object"),
],
dir=dir_param.expected,
),
],
)
def test_empty_class_with_derived(generate):
assert generate([
schema.Class(name="Base", derived={"Left", "Right"}),
schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(name="Base", derived={"Left", "Right"}),
schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
@@ -305,17 +371,20 @@ def test_empty_class_with_derived(generate):
def test_class_with_derived_and_single_property(generate, dir_param):
assert generate([
schema.Class(
name="Base",
derived={"Left", "Right"},
pragmas={"group": dir_param.input},
properties=[
schema.SingleProperty("single", "Prop"),
]),
schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(
name="Base",
derived={"Left", "Right"},
pragmas={"group": dir_param.input},
properties=[
schema.SingleProperty("single", "Prop"),
],
),
schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
@@ -327,8 +396,8 @@ def test_class_with_derived_and_single_property(generate, dir_param):
name="bases",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@base'),
dbscheme.Column('single', '@prop'),
dbscheme.Column("id", "@base"),
dbscheme.Column("single", "@prop"),
],
dir=dir_param.expected,
),
@@ -345,17 +414,20 @@ def test_class_with_derived_and_single_property(generate, dir_param):
def test_class_with_derived_and_optional_property(generate, dir_param):
assert generate([
schema.Class(
name="Base",
derived={"Left", "Right"},
pragmas={"group": dir_param.input},
properties=[
schema.OptionalProperty("opt", "Prop"),
]),
schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(
name="Base",
derived={"Left", "Right"},
pragmas={"group": dir_param.input},
properties=[
schema.OptionalProperty("opt", "Prop"),
],
),
schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
@@ -367,8 +439,8 @@ def test_class_with_derived_and_optional_property(generate, dir_param):
name="base_opts",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@base'),
dbscheme.Column('opt', '@prop'),
dbscheme.Column("id", "@base"),
dbscheme.Column("opt", "@prop"),
],
dir=dir_param.expected,
),
@@ -385,17 +457,20 @@ def test_class_with_derived_and_optional_property(generate, dir_param):
def test_class_with_derived_and_repeated_property(generate, dir_param):
assert generate([
schema.Class(
name="Base",
pragmas={"group": dir_param.input},
derived={"Left", "Right"},
properties=[
schema.RepeatedProperty("rep", "Prop"),
]),
schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(
name="Base",
pragmas={"group": dir_param.input},
derived={"Left", "Right"},
properties=[
schema.RepeatedProperty("rep", "Prop"),
],
),
schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
@@ -407,9 +482,9 @@ def test_class_with_derived_and_repeated_property(generate, dir_param):
name="base_reps",
keyset=dbscheme.KeySet(["id", "index"]),
columns=[
dbscheme.Column('id', '@base'),
dbscheme.Column('index', 'int'),
dbscheme.Column('rep', '@prop'),
dbscheme.Column("id", "@base"),
dbscheme.Column("index", "int"),
dbscheme.Column("rep", "@prop"),
],
dir=dir_param.expected,
),
@@ -426,38 +501,41 @@ def test_class_with_derived_and_repeated_property(generate, dir_param):
def test_null_class(generate):
assert generate([
schema.Class(
name="Base",
derived={"W", "X", "Y", "Z", "Null"},
),
schema.Class(
name="W",
bases=["Base"],
properties=[
schema.SingleProperty("w", "W"),
schema.SingleProperty("x", "X"),
schema.OptionalProperty("y", "Y"),
schema.RepeatedProperty("z", "Z"),
]
),
schema.Class(
name="X",
bases=["Base"],
),
schema.Class(
name="Y",
bases=["Base"],
),
schema.Class(
name="Z",
bases=["Base"],
),
schema.Class(
name="Null",
bases=["Base"],
),
], null="Null") == dbscheme.Scheme(
assert generate(
[
schema.Class(
name="Base",
derived={"W", "X", "Y", "Z", "Null"},
),
schema.Class(
name="W",
bases=["Base"],
properties=[
schema.SingleProperty("w", "W"),
schema.SingleProperty("x", "X"),
schema.OptionalProperty("y", "Y"),
schema.RepeatedProperty("z", "Z"),
],
),
schema.Class(
name="X",
bases=["Base"],
),
schema.Class(
name="Y",
bases=["Base"],
),
schema.Class(
name="Z",
bases=["Base"],
),
schema.Class(
name="Null",
bases=["Base"],
),
],
null="Null",
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
@@ -468,50 +546,50 @@ def test_null_class(generate):
dbscheme.Table(
name="ws",
columns=[
dbscheme.Column('id', '@w', binding=True),
dbscheme.Column('w', '@w_or_none'),
dbscheme.Column('x', '@x_or_none'),
dbscheme.Column("id", "@w", binding=True),
dbscheme.Column("w", "@w_or_none"),
dbscheme.Column("x", "@x_or_none"),
],
),
dbscheme.Table(
name="w_ies",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@w'),
dbscheme.Column('y', '@y_or_none'),
dbscheme.Column("id", "@w"),
dbscheme.Column("y", "@y_or_none"),
],
),
dbscheme.Table(
name="w_zs",
keyset=dbscheme.KeySet(["id", "index"]),
columns=[
dbscheme.Column('id', '@w'),
dbscheme.Column('index', 'int'),
dbscheme.Column('z', '@z_or_none'),
dbscheme.Column("id", "@w"),
dbscheme.Column("index", "int"),
dbscheme.Column("z", "@z_or_none"),
],
),
dbscheme.Table(
name="xes",
columns=[
dbscheme.Column('id', '@x', binding=True),
dbscheme.Column("id", "@x", binding=True),
],
),
dbscheme.Table(
name="ys",
columns=[
dbscheme.Column('id', '@y', binding=True),
dbscheme.Column("id", "@y", binding=True),
],
),
dbscheme.Table(
name="zs",
columns=[
dbscheme.Column('id', '@z', binding=True),
dbscheme.Column("id", "@z", binding=True),
],
),
dbscheme.Table(
name="nulls",
columns=[
dbscheme.Column('id', '@null', binding=True),
dbscheme.Column("id", "@null", binding=True),
],
),
dbscheme.Union(
@@ -535,11 +613,15 @@ def test_null_class(generate):
def test_synth_classes_ignored(generate):
assert generate([
schema.Class(name="A", pragmas={"synth": schema.SynthInfo()}),
schema.Class(name="B", pragmas={"synth": schema.SynthInfo(from_class="A")}),
schema.Class(name="C", pragmas={"synth": schema.SynthInfo(on_arguments={"x": "A"})}),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(name="A", pragmas={"synth": schema.SynthInfo()}),
schema.Class(name="B", pragmas={"synth": schema.SynthInfo(from_class="A")}),
schema.Class(
name="C", pragmas={"synth": schema.SynthInfo(on_arguments={"x": "A"})}
),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[],
@@ -547,11 +629,13 @@ def test_synth_classes_ignored(generate):
def test_synth_derived_classes_ignored(generate):
assert generate([
schema.Class(name="A", derived={"B", "C"}),
schema.Class(name="B", bases=["A"], pragmas={"synth": schema.SynthInfo()}),
schema.Class(name="C", bases=["A"]),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(name="A", derived={"B", "C"}),
schema.Class(name="B", bases=["A"], pragmas={"synth": schema.SynthInfo()}),
schema.Class(name="C", bases=["A"]),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
@@ -561,23 +645,28 @@ def test_synth_derived_classes_ignored(generate):
columns=[
dbscheme.Column("id", "@c", binding=True),
],
)
),
],
)
def test_synth_properties_ignored(generate):
assert generate([
schema.Class(name="A", properties=[
schema.SingleProperty("x", "a"),
schema.SingleProperty("y", "b", synth=True),
schema.SingleProperty("z", "c"),
schema.OptionalProperty("foo", "bar", synth=True),
schema.RepeatedProperty("baz", "bazz", synth=True),
schema.RepeatedOptionalProperty("bazzz", "bazzzz", synth=True),
schema.RepeatedUnorderedProperty("bazzzzz", "bazzzzzz", synth=True),
]),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(
name="A",
properties=[
schema.SingleProperty("x", "a"),
schema.SingleProperty("y", "b", synth=True),
schema.SingleProperty("z", "c"),
schema.OptionalProperty("foo", "bar", synth=True),
schema.RepeatedProperty("baz", "bazz", synth=True),
schema.RepeatedOptionalProperty("bazzz", "bazzzz", synth=True),
schema.RepeatedUnorderedProperty("bazzzzz", "bazzzzzz", synth=True),
],
),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
@@ -595,24 +684,44 @@ def test_synth_properties_ignored(generate):
def test_table_conflict(generate):
with pytest.raises(dbschemegen.Error):
generate([
schema.Class("Foo", properties=[
schema.OptionalProperty("bar", "FooBar"),
]),
schema.Class("FooBar"),
])
generate(
[
schema.Class(
"Foo",
properties=[
schema.OptionalProperty("bar", "FooBar"),
],
),
schema.Class("FooBar"),
]
)
def test_table_name_overrides(generate):
assert generate([
schema.Class("Obj", properties=[
schema.OptionalProperty("x", "a", pragmas={"ql_db_table_name": "foo"}),
schema.RepeatedProperty("y", "b", pragmas={"ql_db_table_name": "bar"}),
schema.RepeatedOptionalProperty("z", "c", pragmas={"ql_db_table_name": "baz"}),
schema.PredicateProperty("p", pragmas={"ql_db_table_name": "hello"}),
schema.RepeatedUnorderedProperty("q", "d", pragmas={"ql_db_table_name": "world"}),
]),
]) == dbscheme.Scheme(
assert generate(
[
schema.Class(
"Obj",
properties=[
schema.OptionalProperty(
"x", "a", pragmas={"ql_db_table_name": "foo"}
),
schema.RepeatedProperty(
"y", "b", pragmas={"ql_db_table_name": "bar"}
),
schema.RepeatedOptionalProperty(
"z", "c", pragmas={"ql_db_table_name": "baz"}
),
schema.PredicateProperty(
"p", pragmas={"ql_db_table_name": "hello"}
),
schema.RepeatedUnorderedProperty(
"q", "d", pragmas={"ql_db_table_name": "world"}
),
],
),
]
) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
@@ -666,5 +775,5 @@ def test_table_name_overrides(generate):
)
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -22,26 +22,42 @@ def test_load_empty(load):
def test_load_one_empty_table(load):
assert load("""
assert (
load(
"""
test_foos();
""") == [
dbscheme.Table(name="test_foos", columns=[])
]
"""
)
== [dbscheme.Table(name="test_foos", columns=[])]
)
def test_load_table_with_keyset(load):
assert load("""
assert (
load(
"""
#keyset[x, y,z]
test_foos();
""") == [
dbscheme.Table(name="test_foos", columns=[], keyset=dbscheme.KeySet(["x", "y", "z"]))
]
"""
)
== [
dbscheme.Table(
name="test_foos", columns=[], keyset=dbscheme.KeySet(["x", "y", "z"])
)
]
)
expected_columns = [
("int foo: int ref", dbscheme.Column(schema_name="foo", type="int", binding=False)),
(" int bar : int ref", dbscheme.Column(schema_name="bar", type="int", binding=False)),
("str baz_: str ref", dbscheme.Column(schema_name="baz", type="str", binding=False)),
(
" int bar : int ref",
dbscheme.Column(schema_name="bar", type="int", binding=False),
),
(
"str baz_: str ref",
dbscheme.Column(schema_name="baz", type="str", binding=False),
),
("int x: @foo ref", dbscheme.Column(schema_name="x", type="@foo", binding=False)),
("int y: @foo", dbscheme.Column(schema_name="y", type="@foo", binding=True)),
("unique int z: @foo", dbscheme.Column(schema_name="z", type="@foo", binding=True)),
@@ -50,42 +66,58 @@ expected_columns = [
@pytest.mark.parametrize("column,expected", expected_columns)
def test_load_table_with_column(load, column, expected):
assert load(f"""
assert (
load(
f"""
foos(
{column}
);
""") == [
dbscheme.Table(name="foos", columns=[deepcopy(expected)])
]
"""
)
== [dbscheme.Table(name="foos", columns=[deepcopy(expected)])]
)
def test_load_table_with_multiple_columns(load):
columns = ",\n".join(c for c, _ in expected_columns)
expected = [deepcopy(e) for _, e in expected_columns]
assert load(f"""
assert (
load(
f"""
foos(
{columns}
);
""") == [
dbscheme.Table(name="foos", columns=expected)
]
"""
)
== [dbscheme.Table(name="foos", columns=expected)]
)
def test_load_table_with_multiple_columns_and_dir(load):
columns = ",\n".join(c for c, _ in expected_columns)
expected = [deepcopy(e) for _, e in expected_columns]
assert load(f"""
assert (
load(
f"""
foos( //dir=foo/bar/baz
{columns}
);
""") == [
dbscheme.Table(name="foos", columns=expected, dir=pathlib.Path("foo/bar/baz"))
]
"""
)
== [
dbscheme.Table(
name="foos", columns=expected, dir=pathlib.Path("foo/bar/baz")
)
]
)
def test_load_multiple_table_with_columns(load):
tables = [f"table{i}({col});" for i, (col, _) in enumerate(expected_columns)]
expected = [dbscheme.Table(name=f"table{i}", columns=[deepcopy(e)]) for i, (_, e) in enumerate(expected_columns)]
expected = [
dbscheme.Table(name=f"table{i}", columns=[deepcopy(e)])
for i, (_, e) in enumerate(expected_columns)
]
assert load("\n".join(tables)) == expected
@@ -96,28 +128,41 @@ def test_union(load):
def test_table_and_union(load):
assert load("""
assert (
load(
"""
foos();
@foo = @bar | @baz | @bla;""") == [
dbscheme.Table(name="foos", columns=[]),
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
@foo = @bar | @baz | @bla;"""
)
== [
dbscheme.Table(name="foos", columns=[]),
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
)
def test_comments_ignored(load):
assert load("""
assert (
load(
"""
// fake_table();
foos(/* x */unique /*y*/int/*
z
*/ id/* */: /* * */ @bar/*,
int ignored: int ref*/);
@foo = @bar | @baz | @bla; // | @xxx""") == [
dbscheme.Table(name="foos", columns=[dbscheme.Column(schema_name="id", type="@bar", binding=True)]),
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
@foo = @bar | @baz | @bla; // | @xxx"""
)
== [
dbscheme.Table(
name="foos",
columns=[dbscheme.Column(schema_name="id", type="@bar", binding=True)],
),
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
)
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -34,37 +34,55 @@ def test_property_unordered_getter(name, expected_getter):
assert prop.getter == expected_getter
@pytest.mark.parametrize("plural,expected", [
(None, False),
("", False),
("X", True),
])
@pytest.mark.parametrize(
"plural,expected",
[
(None, False),
("", False),
("X", True),
],
)
def test_property_is_repeated(plural, expected):
prop = ql.Property("foo", "Foo", "props", ["result"], plural=plural)
assert prop.is_repeated is expected
@pytest.mark.parametrize("plural,unordered,expected", [
(None, False, False),
("", False, False),
("X", False, True),
("X", True, False),
])
@pytest.mark.parametrize(
"plural,unordered,expected",
[
(None, False, False),
("", False, False),
("X", False, True),
("X", True, False),
],
)
def test_property_is_indexed(plural, unordered, expected):
prop = ql.Property("foo", "Foo", "props", ["result"], plural=plural, is_unordered=unordered)
prop = ql.Property(
"foo", "Foo", "props", ["result"], plural=plural, is_unordered=unordered
)
assert prop.is_indexed is expected
@pytest.mark.parametrize("is_optional,is_predicate,plural,expected", [
(False, False, None, True),
(False, False, "", True),
(False, False, "X", False),
(True, False, None, False),
(False, True, None, False),
])
@pytest.mark.parametrize(
"is_optional,is_predicate,plural,expected",
[
(False, False, None, True),
(False, False, "", True),
(False, False, "X", False),
(True, False, None, False),
(False, True, None, False),
],
)
def test_property_is_single(is_optional, is_predicate, plural, expected):
prop = ql.Property("foo", "Foo", "props", ["result"], plural=plural,
is_predicate=is_predicate, is_optional=is_optional)
prop = ql.Property(
"foo",
"Foo",
"props",
["result"],
plural=plural,
is_predicate=is_predicate,
is_optional=is_optional,
)
assert prop.is_single is expected
@@ -85,7 +103,12 @@ def test_property_predicate_getter():
def test_class_processes_bases():
bases = ["B", "Ab", "C", "Aa"]
expected = [ql.Base("B"), ql.Base("Ab", prev="B"), ql.Base("C", prev="Ab"), ql.Base("Aa", prev="C")]
expected = [
ql.Base("B"),
ql.Base("Ab", prev="B"),
ql.Base("C", prev="Ab"),
ql.Base("Aa", prev="C"),
]
cls = ql.Class("Foo", bases=bases)
assert cls.bases == expected
@@ -110,7 +133,9 @@ def test_non_root_class():
assert not cls.root
@pytest.mark.parametrize("prev_child,is_child", [(None, False), ("", True), ("x", True)])
@pytest.mark.parametrize(
"prev_child,is_child", [(None, False), ("", True), ("x", True)]
)
def test_is_child(prev_child, is_child):
p = ql.Property("Foo", "int", prev_child=prev_child)
assert p.is_child is is_child
@@ -122,22 +147,27 @@ def test_empty_class_no_children():
def test_class_no_children():
cls = ql.Class("Class", properties=[ql.Property("Foo", "int"), ql.Property("Bar", "string")])
cls = ql.Class(
"Class", properties=[ql.Property("Foo", "int"), ql.Property("Bar", "string")]
)
assert cls.has_children is False
def test_class_with_children():
cls = ql.Class("Class", properties=[ql.Property("Foo", "int"), ql.Property("Child", "x", prev_child=""),
ql.Property("Bar", "string")])
cls = ql.Class(
"Class",
properties=[
ql.Property("Foo", "int"),
ql.Property("Child", "x", prev_child=""),
ql.Property("Bar", "string"),
],
)
assert cls.has_children is True
@pytest.mark.parametrize("doc,expected",
[
(["foo", "bar"], True),
(["foo", "bar"], True),
([], False)
])
@pytest.mark.parametrize(
"doc,expected", [(["foo", "bar"], True), (["foo", "bar"], True), ([], False)]
)
def test_has_doc(doc, expected):
stub = ql.Stub("Class", base_import="foo", import_prefix="bar", doc=doc)
assert stub.has_qldoc is expected
@@ -150,5 +180,5 @@ def test_synth_accessor_has_first_constructor_param_marked():
assert [p.param for p in x.constructorparams] == params
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

File diff suppressed because it is too large Load Diff

View File

@@ -46,7 +46,10 @@ def write_registry(file, *files_and_hashes):
def assert_registry(file, *files_and_hashes):
assert_file(file, create_registry(files_and_hashes))
files = [file.name, ".gitattributes"] + [f for f, _, _ in files_and_hashes]
assert_file(file.parent / ".gitattributes", "\n".join(f"/{f} linguist-generated" for f in files) + "\n")
assert_file(
file.parent / ".gitattributes",
"\n".join(f"/{f} linguist-generated" for f in files) + "\n",
)
def hash(text):
@@ -56,11 +59,11 @@ def hash(text):
def test_constructor(pystache_renderer_cls, sut):
pystache_init, = pystache_renderer_cls.mock_calls
assert set(pystache_init.kwargs) == {'search_dirs', 'escape'}
assert pystache_init.kwargs['search_dirs'] == str(paths.templates_dir)
(pystache_init,) = pystache_renderer_cls.mock_calls
assert set(pystache_init.kwargs) == {"search_dirs", "escape"}
assert pystache_init.kwargs["search_dirs"] == str(paths.templates_dir)
an_object = object()
assert pystache_init.kwargs['escape'](an_object) is an_object
assert pystache_init.kwargs["escape"](an_object) is an_object
def test_render(pystache_renderer, sut):
@@ -218,7 +221,9 @@ def test_managed_render_with_skipping_of_stub_file(pystache_renderer, sut):
some_processed_output = "// generated some processed output"
registry = paths.root_dir / "a/registry.list"
write(stub, some_processed_output)
write_registry(registry, ("some/stub.txt", hash(some_output), hash(some_processed_output)))
write_registry(
registry, ("some/stub.txt", hash(some_output), hash(some_processed_output))
)
pystache_renderer.render_name.side_effect = (some_output,)
@@ -227,7 +232,9 @@ def test_managed_render_with_skipping_of_stub_file(pystache_renderer, sut):
assert renderer.written == set()
assert_file(stub, some_processed_output)
assert_registry(registry, ("some/stub.txt", hash(some_output), hash(some_processed_output)))
assert_registry(
registry, ("some/stub.txt", hash(some_output), hash(some_processed_output))
)
assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=generator),
]
@@ -238,13 +245,17 @@ def test_managed_render_with_modified_generated_file(pystache_renderer, sut):
some_processed_output = "// some processed output"
registry = paths.root_dir / "a/registry.list"
write(output, "// something else")
write_registry(registry, ("some/output.txt", "whatever", hash(some_processed_output)))
write_registry(
registry, ("some/output.txt", "whatever", hash(some_processed_output))
)
with pytest.raises(render.Error):
sut.manage(generated=(output,), stubs=(), registry=registry)
def test_managed_render_with_modified_stub_file_still_marked_as_generated(pystache_renderer, sut):
def test_managed_render_with_modified_stub_file_still_marked_as_generated(
pystache_renderer, sut
):
stub = paths.root_dir / "a/some/stub.txt"
some_processed_output = "// generated some processed output"
registry = paths.root_dir / "a/registry.list"
@@ -255,7 +266,9 @@ def test_managed_render_with_modified_stub_file_still_marked_as_generated(pystac
sut.manage(generated=(), stubs=(stub,), registry=registry)
def test_managed_render_with_modified_stub_file_not_marked_as_generated(pystache_renderer, sut):
def test_managed_render_with_modified_stub_file_not_marked_as_generated(
pystache_renderer, sut
):
stub = paths.root_dir / "a/some/stub.txt"
some_processed_output = "// generated some processed output"
registry = paths.root_dir / "a/registry.list"
@@ -272,7 +285,9 @@ class MyError(Exception):
pass
def test_managed_render_exception_drops_written_and_inexsistent_from_registry(pystache_renderer, sut):
def test_managed_render_exception_drops_written_and_inexsistent_from_registry(
pystache_renderer, sut
):
data = mock.Mock(spec=("template",))
text = "some text"
pystache_renderer.render_name.side_effect = (text,)
@@ -281,11 +296,9 @@ def test_managed_render_exception_drops_written_and_inexsistent_from_registry(py
write(output, text)
write(paths.root_dir / "a/a")
write(paths.root_dir / "a/c")
write_registry(registry,
"aaa",
("some/output.txt", "whatever", hash(text)),
"bbb",
"ccc")
write_registry(
registry, "aaa", ("some/output.txt", "whatever", hash(text)), "bbb", "ccc"
)
with pytest.raises(MyError):
with sut.manage(generated=(), stubs=(), registry=registry) as renderer:
@@ -299,17 +312,14 @@ def test_managed_render_drops_inexsistent_from_registry(pystache_renderer, sut):
registry = paths.root_dir / "a/registry.list"
write(paths.root_dir / "a/a")
write(paths.root_dir / "a/c")
write_registry(registry,
("a", hash(''), hash('')),
"bbb",
("c", hash(''), hash('')))
write_registry(
registry, ("a", hash(""), hash("")), "bbb", ("c", hash(""), hash(""))
)
with sut.manage(generated=(), stubs=(), registry=registry):
pass
assert_registry(registry,
("a", hash(''), hash('')),
("c", hash(''), hash('')))
assert_registry(registry, ("a", hash(""), hash("")), ("c", hash(""), hash("")))
def test_managed_render_exception_does_not_erase(pystache_renderer, sut):
@@ -321,7 +331,9 @@ def test_managed_render_exception_does_not_erase(pystache_renderer, sut):
write_registry(registry)
with pytest.raises(MyError):
with sut.manage(generated=(output,), stubs=(stub,), registry=registry) as renderer:
with sut.manage(
generated=(output,), stubs=(stub,), registry=registry
) as renderer:
raise MyError
assert output.is_file()
@@ -333,14 +345,15 @@ def test_render_with_extensions(pystache_renderer, sut):
data.template = "test_template"
data.extensions = ["foo", "bar", "baz"]
output = pathlib.Path("my", "test", "file")
expected_outputs = [pathlib.Path("my", "test", p) for p in ("file.foo", "file.bar", "file.baz")]
expected_outputs = [
pathlib.Path("my", "test", p) for p in ("file.foo", "file.bar", "file.baz")
]
rendered = [f"text{i}" for i in range(len(expected_outputs))]
pystache_renderer.render_name.side_effect = rendered
sut.render(data, output)
expected_templates = ["test_template_foo", "test_template_bar", "test_template_baz"]
assert pystache_renderer.mock_calls == [
mock.call.render_name(t, data, generator=generator)
for t in expected_templates
mock.call.render_name(t, data, generator=generator) for t in expected_templates
]
for expected_output, expected_contents in zip(expected_outputs, rendered):
assert_file(expected_output, expected_contents)
@@ -356,7 +369,9 @@ def test_managed_render_with_force_not_skipping_generated_file(pystache_renderer
pystache_renderer.render_name.side_effect = (some_output,)
with sut.manage(generated=(output,), stubs=(), registry=registry, force=True) as renderer:
with sut.manage(
generated=(output,), stubs=(), registry=registry, force=True
) as renderer:
renderer.render(data, output)
assert renderer.written == {output}
assert_file(output, some_output)
@@ -374,11 +389,15 @@ def test_managed_render_with_force_not_skipping_stub_file(pystache_renderer, sut
some_processed_output = "// generated some processed output"
registry = paths.root_dir / "a/registry.list"
write(stub, some_processed_output)
write_registry(registry, ("some/stub.txt", hash(some_output), hash(some_processed_output)))
write_registry(
registry, ("some/stub.txt", hash(some_output), hash(some_processed_output))
)
pystache_renderer.render_name.side_effect = (some_output,)
with sut.manage(generated=(), stubs=(stub,), registry=registry, force=True) as renderer:
with sut.manage(
generated=(), stubs=(stub,), registry=registry, force=True
) as renderer:
renderer.render(data, stub)
assert renderer.written == {stub}
assert_file(stub, some_output)
@@ -394,13 +413,17 @@ def test_managed_render_with_force_ignores_modified_generated_file(sut):
some_processed_output = "// some processed output"
registry = paths.root_dir / "a/registry.list"
write(output, "// something else")
write_registry(registry, ("some/output.txt", "whatever", hash(some_processed_output)))
write_registry(
registry, ("some/output.txt", "whatever", hash(some_processed_output))
)
with sut.manage(generated=(output,), stubs=(), registry=registry, force=True):
pass
def test_managed_render_with_force_ignores_modified_stub_file_still_marked_as_generated(sut):
def test_managed_render_with_force_ignores_modified_stub_file_still_marked_as_generated(
sut,
):
stub = paths.root_dir / "a/some/stub.txt"
some_processed_output = "// generated some processed output"
registry = paths.root_dir / "a/registry.list"
@@ -411,5 +434,5 @@ def test_managed_render_with_force_ignores_modified_stub_file_still_marked_as_ge
pass
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -26,9 +26,9 @@ def test_one_empty_class():
pass
assert data.classes == {
'MyClass': schema.Class('MyClass'),
"MyClass": schema.Class("MyClass"),
}
assert data.root_class is data.classes['MyClass']
assert data.root_class is data.classes["MyClass"]
def test_two_empty_classes():
@@ -41,10 +41,10 @@ def test_two_empty_classes():
pass
assert data.classes == {
'MyClass1': schema.Class('MyClass1', derived={'MyClass2'}),
'MyClass2': schema.Class('MyClass2', bases=['MyClass1']),
"MyClass1": schema.Class("MyClass1", derived={"MyClass2"}),
"MyClass2": schema.Class("MyClass2", bases=["MyClass1"]),
}
assert data.root_class is data.classes['MyClass1']
assert data.root_class is data.classes["MyClass1"]
def test_no_external_bases():
@@ -52,6 +52,7 @@ def test_no_external_bases():
pass
with pytest.raises(schema.Error):
@load
class data:
class MyClass(A):
@@ -60,6 +61,7 @@ def test_no_external_bases():
def test_no_multiple_roots():
with pytest.raises(schema.Error):
@load
class data:
class MyClass1:
@@ -85,10 +87,10 @@ def test_empty_classes_diamond():
pass
assert data.classes == {
'A': schema.Class('A', derived={'B', 'C'}),
'B': schema.Class('B', bases=['A'], derived={'D'}),
'C': schema.Class('C', bases=['A'], derived={'D'}),
'D': schema.Class('D', bases=['B', 'C']),
"A": schema.Class("A", derived={"B", "C"}),
"B": schema.Class("B", bases=["A"], derived={"D"}),
"C": schema.Class("C", bases=["A"], derived={"D"}),
"D": schema.Class("D", bases=["B", "C"]),
}
@@ -101,7 +103,7 @@ def test_group():
pass
assert data.classes == {
'A': schema.Class('A', pragmas={"group": "xxx"}),
"A": schema.Class("A", pragmas={"group": "xxx"}),
}
@@ -114,7 +116,7 @@ def test_group_is_inherited():
class B(A):
pass
@defs.group('xxx')
@defs.group("xxx")
class C(A):
pass
@@ -122,25 +124,26 @@ def test_group_is_inherited():
pass
assert data.classes == {
'A': schema.Class('A', derived={'B', 'C'}),
'B': schema.Class('B', bases=['A'], derived={'D'}),
'C': schema.Class('C', bases=['A'], derived={'D'}, pragmas={"group": "xxx"}),
'D': schema.Class('D', bases=['B', 'C'], pragmas={"group": "xxx"}),
"A": schema.Class("A", derived={"B", "C"}),
"B": schema.Class("B", bases=["A"], derived={"D"}),
"C": schema.Class("C", bases=["A"], derived={"D"}, pragmas={"group": "xxx"}),
"D": schema.Class("D", bases=["B", "C"], pragmas={"group": "xxx"}),
}
def test_no_mixed_groups_in_bases():
with pytest.raises(schema.Error):
@load
class data:
class A:
pass
@defs.group('x')
@defs.group("x")
class B(A):
pass
@defs.group('y')
@defs.group("y")
class C(A):
pass
@@ -153,6 +156,7 @@ def test_no_mixed_groups_in_bases():
def test_lowercase_rejected():
with pytest.raises(schema.Error):
@load
class data:
class aLowerCase:
@@ -171,14 +175,17 @@ def test_properties():
six: defs.set[defs.string]
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('one', 'string'),
schema.OptionalProperty('two', 'int'),
schema.RepeatedProperty('three', 'boolean'),
schema.RepeatedOptionalProperty('four', 'string'),
schema.PredicateProperty('five'),
schema.RepeatedUnorderedProperty('six', 'string'),
]),
"A": schema.Class(
"A",
properties=[
schema.SingleProperty("one", "string"),
schema.OptionalProperty("two", "int"),
schema.RepeatedProperty("three", "boolean"),
schema.RepeatedOptionalProperty("four", "string"),
schema.PredicateProperty("five"),
schema.RepeatedUnorderedProperty("six", "string"),
],
),
}
@@ -199,14 +206,18 @@ def test_class_properties():
five: defs.set[A]
assert data.classes == {
'A': schema.Class('A', derived={'B'}),
'B': schema.Class('B', bases=['A'], properties=[
schema.SingleProperty('one', 'A'),
schema.OptionalProperty('two', 'A'),
schema.RepeatedProperty('three', 'A'),
schema.RepeatedOptionalProperty('four', 'A'),
schema.RepeatedUnorderedProperty('five', 'A'),
]),
"A": schema.Class("A", derived={"B"}),
"B": schema.Class(
"B",
bases=["A"],
properties=[
schema.SingleProperty("one", "A"),
schema.OptionalProperty("two", "A"),
schema.RepeatedProperty("three", "A"),
schema.RepeatedOptionalProperty("four", "A"),
schema.RepeatedUnorderedProperty("five", "A"),
],
),
}
@@ -221,20 +232,31 @@ def test_string_reference_class_properties():
five: defs.set["A"]
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('one', 'A'),
schema.OptionalProperty('two', 'A'),
schema.RepeatedProperty('three', 'A'),
schema.RepeatedOptionalProperty('four', 'A'),
schema.RepeatedUnorderedProperty('five', 'A'),
]),
"A": schema.Class(
"A",
properties=[
schema.SingleProperty("one", "A"),
schema.OptionalProperty("two", "A"),
schema.RepeatedProperty("three", "A"),
schema.RepeatedOptionalProperty("four", "A"),
schema.RepeatedUnorderedProperty("five", "A"),
],
),
}
@pytest.mark.parametrize("spec", [lambda t: t, lambda t: defs.optional[t], lambda t: defs.list[t],
lambda t: defs.list[defs.optional[t]]])
@pytest.mark.parametrize(
"spec",
[
lambda t: t,
lambda t: defs.optional[t],
lambda t: defs.list[t],
lambda t: defs.list[defs.optional[t]],
],
)
def test_string_reference_dangling(spec):
with pytest.raises(schema.Error):
@load
class data:
class A:
@@ -251,18 +273,24 @@ def test_children():
four: defs.list[defs.optional["A"]] | defs.child
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('one', 'A', is_child=True),
schema.OptionalProperty('two', 'A', is_child=True),
schema.RepeatedProperty('three', 'A', is_child=True),
schema.RepeatedOptionalProperty('four', 'A', is_child=True),
]),
"A": schema.Class(
"A",
properties=[
schema.SingleProperty("one", "A", is_child=True),
schema.OptionalProperty("two", "A", is_child=True),
schema.RepeatedProperty("three", "A", is_child=True),
schema.RepeatedOptionalProperty("four", "A", is_child=True),
],
),
}
@pytest.mark.parametrize("spec", [defs.string, defs.int, defs.boolean, defs.predicate, defs.set["A"]])
@pytest.mark.parametrize(
"spec", [defs.string, defs.int, defs.boolean, defs.predicate, defs.set["A"]]
)
def test_builtin_predicate_and_set_children_not_allowed(spec):
with pytest.raises(schema.Error):
@load
class data:
class A:
@@ -291,9 +319,12 @@ def test_property_with_pragma(pragma, expected):
x: defs.string | pragma
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('x', 'string', pragmas=[expected]),
]),
"A": schema.Class(
"A",
properties=[
schema.SingleProperty("x", "string", pragmas=[expected]),
],
),
}
@@ -308,9 +339,16 @@ def test_property_with_pragmas():
x: spec
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('x', 'string', pragmas=[expected for _, expected in _property_pragmas]),
]),
"A": schema.Class(
"A",
properties=[
schema.SingleProperty(
"x",
"string",
pragmas=[expected for _, expected in _property_pragmas],
),
],
),
}
@@ -323,7 +361,7 @@ def test_class_with_pragma(pragma, expected):
pass
assert data.classes == {
'A': schema.Class('A', pragmas=[expected]),
"A": schema.Class("A", pragmas=[expected]),
}
@@ -340,7 +378,7 @@ def test_class_with_pragmas():
apply_pragmas(A)
assert data.classes == {
'A': schema.Class('A', pragmas=[e for _, e in _pragmas]),
"A": schema.Class("A", pragmas=[e for _, e in _pragmas]),
}
@@ -355,8 +393,10 @@ def test_synth_from_class():
pass
assert data.classes == {
'A': schema.Class('A', derived={'B'}, pragmas={"synth": True}),
'B': schema.Class('B', bases=['A'], pragmas={"synth": schema.SynthInfo(from_class="A")}),
"A": schema.Class("A", derived={"B"}, pragmas={"synth": True}),
"B": schema.Class(
"B", bases=["A"], pragmas={"synth": schema.SynthInfo(from_class="A")}
),
}
@@ -371,13 +411,16 @@ def test_synth_from_class_ref():
pass
assert data.classes == {
'A': schema.Class('A', derived={'B'}, pragmas={"synth": schema.SynthInfo(from_class="B")}),
'B': schema.Class('B', bases=['A']),
"A": schema.Class(
"A", derived={"B"}, pragmas={"synth": schema.SynthInfo(from_class="B")}
),
"B": schema.Class("B", bases=["A"]),
}
def test_synth_from_class_dangling():
with pytest.raises(schema.Error):
@load
class data:
@defs.synth.from_class("X")
@@ -396,8 +439,12 @@ def test_synth_class_on():
pass
assert data.classes == {
'A': schema.Class('A', derived={'B'}, pragmas={"synth": True}),
'B': schema.Class('B', bases=['A'], pragmas={"synth": schema.SynthInfo(on_arguments={'a': 'A', 'i': 'int'})}),
"A": schema.Class("A", derived={"B"}, pragmas={"synth": True}),
"B": schema.Class(
"B",
bases=["A"],
pragmas={"synth": schema.SynthInfo(on_arguments={"a": "A", "i": "int"})},
),
}
@@ -415,13 +462,18 @@ def test_synth_class_on_ref():
pass
assert data.classes == {
'A': schema.Class('A', derived={'B'}, pragmas={"synth": schema.SynthInfo(on_arguments={'b': 'B', 'i': 'int'})}),
'B': schema.Class('B', bases=['A']),
"A": schema.Class(
"A",
derived={"B"},
pragmas={"synth": schema.SynthInfo(on_arguments={"b": "B", "i": "int"})},
),
"B": schema.Class("B", bases=["A"]),
}
def test_synth_class_on_dangling():
with pytest.raises(schema.Error):
@load
class data:
@defs.synth.on_arguments(s=defs.string, a="A", i=defs.int)
@@ -453,12 +505,25 @@ def test_synth_class_hierarchy():
pass
assert data.classes == {
'Root': schema.Class('Root', derived={'Base', 'C'}),
'Base': schema.Class('Base', bases=['Root'], derived={'Intermediate', 'B'}, pragmas={"synth": True}),
'Intermediate': schema.Class('Intermediate', bases=['Base'], derived={'A'}, pragmas={"synth": True}),
'A': schema.Class('A', bases=['Intermediate'], pragmas={"synth": schema.SynthInfo(on_arguments={'a': 'Base', 'i': 'int'})}),
'B': schema.Class('B', bases=['Base'], pragmas={"synth": schema.SynthInfo(from_class='Base')}),
'C': schema.Class('C', bases=['Root']),
"Root": schema.Class("Root", derived={"Base", "C"}),
"Base": schema.Class(
"Base",
bases=["Root"],
derived={"Intermediate", "B"},
pragmas={"synth": True},
),
"Intermediate": schema.Class(
"Intermediate", bases=["Base"], derived={"A"}, pragmas={"synth": True}
),
"A": schema.Class(
"A",
bases=["Intermediate"],
pragmas={"synth": schema.SynthInfo(on_arguments={"a": "Base", "i": "int"})},
),
"B": schema.Class(
"B", bases=["Base"], pragmas={"synth": schema.SynthInfo(from_class="Base")}
),
"C": schema.Class("C", bases=["Root"]),
}
@@ -479,9 +544,7 @@ def test_class_docstring():
class A:
"""Very important class."""
assert data.classes == {
'A': schema.Class('A', doc=["Very important class."])
}
assert data.classes == {"A": schema.Class("A", doc=["Very important class."])}
def test_property_docstring():
@@ -491,7 +554,14 @@ def test_property_docstring():
x: int | defs.desc("very important property.")
assert data.classes == {
'A': schema.Class('A', properties=[schema.SingleProperty('x', 'int', description=["very important property."])])
"A": schema.Class(
"A",
properties=[
schema.SingleProperty(
"x", "int", description=["very important property."]
)
],
)
}
@@ -502,21 +572,27 @@ def test_class_docstring_newline():
"""Very important
class."""
assert data.classes == {
'A': schema.Class('A', doc=["Very important", "class."])
}
assert data.classes == {"A": schema.Class("A", doc=["Very important", "class."])}
def test_property_docstring_newline():
@load
class data:
class A:
x: int | defs.desc("""very important
property.""")
x: int | defs.desc(
"""very important
property."""
)
assert data.classes == {
'A': schema.Class('A',
properties=[schema.SingleProperty('x', 'int', description=["very important", "property."])])
"A": schema.Class(
"A",
properties=[
schema.SingleProperty(
"x", "int", description=["very important", "property."]
)
],
)
}
@@ -530,23 +606,30 @@ def test_class_docstring_stripped():
"""
assert data.classes == {
'A': schema.Class('A', doc=["Very important class."])
}
assert data.classes == {"A": schema.Class("A", doc=["Very important class."])}
def test_property_docstring_stripped():
@load
class data:
class A:
x: int | defs.desc("""
x: int | defs.desc(
"""
very important property.
""")
"""
)
assert data.classes == {
'A': schema.Class('A', properties=[schema.SingleProperty('x', 'int', description=["very important property."])])
"A": schema.Class(
"A",
properties=[
schema.SingleProperty(
"x", "int", description=["very important property."]
)
],
)
}
@@ -559,7 +642,9 @@ def test_class_docstring_split():
As said, very important."""
assert data.classes == {
'A': schema.Class('A', doc=["Very important class.", "", "As said, very important."])
"A": schema.Class(
"A", doc=["Very important class.", "", "As said, very important."]
)
}
@@ -567,13 +652,27 @@ def test_property_docstring_split():
@load
class data:
class A:
x: int | defs.desc("""very important property.
x: int | defs.desc(
"""very important property.
Very very important.""")
Very very important."""
)
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('x', 'int', description=["very important property.", "", "Very very important."])])
"A": schema.Class(
"A",
properties=[
schema.SingleProperty(
"x",
"int",
description=[
"very important property.",
"",
"Very very important.",
],
)
],
)
}
@@ -587,7 +686,9 @@ def test_class_docstring_indent():
"""
assert data.classes == {
'A': schema.Class('A', doc=["Very important class.", " As said, very important."])
"A": schema.Class(
"A", doc=["Very important class.", " As said, very important."]
)
}
@@ -595,14 +696,24 @@ def test_property_docstring_indent():
@load
class data:
class A:
x: int | defs.desc("""
x: int | defs.desc(
"""
very important property.
Very very important.
""")
"""
)
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('x', 'int', description=["very important property.", " Very very important."])])
"A": schema.Class(
"A",
properties=[
schema.SingleProperty(
"x",
"int",
description=["very important property.", " Very very important."],
)
],
)
}
@@ -613,13 +724,13 @@ def test_property_doc_override():
x: int | defs.doc("y")
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('x', 'int', doc="y")]),
"A": schema.Class("A", properties=[schema.SingleProperty("x", "int", doc="y")]),
}
def test_property_doc_override_no_newlines():
with pytest.raises(schema.Error):
@load
class data:
class A:
@@ -628,6 +739,7 @@ def test_property_doc_override_no_newlines():
def test_property_doc_override_no_trailing_dot():
with pytest.raises(schema.Error):
@load
class data:
class A:
@@ -642,7 +754,7 @@ def test_class_default_doc_name():
pass
assert data.classes == {
'A': schema.Class('A', pragmas={"ql_default_doc_name": "b"}),
"A": schema.Class("A", pragmas={"ql_default_doc_name": "b"}),
}
@@ -653,7 +765,12 @@ def test_db_table_name():
x: optional[int] | defs.ql.db_table_name("foo")
assert data.classes == {
'A': schema.Class('A', properties=[schema.OptionalProperty("x", "int", pragmas={"ql_db_table_name": "foo"})]),
"A": schema.Class(
"A",
properties=[
schema.OptionalProperty("x", "int", pragmas={"ql_db_table_name": "foo"})
],
),
}
@@ -668,15 +785,16 @@ def test_null_class():
pass
assert data.classes == {
'Root': schema.Class('Root', derived={'Null'}),
'Null': schema.Class('Null', bases=['Root']),
"Root": schema.Class("Root", derived={"Null"}),
"Null": schema.Class("Null", bases=["Root"]),
}
assert data.null == 'Null'
assert data.null == "Null"
assert data.null_class is data.classes[data.null]
def test_null_class_cannot_be_derived():
with pytest.raises(schema.Error):
@load
class data:
class Root:
@@ -692,6 +810,7 @@ def test_null_class_cannot_be_derived():
def test_null_class_cannot_be_defined_multiple_times():
with pytest.raises(schema.Error):
@load
class data:
class Root:
@@ -708,6 +827,7 @@ def test_null_class_cannot_be_defined_multiple_times():
def test_uppercase_acronyms_are_rejected():
with pytest.raises(schema.Error):
@load
class data:
class Root:
@@ -737,10 +857,18 @@ def test_hideable():
pass
assert data.classes == {
"Root": schema.Class("Root", derived={"A", "IndirectlyHideable", "NonHideable"}, pragmas=["ql_hideable"]),
"Root": schema.Class(
"Root",
derived={"A", "IndirectlyHideable", "NonHideable"},
pragmas=["ql_hideable"],
),
"A": schema.Class("A", bases=["Root"], derived={"B"}, pragmas=["ql_hideable"]),
"IndirectlyHideable": schema.Class("IndirectlyHideable", bases=["Root"], derived={"B"}, pragmas=["ql_hideable"]),
"B": schema.Class("B", bases=["A", "IndirectlyHideable"], pragmas=["ql_hideable"]),
"IndirectlyHideable": schema.Class(
"IndirectlyHideable", bases=["Root"], derived={"B"}, pragmas=["ql_hideable"]
),
"B": schema.Class(
"B", bases=["A", "IndirectlyHideable"], pragmas=["ql_hideable"]
),
"NonHideable": schema.Class("NonHideable", bases=["Root"]),
}
@@ -771,7 +899,9 @@ def test_test_with():
assert data.classes == {
"Root": schema.Class("Root", derived=set("ABCD")),
"A": schema.Class("A", bases=["Root"]),
"B": schema.Class("B", bases=["Root"], pragmas={"qltest_test_with": "A"}, derived={'E'}),
"B": schema.Class(
"B", bases=["Root"], pragmas={"qltest_test_with": "A"}, derived={"E"}
),
"C": schema.Class("C", bases=["Root"], pragmas={"qltest_test_with": "D"}),
"D": schema.Class("D", bases=["Root"]),
"E": schema.Class("E", bases=["B"], pragmas={"qltest_test_with": "A"}),
@@ -782,10 +912,10 @@ def test_annotate_docstring():
@load
class data:
class Root:
""" old docstring """
"""old docstring"""
class A(Root):
""" A docstring """
"""A docstring"""
@defs.annotate(Root)
class _:
@@ -819,7 +949,15 @@ def test_annotate_decorations():
pass
assert data.classes == {
"Root": schema.Class("Root", pragmas=["qltest_skip", "cpp_skip", "ql_hideable", "qltest_collapse_hierarchy"]),
"Root": schema.Class(
"Root",
pragmas=[
"qltest_skip",
"cpp_skip",
"ql_hideable",
"qltest_collapse_hierarchy",
],
),
}
@@ -837,11 +975,16 @@ def test_annotate_fields():
z: defs.string
assert data.classes == {
"Root": schema.Class("Root", properties=[
schema.SingleProperty("x", "int", doc="foo"),
schema.OptionalProperty("y", "Root", pragmas=["ql_internal"], is_child=True),
schema.SingleProperty("z", "string"),
]),
"Root": schema.Class(
"Root",
properties=[
schema.SingleProperty("x", "int", doc="foo"),
schema.OptionalProperty(
"y", "Root", pragmas=["ql_internal"], is_child=True
),
schema.SingleProperty("z", "string"),
],
),
}
@@ -860,16 +1003,20 @@ def test_annotate_fields_negations():
z: defs._ | ~defs.synth | ~defs.doc
assert data.classes == {
"Root": schema.Class("Root", properties=[
schema.SingleProperty("x", "int"),
schema.OptionalProperty("y", "Root"),
schema.SingleProperty("z", "string"),
]),
"Root": schema.Class(
"Root",
properties=[
schema.SingleProperty("x", "int"),
schema.OptionalProperty("y", "Root"),
schema.SingleProperty("z", "string"),
],
),
}
def test_annotate_non_existing_field():
with pytest.raises(schema.Error):
@load
class data:
class Root:
@@ -882,6 +1029,7 @@ def test_annotate_non_existing_field():
def test_annotate_not_underscore():
with pytest.raises(schema.Error):
@load
class data:
class Root:
@@ -916,6 +1064,7 @@ def test_annotate_replace_bases():
@defs.annotate(Derived, replace_bases={B: C})
class _:
pass
assert data.classes == {
"Root": schema.Class("Root", derived={"A", "B"}),
"A": schema.Class("A", bases=["Root"], derived={"Derived"}),
@@ -946,6 +1095,7 @@ def test_annotate_add_bases():
@defs.annotate(Derived, add_bases=(B, C))
class _:
pass
assert data.classes == {
"Root": schema.Class("Root", derived={"A", "B", "C"}),
"A": schema.Class("A", bases=["Root"], derived={"Derived"}),
@@ -968,15 +1118,19 @@ def test_annotate_drop_field():
y: defs.drop
assert data.classes == {
"Root": schema.Class("Root", properties=[
schema.SingleProperty("x", "int"),
schema.SingleProperty("z", "boolean"),
]),
"Root": schema.Class(
"Root",
properties=[
schema.SingleProperty("x", "int"),
schema.SingleProperty("z", "boolean"),
],
),
}
def test_test_with_unknown_string():
with pytest.raises(schema.Error):
@load
class data:
class Root:
@@ -989,6 +1143,7 @@ def test_test_with_unknown_string():
def test_test_with_unknown_class():
with pytest.raises(schema.Error):
class B:
pass
@@ -1004,6 +1159,7 @@ def test_test_with_unknown_class():
def test_test_with_double():
with pytest.raises(schema.Error):
class B:
pass
@@ -1024,5 +1180,5 @@ def test_test_with_double():
pass
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -17,10 +17,16 @@ def generate_grouped(opts, renderer, dbscheme_input):
dirs = {f.parent for f in generated}
assert all(isinstance(f, pathlib.Path) for f in generated)
assert all(f.name in ("TrapEntries", "TrapTags") for f in generated)
assert set(f for f in generated if f.name == "TrapTags") == {output_dir / "TrapTags"}
return ({
str(d.relative_to(output_dir)): generated[d / "TrapEntries"] for d in dirs
}, generated[output_dir / "TrapTags"])
assert set(f for f in generated if f.name == "TrapTags") == {
output_dir / "TrapTags"
}
return (
{
str(d.relative_to(output_dir)): generated[d / "TrapEntries"]
for d in dirs
},
generated[output_dir / "TrapTags"],
)
return ret
@@ -65,87 +71,130 @@ def test_empty_tags(generate_tags):
def test_one_empty_table_rejected(generate_traps):
with pytest.raises(AssertionError):
generate_traps([
dbscheme.Table(name="foos", columns=[]),
])
generate_traps(
[
dbscheme.Table(name="foos", columns=[]),
]
)
def test_one_table(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
]) == [
assert generate_traps(
[
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
]
) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field("bla", "int")]),
]
def test_one_table(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
]) == [
assert generate_traps(
[
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
]
) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field("bla", "int")]),
]
def test_one_table_with_id(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[
dbscheme.Column("bla", "int", binding=True)]),
]) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field(
"bla", "int")], id=cpp.Field("bla", "int")),
assert generate_traps(
[
dbscheme.Table(
name="foos", columns=[dbscheme.Column("bla", "int", binding=True)]
),
]
) == [
cpp.Trap(
"foos",
name="Foos",
fields=[cpp.Field("bla", "int")],
id=cpp.Field("bla", "int"),
),
]
def test_one_table_with_two_binding_first_is_id(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[
dbscheme.Column("x", "a", binding=True),
dbscheme.Column("y", "b", binding=True),
]),
]) == [
cpp.Trap("foos", name="Foos", fields=[
cpp.Field("x", "a"),
cpp.Field("y", "b"),
], id=cpp.Field("x", "a")),
assert generate_traps(
[
dbscheme.Table(
name="foos",
columns=[
dbscheme.Column("x", "a", binding=True),
dbscheme.Column("y", "b", binding=True),
],
),
]
) == [
cpp.Trap(
"foos",
name="Foos",
fields=[
cpp.Field("x", "a"),
cpp.Field("y", "b"),
],
id=cpp.Field("x", "a"),
),
]
@pytest.mark.parametrize("column,field", [
(dbscheme.Column("x", "string"), cpp.Field("x", "std::string")),
(dbscheme.Column("y", "boolean"), cpp.Field("y", "bool")),
(dbscheme.Column("z", "@db_type"), cpp.Field("z", "TrapLabel<DbTypeTag>")),
])
@pytest.mark.parametrize(
"column,field",
[
(dbscheme.Column("x", "string"), cpp.Field("x", "std::string")),
(dbscheme.Column("y", "boolean"), cpp.Field("y", "bool")),
(dbscheme.Column("z", "@db_type"), cpp.Field("z", "TrapLabel<DbTypeTag>")),
],
)
def test_one_table_special_types(generate_traps, column, field):
assert generate_traps([
dbscheme.Table(name="foos", columns=[column]),
]) == [
assert generate_traps(
[
dbscheme.Table(name="foos", columns=[column]),
]
) == [
cpp.Trap("foos", name="Foos", fields=[field]),
]
@pytest.mark.parametrize("name", ["start_line", "start_column", "end_line", "end_column", "index", "num_whatever"])
@pytest.mark.parametrize(
"name",
["start_line", "start_column", "end_line", "end_column", "index", "num_whatever"],
)
def test_one_table_overridden_unsigned_field(generate_traps, name):
assert generate_traps([
dbscheme.Table(name="foos", columns=[dbscheme.Column(name, "bar")]),
]) == [
assert generate_traps(
[
dbscheme.Table(name="foos", columns=[dbscheme.Column(name, "bar")]),
]
) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field(name, "unsigned")]),
]
def test_one_table_overridden_underscore_named_field(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[dbscheme.Column("whatever_", "bar")]),
]) == [
assert generate_traps(
[
dbscheme.Table(name="foos", columns=[dbscheme.Column("whatever_", "bar")]),
]
) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field("whatever", "bar")]),
]
def test_tables_with_dir(generate_grouped_traps):
assert generate_grouped_traps([
dbscheme.Table(name="x", columns=[dbscheme.Column("i", "int")]),
dbscheme.Table(name="y", columns=[dbscheme.Column("i", "int")], dir=pathlib.Path("foo")),
dbscheme.Table(name="z", columns=[dbscheme.Column("i", "int")], dir=pathlib.Path("foo/bar")),
]) == {
assert generate_grouped_traps(
[
dbscheme.Table(name="x", columns=[dbscheme.Column("i", "int")]),
dbscheme.Table(
name="y", columns=[dbscheme.Column("i", "int")], dir=pathlib.Path("foo")
),
dbscheme.Table(
name="z",
columns=[dbscheme.Column("i", "int")],
dir=pathlib.Path("foo/bar"),
),
]
) == {
".": [cpp.Trap("x", name="X", fields=[cpp.Field("i", "int")])],
"foo": [cpp.Trap("y", name="Y", fields=[cpp.Field("i", "int")])],
"foo/bar": [cpp.Trap("z", name="Z", fields=[cpp.Field("i", "int")])],
@@ -153,15 +202,22 @@ def test_tables_with_dir(generate_grouped_traps):
def test_one_table_no_tags(generate_tags):
assert generate_tags([
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
]) == []
assert (
generate_tags(
[
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
]
)
== []
)
def test_one_union_tags(generate_tags):
assert generate_tags([
dbscheme.Union(lhs="@left_hand_side", rhs=["@b", "@a", "@c"]),
]) == [
assert generate_tags(
[
dbscheme.Union(lhs="@left_hand_side", rhs=["@b", "@a", "@c"]),
]
) == [
cpp.Tag(name="LeftHandSide", bases=[], id="@left_hand_side"),
cpp.Tag(name="A", bases=["LeftHandSide"], id="@a"),
cpp.Tag(name="B", bases=["LeftHandSide"], id="@b"),
@@ -170,11 +226,13 @@ def test_one_union_tags(generate_tags):
def test_multiple_union_tags(generate_tags):
assert generate_tags([
dbscheme.Union(lhs="@d", rhs=["@a"]),
dbscheme.Union(lhs="@a", rhs=["@b", "@c"]),
dbscheme.Union(lhs="@e", rhs=["@c", "@f"]),
]) == [
assert generate_tags(
[
dbscheme.Union(lhs="@d", rhs=["@a"]),
dbscheme.Union(lhs="@a", rhs=["@b", "@c"]),
dbscheme.Union(lhs="@e", rhs=["@c", "@f"]),
]
) == [
cpp.Tag(name="D", bases=[], id="@d"),
cpp.Tag(name="E", bases=[], id="@e"),
cpp.Tag(name="A", bases=["D"], id="@a"),
@@ -184,5 +242,5 @@ def test_multiple_union_tags(generate_tags):
]
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -39,8 +39,9 @@ def opts():
@pytest.fixture(autouse=True)
def override_paths(tmp_path):
with mock.patch("misc.codegen.lib.paths.root_dir", tmp_path), \
mock.patch("misc.codegen.lib.paths.exe_file", tmp_path / "exe"):
with mock.patch("misc.codegen.lib.paths.root_dir", tmp_path), mock.patch(
"misc.codegen.lib.paths.exe_file", tmp_path / "exe"
):
yield