mirror of
https://github.com/github/codeql.git
synced 2025-12-16 16:53:25 +01:00
Add black pre-commit hook
This switched `codegen` from the `autopep8` formatting to the `black` one, and applies it to `bulk_mad_generator.py` as well. We can enroll more python scripts to it in the future.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
""" Driver script to run all code generation """
|
||||
"""Driver script to run all code generation"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@@ -9,7 +9,7 @@ import pathlib
|
||||
import typing
|
||||
import shlex
|
||||
|
||||
if 'BUILD_WORKSPACE_DIRECTORY' not in os.environ:
|
||||
if "BUILD_WORKSPACE_DIRECTORY" not in os.environ:
|
||||
# we are not running with `bazel run`, set up module search path
|
||||
_repo_root = pathlib.Path(__file__).resolve().parents[2]
|
||||
sys.path.append(str(_repo_root))
|
||||
@@ -29,57 +29,105 @@ def _parse_args() -> argparse.Namespace:
|
||||
conf = None
|
||||
|
||||
p = argparse.ArgumentParser(description="Code generation suite")
|
||||
p.add_argument("--generate", type=lambda x: x.split(","),
|
||||
help="specify what targets to generate as a comma separated list, choosing among dbscheme, ql, "
|
||||
"trap, cpp and rust")
|
||||
p.add_argument("--verbose", "-v", action="store_true", help="print more information")
|
||||
p.add_argument(
|
||||
"--generate",
|
||||
type=lambda x: x.split(","),
|
||||
help="specify what targets to generate as a comma separated list, choosing among dbscheme, ql, "
|
||||
"trap, cpp and rust",
|
||||
)
|
||||
p.add_argument(
|
||||
"--verbose", "-v", action="store_true", help="print more information"
|
||||
)
|
||||
p.add_argument("--quiet", "-q", action="store_true", help="only print errors")
|
||||
p.add_argument("--configuration-file", "-c", type=_abspath, default=conf,
|
||||
help="A configuration file to load options from. By default, the first codegen.conf file found by "
|
||||
"going up directories from the current location. If present all paths provided in options are "
|
||||
"considered relative to its directory")
|
||||
p.add_argument("--root-dir", type=_abspath,
|
||||
help="the directory that should be regarded as the root of the language pack codebase. Used to "
|
||||
"compute QL imports and in some comments and as root for relative paths provided as options. "
|
||||
"If not provided it defaults to the directory of the configuration file, if any")
|
||||
p.add_argument(
|
||||
"--configuration-file",
|
||||
"-c",
|
||||
type=_abspath,
|
||||
default=conf,
|
||||
help="A configuration file to load options from. By default, the first codegen.conf file found by "
|
||||
"going up directories from the current location. If present all paths provided in options are "
|
||||
"considered relative to its directory",
|
||||
)
|
||||
p.add_argument(
|
||||
"--root-dir",
|
||||
type=_abspath,
|
||||
help="the directory that should be regarded as the root of the language pack codebase. Used to "
|
||||
"compute QL imports and in some comments and as root for relative paths provided as options. "
|
||||
"If not provided it defaults to the directory of the configuration file, if any",
|
||||
)
|
||||
path_arguments = [
|
||||
p.add_argument("--schema",
|
||||
help="input schema file (default schema.py)"),
|
||||
p.add_argument("--dbscheme",
|
||||
help="output file for dbscheme generation, input file for trap generation"),
|
||||
p.add_argument("--ql-output",
|
||||
help="output directory for generated QL files"),
|
||||
p.add_argument("--ql-stub-output",
|
||||
help="output directory for QL stub/customization files. Defines also the "
|
||||
"generated qll file importing every class file"),
|
||||
p.add_argument("--ql-test-output",
|
||||
help="output directory for QL generated extractor test files"),
|
||||
p.add_argument("--ql-cfg-output",
|
||||
help="output directory for QL CFG layer (optional)."),
|
||||
p.add_argument("--cpp-output",
|
||||
help="output directory for generated C++ files, required if trap or cpp is provided to "
|
||||
"--generate"),
|
||||
p.add_argument("--rust-output",
|
||||
help="output directory for generated Rust files, required if rust is provided to "
|
||||
"--generate"),
|
||||
p.add_argument("--generated-registry",
|
||||
help="registry file containing information about checked-in generated code. A .gitattributes"
|
||||
"file is generated besides it to mark those files with linguist-generated=true. Must"
|
||||
"be in a directory containing all generated code."),
|
||||
p.add_argument("--schema", help="input schema file (default schema.py)"),
|
||||
p.add_argument(
|
||||
"--dbscheme",
|
||||
help="output file for dbscheme generation, input file for trap generation",
|
||||
),
|
||||
p.add_argument("--ql-output", help="output directory for generated QL files"),
|
||||
p.add_argument(
|
||||
"--ql-stub-output",
|
||||
help="output directory for QL stub/customization files. Defines also the "
|
||||
"generated qll file importing every class file",
|
||||
),
|
||||
p.add_argument(
|
||||
"--ql-test-output",
|
||||
help="output directory for QL generated extractor test files",
|
||||
),
|
||||
p.add_argument(
|
||||
"--ql-cfg-output", help="output directory for QL CFG layer (optional)."
|
||||
),
|
||||
p.add_argument(
|
||||
"--cpp-output",
|
||||
help="output directory for generated C++ files, required if trap or cpp is provided to "
|
||||
"--generate",
|
||||
),
|
||||
p.add_argument(
|
||||
"--rust-output",
|
||||
help="output directory for generated Rust files, required if rust is provided to "
|
||||
"--generate",
|
||||
),
|
||||
p.add_argument(
|
||||
"--generated-registry",
|
||||
help="registry file containing information about checked-in generated code. A .gitattributes"
|
||||
"file is generated besides it to mark those files with linguist-generated=true. Must"
|
||||
"be in a directory containing all generated code.",
|
||||
),
|
||||
]
|
||||
p.add_argument("--script-name",
|
||||
help="script name to put in header comments of generated files. By default, the path of this "
|
||||
"script relative to the root directory")
|
||||
p.add_argument("--trap-library",
|
||||
help="path to the trap library from an include directory, required if generating C++ trap bindings"),
|
||||
p.add_argument("--ql-format", action="store_true", default=True,
|
||||
help="use codeql to autoformat QL files (which is the default)")
|
||||
p.add_argument("--no-ql-format", action="store_false", dest="ql_format", help="do not format QL files")
|
||||
p.add_argument("--codeql-binary", default="codeql", help="command to use for QL formatting (default %(default)s)")
|
||||
p.add_argument("--force", "-f", action="store_true",
|
||||
help="generate all files without skipping unchanged files and overwriting modified ones")
|
||||
p.add_argument("--use-current-directory", action="store_true",
|
||||
help="do not consider paths as relative to --root-dir or the configuration directory")
|
||||
p.add_argument(
|
||||
"--script-name",
|
||||
help="script name to put in header comments of generated files. By default, the path of this "
|
||||
"script relative to the root directory",
|
||||
)
|
||||
p.add_argument(
|
||||
"--trap-library",
|
||||
help="path to the trap library from an include directory, required if generating C++ trap bindings",
|
||||
),
|
||||
p.add_argument(
|
||||
"--ql-format",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="use codeql to autoformat QL files (which is the default)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--no-ql-format",
|
||||
action="store_false",
|
||||
dest="ql_format",
|
||||
help="do not format QL files",
|
||||
)
|
||||
p.add_argument(
|
||||
"--codeql-binary",
|
||||
default="codeql",
|
||||
help="command to use for QL formatting (default %(default)s)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--force",
|
||||
"-f",
|
||||
action="store_true",
|
||||
help="generate all files without skipping unchanged files and overwriting modified ones",
|
||||
)
|
||||
p.add_argument(
|
||||
"--use-current-directory",
|
||||
action="store_true",
|
||||
help="do not consider paths as relative to --root-dir or the configuration directory",
|
||||
)
|
||||
opts = p.parse_args()
|
||||
if opts.configuration_file is not None:
|
||||
with open(opts.configuration_file) as config:
|
||||
@@ -97,7 +145,15 @@ def _parse_args() -> argparse.Namespace:
|
||||
for arg in path_arguments:
|
||||
path = getattr(opts, arg.dest)
|
||||
if path is not None:
|
||||
setattr(opts, arg.dest, _abspath(path) if opts.use_current_directory else (opts.root_dir / path))
|
||||
setattr(
|
||||
opts,
|
||||
arg.dest,
|
||||
(
|
||||
_abspath(path)
|
||||
if opts.use_current_directory
|
||||
else (opts.root_dir / path)
|
||||
),
|
||||
)
|
||||
if not opts.script_name:
|
||||
opts.script_name = paths.exe_file.relative_to(opts.root_dir)
|
||||
return opts
|
||||
@@ -115,7 +171,7 @@ def run():
|
||||
log_level = logging.ERROR
|
||||
else:
|
||||
log_level = logging.INFO
|
||||
logging.basicConfig(format="{levelname} {message}", style='{', level=log_level)
|
||||
logging.basicConfig(format="{levelname} {message}", style="{", level=log_level)
|
||||
for target in opts.generate:
|
||||
generate(target, opts, render.Renderer(opts.script_name))
|
||||
|
||||
|
||||
@@ -49,7 +49,11 @@ def _get_trap_name(cls: schema.Class, p: schema.Property) -> str | None:
|
||||
return inflection.pluralize(trap_name)
|
||||
|
||||
|
||||
def _get_field(cls: schema.Class, p: schema.Property, add_or_none_except: typing.Optional[str] = None) -> cpp.Field:
|
||||
def _get_field(
|
||||
cls: schema.Class,
|
||||
p: schema.Property,
|
||||
add_or_none_except: typing.Optional[str] = None,
|
||||
) -> cpp.Field:
|
||||
args = dict(
|
||||
field_name=p.name + ("_" if p.name in cpp.cpp_keywords else ""),
|
||||
base_type=_get_type(p.type, add_or_none_except),
|
||||
@@ -83,14 +87,15 @@ class Processor:
|
||||
bases=[self._get_class(b) for b in cls.bases],
|
||||
fields=[
|
||||
_get_field(cls, p, self._add_or_none_except)
|
||||
for p in cls.properties if "cpp_skip" not in p.pragmas and not p.synth
|
||||
for p in cls.properties
|
||||
if "cpp_skip" not in p.pragmas and not p.synth
|
||||
],
|
||||
final=not cls.derived,
|
||||
trap_name=trap_name,
|
||||
)
|
||||
|
||||
def get_classes(self):
|
||||
ret = {'': []}
|
||||
ret = {"": []}
|
||||
for k, cls in self._classmap.items():
|
||||
if not cls.synth:
|
||||
ret.setdefault(cls.group, []).append(self._get_class(cls.name))
|
||||
@@ -102,6 +107,12 @@ def generate(opts, renderer):
|
||||
processor = Processor(schemaloader.load_file(opts.schema))
|
||||
out = opts.cpp_output
|
||||
for dir, classes in processor.get_classes().items():
|
||||
renderer.render(cpp.ClassList(classes, opts.schema,
|
||||
include_parent=bool(dir),
|
||||
trap_library=opts.trap_library), out / dir / "TrapClasses")
|
||||
renderer.render(
|
||||
cpp.ClassList(
|
||||
classes,
|
||||
opts.schema,
|
||||
include_parent=bool(dir),
|
||||
trap_library=opts.trap_library,
|
||||
),
|
||||
out / dir / "TrapClasses",
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ Moreover:
|
||||
as columns
|
||||
The type hierarchy will be translated to corresponding `union` declarations.
|
||||
"""
|
||||
|
||||
import typing
|
||||
|
||||
import inflection
|
||||
@@ -29,7 +30,7 @@ class Error(Exception):
|
||||
|
||||
|
||||
def dbtype(typename: str, add_or_none_except: typing.Optional[str] = None) -> str:
|
||||
""" translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes.
|
||||
"""translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes.
|
||||
For class types, appends an underscore followed by `null` if provided
|
||||
"""
|
||||
if typename[0].isupper():
|
||||
@@ -42,12 +43,18 @@ def dbtype(typename: str, add_or_none_except: typing.Optional[str] = None) -> st
|
||||
return typename
|
||||
|
||||
|
||||
def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], add_or_none_except: typing.Optional[str] = None):
|
||||
""" Yield all dbscheme entities needed to model class `cls` """
|
||||
def cls_to_dbscheme(
|
||||
cls: schema.Class,
|
||||
lookup: typing.Dict[str, schema.Class],
|
||||
add_or_none_except: typing.Optional[str] = None,
|
||||
):
|
||||
"""Yield all dbscheme entities needed to model class `cls`"""
|
||||
if cls.synth:
|
||||
return
|
||||
if cls.derived:
|
||||
yield Union(dbtype(cls.name), (dbtype(c) for c in cls.derived if not lookup[c].synth))
|
||||
yield Union(
|
||||
dbtype(cls.name), (dbtype(c) for c in cls.derived if not lookup[c].synth)
|
||||
)
|
||||
dir = pathlib.Path(cls.group) if cls.group else None
|
||||
# output a table specific to a class only if it is a leaf class or it has 1-to-1 properties
|
||||
# Leaf classes need a table to bind the `@` ids
|
||||
@@ -61,9 +68,11 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a
|
||||
name=inflection.tableize(cls.name),
|
||||
columns=[
|
||||
Column("id", type=dbtype(cls.name), binding=binding),
|
||||
] + [
|
||||
]
|
||||
+ [
|
||||
Column(f.name, dbtype(f.type, add_or_none_except))
|
||||
for f in cls.properties if f.is_single and not f.synth
|
||||
for f in cls.properties
|
||||
if f.is_single and not f.synth
|
||||
],
|
||||
dir=dir,
|
||||
)
|
||||
@@ -74,28 +83,37 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a
|
||||
continue
|
||||
if f.is_unordered:
|
||||
yield Table(
|
||||
name=overridden_table_name or inflection.tableize(f"{cls.name}_{f.name}"),
|
||||
name=overridden_table_name
|
||||
or inflection.tableize(f"{cls.name}_{f.name}"),
|
||||
columns=[
|
||||
Column("id", type=dbtype(cls.name)),
|
||||
Column(inflection.singularize(f.name), dbtype(f.type, add_or_none_except)),
|
||||
Column(
|
||||
inflection.singularize(f.name),
|
||||
dbtype(f.type, add_or_none_except),
|
||||
),
|
||||
],
|
||||
dir=dir,
|
||||
)
|
||||
elif f.is_repeated:
|
||||
yield Table(
|
||||
keyset=KeySet(["id", "index"]),
|
||||
name=overridden_table_name or inflection.tableize(f"{cls.name}_{f.name}"),
|
||||
name=overridden_table_name
|
||||
or inflection.tableize(f"{cls.name}_{f.name}"),
|
||||
columns=[
|
||||
Column("id", type=dbtype(cls.name)),
|
||||
Column("index", type="int"),
|
||||
Column(inflection.singularize(f.name), dbtype(f.type, add_or_none_except)),
|
||||
Column(
|
||||
inflection.singularize(f.name),
|
||||
dbtype(f.type, add_or_none_except),
|
||||
),
|
||||
],
|
||||
dir=dir,
|
||||
)
|
||||
elif f.is_optional:
|
||||
yield Table(
|
||||
keyset=KeySet(["id"]),
|
||||
name=overridden_table_name or inflection.tableize(f"{cls.name}_{f.name}"),
|
||||
name=overridden_table_name
|
||||
or inflection.tableize(f"{cls.name}_{f.name}"),
|
||||
columns=[
|
||||
Column("id", type=dbtype(cls.name)),
|
||||
Column(f.name, dbtype(f.type, add_or_none_except)),
|
||||
@@ -105,7 +123,8 @@ def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], a
|
||||
elif f.is_predicate:
|
||||
yield Table(
|
||||
keyset=KeySet(["id"]),
|
||||
name=overridden_table_name or inflection.underscore(f"{cls.name}_{f.name}"),
|
||||
name=overridden_table_name
|
||||
or inflection.underscore(f"{cls.name}_{f.name}"),
|
||||
columns=[
|
||||
Column("id", type=dbtype(cls.name)),
|
||||
],
|
||||
@@ -119,33 +138,46 @@ def check_name_conflicts(decls: list[Table | Union]):
|
||||
match decl:
|
||||
case Table(name=name):
|
||||
if name in names:
|
||||
raise Error(f"Duplicate table name: {
|
||||
name}, you can use `@ql.db_table_name` on a property to resolve this")
|
||||
raise Error(
|
||||
f"Duplicate table name: {
|
||||
name}, you can use `@ql.db_table_name` on a property to resolve this"
|
||||
)
|
||||
names.add(name)
|
||||
|
||||
|
||||
def get_declarations(data: schema.Schema):
|
||||
add_or_none_except = data.root_class.name if data.null else None
|
||||
declarations = [d for cls in data.classes.values() if not cls.imported for d in cls_to_dbscheme(cls,
|
||||
data.classes, add_or_none_except)]
|
||||
declarations = [
|
||||
d
|
||||
for cls in data.classes.values()
|
||||
if not cls.imported
|
||||
for d in cls_to_dbscheme(cls, data.classes, add_or_none_except)
|
||||
]
|
||||
if data.null:
|
||||
property_classes = {
|
||||
prop.type for cls in data.classes.values() for prop in cls.properties
|
||||
prop.type
|
||||
for cls in data.classes.values()
|
||||
for prop in cls.properties
|
||||
if cls.name != data.null and prop.type and prop.type[0].isupper()
|
||||
}
|
||||
declarations += [
|
||||
Union(dbtype(t, data.null), [dbtype(t), dbtype(data.null)]) for t in sorted(property_classes)
|
||||
Union(dbtype(t, data.null), [dbtype(t), dbtype(data.null)])
|
||||
for t in sorted(property_classes)
|
||||
]
|
||||
check_name_conflicts(declarations)
|
||||
return declarations
|
||||
|
||||
|
||||
def get_includes(data: schema.Schema, include_dir: pathlib.Path, root_dir: pathlib.Path):
|
||||
def get_includes(
|
||||
data: schema.Schema, include_dir: pathlib.Path, root_dir: pathlib.Path
|
||||
):
|
||||
includes = []
|
||||
for inc in data.includes:
|
||||
inc = include_dir / inc
|
||||
with open(inc) as inclusion:
|
||||
includes.append(SchemeInclude(src=inc.relative_to(root_dir), data=inclusion.read()))
|
||||
includes.append(
|
||||
SchemeInclude(src=inc.relative_to(root_dir), data=inclusion.read())
|
||||
)
|
||||
return includes
|
||||
|
||||
|
||||
@@ -155,8 +187,10 @@ def generate(opts, renderer):
|
||||
|
||||
data = schemaloader.load_file(input)
|
||||
|
||||
dbscheme = Scheme(src=input.name,
|
||||
includes=get_includes(data, include_dir=input.parent, root_dir=input.parent),
|
||||
declarations=get_declarations(data))
|
||||
dbscheme = Scheme(
|
||||
src=input.name,
|
||||
includes=get_includes(data, include_dir=input.parent, root_dir=input.parent),
|
||||
declarations=get_declarations(data),
|
||||
)
|
||||
|
||||
renderer.render(dbscheme, out)
|
||||
|
||||
@@ -19,6 +19,7 @@ Moreover in the test directory for each <Class> in <group> it will generate bene
|
||||
* one `<Class>.ql` test query for all single properties and on `<Class>_<property>.ql` test query for each optional or
|
||||
repeated property
|
||||
"""
|
||||
|
||||
# TODO this should probably be split in different generators now: ql, qltest, maybe qlsynth
|
||||
|
||||
import logging
|
||||
@@ -70,7 +71,7 @@ abbreviations = {
|
||||
|
||||
abbreviations.update({f"{k}s": f"{v}s" for k, v in abbreviations.items()})
|
||||
|
||||
_abbreviations_re = re.compile("|".join(fr"\b{abbr}\b" for abbr in abbreviations))
|
||||
_abbreviations_re = re.compile("|".join(rf"\b{abbr}\b" for abbr in abbreviations))
|
||||
|
||||
|
||||
def _humanize(s: str) -> str:
|
||||
@@ -98,11 +99,17 @@ def _get_doc(cls: schema.Class, prop: schema.Property, plural=None):
|
||||
return format.format(**{noun: transform(noun) for noun in nouns})
|
||||
|
||||
prop_name = _humanize(prop.name)
|
||||
class_name = cls.pragmas.get("ql_default_doc_name", _humanize(inflection.underscore(cls.name)))
|
||||
class_name = cls.pragmas.get(
|
||||
"ql_default_doc_name", _humanize(inflection.underscore(cls.name))
|
||||
)
|
||||
if prop.is_predicate:
|
||||
return f"this {class_name} {prop_name}"
|
||||
if plural is not None:
|
||||
prop_name = inflection.pluralize(prop_name) if plural else inflection.singularize(prop_name)
|
||||
prop_name = (
|
||||
inflection.pluralize(prop_name)
|
||||
if plural
|
||||
else inflection.singularize(prop_name)
|
||||
)
|
||||
return f"{prop_name} of this {class_name}"
|
||||
|
||||
|
||||
@@ -114,8 +121,12 @@ def _type_is_hideable(t: str, lookup: typing.Dict[str, schema.ClassBase]) -> boo
|
||||
return False
|
||||
|
||||
|
||||
def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dict[str, schema.ClassBase],
|
||||
prev_child: str = "") -> ql.Property:
|
||||
def get_ql_property(
|
||||
cls: schema.Class,
|
||||
prop: schema.Property,
|
||||
lookup: typing.Dict[str, schema.ClassBase],
|
||||
prev_child: str = "",
|
||||
) -> ql.Property:
|
||||
|
||||
args = dict(
|
||||
type=prop.type if not prop.is_predicate else "predicate",
|
||||
@@ -133,12 +144,15 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
|
||||
ql_name = prop.pragmas.get("ql_name", prop.name)
|
||||
db_table_name = prop.pragmas.get("ql_db_table_name")
|
||||
if db_table_name and prop.is_single:
|
||||
raise Error(f"`db_table_name` pragma is not supported for single properties, but {cls.name}.{prop.name} has it")
|
||||
raise Error(
|
||||
f"`db_table_name` pragma is not supported for single properties, but {cls.name}.{prop.name} has it"
|
||||
)
|
||||
if prop.is_single:
|
||||
args.update(
|
||||
singular=inflection.camelize(ql_name),
|
||||
tablename=inflection.tableize(cls.name),
|
||||
tableparams=["this"] + ["result" if p is prop else "_" for p in cls.properties if p.is_single],
|
||||
tableparams=["this"]
|
||||
+ ["result" if p is prop else "_" for p in cls.properties if p.is_single],
|
||||
doc=_get_doc(cls, prop),
|
||||
)
|
||||
elif prop.is_repeated:
|
||||
@@ -146,7 +160,11 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
|
||||
singular=inflection.singularize(inflection.camelize(ql_name)),
|
||||
plural=inflection.pluralize(inflection.camelize(ql_name)),
|
||||
tablename=db_table_name or inflection.tableize(f"{cls.name}_{prop.name}"),
|
||||
tableparams=["this", "index", "result"] if not prop.is_unordered else ["this", "result"],
|
||||
tableparams=(
|
||||
["this", "index", "result"]
|
||||
if not prop.is_unordered
|
||||
else ["this", "result"]
|
||||
),
|
||||
doc=_get_doc(cls, prop, plural=False),
|
||||
doc_plural=_get_doc(cls, prop, plural=True),
|
||||
)
|
||||
@@ -169,7 +187,9 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, lookup: typing.Dic
|
||||
return ql.Property(**args)
|
||||
|
||||
|
||||
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.ClassBase]) -> ql.Class:
|
||||
def get_ql_class(
|
||||
cls: schema.Class, lookup: typing.Dict[str, schema.ClassBase]
|
||||
) -> ql.Class:
|
||||
if "ql_name" in cls.pragmas:
|
||||
raise Error("ql_name is not supported yet for classes, only for properties")
|
||||
prev_child = ""
|
||||
@@ -195,12 +215,14 @@ def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.ClassBase])
|
||||
)
|
||||
|
||||
|
||||
def get_ql_cfg_class(cls: schema.Class, lookup: typing.Dict[str, ql.Class]) -> ql.CfgClass:
|
||||
def get_ql_cfg_class(
|
||||
cls: schema.Class, lookup: typing.Dict[str, ql.Class]
|
||||
) -> ql.CfgClass:
|
||||
return ql.CfgClass(
|
||||
name=cls.name,
|
||||
bases=[base for base in cls.bases if lookup[base.base].cfg],
|
||||
properties=cls.properties,
|
||||
doc=cls.doc
|
||||
doc=cls.doc,
|
||||
)
|
||||
|
||||
|
||||
@@ -214,24 +236,33 @@ _final_db_class_lookup = {}
|
||||
|
||||
|
||||
def get_ql_synth_class_db(name: str) -> ql.Synth.FinalClassDb:
|
||||
return _final_db_class_lookup.setdefault(name, ql.Synth.FinalClassDb(name=name,
|
||||
params=[
|
||||
ql.Synth.Param("id", _to_db_type(name))]))
|
||||
return _final_db_class_lookup.setdefault(
|
||||
name,
|
||||
ql.Synth.FinalClassDb(
|
||||
name=name, params=[ql.Synth.Param("id", _to_db_type(name))]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_ql_synth_class(cls: schema.Class):
|
||||
if cls.derived:
|
||||
return ql.Synth.NonFinalClass(name=cls.name, derived=sorted(cls.derived),
|
||||
root=not cls.bases)
|
||||
return ql.Synth.NonFinalClass(
|
||||
name=cls.name, derived=sorted(cls.derived), root=not cls.bases
|
||||
)
|
||||
if cls.synth and cls.synth.from_class is not None:
|
||||
source = cls.synth.from_class
|
||||
get_ql_synth_class_db(source).subtract_type(cls.name)
|
||||
return ql.Synth.FinalClassDerivedSynth(name=cls.name,
|
||||
params=[ql.Synth.Param("id", _to_db_type(source))])
|
||||
return ql.Synth.FinalClassDerivedSynth(
|
||||
name=cls.name, params=[ql.Synth.Param("id", _to_db_type(source))]
|
||||
)
|
||||
if cls.synth and cls.synth.on_arguments is not None:
|
||||
return ql.Synth.FinalClassFreshSynth(name=cls.name,
|
||||
params=[ql.Synth.Param(k, _to_db_type(v))
|
||||
for k, v in cls.synth.on_arguments.items()])
|
||||
return ql.Synth.FinalClassFreshSynth(
|
||||
name=cls.name,
|
||||
params=[
|
||||
ql.Synth.Param(k, _to_db_type(v))
|
||||
for k, v in cls.synth.on_arguments.items()
|
||||
],
|
||||
)
|
||||
return get_ql_synth_class_db(cls.name)
|
||||
|
||||
|
||||
@@ -250,7 +281,13 @@ def get_types_used_by(cls: ql.Class, is_impl: bool) -> typing.Iterable[str]:
|
||||
|
||||
|
||||
def get_classes_used_by(cls: ql.Class, is_impl: bool) -> typing.List[str]:
|
||||
return sorted(set(t for t in get_types_used_by(cls, is_impl) if t[0].isupper() and (is_impl or t != cls.name)))
|
||||
return sorted(
|
||||
set(
|
||||
t
|
||||
for t in get_types_used_by(cls, is_impl)
|
||||
if t[0].isupper() and (is_impl or t != cls.name)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def format(codeql, files):
|
||||
@@ -265,7 +302,8 @@ def format(codeql, files):
|
||||
codeql_path = shutil.which(codeql)
|
||||
if not codeql_path:
|
||||
raise FormatError(
|
||||
f"`{codeql}` not found in PATH. Either install it, or pass `-- --codeql-binary` with a full path")
|
||||
f"`{codeql}` not found in PATH. Either install it, or pass `-- --codeql-binary` with a full path"
|
||||
)
|
||||
codeql = codeql_path
|
||||
res = subprocess.run(format_cmd, stderr=subprocess.PIPE, text=True)
|
||||
if res.returncode:
|
||||
@@ -281,16 +319,22 @@ def _get_path(cls: schema.Class) -> pathlib.Path:
|
||||
|
||||
|
||||
def _get_path_impl(cls: schema.Class) -> pathlib.Path:
|
||||
return pathlib.Path(cls.group or "", "internal", cls.name+"Impl").with_suffix(".qll")
|
||||
return pathlib.Path(cls.group or "", "internal", cls.name + "Impl").with_suffix(
|
||||
".qll"
|
||||
)
|
||||
|
||||
|
||||
def _get_path_public(cls: schema.Class) -> pathlib.Path:
|
||||
return pathlib.Path(cls.group or "", "internal" if "ql_internal" in cls.pragmas else "", cls.name).with_suffix(".qll")
|
||||
return pathlib.Path(
|
||||
cls.group or "", "internal" if "ql_internal" in cls.pragmas else "", cls.name
|
||||
).with_suffix(".qll")
|
||||
|
||||
|
||||
def _get_all_properties(cls: schema.Class, lookup: typing.Dict[str, schema.Class],
|
||||
already_seen: typing.Optional[typing.Set[int]] = None) -> \
|
||||
typing.Iterable[typing.Tuple[schema.Class, schema.Property]]:
|
||||
def _get_all_properties(
|
||||
cls: schema.Class,
|
||||
lookup: typing.Dict[str, schema.Class],
|
||||
already_seen: typing.Optional[typing.Set[int]] = None,
|
||||
) -> typing.Iterable[typing.Tuple[schema.Class, schema.Property]]:
|
||||
# deduplicate using ids
|
||||
if already_seen is None:
|
||||
already_seen = set()
|
||||
@@ -304,14 +348,19 @@ def _get_all_properties(cls: schema.Class, lookup: typing.Dict[str, schema.Class
|
||||
yield cls, p
|
||||
|
||||
|
||||
def _get_all_properties_to_be_tested(cls: schema.Class, lookup: typing.Dict[str, schema.Class]) -> \
|
||||
typing.Iterable[ql.PropertyForTest]:
|
||||
def _get_all_properties_to_be_tested(
|
||||
cls: schema.Class, lookup: typing.Dict[str, schema.Class]
|
||||
) -> typing.Iterable[ql.PropertyForTest]:
|
||||
for c, p in _get_all_properties(cls, lookup):
|
||||
if not ("qltest_skip" in c.pragmas or "qltest_skip" in p.pragmas):
|
||||
# TODO here operations are duplicated, but should be better if we split ql and qltest generation
|
||||
p = get_ql_property(c, p, lookup)
|
||||
yield ql.PropertyForTest(p.getter, is_total=p.is_single or p.is_predicate,
|
||||
type=p.type if not p.is_predicate else None, is_indexed=p.is_indexed)
|
||||
yield ql.PropertyForTest(
|
||||
p.getter,
|
||||
is_total=p.is_single or p.is_predicate,
|
||||
type=p.type if not p.is_predicate else None,
|
||||
is_indexed=p.is_indexed,
|
||||
)
|
||||
if p.is_repeated and not p.is_optional:
|
||||
yield ql.PropertyForTest(f"getNumberOf{p.plural}", type="int")
|
||||
elif p.is_optional and not p.is_repeated:
|
||||
@@ -324,33 +373,45 @@ def _partition_iter(x, pred):
|
||||
|
||||
|
||||
def _partition(l, pred):
|
||||
""" partitions a list according to boolean predicate """
|
||||
"""partitions a list according to boolean predicate"""
|
||||
return map(list, _partition_iter(l, pred))
|
||||
|
||||
|
||||
def _is_in_qltest_collapsed_hierarchy(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
|
||||
return "qltest_collapse_hierarchy" in cls.pragmas or _is_under_qltest_collapsed_hierarchy(cls, lookup)
|
||||
def _is_in_qltest_collapsed_hierarchy(
|
||||
cls: schema.Class, lookup: typing.Dict[str, schema.Class]
|
||||
):
|
||||
return (
|
||||
"qltest_collapse_hierarchy" in cls.pragmas
|
||||
or _is_under_qltest_collapsed_hierarchy(cls, lookup)
|
||||
)
|
||||
|
||||
|
||||
def _is_under_qltest_collapsed_hierarchy(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
|
||||
def _is_under_qltest_collapsed_hierarchy(
|
||||
cls: schema.Class, lookup: typing.Dict[str, schema.Class]
|
||||
):
|
||||
return "qltest_uncollapse_hierarchy" not in cls.pragmas and any(
|
||||
_is_in_qltest_collapsed_hierarchy(lookup[b], lookup) for b in cls.bases)
|
||||
_is_in_qltest_collapsed_hierarchy(lookup[b], lookup) for b in cls.bases
|
||||
)
|
||||
|
||||
|
||||
def should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
|
||||
return "qltest_skip" in cls.pragmas or not (
|
||||
cls.final or "qltest_collapse_hierarchy" in cls.pragmas) or _is_under_qltest_collapsed_hierarchy(
|
||||
cls, lookup)
|
||||
return (
|
||||
"qltest_skip" in cls.pragmas
|
||||
or not (cls.final or "qltest_collapse_hierarchy" in cls.pragmas)
|
||||
or _is_under_qltest_collapsed_hierarchy(cls, lookup)
|
||||
)
|
||||
|
||||
|
||||
def _get_stub(cls: schema.Class, base_import: str, generated_import_prefix: str) -> ql.Stub:
|
||||
def _get_stub(
|
||||
cls: schema.Class, base_import: str, generated_import_prefix: str
|
||||
) -> ql.Stub:
|
||||
if isinstance(cls.synth, schema.SynthInfo):
|
||||
if cls.synth.from_class is not None:
|
||||
accessors = [
|
||||
ql.SynthUnderlyingAccessor(
|
||||
argument="Entity",
|
||||
type=_to_db_type(cls.synth.from_class),
|
||||
constructorparams=["result"]
|
||||
constructorparams=["result"],
|
||||
)
|
||||
]
|
||||
elif cls.synth.on_arguments is not None:
|
||||
@@ -358,28 +419,39 @@ def _get_stub(cls: schema.Class, base_import: str, generated_import_prefix: str)
|
||||
ql.SynthUnderlyingAccessor(
|
||||
argument=inflection.camelize(arg),
|
||||
type=_to_db_type(type),
|
||||
constructorparams=["result" if a == arg else "_" for a in cls.synth.on_arguments]
|
||||
) for arg, type in cls.synth.on_arguments.items()
|
||||
constructorparams=[
|
||||
"result" if a == arg else "_" for a in cls.synth.on_arguments
|
||||
],
|
||||
)
|
||||
for arg, type in cls.synth.on_arguments.items()
|
||||
]
|
||||
else:
|
||||
accessors = []
|
||||
return ql.Stub(name=cls.name, base_import=base_import, import_prefix=generated_import_prefix,
|
||||
doc=cls.doc, synth_accessors=accessors)
|
||||
return ql.Stub(
|
||||
name=cls.name,
|
||||
base_import=base_import,
|
||||
import_prefix=generated_import_prefix,
|
||||
doc=cls.doc,
|
||||
synth_accessors=accessors,
|
||||
)
|
||||
|
||||
|
||||
def _get_class_public(cls: schema.Class) -> ql.ClassPublic:
|
||||
return ql.ClassPublic(name=cls.name, doc=cls.doc, internal="ql_internal" in cls.pragmas)
|
||||
return ql.ClassPublic(
|
||||
name=cls.name, doc=cls.doc, internal="ql_internal" in cls.pragmas
|
||||
)
|
||||
|
||||
|
||||
_stub_qldoc_header = "// the following QLdoc is generated: if you need to edit it, do it in the schema file\n "
|
||||
|
||||
_class_qldoc_re = re.compile(
|
||||
rf"(?P<qldoc>(?:{re.escape(_stub_qldoc_header)})?/\*\*.*?\*/\s*|^\s*)(?:class\s+(?P<class>\w+))?",
|
||||
re.MULTILINE | re.DOTALL)
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def _patch_class_qldoc(cls: str, qldoc: str, stub_file: pathlib.Path):
|
||||
""" Replace or insert `qldoc` as the QLdoc of class `cls` in `stub_file` """
|
||||
"""Replace or insert `qldoc` as the QLdoc of class `cls` in `stub_file`"""
|
||||
if not qldoc or not stub_file.exists():
|
||||
return
|
||||
qldoc = "\n ".join(l.rstrip() for l in qldoc.splitlines())
|
||||
@@ -415,7 +487,11 @@ def generate(opts, renderer):
|
||||
|
||||
data = schemaloader.load_file(input)
|
||||
|
||||
classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items() if not cls.imported}
|
||||
classes = {
|
||||
name: get_ql_class(cls, data.classes)
|
||||
for name, cls in data.classes.items()
|
||||
if not cls.imported
|
||||
}
|
||||
if not classes:
|
||||
raise NoClasses
|
||||
root = next(iter(classes.values()))
|
||||
@@ -429,28 +505,47 @@ def generate(opts, renderer):
|
||||
cfg_classes = []
|
||||
generated_import_prefix = get_import(out, opts.root_dir)
|
||||
registry = opts.generated_registry or pathlib.Path(
|
||||
os.path.commonpath((out, stub_out, test_out)), ".generated.list")
|
||||
os.path.commonpath((out, stub_out, test_out)), ".generated.list"
|
||||
)
|
||||
|
||||
with renderer.manage(generated=generated, stubs=stubs, registry=registry,
|
||||
force=opts.force) as renderer:
|
||||
with renderer.manage(
|
||||
generated=generated, stubs=stubs, registry=registry, force=opts.force
|
||||
) as renderer:
|
||||
|
||||
db_classes = [cls for name, cls in classes.items() if not data.classes[name].synth]
|
||||
renderer.render(ql.DbClasses(classes=db_classes, imports=sorted(set(pre_imports.values()))), out / "Raw.qll")
|
||||
db_classes = [
|
||||
cls for name, cls in classes.items() if not data.classes[name].synth
|
||||
]
|
||||
renderer.render(
|
||||
ql.DbClasses(classes=db_classes, imports=sorted(set(pre_imports.values()))),
|
||||
out / "Raw.qll",
|
||||
)
|
||||
|
||||
classes_by_dir_and_name = sorted(classes.values(), key=lambda cls: (cls.dir, cls.name))
|
||||
classes_by_dir_and_name = sorted(
|
||||
classes.values(), key=lambda cls: (cls.dir, cls.name)
|
||||
)
|
||||
for c in classes_by_dir_and_name:
|
||||
path = get_import(stub_out / c.dir / "internal" /
|
||||
c.name if c.internal else stub_out / c.path, opts.root_dir)
|
||||
path = get_import(
|
||||
(
|
||||
stub_out / c.dir / "internal" / c.name
|
||||
if c.internal
|
||||
else stub_out / c.path
|
||||
),
|
||||
opts.root_dir,
|
||||
)
|
||||
imports[c.name] = path
|
||||
path_impl = get_import(stub_out / c.dir / "internal" / c.name, opts.root_dir)
|
||||
path_impl = get_import(
|
||||
stub_out / c.dir / "internal" / c.name, opts.root_dir
|
||||
)
|
||||
imports_impl[c.name + "Impl"] = path_impl + "Impl"
|
||||
if c.cfg:
|
||||
cfg_classes.append(get_ql_cfg_class(c, classes))
|
||||
|
||||
for c in classes.values():
|
||||
qll = out / c.path.with_suffix(".qll")
|
||||
c.imports = [imports[t] if t in imports else imports_impl[t] +
|
||||
"::Impl as " + t for t in get_classes_used_by(c, is_impl=True)]
|
||||
c.imports = [
|
||||
imports[t] if t in imports else imports_impl[t] + "::Impl as " + t
|
||||
for t in get_classes_used_by(c, is_impl=True)
|
||||
]
|
||||
classes_used_by[c.name] = get_classes_used_by(c, is_impl=False)
|
||||
c.import_prefix = generated_import_prefix
|
||||
renderer.render(c, qll)
|
||||
@@ -458,7 +553,7 @@ def generate(opts, renderer):
|
||||
if cfg_out:
|
||||
cfg_classes_val = ql.CfgClasses(
|
||||
include_file_import=get_import(include_file, opts.root_dir),
|
||||
classes=cfg_classes
|
||||
classes=cfg_classes,
|
||||
)
|
||||
cfg_qll = cfg_out / "CfgNodes.qll"
|
||||
renderer.render(cfg_classes_val, cfg_qll)
|
||||
@@ -475,7 +570,7 @@ def generate(opts, renderer):
|
||||
if not renderer.is_customized_stub(stub_file):
|
||||
renderer.render(stub, stub_file)
|
||||
else:
|
||||
qldoc = renderer.render_str(stub, template='ql_stub_class_qldoc')
|
||||
qldoc = renderer.render_str(stub, template="ql_stub_class_qldoc")
|
||||
_patch_class_qldoc(c.name, qldoc, stub_file)
|
||||
class_public = _get_class_public(c)
|
||||
path_public = _get_path_public(c)
|
||||
@@ -484,18 +579,31 @@ def generate(opts, renderer):
|
||||
renderer.render(class_public, class_public_file)
|
||||
|
||||
# for example path/to/elements -> path/to/elements.qll
|
||||
renderer.render(ql.ImportList([i for name, i in imports.items() if name not in classes or not classes[name].internal]),
|
||||
include_file)
|
||||
renderer.render(
|
||||
ql.ImportList(
|
||||
[
|
||||
i
|
||||
for name, i in imports.items()
|
||||
if name not in classes or not classes[name].internal
|
||||
]
|
||||
),
|
||||
include_file,
|
||||
)
|
||||
|
||||
elements_module = get_import(include_file, opts.root_dir)
|
||||
|
||||
renderer.render(
|
||||
ql.GetParentImplementation(
|
||||
classes=list(classes.values()),
|
||||
imports=[elements_module] + [i for name,
|
||||
i in imports.items() if name in classes and classes[name].internal],
|
||||
imports=[elements_module]
|
||||
+ [
|
||||
i
|
||||
for name, i in imports.items()
|
||||
if name in classes and classes[name].internal
|
||||
],
|
||||
),
|
||||
out / 'ParentChild.qll')
|
||||
out / "ParentChild.qll",
|
||||
)
|
||||
|
||||
if test_out:
|
||||
for c in data.classes.values():
|
||||
@@ -507,39 +615,61 @@ def generate(opts, renderer):
|
||||
test_with = data.classes[test_with_name] if test_with_name else c
|
||||
test_dir = test_out / test_with.group / test_with.name
|
||||
test_dir.mkdir(parents=True, exist_ok=True)
|
||||
if all(f.suffix in (".txt", ".ql", ".actual", ".expected") for f in test_dir.glob("*.*")):
|
||||
if all(
|
||||
f.suffix in (".txt", ".ql", ".actual", ".expected")
|
||||
for f in test_dir.glob("*.*")
|
||||
):
|
||||
log.warning(f"no test source in {test_dir.relative_to(test_out)}")
|
||||
renderer.render(ql.MissingTestInstructions(),
|
||||
test_dir / missing_test_source_filename)
|
||||
renderer.render(
|
||||
ql.MissingTestInstructions(),
|
||||
test_dir / missing_test_source_filename,
|
||||
)
|
||||
continue
|
||||
total_props, partial_props = _partition(_get_all_properties_to_be_tested(c, data.classes),
|
||||
lambda p: p.is_total)
|
||||
renderer.render(ql.ClassTester(class_name=c.name,
|
||||
properties=total_props,
|
||||
elements_module=elements_module,
|
||||
# in case of collapsed hierarchies we want to see the actual QL class in results
|
||||
show_ql_class="qltest_collapse_hierarchy" in c.pragmas),
|
||||
test_dir / f"{c.name}.ql")
|
||||
total_props, partial_props = _partition(
|
||||
_get_all_properties_to_be_tested(c, data.classes),
|
||||
lambda p: p.is_total,
|
||||
)
|
||||
renderer.render(
|
||||
ql.ClassTester(
|
||||
class_name=c.name,
|
||||
properties=total_props,
|
||||
elements_module=elements_module,
|
||||
# in case of collapsed hierarchies we want to see the actual QL class in results
|
||||
show_ql_class="qltest_collapse_hierarchy" in c.pragmas,
|
||||
),
|
||||
test_dir / f"{c.name}.ql",
|
||||
)
|
||||
for p in partial_props:
|
||||
renderer.render(ql.PropertyTester(class_name=c.name,
|
||||
elements_module=elements_module,
|
||||
property=p), test_dir / f"{c.name}_{p.getter}.ql")
|
||||
renderer.render(
|
||||
ql.PropertyTester(
|
||||
class_name=c.name,
|
||||
elements_module=elements_module,
|
||||
property=p,
|
||||
),
|
||||
test_dir / f"{c.name}_{p.getter}.ql",
|
||||
)
|
||||
|
||||
final_synth_types = []
|
||||
non_final_synth_types = []
|
||||
constructor_imports = []
|
||||
synth_constructor_imports = []
|
||||
stubs = {}
|
||||
for cls in sorted((cls for cls in data.classes.values() if not cls.imported),
|
||||
key=lambda cls: (cls.group, cls.name)):
|
||||
for cls in sorted(
|
||||
(cls for cls in data.classes.values() if not cls.imported),
|
||||
key=lambda cls: (cls.group, cls.name),
|
||||
):
|
||||
synth_type = get_ql_synth_class(cls)
|
||||
if synth_type.is_final:
|
||||
final_synth_types.append(synth_type)
|
||||
if synth_type.has_params:
|
||||
stub_file = stub_out / cls.group / "internal" / f"{cls.name}Constructor.qll"
|
||||
stub_file = (
|
||||
stub_out / cls.group / "internal" / f"{cls.name}Constructor.qll"
|
||||
)
|
||||
if not renderer.is_customized_stub(stub_file):
|
||||
# stub rendering must be postponed as we might not have yet all subtracted synth types in `synth_type`
|
||||
stubs[stub_file] = ql.Synth.ConstructorStub(synth_type, import_prefix=generated_import_prefix)
|
||||
stubs[stub_file] = ql.Synth.ConstructorStub(
|
||||
synth_type, import_prefix=generated_import_prefix
|
||||
)
|
||||
constructor_import = get_import(stub_file, opts.root_dir)
|
||||
constructor_imports.append(constructor_import)
|
||||
if synth_type.is_synth:
|
||||
@@ -549,9 +679,20 @@ def generate(opts, renderer):
|
||||
|
||||
for stub_file, data in stubs.items():
|
||||
renderer.render(data, stub_file)
|
||||
renderer.render(ql.Synth.Types(root.name, generated_import_prefix,
|
||||
final_synth_types, non_final_synth_types), out / "Synth.qll")
|
||||
renderer.render(ql.ImportList(constructor_imports), out / "SynthConstructors.qll")
|
||||
renderer.render(ql.ImportList(synth_constructor_imports), out / "PureSynthConstructors.qll")
|
||||
renderer.render(
|
||||
ql.Synth.Types(
|
||||
root.name,
|
||||
generated_import_prefix,
|
||||
final_synth_types,
|
||||
non_final_synth_types,
|
||||
),
|
||||
out / "Synth.qll",
|
||||
)
|
||||
renderer.render(
|
||||
ql.ImportList(constructor_imports), out / "SynthConstructors.qll"
|
||||
)
|
||||
renderer.render(
|
||||
ql.ImportList(synth_constructor_imports), out / "PureSynthConstructors.qll"
|
||||
)
|
||||
if opts.ql_format:
|
||||
format(opts.codeql_binary, renderer.written)
|
||||
|
||||
@@ -55,7 +55,8 @@ def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
|
||||
|
||||
|
||||
def _get_properties(
|
||||
cls: schema.Class, lookup: dict[str, schema.ClassBase],
|
||||
cls: schema.Class,
|
||||
lookup: dict[str, schema.ClassBase],
|
||||
) -> typing.Iterable[tuple[schema.Class, schema.Property]]:
|
||||
for b in cls.bases:
|
||||
yield from _get_properties(lookup[b], lookup)
|
||||
@@ -92,8 +93,9 @@ class Processor:
|
||||
# only generate detached fields in the actual class defining them, not the derived ones
|
||||
if c is cls:
|
||||
# TODO lift this restriction if required (requires change in dbschemegen as well)
|
||||
assert c.derived or not p.is_single, \
|
||||
f"property {p.name} in concrete class marked as detached but not optional"
|
||||
assert (
|
||||
c.derived or not p.is_single
|
||||
), f"property {p.name} in concrete class marked as detached but not optional"
|
||||
detached_fields.append(_get_field(c, p))
|
||||
elif not cls.derived:
|
||||
# for non-detached ones, only generate fields in the concrete classes
|
||||
@@ -123,10 +125,12 @@ def generate(opts, renderer):
|
||||
processor = Processor(schemaloader.load_file(opts.schema))
|
||||
out = opts.rust_output
|
||||
groups = set()
|
||||
with renderer.manage(generated=out.rglob("*.rs"),
|
||||
stubs=(),
|
||||
registry=out / ".generated.list",
|
||||
force=opts.force) as renderer:
|
||||
with renderer.manage(
|
||||
generated=out.rglob("*.rs"),
|
||||
stubs=(),
|
||||
registry=out / ".generated.list",
|
||||
force=opts.force,
|
||||
) as renderer:
|
||||
for group, classes in processor.get_classes().items():
|
||||
group = group or "top"
|
||||
groups.add(group)
|
||||
|
||||
@@ -42,7 +42,9 @@ def _get_code(doc: list[str]) -> list[str]:
|
||||
code.append(f"// {line}")
|
||||
case _, True:
|
||||
code.append(line)
|
||||
assert not adding_code, "Unterminated code block in docstring:\n " + "\n ".join(doc)
|
||||
assert not adding_code, "Unterminated code block in docstring:\n " + "\n ".join(
|
||||
doc
|
||||
)
|
||||
if has_code:
|
||||
return code
|
||||
return []
|
||||
@@ -51,15 +53,19 @@ def _get_code(doc: list[str]) -> list[str]:
|
||||
def generate(opts, renderer):
|
||||
assert opts.ql_test_output
|
||||
schema = schemaloader.load_file(opts.schema)
|
||||
with renderer.manage(generated=opts.ql_test_output.rglob("gen_*.rs"),
|
||||
stubs=(),
|
||||
registry=opts.ql_test_output / ".generated_tests.list",
|
||||
force=opts.force) as renderer:
|
||||
with renderer.manage(
|
||||
generated=opts.ql_test_output.rglob("gen_*.rs"),
|
||||
stubs=(),
|
||||
registry=opts.ql_test_output / ".generated_tests.list",
|
||||
force=opts.force,
|
||||
) as renderer:
|
||||
for cls in schema.classes.values():
|
||||
if cls.imported:
|
||||
continue
|
||||
if (qlgen.should_skip_qltest(cls, schema.classes) or
|
||||
"rust_skip_doc_test" in cls.pragmas):
|
||||
if (
|
||||
qlgen.should_skip_qltest(cls, schema.classes)
|
||||
or "rust_skip_doc_test" in cls.pragmas
|
||||
):
|
||||
continue
|
||||
code = _get_code(cls.doc)
|
||||
for p in schema.iter_properties(cls.name):
|
||||
@@ -79,5 +85,10 @@ def generate(opts, renderer):
|
||||
code = [indent + l for l in code]
|
||||
test_with_name = typing.cast(str, cls.pragmas.get("qltest_test_with"))
|
||||
test_with = schema.classes[test_with_name] if test_with_name else cls
|
||||
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{test_name}.rs"
|
||||
test = (
|
||||
opts.ql_test_output
|
||||
/ test_with.group
|
||||
/ test_with.name
|
||||
/ f"gen_{test_name}.rs"
|
||||
)
|
||||
renderer.render(TestCode(code="\n".join(code), function=fn), test)
|
||||
|
||||
@@ -86,13 +86,18 @@ def generate(opts, renderer):
|
||||
for dir, entries in traps.items():
|
||||
dir = dir or pathlib.Path()
|
||||
relative_gen_dir = pathlib.Path(*[".." for _ in dir.parents])
|
||||
renderer.render(cpp.TrapList(entries, opts.dbscheme, trap_library, relative_gen_dir), out / dir / "TrapEntries")
|
||||
renderer.render(
|
||||
cpp.TrapList(entries, opts.dbscheme, trap_library, relative_gen_dir),
|
||||
out / dir / "TrapEntries",
|
||||
)
|
||||
|
||||
tags = []
|
||||
for tag in toposort_flatten(tag_graph):
|
||||
tags.append(cpp.Tag(
|
||||
name=get_tag_name(tag),
|
||||
bases=[get_tag_name(b) for b in sorted(tag_graph[tag])],
|
||||
id=tag,
|
||||
))
|
||||
tags.append(
|
||||
cpp.Tag(
|
||||
name=get_tag_name(tag),
|
||||
bases=[get_tag_name(b) for b in sorted(tag_graph[tag])],
|
||||
id=tag,
|
||||
)
|
||||
)
|
||||
renderer.render(cpp.TagList(tags, opts.dbscheme), out / "TrapTags")
|
||||
|
||||
@@ -4,20 +4,111 @@ from dataclasses import dataclass, field
|
||||
from typing import List, ClassVar
|
||||
|
||||
# taken from https://en.cppreference.com/w/cpp/keyword
|
||||
cpp_keywords = {"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", "atomic_commit", "atomic_noexcept",
|
||||
"auto", "bitand", "bitor", "bool", "break", "case", "catch", "char", "char8_t", "char16_t", "char32_t",
|
||||
"class", "compl", "concept", "const", "consteval", "constexpr", "constinit", "const_cast", "continue",
|
||||
"co_await", "co_return", "co_yield", "decltype", "default", "delete", "do", "double", "dynamic_cast",
|
||||
"else", "enum", "explicit", "export", "extern", "false", "float", "for", "friend", "goto", "if",
|
||||
"inline", "int", "long", "mutable", "namespace", "new", "noexcept", "not", "not_eq", "nullptr",
|
||||
"operator", "or", "or_eq", "private", "protected", "public", "reflexpr", "register", "reinterpret_cast",
|
||||
"requires", "return", "short", "signed", "sizeof", "static", "static_assert", "static_cast", "struct",
|
||||
"switch", "synchronized", "template", "this", "thread_local", "throw", "true", "try", "typedef",
|
||||
"typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile", "wchar_t", "while",
|
||||
"xor", "xor_eq"}
|
||||
cpp_keywords = {
|
||||
"alignas",
|
||||
"alignof",
|
||||
"and",
|
||||
"and_eq",
|
||||
"asm",
|
||||
"atomic_cancel",
|
||||
"atomic_commit",
|
||||
"atomic_noexcept",
|
||||
"auto",
|
||||
"bitand",
|
||||
"bitor",
|
||||
"bool",
|
||||
"break",
|
||||
"case",
|
||||
"catch",
|
||||
"char",
|
||||
"char8_t",
|
||||
"char16_t",
|
||||
"char32_t",
|
||||
"class",
|
||||
"compl",
|
||||
"concept",
|
||||
"const",
|
||||
"consteval",
|
||||
"constexpr",
|
||||
"constinit",
|
||||
"const_cast",
|
||||
"continue",
|
||||
"co_await",
|
||||
"co_return",
|
||||
"co_yield",
|
||||
"decltype",
|
||||
"default",
|
||||
"delete",
|
||||
"do",
|
||||
"double",
|
||||
"dynamic_cast",
|
||||
"else",
|
||||
"enum",
|
||||
"explicit",
|
||||
"export",
|
||||
"extern",
|
||||
"false",
|
||||
"float",
|
||||
"for",
|
||||
"friend",
|
||||
"goto",
|
||||
"if",
|
||||
"inline",
|
||||
"int",
|
||||
"long",
|
||||
"mutable",
|
||||
"namespace",
|
||||
"new",
|
||||
"noexcept",
|
||||
"not",
|
||||
"not_eq",
|
||||
"nullptr",
|
||||
"operator",
|
||||
"or",
|
||||
"or_eq",
|
||||
"private",
|
||||
"protected",
|
||||
"public",
|
||||
"reflexpr",
|
||||
"register",
|
||||
"reinterpret_cast",
|
||||
"requires",
|
||||
"return",
|
||||
"short",
|
||||
"signed",
|
||||
"sizeof",
|
||||
"static",
|
||||
"static_assert",
|
||||
"static_cast",
|
||||
"struct",
|
||||
"switch",
|
||||
"synchronized",
|
||||
"template",
|
||||
"this",
|
||||
"thread_local",
|
||||
"throw",
|
||||
"true",
|
||||
"try",
|
||||
"typedef",
|
||||
"typeid",
|
||||
"typename",
|
||||
"union",
|
||||
"unsigned",
|
||||
"using",
|
||||
"virtual",
|
||||
"void",
|
||||
"volatile",
|
||||
"wchar_t",
|
||||
"while",
|
||||
"xor",
|
||||
"xor_eq",
|
||||
}
|
||||
|
||||
_field_overrides = [
|
||||
(re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"), {"base_type": "unsigned"}),
|
||||
(
|
||||
re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"),
|
||||
{"base_type": "unsigned"},
|
||||
),
|
||||
(re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}),
|
||||
]
|
||||
|
||||
@@ -108,7 +199,7 @@ class Tag:
|
||||
|
||||
@dataclass
|
||||
class TrapList:
|
||||
template: ClassVar = 'trap_traps'
|
||||
template: ClassVar = "trap_traps"
|
||||
extensions = ["h", "cpp"]
|
||||
traps: List[Trap]
|
||||
source: str
|
||||
@@ -118,7 +209,7 @@ class TrapList:
|
||||
|
||||
@dataclass
|
||||
class TagList:
|
||||
template: ClassVar = 'trap_tags'
|
||||
template: ClassVar = "trap_tags"
|
||||
extensions = ["h"]
|
||||
|
||||
tags: List[Tag]
|
||||
@@ -127,7 +218,7 @@ class TagList:
|
||||
|
||||
@dataclass
|
||||
class ClassBase:
|
||||
ref: 'Class'
|
||||
ref: "Class"
|
||||
first: bool = False
|
||||
|
||||
|
||||
@@ -140,7 +231,9 @@ class Class:
|
||||
trap_name: str = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.bases = [ClassBase(c) for c in sorted(self.bases, key=lambda cls: cls.name)]
|
||||
self.bases = [
|
||||
ClassBase(c) for c in sorted(self.bases, key=lambda cls: cls.name)
|
||||
]
|
||||
if self.bases:
|
||||
self.bases[0].first = True
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
""" dbscheme format representation """
|
||||
"""dbscheme format representation"""
|
||||
|
||||
import logging
|
||||
import pathlib
|
||||
@@ -100,7 +100,7 @@ class SchemeInclude:
|
||||
|
||||
@dataclass
|
||||
class Scheme:
|
||||
template: ClassVar = 'dbscheme'
|
||||
template: ClassVar = "dbscheme"
|
||||
|
||||
src: str
|
||||
includes: List[SchemeInclude]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
""" module providing useful filesystem paths """
|
||||
"""module providing useful filesystem paths"""
|
||||
|
||||
import pathlib
|
||||
import sys
|
||||
@@ -7,13 +7,15 @@ import os
|
||||
_this_file = pathlib.Path(__file__).resolve()
|
||||
|
||||
try:
|
||||
workspace_dir = pathlib.Path(os.environ['BUILD_WORKSPACE_DIRECTORY']).resolve() # <- means we are using bazel run
|
||||
root_dir = workspace_dir / 'swift'
|
||||
workspace_dir = pathlib.Path(
|
||||
os.environ["BUILD_WORKSPACE_DIRECTORY"]
|
||||
).resolve() # <- means we are using bazel run
|
||||
root_dir = workspace_dir / "swift"
|
||||
except KeyError:
|
||||
root_dir = _this_file.parents[2]
|
||||
workspace_dir = root_dir.parent
|
||||
|
||||
lib_dir = _this_file.parents[2] / 'codegen' / 'lib'
|
||||
templates_dir = _this_file.parents[2] / 'codegen' / 'templates'
|
||||
lib_dir = _this_file.parents[2] / "codegen" / "lib"
|
||||
templates_dir = _this_file.parents[2] / "codegen" / "templates"
|
||||
|
||||
exe_file = pathlib.Path(sys.argv[0]).resolve()
|
||||
|
||||
@@ -100,7 +100,7 @@ class Base:
|
||||
|
||||
@dataclass
|
||||
class Class:
|
||||
template: ClassVar = 'ql_class'
|
||||
template: ClassVar = "ql_class"
|
||||
|
||||
name: str
|
||||
bases: List[Base] = field(default_factory=list)
|
||||
@@ -116,7 +116,12 @@ class Class:
|
||||
cfg: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
def get_bases(bases): return [Base(str(b), str(prev)) for b, prev in zip(bases, itertools.chain([""], bases))]
|
||||
def get_bases(bases):
|
||||
return [
|
||||
Base(str(b), str(prev))
|
||||
for b, prev in zip(bases, itertools.chain([""], bases))
|
||||
]
|
||||
|
||||
self.bases = get_bases(self.bases)
|
||||
self.bases_impl = get_bases(self.bases_impl)
|
||||
if self.properties:
|
||||
@@ -164,7 +169,7 @@ class SynthUnderlyingAccessor:
|
||||
|
||||
@dataclass
|
||||
class Stub:
|
||||
template: ClassVar = 'ql_stub'
|
||||
template: ClassVar = "ql_stub"
|
||||
|
||||
name: str
|
||||
base_import: str
|
||||
@@ -183,7 +188,7 @@ class Stub:
|
||||
|
||||
@dataclass
|
||||
class ClassPublic:
|
||||
template: ClassVar = 'ql_class_public'
|
||||
template: ClassVar = "ql_class_public"
|
||||
|
||||
name: str
|
||||
imports: List[str] = field(default_factory=list)
|
||||
@@ -197,7 +202,7 @@ class ClassPublic:
|
||||
|
||||
@dataclass
|
||||
class DbClasses:
|
||||
template: ClassVar = 'ql_db'
|
||||
template: ClassVar = "ql_db"
|
||||
|
||||
classes: List[Class] = field(default_factory=list)
|
||||
imports: List[str] = field(default_factory=list)
|
||||
@@ -205,14 +210,14 @@ class DbClasses:
|
||||
|
||||
@dataclass
|
||||
class ImportList:
|
||||
template: ClassVar = 'ql_imports'
|
||||
template: ClassVar = "ql_imports"
|
||||
|
||||
imports: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetParentImplementation:
|
||||
template: ClassVar = 'ql_parent'
|
||||
template: ClassVar = "ql_parent"
|
||||
|
||||
classes: List[Class] = field(default_factory=list)
|
||||
imports: List[str] = field(default_factory=list)
|
||||
@@ -234,7 +239,7 @@ class TesterBase:
|
||||
|
||||
@dataclass
|
||||
class ClassTester(TesterBase):
|
||||
template: ClassVar = 'ql_test_class'
|
||||
template: ClassVar = "ql_test_class"
|
||||
|
||||
properties: List[PropertyForTest] = field(default_factory=list)
|
||||
show_ql_class: bool = False
|
||||
@@ -242,14 +247,14 @@ class ClassTester(TesterBase):
|
||||
|
||||
@dataclass
|
||||
class PropertyTester(TesterBase):
|
||||
template: ClassVar = 'ql_test_property'
|
||||
template: ClassVar = "ql_test_property"
|
||||
|
||||
property: PropertyForTest
|
||||
|
||||
|
||||
@dataclass
|
||||
class MissingTestInstructions:
|
||||
template: ClassVar = 'ql_test_missing'
|
||||
template: ClassVar = "ql_test_missing"
|
||||
|
||||
|
||||
class Synth:
|
||||
@@ -306,7 +311,9 @@ class Synth:
|
||||
subtracted_synth_types: List["Synth.Class"] = field(default_factory=list)
|
||||
|
||||
def subtract_type(self, type: str):
|
||||
self.subtracted_synth_types.append(Synth.Class(type, first=not self.subtracted_synth_types))
|
||||
self.subtracted_synth_types.append(
|
||||
Synth.Class(type, first=not self.subtracted_synth_types)
|
||||
)
|
||||
|
||||
@property
|
||||
def has_subtracted_synth_types(self) -> bool:
|
||||
@@ -357,6 +364,6 @@ class CfgClass:
|
||||
|
||||
@dataclass
|
||||
class CfgClasses:
|
||||
template: ClassVar = 'ql_cfg_nodes'
|
||||
template: ClassVar = "ql_cfg_nodes"
|
||||
include_file_import: Optional[str] = None
|
||||
classes: List[CfgClass] = field(default_factory=list)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
""" template renderer module, wrapping around `pystache.Renderer`
|
||||
"""template renderer module, wrapping around `pystache.Renderer`
|
||||
|
||||
`pystache` is a python mustache engine, and mustache is a template language. More information on
|
||||
|
||||
@@ -23,14 +23,21 @@ class Error(Exception):
|
||||
|
||||
|
||||
class Renderer:
|
||||
""" Template renderer using mustache templates in the `templates` directory """
|
||||
"""Template renderer using mustache templates in the `templates` directory"""
|
||||
|
||||
def __init__(self, generator: pathlib.Path):
|
||||
self._r = pystache.Renderer(search_dirs=str(paths.templates_dir), escape=lambda u: u)
|
||||
self._r = pystache.Renderer(
|
||||
search_dirs=str(paths.templates_dir), escape=lambda u: u
|
||||
)
|
||||
self._generator = generator
|
||||
|
||||
def render(self, data: object, output: typing.Optional[pathlib.Path], template: typing.Optional[str] = None):
|
||||
""" Render `data` to `output`.
|
||||
def render(
|
||||
self,
|
||||
data: object,
|
||||
output: typing.Optional[pathlib.Path],
|
||||
template: typing.Optional[str] = None,
|
||||
):
|
||||
"""Render `data` to `output`.
|
||||
|
||||
`data` must have a `template` attribute denoting which template to use from the template directory.
|
||||
|
||||
@@ -58,13 +65,18 @@ class Renderer:
|
||||
out.write(contents)
|
||||
log.debug(f"{mnemonic}: generated {output.name}")
|
||||
|
||||
def manage(self, generated: typing.Iterable[pathlib.Path], stubs: typing.Iterable[pathlib.Path],
|
||||
registry: pathlib.Path, force: bool = False) -> "RenderManager":
|
||||
def manage(
|
||||
self,
|
||||
generated: typing.Iterable[pathlib.Path],
|
||||
stubs: typing.Iterable[pathlib.Path],
|
||||
registry: pathlib.Path,
|
||||
force: bool = False,
|
||||
) -> "RenderManager":
|
||||
return RenderManager(self._generator, generated, stubs, registry, force)
|
||||
|
||||
|
||||
class RenderManager(Renderer):
|
||||
""" A context manager allowing to manage checked in generated files and their cleanup, able
|
||||
"""A context manager allowing to manage checked in generated files and their cleanup, able
|
||||
to skip unneeded writes.
|
||||
|
||||
This is done by using and updating a checked in list of generated files that assigns two
|
||||
@@ -74,6 +86,7 @@ class RenderManager(Renderer):
|
||||
* the other is the hash of the actual file after code generation has finished. This will be
|
||||
different from the above because of post-processing like QL formatting. This hash is used
|
||||
to detect invalid modification of generated files"""
|
||||
|
||||
written: typing.Set[pathlib.Path]
|
||||
|
||||
@dataclass
|
||||
@@ -82,12 +95,18 @@ class RenderManager(Renderer):
|
||||
pre contains the hash of a file as rendered, post is the hash after
|
||||
postprocessing (for example QL formatting)
|
||||
"""
|
||||
|
||||
pre: str
|
||||
post: typing.Optional[str] = None
|
||||
|
||||
def __init__(self, generator: pathlib.Path, generated: typing.Iterable[pathlib.Path],
|
||||
stubs: typing.Iterable[pathlib.Path],
|
||||
registry: pathlib.Path, force: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
generator: pathlib.Path,
|
||||
generated: typing.Iterable[pathlib.Path],
|
||||
stubs: typing.Iterable[pathlib.Path],
|
||||
registry: pathlib.Path,
|
||||
force: bool = False,
|
||||
):
|
||||
super().__init__(generator)
|
||||
self._registry_path = registry
|
||||
self._force = force
|
||||
@@ -142,10 +161,14 @@ class RenderManager(Renderer):
|
||||
if self._force:
|
||||
pass
|
||||
elif rel_path not in self._hashes:
|
||||
log.warning(f"{rel_path} marked as generated but absent from the registry")
|
||||
log.warning(
|
||||
f"{rel_path} marked as generated but absent from the registry"
|
||||
)
|
||||
elif self._hashes[rel_path].post != self._hash_file(f):
|
||||
raise Error(f"{rel_path} is generated but was modified, please revert the file "
|
||||
"or pass --force to overwrite")
|
||||
raise Error(
|
||||
f"{rel_path} is generated but was modified, please revert the file "
|
||||
"or pass --force to overwrite"
|
||||
)
|
||||
|
||||
def _process_stubs(self, stubs: typing.Iterable[pathlib.Path]):
|
||||
for f in stubs:
|
||||
@@ -159,8 +182,10 @@ class RenderManager(Renderer):
|
||||
elif rel_path not in self._hashes:
|
||||
log.warning(f"{rel_path} marked as stub but absent from the registry")
|
||||
elif self._hashes[rel_path].post != self._hash_file(f):
|
||||
raise Error(f"{rel_path} is a stub marked as generated, but it was modified, "
|
||||
"please remove the `// generated` header, revert the file or pass --force to overwrite it")
|
||||
raise Error(
|
||||
f"{rel_path} is a stub marked as generated, but it was modified, "
|
||||
"please remove the `// generated` header, revert the file or pass --force to overwrite it"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_customized_stub(file: pathlib.Path) -> bool:
|
||||
@@ -191,13 +216,17 @@ class RenderManager(Renderer):
|
||||
for line in reg:
|
||||
if line.strip():
|
||||
filename, prehash, posthash = line.split()
|
||||
self._hashes[pathlib.Path(filename)] = self.Hashes(prehash, posthash)
|
||||
self._hashes[pathlib.Path(filename)] = self.Hashes(
|
||||
prehash, posthash
|
||||
)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def _dump_registry(self):
|
||||
self._registry_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self._registry_path, 'w') as out, open(self._registry_path.parent / ".gitattributes", "w") as attrs:
|
||||
with open(self._registry_path, "w") as out, open(
|
||||
self._registry_path.parent / ".gitattributes", "w"
|
||||
) as attrs:
|
||||
print(f"/{self._registry_path.name}", "linguist-generated", file=attrs)
|
||||
print("/.gitattributes", "linguist-generated", file=attrs)
|
||||
for f, hashes in sorted(self._hashes.items()):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
""" schema format representation """
|
||||
"""schema format representation"""
|
||||
|
||||
import abc
|
||||
import typing
|
||||
from collections.abc import Iterable
|
||||
@@ -52,7 +53,11 @@ class Property:
|
||||
|
||||
@property
|
||||
def is_repeated(self) -> bool:
|
||||
return self.kind in (self.Kind.REPEATED, self.Kind.REPEATED_OPTIONAL, self.Kind.REPEATED_UNORDERED)
|
||||
return self.kind in (
|
||||
self.Kind.REPEATED,
|
||||
self.Kind.REPEATED_OPTIONAL,
|
||||
self.Kind.REPEATED_UNORDERED,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_unordered(self) -> bool:
|
||||
@@ -74,10 +79,11 @@ class Property:
|
||||
SingleProperty = functools.partial(Property, Property.Kind.SINGLE)
|
||||
OptionalProperty = functools.partial(Property, Property.Kind.OPTIONAL)
|
||||
RepeatedProperty = functools.partial(Property, Property.Kind.REPEATED)
|
||||
RepeatedOptionalProperty = functools.partial(
|
||||
Property, Property.Kind.REPEATED_OPTIONAL)
|
||||
RepeatedOptionalProperty = functools.partial(Property, Property.Kind.REPEATED_OPTIONAL)
|
||||
PredicateProperty = functools.partial(Property, Property.Kind.PREDICATE)
|
||||
RepeatedUnorderedProperty = functools.partial(Property, Property.Kind.REPEATED_UNORDERED)
|
||||
RepeatedUnorderedProperty = functools.partial(
|
||||
Property, Property.Kind.REPEATED_UNORDERED
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -197,9 +203,9 @@ def _make_property(arg: object) -> Property:
|
||||
|
||||
|
||||
class PropertyModifier(abc.ABC):
|
||||
""" Modifier of `Property` objects.
|
||||
Being on the right of `|` it will trigger construction of a `Property` from
|
||||
the left operand.
|
||||
"""Modifier of `Property` objects.
|
||||
Being on the right of `|` it will trigger construction of a `Property` from
|
||||
the left operand.
|
||||
"""
|
||||
|
||||
def __ror__(self, other: object) -> Property:
|
||||
@@ -210,11 +216,9 @@ class PropertyModifier(abc.ABC):
|
||||
def __invert__(self) -> "PropertyModifier":
|
||||
return self.negate()
|
||||
|
||||
def modify(self, prop: Property):
|
||||
...
|
||||
def modify(self, prop: Property): ...
|
||||
|
||||
def negate(self) -> "PropertyModifier":
|
||||
...
|
||||
def negate(self) -> "PropertyModifier": ...
|
||||
|
||||
|
||||
def split_doc(doc):
|
||||
@@ -224,7 +228,11 @@ def split_doc(doc):
|
||||
lines = doc.splitlines()
|
||||
# Determine minimum indentation (first line doesn't count):
|
||||
strippedlines = (line.lstrip() for line in lines[1:])
|
||||
indents = [len(line) - len(stripped) for line, stripped in zip(lines[1:], strippedlines) if stripped]
|
||||
indents = [
|
||||
len(line) - len(stripped)
|
||||
for line, stripped in zip(lines[1:], strippedlines)
|
||||
if stripped
|
||||
]
|
||||
# Remove indentation (first line is special):
|
||||
trimmed = [lines[0].strip()]
|
||||
if indents:
|
||||
|
||||
@@ -39,7 +39,9 @@ class _DocModifier(_schema.PropertyModifier, metaclass=_DocModifierMetaclass):
|
||||
|
||||
def modify(self, prop: _schema.Property):
|
||||
if self.doc and ("\n" in self.doc or self.doc[-1] == "."):
|
||||
raise _schema.Error("No newlines or trailing dots are allowed in doc, did you intend to use desc?")
|
||||
raise _schema.Error(
|
||||
"No newlines or trailing dots are allowed in doc, did you intend to use desc?"
|
||||
)
|
||||
prop.doc = self.doc
|
||||
|
||||
def negate(self) -> _schema.PropertyModifier:
|
||||
@@ -73,10 +75,13 @@ imported = _schema.ImportedClass
|
||||
|
||||
@_dataclass
|
||||
class _Namespace:
|
||||
""" simple namespacing mechanism """
|
||||
"""simple namespacing mechanism"""
|
||||
|
||||
_name: str
|
||||
|
||||
def add(self, pragma: _Union["_PragmaBase", "_Parametrized"], key: str | None = None):
|
||||
def add(
|
||||
self, pragma: _Union["_PragmaBase", "_Parametrized"], key: str | None = None
|
||||
):
|
||||
self.__dict__[pragma.pragma] = pragma
|
||||
pragma.pragma = key or f"{self._name}_{pragma.pragma}"
|
||||
|
||||
@@ -110,15 +115,18 @@ class _PragmaBase:
|
||||
|
||||
@_dataclass
|
||||
class _ClassPragma(_PragmaBase):
|
||||
""" A class pragma.
|
||||
"""A class pragma.
|
||||
For schema classes it acts as a python decorator with `@`.
|
||||
"""
|
||||
|
||||
inherited: bool = False
|
||||
|
||||
def __call__(self, cls: type) -> type:
|
||||
""" use this pragma as a decorator on classes """
|
||||
"""use this pragma as a decorator on classes"""
|
||||
if self.inherited:
|
||||
setattr(cls, f"{_schema.inheritable_pragma_prefix}{self.pragma}", self.value)
|
||||
setattr(
|
||||
cls, f"{_schema.inheritable_pragma_prefix}{self.pragma}", self.value
|
||||
)
|
||||
else:
|
||||
# not using hasattr as we don't want to land on inherited pragmas
|
||||
if "_pragmas" not in cls.__dict__:
|
||||
@@ -129,9 +137,10 @@ class _ClassPragma(_PragmaBase):
|
||||
|
||||
@_dataclass
|
||||
class _PropertyPragma(_PragmaBase, _schema.PropertyModifier):
|
||||
""" A property pragma.
|
||||
"""A property pragma.
|
||||
It functions similarly to a `_PropertyModifier` with `|`, adding the pragma.
|
||||
"""
|
||||
|
||||
remove: bool = False
|
||||
|
||||
def modify(self, prop: _schema.Property):
|
||||
@@ -149,21 +158,23 @@ class _PropertyPragma(_PragmaBase, _schema.PropertyModifier):
|
||||
|
||||
@_dataclass
|
||||
class _Pragma(_ClassPragma, _PropertyPragma):
|
||||
""" A class or property pragma.
|
||||
"""A class or property pragma.
|
||||
For properties, it functions similarly to a `_PropertyModifier` with `|`, adding the pragma.
|
||||
For schema classes it acts as a python decorator with `@`.
|
||||
"""
|
||||
|
||||
|
||||
class _Parametrized[P, **Q, T]:
|
||||
""" A parametrized pragma.
|
||||
"""A parametrized pragma.
|
||||
Needs to be applied to a parameter to give a pragma.
|
||||
"""
|
||||
|
||||
def __init__(self, pragma_instance: P, factory: _Callable[Q, T]):
|
||||
self.pragma_instance = pragma_instance
|
||||
self.factory = factory
|
||||
self.__signature__ = _inspect.signature(self.factory).replace(return_annotation=type(self.pragma_instance))
|
||||
self.__signature__ = _inspect.signature(self.factory).replace(
|
||||
return_annotation=type(self.pragma_instance)
|
||||
)
|
||||
|
||||
@property
|
||||
def pragma(self):
|
||||
@@ -187,7 +198,8 @@ class _Optionalizer(_schema.PropertyModifier):
|
||||
K = _schema.Property.Kind
|
||||
if prop.kind != K.SINGLE:
|
||||
raise _schema.Error(
|
||||
"optional should only be applied to simple property types")
|
||||
"optional should only be applied to simple property types"
|
||||
)
|
||||
prop.kind = K.OPTIONAL
|
||||
|
||||
|
||||
@@ -200,7 +212,8 @@ class _Listifier(_schema.PropertyModifier):
|
||||
prop.kind = K.REPEATED_OPTIONAL
|
||||
else:
|
||||
raise _schema.Error(
|
||||
"list should only be applied to simple or optional property types")
|
||||
"list should only be applied to simple or optional property types"
|
||||
)
|
||||
|
||||
|
||||
class _Setifier(_schema.PropertyModifier):
|
||||
@@ -212,7 +225,7 @@ class _Setifier(_schema.PropertyModifier):
|
||||
|
||||
|
||||
class _TypeModifier:
|
||||
""" Modifies types using get item notation """
|
||||
"""Modifies types using get item notation"""
|
||||
|
||||
def __init__(self, modifier: _schema.PropertyModifier):
|
||||
self.modifier = modifier
|
||||
@@ -242,7 +255,11 @@ use_for_null = _ClassPragma("null")
|
||||
qltest.add(_ClassPragma("skip"))
|
||||
qltest.add(_ClassPragma("collapse_hierarchy"))
|
||||
qltest.add(_ClassPragma("uncollapse_hierarchy"))
|
||||
qltest.add(_Parametrized(_ClassPragma("test_with", inherited=True), factory=_schema.get_type_name))
|
||||
qltest.add(
|
||||
_Parametrized(
|
||||
_ClassPragma("test_with", inherited=True), factory=_schema.get_type_name
|
||||
)
|
||||
)
|
||||
|
||||
ql.add(_Parametrized(_ClassPragma("default_doc_name"), factory=lambda doc: doc))
|
||||
ql.add(_ClassPragma("hideable", inherited=True))
|
||||
@@ -255,15 +272,33 @@ cpp.add(_Pragma("skip"))
|
||||
rust.add(_PropertyPragma("detach"))
|
||||
rust.add(_Pragma("skip_doc_test"))
|
||||
|
||||
rust.add(_Parametrized(_ClassPragma("doc_test_signature"), factory=lambda signature: signature))
|
||||
rust.add(
|
||||
_Parametrized(
|
||||
_ClassPragma("doc_test_signature"), factory=lambda signature: signature
|
||||
)
|
||||
)
|
||||
|
||||
group = _Parametrized(_ClassPragma("group", inherited=True), factory=lambda group: group)
|
||||
group = _Parametrized(
|
||||
_ClassPragma("group", inherited=True), factory=lambda group: group
|
||||
)
|
||||
|
||||
|
||||
synth.add(_Parametrized(_ClassPragma("from_class"), factory=lambda ref: _schema.SynthInfo(
|
||||
from_class=_schema.get_type_name(ref))), key="synth")
|
||||
synth.add(_Parametrized(_ClassPragma("on_arguments"), factory=lambda **kwargs:
|
||||
_schema.SynthInfo(on_arguments={k: _schema.get_type_name(t) for k, t in kwargs.items()})), key="synth")
|
||||
synth.add(
|
||||
_Parametrized(
|
||||
_ClassPragma("from_class"),
|
||||
factory=lambda ref: _schema.SynthInfo(from_class=_schema.get_type_name(ref)),
|
||||
),
|
||||
key="synth",
|
||||
)
|
||||
synth.add(
|
||||
_Parametrized(
|
||||
_ClassPragma("on_arguments"),
|
||||
factory=lambda **kwargs: _schema.SynthInfo(
|
||||
on_arguments={k: _schema.get_type_name(t) for k, t in kwargs.items()}
|
||||
),
|
||||
),
|
||||
key="synth",
|
||||
)
|
||||
|
||||
|
||||
@_dataclass(frozen=True)
|
||||
@@ -283,7 +318,12 @@ _ = _PropertyModifierList(())
|
||||
drop = object()
|
||||
|
||||
|
||||
def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, replace_bases: _Dict[type, type] | None = None, cfg: bool = False) -> _Callable[[type], _PropertyModifierList]:
|
||||
def annotate(
|
||||
annotated_cls: type,
|
||||
add_bases: _Iterable[type] | None = None,
|
||||
replace_bases: _Dict[type, type] | None = None,
|
||||
cfg: bool = False,
|
||||
) -> _Callable[[type], _PropertyModifierList]:
|
||||
"""
|
||||
Add or modify schema annotations after a class has been defined previously.
|
||||
|
||||
@@ -291,6 +331,7 @@ def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, repl
|
||||
|
||||
`replace_bases` can be used to replace bases on the annotated class.
|
||||
"""
|
||||
|
||||
def decorator(cls: type) -> _PropertyModifierList:
|
||||
if cls.__name__ != "_":
|
||||
raise _schema.Error("Annotation classes must be named _")
|
||||
@@ -299,7 +340,9 @@ def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, repl
|
||||
for p, v in cls.__dict__.get("_pragmas", {}).items():
|
||||
_ClassPragma(p, value=v)(annotated_cls)
|
||||
if replace_bases:
|
||||
annotated_cls.__bases__ = tuple(replace_bases.get(b, b) for b in annotated_cls.__bases__)
|
||||
annotated_cls.__bases__ = tuple(
|
||||
replace_bases.get(b, b) for b in annotated_cls.__bases__
|
||||
)
|
||||
if add_bases:
|
||||
annotated_cls.__bases__ += tuple(add_bases)
|
||||
annotated_cls.__cfg__ = cfg
|
||||
@@ -312,9 +355,12 @@ def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, repl
|
||||
elif p in annotated_cls.__annotations__:
|
||||
annotated_cls.__annotations__[p] |= a
|
||||
elif isinstance(a, (_PropertyModifierList, _PropertyModifierList)):
|
||||
raise _schema.Error(f"annotated property {p} not present in annotated class "
|
||||
f"{annotated_cls.__name__}")
|
||||
raise _schema.Error(
|
||||
f"annotated property {p} not present in annotated class "
|
||||
f"{annotated_cls.__name__}"
|
||||
)
|
||||
else:
|
||||
annotated_cls.__annotations__[p] = a
|
||||
return _
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -12,9 +12,13 @@ class _Re:
|
||||
"|"
|
||||
r"^(?P<union>@\w+)\s*=\s*(?P<unionbody>@\w+(?:\s*\|\s*@\w+)*)\s*;?"
|
||||
)
|
||||
field = re.compile(r"(?m)[\w\s]*\s(?P<field>\w+)\s*:\s*(?P<type>@?\w+)(?P<ref>\s+ref)?")
|
||||
field = re.compile(
|
||||
r"(?m)[\w\s]*\s(?P<field>\w+)\s*:\s*(?P<type>@?\w+)(?P<ref>\s+ref)?"
|
||||
)
|
||||
key = re.compile(r"@\w+")
|
||||
comment = re.compile(r"(?m)(?s)/\*.*?\*/|//(?!dir=)[^\n]*$") # lookahead avoid ignoring metadata like //dir=foo
|
||||
comment = re.compile(
|
||||
r"(?m)(?s)/\*.*?\*/|//(?!dir=)[^\n]*$"
|
||||
) # lookahead avoid ignoring metadata like //dir=foo
|
||||
|
||||
|
||||
def _get_column(match):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
""" schema loader """
|
||||
"""schema loader"""
|
||||
|
||||
import sys
|
||||
|
||||
import inflection
|
||||
@@ -33,37 +34,56 @@ def _get_class(cls: type) -> schema.Class:
|
||||
raise schema.Error(f"Only class definitions allowed in schema, found {cls}")
|
||||
# we must check that going to dbscheme names and back is preserved
|
||||
# In particular this will not happen if uppercase acronyms are included in the name
|
||||
to_underscore_and_back = inflection.camelize(inflection.underscore(cls.__name__), uppercase_first_letter=True)
|
||||
to_underscore_and_back = inflection.camelize(
|
||||
inflection.underscore(cls.__name__), uppercase_first_letter=True
|
||||
)
|
||||
if cls.__name__ != to_underscore_and_back:
|
||||
raise schema.Error(f"Class name must be upper camel-case, without capitalized acronyms, found {cls.__name__} "
|
||||
f"instead of {to_underscore_and_back}")
|
||||
if len({g for g in (getattr(b, f"{schema.inheritable_pragma_prefix}group", None)
|
||||
for b in cls.__bases__) if g}) > 1:
|
||||
raise schema.Error(
|
||||
f"Class name must be upper camel-case, without capitalized acronyms, found {cls.__name__} "
|
||||
f"instead of {to_underscore_and_back}"
|
||||
)
|
||||
if (
|
||||
len(
|
||||
{
|
||||
g
|
||||
for g in (
|
||||
getattr(b, f"{schema.inheritable_pragma_prefix}group", None)
|
||||
for b in cls.__bases__
|
||||
)
|
||||
if g
|
||||
}
|
||||
)
|
||||
> 1
|
||||
):
|
||||
raise schema.Error(f"Bases with mixed groups for {cls.__name__}")
|
||||
pragmas = {
|
||||
# dir and getattr inherit from bases
|
||||
a[len(schema.inheritable_pragma_prefix):]: getattr(cls, a)
|
||||
for a in dir(cls) if a.startswith(schema.inheritable_pragma_prefix)
|
||||
a[len(schema.inheritable_pragma_prefix) :]: getattr(cls, a)
|
||||
for a in dir(cls)
|
||||
if a.startswith(schema.inheritable_pragma_prefix)
|
||||
}
|
||||
pragmas |= cls.__dict__.get("_pragmas", {})
|
||||
derived = {d.__name__ for d in cls.__subclasses__()}
|
||||
if "null" in pragmas and derived:
|
||||
raise schema.Error(f"Null class cannot be derived")
|
||||
return schema.Class(name=cls.__name__,
|
||||
bases=[b.__name__ for b in cls.__bases__ if b is not object],
|
||||
derived=derived,
|
||||
pragmas=pragmas,
|
||||
cfg=cls.__cfg__ if hasattr(cls, "__cfg__") else False,
|
||||
# in the following we don't use `getattr` to avoid inheriting
|
||||
properties=[
|
||||
a | _PropertyNamer(n)
|
||||
for n, a in cls.__dict__.get("__annotations__", {}).items()
|
||||
],
|
||||
doc=schema.split_doc(cls.__doc__),
|
||||
)
|
||||
return schema.Class(
|
||||
name=cls.__name__,
|
||||
bases=[b.__name__ for b in cls.__bases__ if b is not object],
|
||||
derived=derived,
|
||||
pragmas=pragmas,
|
||||
cfg=cls.__cfg__ if hasattr(cls, "__cfg__") else False,
|
||||
# in the following we don't use `getattr` to avoid inheriting
|
||||
properties=[
|
||||
a | _PropertyNamer(n)
|
||||
for n, a in cls.__dict__.get("__annotations__", {}).items()
|
||||
],
|
||||
doc=schema.split_doc(cls.__doc__),
|
||||
)
|
||||
|
||||
|
||||
def _toposort_classes_by_group(classes: typing.Dict[str, schema.Class]) -> typing.Dict[str, schema.Class]:
|
||||
def _toposort_classes_by_group(
|
||||
classes: typing.Dict[str, schema.Class],
|
||||
) -> typing.Dict[str, schema.Class]:
|
||||
groups = {}
|
||||
ret = {}
|
||||
|
||||
@@ -79,7 +99,7 @@ def _toposort_classes_by_group(classes: typing.Dict[str, schema.Class]) -> typin
|
||||
|
||||
|
||||
def _fill_synth_information(classes: typing.Dict[str, schema.Class]):
|
||||
""" Take a dictionary where the `synth` field is filled for all explicitly synthesized classes
|
||||
"""Take a dictionary where the `synth` field is filled for all explicitly synthesized classes
|
||||
and update it so that all non-final classes that have only synthesized final descendants
|
||||
get `True` as` value for the `synth` field
|
||||
"""
|
||||
@@ -109,7 +129,7 @@ def _fill_synth_information(classes: typing.Dict[str, schema.Class]):
|
||||
|
||||
|
||||
def _fill_hideable_information(classes: typing.Dict[str, schema.Class]):
|
||||
""" Update the class map propagating the `hideable` attribute upwards in the hierarchy """
|
||||
"""Update the class map propagating the `hideable` attribute upwards in the hierarchy"""
|
||||
todo = [cls for cls in classes.values() if "ql_hideable" in cls.pragmas]
|
||||
while todo:
|
||||
cls = todo.pop()
|
||||
@@ -123,10 +143,14 @@ def _fill_hideable_information(classes: typing.Dict[str, schema.Class]):
|
||||
def _check_test_with(classes: typing.Dict[str, schema.Class]):
|
||||
for cls in classes.values():
|
||||
test_with = typing.cast(str, cls.pragmas.get("qltest_test_with"))
|
||||
transitive_test_with = test_with and classes[test_with].pragmas.get("qltest_test_with")
|
||||
transitive_test_with = test_with and classes[test_with].pragmas.get(
|
||||
"qltest_test_with"
|
||||
)
|
||||
if test_with and transitive_test_with:
|
||||
raise schema.Error(f"{cls.name} has test_with {test_with} which in turn "
|
||||
f"has test_with {transitive_test_with}, use that directly")
|
||||
raise schema.Error(
|
||||
f"{cls.name} has test_with {test_with} which in turn "
|
||||
f"has test_with {transitive_test_with}, use that directly"
|
||||
)
|
||||
|
||||
|
||||
def load(m: types.ModuleType) -> schema.Schema:
|
||||
@@ -136,6 +160,7 @@ def load(m: types.ModuleType) -> schema.Schema:
|
||||
known = {"int", "string", "boolean"}
|
||||
known.update(n for n in m.__dict__ if not n.startswith("__"))
|
||||
import misc.codegen.lib.schemadefs as defs
|
||||
|
||||
null = None
|
||||
for name, data in m.__dict__.items():
|
||||
if hasattr(defs, name):
|
||||
@@ -152,21 +177,26 @@ def load(m: types.ModuleType) -> schema.Schema:
|
||||
continue
|
||||
cls = _get_class(data)
|
||||
if classes and not cls.bases:
|
||||
raise schema.Error(
|
||||
f"Only one root class allowed, found second root {name}")
|
||||
raise schema.Error(f"Only one root class allowed, found second root {name}")
|
||||
cls.check_types(known)
|
||||
classes[name] = cls
|
||||
if "null" in cls.pragmas:
|
||||
del cls.pragmas["null"]
|
||||
if null is not None:
|
||||
raise schema.Error(f"Null class {null} already defined, second null class {name} not allowed")
|
||||
raise schema.Error(
|
||||
f"Null class {null} already defined, second null class {name} not allowed"
|
||||
)
|
||||
null = name
|
||||
|
||||
_fill_synth_information(classes)
|
||||
_fill_hideable_information(classes)
|
||||
_check_test_with(classes)
|
||||
|
||||
return schema.Schema(includes=includes, classes=imported_classes | _toposort_classes_by_group(classes), null=null)
|
||||
return schema.Schema(
|
||||
includes=includes,
|
||||
classes=imported_classes | _toposort_classes_by_group(classes),
|
||||
null=null,
|
||||
)
|
||||
|
||||
|
||||
def load_file(path: pathlib.Path) -> schema.Schema:
|
||||
|
||||
@@ -17,34 +17,49 @@ def test_field_name():
|
||||
assert f.field_name == "foo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type,expected", [
|
||||
("std::string", "trapQuoted(value)"),
|
||||
("bool", '(value ? "true" : "false")'),
|
||||
("something_else", "value"),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"type,expected",
|
||||
[
|
||||
("std::string", "trapQuoted(value)"),
|
||||
("bool", '(value ? "true" : "false")'),
|
||||
("something_else", "value"),
|
||||
],
|
||||
)
|
||||
def test_field_get_streamer(type, expected):
|
||||
f = cpp.Field("name", type)
|
||||
assert f.get_streamer()("value") == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_optional,is_repeated,is_predicate,expected", [
|
||||
(False, False, False, True),
|
||||
(True, False, False, False),
|
||||
(False, True, False, False),
|
||||
(True, True, False, False),
|
||||
(False, False, True, False),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"is_optional,is_repeated,is_predicate,expected",
|
||||
[
|
||||
(False, False, False, True),
|
||||
(True, False, False, False),
|
||||
(False, True, False, False),
|
||||
(True, True, False, False),
|
||||
(False, False, True, False),
|
||||
],
|
||||
)
|
||||
def test_field_is_single(is_optional, is_repeated, is_predicate, expected):
|
||||
f = cpp.Field("name", "type", is_optional=is_optional, is_repeated=is_repeated, is_predicate=is_predicate)
|
||||
f = cpp.Field(
|
||||
"name",
|
||||
"type",
|
||||
is_optional=is_optional,
|
||||
is_repeated=is_repeated,
|
||||
is_predicate=is_predicate,
|
||||
)
|
||||
assert f.is_single is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_optional,is_repeated,expected", [
|
||||
(False, False, "bar"),
|
||||
(True, False, "std::optional<bar>"),
|
||||
(False, True, "std::vector<bar>"),
|
||||
(True, True, "std::vector<std::optional<bar>>"),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"is_optional,is_repeated,expected",
|
||||
[
|
||||
(False, False, "bar"),
|
||||
(True, False, "std::optional<bar>"),
|
||||
(False, True, "std::vector<bar>"),
|
||||
(True, True, "std::vector<std::optional<bar>>"),
|
||||
],
|
||||
)
|
||||
def test_field_modal_types(is_optional, is_repeated, expected):
|
||||
f = cpp.Field("name", "bar", is_optional=is_optional, is_repeated=is_repeated)
|
||||
assert f.type == expected
|
||||
@@ -69,11 +84,9 @@ def test_tag_has_first_base_marked():
|
||||
assert t.bases == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bases,expected", [
|
||||
([], False),
|
||||
(["a"], True),
|
||||
(["a", "b"], True)
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"bases,expected", [([], False), (["a"], True), (["a", "b"], True)]
|
||||
)
|
||||
def test_tag_has_bases(bases, expected):
|
||||
t = cpp.Tag("name", bases, "id")
|
||||
assert t.has_bases is expected
|
||||
@@ -91,11 +104,9 @@ def test_class_has_first_base_marked():
|
||||
assert c.bases == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bases,expected", [
|
||||
([], False),
|
||||
(["a"], True),
|
||||
(["a", "b"], True)
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"bases,expected", [([], False), (["a"], True), (["a", "b"], True)]
|
||||
)
|
||||
def test_class_has_bases(bases, expected):
|
||||
t = cpp.Class("name", [cpp.Class(b) for b in bases])
|
||||
assert t.has_bases is expected
|
||||
@@ -113,5 +124,5 @@ def test_class_single_fields():
|
||||
assert c.single_fields == fields[::2]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__] + sys.argv[1:]))
|
||||
|
||||
@@ -18,7 +18,10 @@ def generate_grouped(opts, renderer, input):
|
||||
assert isinstance(g, cpp.ClassList), f
|
||||
assert g.include_parent is (f.parent != output_dir)
|
||||
assert f.name == "TrapClasses", f
|
||||
return {str(f.parent.relative_to(output_dir)): g.classes for f, g in generated.items()}
|
||||
return {
|
||||
str(f.parent.relative_to(output_dir)): g.classes
|
||||
for f, g in generated.items()
|
||||
}
|
||||
|
||||
return ret
|
||||
|
||||
@@ -38,129 +41,193 @@ def test_empty(generate):
|
||||
|
||||
|
||||
def test_empty_class(generate):
|
||||
assert generate([
|
||||
schema.Class(name="MyClass"),
|
||||
]) == [
|
||||
cpp.Class(name="MyClass", final=True, trap_name="MyClasses")
|
||||
]
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(name="MyClass"),
|
||||
]
|
||||
) == [cpp.Class(name="MyClass", final=True, trap_name="MyClasses")]
|
||||
|
||||
|
||||
def test_two_class_hierarchy(generate):
|
||||
base = cpp.Class(name="A")
|
||||
assert generate([
|
||||
schema.Class(name="A", derived={"B"}),
|
||||
schema.Class(name="B", bases=["A"]),
|
||||
]) == [
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(name="A", derived={"B"}),
|
||||
schema.Class(name="B", bases=["A"]),
|
||||
]
|
||||
) == [
|
||||
base,
|
||||
cpp.Class(name="B", bases=[base], final=True, trap_name="Bs"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type,expected", [
|
||||
("a", "a"),
|
||||
("string", "std::string"),
|
||||
("boolean", "bool"),
|
||||
("MyClass", "TrapLabel<MyClassTag>"),
|
||||
])
|
||||
@pytest.mark.parametrize("property_cls,optional,repeated,unordered,trap_name", [
|
||||
(schema.SingleProperty, False, False, False, None),
|
||||
(schema.OptionalProperty, True, False, False, "MyClassProps"),
|
||||
(schema.RepeatedProperty, False, True, False, "MyClassProps"),
|
||||
(schema.RepeatedOptionalProperty, True, True, False, "MyClassProps"),
|
||||
(schema.RepeatedUnorderedProperty, False, True, True, "MyClassProps"),
|
||||
])
|
||||
def test_class_with_field(generate, type, expected, property_cls, optional, repeated, unordered, trap_name):
|
||||
assert generate([
|
||||
schema.Class(name="MyClass", properties=[property_cls("prop", type)]),
|
||||
]) == [
|
||||
cpp.Class(name="MyClass",
|
||||
fields=[cpp.Field("prop", expected, is_optional=optional,
|
||||
is_repeated=repeated, is_unordered=unordered, trap_name=trap_name)],
|
||||
trap_name="MyClasses",
|
||||
final=True)
|
||||
@pytest.mark.parametrize(
|
||||
"type,expected",
|
||||
[
|
||||
("a", "a"),
|
||||
("string", "std::string"),
|
||||
("boolean", "bool"),
|
||||
("MyClass", "TrapLabel<MyClassTag>"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"property_cls,optional,repeated,unordered,trap_name",
|
||||
[
|
||||
(schema.SingleProperty, False, False, False, None),
|
||||
(schema.OptionalProperty, True, False, False, "MyClassProps"),
|
||||
(schema.RepeatedProperty, False, True, False, "MyClassProps"),
|
||||
(schema.RepeatedOptionalProperty, True, True, False, "MyClassProps"),
|
||||
(schema.RepeatedUnorderedProperty, False, True, True, "MyClassProps"),
|
||||
],
|
||||
)
|
||||
def test_class_with_field(
|
||||
generate, type, expected, property_cls, optional, repeated, unordered, trap_name
|
||||
):
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(name="MyClass", properties=[property_cls("prop", type)]),
|
||||
]
|
||||
) == [
|
||||
cpp.Class(
|
||||
name="MyClass",
|
||||
fields=[
|
||||
cpp.Field(
|
||||
"prop",
|
||||
expected,
|
||||
is_optional=optional,
|
||||
is_repeated=repeated,
|
||||
is_unordered=unordered,
|
||||
trap_name=trap_name,
|
||||
)
|
||||
],
|
||||
trap_name="MyClasses",
|
||||
final=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_class_field_with_null(generate, input):
|
||||
input.null = "Null"
|
||||
a = cpp.Class(name="A")
|
||||
assert generate([
|
||||
schema.Class(name="A", derived={"B"}),
|
||||
schema.Class(name="B", bases=["A"], properties=[
|
||||
schema.SingleProperty("x", "A"),
|
||||
schema.SingleProperty("y", "B"),
|
||||
])
|
||||
]) == [
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(name="A", derived={"B"}),
|
||||
schema.Class(
|
||||
name="B",
|
||||
bases=["A"],
|
||||
properties=[
|
||||
schema.SingleProperty("x", "A"),
|
||||
schema.SingleProperty("y", "B"),
|
||||
],
|
||||
),
|
||||
]
|
||||
) == [
|
||||
a,
|
||||
cpp.Class(name="B", bases=[a], final=True, trap_name="Bs",
|
||||
fields=[
|
||||
cpp.Field("x", "TrapLabel<ATag>"),
|
||||
cpp.Field("y", "TrapLabel<BOrNoneTag>"),
|
||||
]),
|
||||
cpp.Class(
|
||||
name="B",
|
||||
bases=[a],
|
||||
final=True,
|
||||
trap_name="Bs",
|
||||
fields=[
|
||||
cpp.Field("x", "TrapLabel<ATag>"),
|
||||
cpp.Field("y", "TrapLabel<BOrNoneTag>"),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def test_class_with_predicate(generate):
|
||||
assert generate([
|
||||
schema.Class(name="MyClass", properties=[
|
||||
schema.PredicateProperty("prop")]),
|
||||
]) == [
|
||||
cpp.Class(name="MyClass",
|
||||
fields=[
|
||||
cpp.Field("prop", "bool", trap_name="MyClassProp", is_predicate=True)],
|
||||
trap_name="MyClasses",
|
||||
final=True)
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(name="MyClass", properties=[schema.PredicateProperty("prop")]),
|
||||
]
|
||||
) == [
|
||||
cpp.Class(
|
||||
name="MyClass",
|
||||
fields=[
|
||||
cpp.Field("prop", "bool", trap_name="MyClassProp", is_predicate=True)
|
||||
],
|
||||
trap_name="MyClasses",
|
||||
final=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name",
|
||||
["start_line", "start_column", "end_line", "end_column", "index", "num_whatever", "width"])
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
[
|
||||
"start_line",
|
||||
"start_column",
|
||||
"end_line",
|
||||
"end_column",
|
||||
"index",
|
||||
"num_whatever",
|
||||
"width",
|
||||
],
|
||||
)
|
||||
def test_class_with_overridden_unsigned_field(generate, name):
|
||||
assert generate([
|
||||
schema.Class(name="MyClass", properties=[
|
||||
schema.SingleProperty(name, "bar")]),
|
||||
]) == [
|
||||
cpp.Class(name="MyClass",
|
||||
fields=[cpp.Field(name, "unsigned")],
|
||||
trap_name="MyClasses",
|
||||
final=True)
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
name="MyClass", properties=[schema.SingleProperty(name, "bar")]
|
||||
),
|
||||
]
|
||||
) == [
|
||||
cpp.Class(
|
||||
name="MyClass",
|
||||
fields=[cpp.Field(name, "unsigned")],
|
||||
trap_name="MyClasses",
|
||||
final=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_class_with_overridden_underscore_field(generate):
|
||||
assert generate([
|
||||
schema.Class(name="MyClass", properties=[
|
||||
schema.SingleProperty("something_", "bar")]),
|
||||
]) == [
|
||||
cpp.Class(name="MyClass",
|
||||
fields=[cpp.Field("something", "bar")],
|
||||
trap_name="MyClasses",
|
||||
final=True)
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
name="MyClass", properties=[schema.SingleProperty("something_", "bar")]
|
||||
),
|
||||
]
|
||||
) == [
|
||||
cpp.Class(
|
||||
name="MyClass",
|
||||
fields=[cpp.Field("something", "bar")],
|
||||
trap_name="MyClasses",
|
||||
final=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", cpp.cpp_keywords)
|
||||
def test_class_with_keyword_field(generate, name):
|
||||
assert generate([
|
||||
schema.Class(name="MyClass", properties=[
|
||||
schema.SingleProperty(name, "bar")]),
|
||||
]) == [
|
||||
cpp.Class(name="MyClass",
|
||||
fields=[cpp.Field(name + "_", "bar")],
|
||||
trap_name="MyClasses",
|
||||
final=True)
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
name="MyClass", properties=[schema.SingleProperty(name, "bar")]
|
||||
),
|
||||
]
|
||||
) == [
|
||||
cpp.Class(
|
||||
name="MyClass",
|
||||
fields=[cpp.Field(name + "_", "bar")],
|
||||
trap_name="MyClasses",
|
||||
final=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_classes_with_dirs(generate_grouped):
|
||||
cbase = cpp.Class(name="CBase")
|
||||
assert generate_grouped([
|
||||
schema.Class(name="A"),
|
||||
schema.Class(name="B", pragmas={"group": "foo"}),
|
||||
schema.Class(name="CBase", derived={"C"}, pragmas={"group": "bar"}),
|
||||
schema.Class(name="C", bases=["CBase"], pragmas={"group": "bar"}),
|
||||
schema.Class(name="D", pragmas={"group": "foo/bar/baz"}),
|
||||
]) == {
|
||||
assert generate_grouped(
|
||||
[
|
||||
schema.Class(name="A"),
|
||||
schema.Class(name="B", pragmas={"group": "foo"}),
|
||||
schema.Class(name="CBase", derived={"C"}, pragmas={"group": "bar"}),
|
||||
schema.Class(name="C", bases=["CBase"], pragmas={"group": "bar"}),
|
||||
schema.Class(name="D", pragmas={"group": "foo/bar/baz"}),
|
||||
]
|
||||
) == {
|
||||
".": [cpp.Class(name="A", trap_name="As", final=True)],
|
||||
"foo": [cpp.Class(name="B", trap_name="Bs", final=True)],
|
||||
"bar": [cbase, cpp.Class(name="C", bases=[cbase], trap_name="Cs", final=True)],
|
||||
@@ -169,81 +236,126 @@ def test_classes_with_dirs(generate_grouped):
|
||||
|
||||
|
||||
def test_cpp_skip_pragma(generate):
|
||||
assert generate([
|
||||
schema.Class(name="A", properties=[
|
||||
schema.SingleProperty("x", "foo"),
|
||||
schema.SingleProperty("y", "bar", pragmas=["x", "cpp_skip", "y"]),
|
||||
])
|
||||
]) == [
|
||||
cpp.Class(name="A", final=True, trap_name="As", fields=[
|
||||
cpp.Field("x", "foo"),
|
||||
]),
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
name="A",
|
||||
properties=[
|
||||
schema.SingleProperty("x", "foo"),
|
||||
schema.SingleProperty("y", "bar", pragmas=["x", "cpp_skip", "y"]),
|
||||
],
|
||||
)
|
||||
]
|
||||
) == [
|
||||
cpp.Class(
|
||||
name="A",
|
||||
final=True,
|
||||
trap_name="As",
|
||||
fields=[
|
||||
cpp.Field("x", "foo"),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def test_synth_classes_ignored(generate):
|
||||
assert generate([
|
||||
schema.Class(
|
||||
name="W",
|
||||
pragmas={"synth": schema.SynthInfo()},
|
||||
),
|
||||
schema.Class(
|
||||
name="X",
|
||||
pragmas={"synth": schema.SynthInfo(from_class="A")},
|
||||
),
|
||||
schema.Class(
|
||||
name="Y",
|
||||
pragmas={"synth": schema.SynthInfo(on_arguments={"a": "A", "b": "int"})},
|
||||
),
|
||||
schema.Class(
|
||||
name="Z",
|
||||
),
|
||||
]) == [
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
name="W",
|
||||
pragmas={"synth": schema.SynthInfo()},
|
||||
),
|
||||
schema.Class(
|
||||
name="X",
|
||||
pragmas={"synth": schema.SynthInfo(from_class="A")},
|
||||
),
|
||||
schema.Class(
|
||||
name="Y",
|
||||
pragmas={
|
||||
"synth": schema.SynthInfo(on_arguments={"a": "A", "b": "int"})
|
||||
},
|
||||
),
|
||||
schema.Class(
|
||||
name="Z",
|
||||
),
|
||||
]
|
||||
) == [
|
||||
cpp.Class(name="Z", final=True, trap_name="Zs"),
|
||||
]
|
||||
|
||||
|
||||
def test_synth_properties_ignored(generate):
|
||||
assert generate([
|
||||
schema.Class(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
name="X",
|
||||
properties=[
|
||||
schema.SingleProperty("x", "a"),
|
||||
schema.SingleProperty("y", "b", synth=True),
|
||||
schema.SingleProperty("z", "c"),
|
||||
schema.OptionalProperty("foo", "bar", synth=True),
|
||||
schema.RepeatedProperty("baz", "bazz", synth=True),
|
||||
schema.RepeatedOptionalProperty("bazzz", "bazzzz", synth=True),
|
||||
schema.RepeatedUnorderedProperty("bazzzzz", "bazzzzzz", synth=True),
|
||||
],
|
||||
),
|
||||
]
|
||||
) == [
|
||||
cpp.Class(
|
||||
name="X",
|
||||
properties=[
|
||||
schema.SingleProperty("x", "a"),
|
||||
schema.SingleProperty("y", "b", synth=True),
|
||||
schema.SingleProperty("z", "c"),
|
||||
schema.OptionalProperty("foo", "bar", synth=True),
|
||||
schema.RepeatedProperty("baz", "bazz", synth=True),
|
||||
schema.RepeatedOptionalProperty("bazzz", "bazzzz", synth=True),
|
||||
schema.RepeatedUnorderedProperty("bazzzzz", "bazzzzzz", synth=True),
|
||||
final=True,
|
||||
trap_name="Xes",
|
||||
fields=[
|
||||
cpp.Field("x", "a"),
|
||||
cpp.Field("z", "c"),
|
||||
],
|
||||
),
|
||||
]) == [
|
||||
cpp.Class(name="X", final=True, trap_name="Xes", fields=[
|
||||
cpp.Field("x", "a"),
|
||||
cpp.Field("z", "c"),
|
||||
]),
|
||||
]
|
||||
|
||||
|
||||
def test_properties_with_custom_db_table_names(generate):
|
||||
assert generate([
|
||||
schema.Class("Obj", properties=[
|
||||
schema.OptionalProperty("x", "a", pragmas={"ql_db_table_name": "foo"}),
|
||||
schema.RepeatedProperty("y", "b", pragmas={"ql_db_table_name": "bar"}),
|
||||
schema.RepeatedOptionalProperty("z", "c", pragmas={"ql_db_table_name": "baz"}),
|
||||
schema.PredicateProperty("p", pragmas={"ql_db_table_name": "hello"}),
|
||||
schema.RepeatedUnorderedProperty("q", "d", pragmas={"ql_db_table_name": "world"}),
|
||||
]),
|
||||
]) == [
|
||||
cpp.Class(name="Obj", final=True, trap_name="Objs", fields=[
|
||||
cpp.Field("x", "a", is_optional=True, trap_name="Foo"),
|
||||
cpp.Field("y", "b", is_repeated=True, trap_name="Bar"),
|
||||
cpp.Field("z", "c", is_repeated=True, is_optional=True, trap_name="Baz"),
|
||||
cpp.Field("p", "bool", is_predicate=True, trap_name="Hello"),
|
||||
cpp.Field("q", "d", is_repeated=True, is_unordered=True, trap_name="World"),
|
||||
]),
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
"Obj",
|
||||
properties=[
|
||||
schema.OptionalProperty(
|
||||
"x", "a", pragmas={"ql_db_table_name": "foo"}
|
||||
),
|
||||
schema.RepeatedProperty(
|
||||
"y", "b", pragmas={"ql_db_table_name": "bar"}
|
||||
),
|
||||
schema.RepeatedOptionalProperty(
|
||||
"z", "c", pragmas={"ql_db_table_name": "baz"}
|
||||
),
|
||||
schema.PredicateProperty(
|
||||
"p", pragmas={"ql_db_table_name": "hello"}
|
||||
),
|
||||
schema.RepeatedUnorderedProperty(
|
||||
"q", "d", pragmas={"ql_db_table_name": "world"}
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
) == [
|
||||
cpp.Class(
|
||||
name="Obj",
|
||||
final=True,
|
||||
trap_name="Objs",
|
||||
fields=[
|
||||
cpp.Field("x", "a", is_optional=True, trap_name="Foo"),
|
||||
cpp.Field("y", "b", is_repeated=True, trap_name="Bar"),
|
||||
cpp.Field(
|
||||
"z", "c", is_repeated=True, is_optional=True, trap_name="Baz"
|
||||
),
|
||||
cpp.Field("p", "bool", is_predicate=True, trap_name="Hello"),
|
||||
cpp.Field(
|
||||
"q", "d", is_repeated=True, is_unordered=True, trap_name="World"
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__] + sys.argv[1:]))
|
||||
|
||||
@@ -14,12 +14,15 @@ def test_dbcolumn_keyword_name(keyword):
|
||||
assert dbscheme.Column(keyword, "some_type").name == keyword + "_"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type,binding,lhstype,rhstype", [
|
||||
("builtin_type", False, "builtin_type", "builtin_type ref"),
|
||||
("builtin_type", True, "builtin_type", "builtin_type ref"),
|
||||
("@at_type", False, "int", "@at_type ref"),
|
||||
("@at_type", True, "unique int", "@at_type"),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"type,binding,lhstype,rhstype",
|
||||
[
|
||||
("builtin_type", False, "builtin_type", "builtin_type ref"),
|
||||
("builtin_type", True, "builtin_type", "builtin_type ref"),
|
||||
("@at_type", False, "int", "@at_type ref"),
|
||||
("@at_type", True, "unique int", "@at_type"),
|
||||
],
|
||||
)
|
||||
def test_dbcolumn_types(type, binding, lhstype, rhstype):
|
||||
col = dbscheme.Column("foo", type, binding)
|
||||
assert col.lhstype == lhstype
|
||||
@@ -34,7 +37,11 @@ def test_keyset_has_first_id_marked():
|
||||
|
||||
|
||||
def test_table_has_first_column_marked():
|
||||
columns = [dbscheme.Column("a", "x"), dbscheme.Column("b", "y", binding=True), dbscheme.Column("c", "z")]
|
||||
columns = [
|
||||
dbscheme.Column("a", "x"),
|
||||
dbscheme.Column("b", "y", binding=True),
|
||||
dbscheme.Column("c", "z"),
|
||||
]
|
||||
expected = deepcopy(columns)
|
||||
table = dbscheme.Table("foo", columns)
|
||||
expected[0].first = True
|
||||
@@ -48,5 +55,5 @@ def test_union_has_first_case_marked():
|
||||
assert [c.type for c in u.rhs] == rhs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__] + sys.argv[1:]))
|
||||
|
||||
@@ -8,10 +8,12 @@ from misc.codegen.test.utils import *
|
||||
InputExpectedPair = collections.namedtuple("InputExpectedPair", ("input", "expected"))
|
||||
|
||||
|
||||
@pytest.fixture(params=[
|
||||
InputExpectedPair(None, None),
|
||||
InputExpectedPair("foodir", pathlib.Path("foodir")),
|
||||
])
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
InputExpectedPair(None, None),
|
||||
InputExpectedPair("foodir", pathlib.Path("foodir")),
|
||||
]
|
||||
)
|
||||
def dir_param(request):
|
||||
return request.param
|
||||
|
||||
@@ -21,7 +23,7 @@ def generate(opts, input, renderer):
|
||||
def func(classes, null=None):
|
||||
input.classes = {cls.name: cls for cls in classes}
|
||||
input.null = null
|
||||
(out, data), = run_generation(dbschemegen.generate, opts, renderer).items()
|
||||
((out, data),) = run_generation(dbschemegen.generate, opts, renderer).items()
|
||||
assert out is opts.dbscheme
|
||||
return data
|
||||
|
||||
@@ -48,23 +50,26 @@ def test_includes(input, opts, generate):
|
||||
dbscheme.SchemeInclude(
|
||||
src=pathlib.Path(i),
|
||||
data=i + " data",
|
||||
) for i in includes
|
||||
)
|
||||
for i in includes
|
||||
],
|
||||
declarations=[],
|
||||
)
|
||||
|
||||
|
||||
def test_empty_final_class(generate, dir_param):
|
||||
assert generate([
|
||||
schema.Class("Object", pragmas={"group": dir_param.input}),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class("Object", pragmas={"group": dir_param.input}),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.Table(
|
||||
name="objects",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object', binding=True),
|
||||
dbscheme.Column("id", "@object", binding=True),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
)
|
||||
@@ -73,218 +78,279 @@ def test_empty_final_class(generate, dir_param):
|
||||
|
||||
|
||||
def test_final_class_with_single_scalar_field(generate, dir_param):
|
||||
assert generate([
|
||||
schema.Class("Object", pragmas={"group": dir_param.input}, properties=[
|
||||
schema.SingleProperty("foo", "bar"),
|
||||
]),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
"Object",
|
||||
pragmas={"group": dir_param.input},
|
||||
properties=[
|
||||
schema.SingleProperty("foo", "bar"),
|
||||
],
|
||||
),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.Table(
|
||||
name="objects",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object', binding=True),
|
||||
dbscheme.Column('foo', 'bar'),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object", binding=True),
|
||||
dbscheme.Column("foo", "bar"),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_final_class_with_single_class_field(generate, dir_param):
|
||||
assert generate([
|
||||
schema.Class("Object", pragmas={"group": dir_param.input}, properties=[
|
||||
schema.SingleProperty("foo", "Bar"),
|
||||
]),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
"Object",
|
||||
pragmas={"group": dir_param.input},
|
||||
properties=[
|
||||
schema.SingleProperty("foo", "Bar"),
|
||||
],
|
||||
),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.Table(
|
||||
name="objects",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object', binding=True),
|
||||
dbscheme.Column('foo', '@bar'),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object", binding=True),
|
||||
dbscheme.Column("foo", "@bar"),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_final_class_with_optional_field(generate, dir_param):
|
||||
assert generate([
|
||||
schema.Class("Object", pragmas={"group": dir_param.input}, properties=[
|
||||
schema.OptionalProperty("foo", "bar"),
|
||||
]),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
"Object",
|
||||
pragmas={"group": dir_param.input},
|
||||
properties=[
|
||||
schema.OptionalProperty("foo", "bar"),
|
||||
],
|
||||
),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.Table(
|
||||
name="objects",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object', binding=True),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object", binding=True),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="object_foos",
|
||||
keyset=dbscheme.KeySet(["id"]),
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object'),
|
||||
dbscheme.Column('foo', 'bar'),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object"),
|
||||
dbscheme.Column("foo", "bar"),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("property_cls", [schema.RepeatedProperty, schema.RepeatedOptionalProperty])
|
||||
@pytest.mark.parametrize(
|
||||
"property_cls", [schema.RepeatedProperty, schema.RepeatedOptionalProperty]
|
||||
)
|
||||
def test_final_class_with_repeated_field(generate, property_cls, dir_param):
|
||||
assert generate([
|
||||
schema.Class("Object", pragmas={"group": dir_param.input}, properties=[
|
||||
property_cls("foo", "bar"),
|
||||
]),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
"Object",
|
||||
pragmas={"group": dir_param.input},
|
||||
properties=[
|
||||
property_cls("foo", "bar"),
|
||||
],
|
||||
),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.Table(
|
||||
name="objects",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object', binding=True),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object", binding=True),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="object_foos",
|
||||
keyset=dbscheme.KeySet(["id", "index"]),
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object'),
|
||||
dbscheme.Column('index', 'int'),
|
||||
dbscheme.Column('foo', 'bar'),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object"),
|
||||
dbscheme.Column("index", "int"),
|
||||
dbscheme.Column("foo", "bar"),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_final_class_with_repeated_unordered_field(generate, dir_param):
|
||||
assert generate([
|
||||
schema.Class("Object", pragmas={"group": dir_param.input}, properties=[
|
||||
schema.RepeatedUnorderedProperty("foo", "bar"),
|
||||
]),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
"Object",
|
||||
pragmas={"group": dir_param.input},
|
||||
properties=[
|
||||
schema.RepeatedUnorderedProperty("foo", "bar"),
|
||||
],
|
||||
),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.Table(
|
||||
name="objects",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object', binding=True),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object", binding=True),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="object_foos",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object'),
|
||||
dbscheme.Column('foo', 'bar'),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object"),
|
||||
dbscheme.Column("foo", "bar"),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_final_class_with_predicate_field(generate, dir_param):
|
||||
assert generate([
|
||||
schema.Class("Object", pragmas={"group": dir_param.input}, properties=[
|
||||
schema.PredicateProperty("foo"),
|
||||
]),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
"Object",
|
||||
pragmas={"group": dir_param.input},
|
||||
properties=[
|
||||
schema.PredicateProperty("foo"),
|
||||
],
|
||||
),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.Table(
|
||||
name="objects",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object', binding=True),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object", binding=True),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="object_foo",
|
||||
keyset=dbscheme.KeySet(["id"]),
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object'),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object"),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_final_class_with_more_fields(generate, dir_param):
|
||||
assert generate([
|
||||
schema.Class("Object", pragmas={"group": dir_param.input}, properties=[
|
||||
schema.SingleProperty("one", "x"),
|
||||
schema.SingleProperty("two", "y"),
|
||||
schema.OptionalProperty("three", "z"),
|
||||
schema.RepeatedProperty("four", "u"),
|
||||
schema.RepeatedOptionalProperty("five", "v"),
|
||||
schema.PredicateProperty("six"),
|
||||
]),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
"Object",
|
||||
pragmas={"group": dir_param.input},
|
||||
properties=[
|
||||
schema.SingleProperty("one", "x"),
|
||||
schema.SingleProperty("two", "y"),
|
||||
schema.OptionalProperty("three", "z"),
|
||||
schema.RepeatedProperty("four", "u"),
|
||||
schema.RepeatedOptionalProperty("five", "v"),
|
||||
schema.PredicateProperty("six"),
|
||||
],
|
||||
),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.Table(
|
||||
name="objects",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object', binding=True),
|
||||
dbscheme.Column('one', 'x'),
|
||||
dbscheme.Column('two', 'y'),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object", binding=True),
|
||||
dbscheme.Column("one", "x"),
|
||||
dbscheme.Column("two", "y"),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="object_threes",
|
||||
keyset=dbscheme.KeySet(["id"]),
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object'),
|
||||
dbscheme.Column('three', 'z'),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object"),
|
||||
dbscheme.Column("three", "z"),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="object_fours",
|
||||
keyset=dbscheme.KeySet(["id", "index"]),
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object'),
|
||||
dbscheme.Column('index', 'int'),
|
||||
dbscheme.Column('four', 'u'),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object"),
|
||||
dbscheme.Column("index", "int"),
|
||||
dbscheme.Column("four", "u"),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="object_fives",
|
||||
keyset=dbscheme.KeySet(["id", "index"]),
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object'),
|
||||
dbscheme.Column('index', 'int'),
|
||||
dbscheme.Column('five', 'v'),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object"),
|
||||
dbscheme.Column("index", "int"),
|
||||
dbscheme.Column("five", "v"),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="object_six",
|
||||
keyset=dbscheme.KeySet(["id"]),
|
||||
columns=[
|
||||
dbscheme.Column('id', '@object'),
|
||||
], dir=dir_param.expected,
|
||||
dbscheme.Column("id", "@object"),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_empty_class_with_derived(generate):
|
||||
assert generate([
|
||||
schema.Class(name="Base", derived={"Left", "Right"}),
|
||||
schema.Class(name="Left", bases=["Base"]),
|
||||
schema.Class(name="Right", bases=["Base"]),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(name="Base", derived={"Left", "Right"}),
|
||||
schema.Class(name="Left", bases=["Base"]),
|
||||
schema.Class(name="Right", bases=["Base"]),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
@@ -305,17 +371,20 @@ def test_empty_class_with_derived(generate):
|
||||
|
||||
|
||||
def test_class_with_derived_and_single_property(generate, dir_param):
|
||||
assert generate([
|
||||
schema.Class(
|
||||
name="Base",
|
||||
derived={"Left", "Right"},
|
||||
pragmas={"group": dir_param.input},
|
||||
properties=[
|
||||
schema.SingleProperty("single", "Prop"),
|
||||
]),
|
||||
schema.Class(name="Left", bases=["Base"]),
|
||||
schema.Class(name="Right", bases=["Base"]),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
name="Base",
|
||||
derived={"Left", "Right"},
|
||||
pragmas={"group": dir_param.input},
|
||||
properties=[
|
||||
schema.SingleProperty("single", "Prop"),
|
||||
],
|
||||
),
|
||||
schema.Class(name="Left", bases=["Base"]),
|
||||
schema.Class(name="Right", bases=["Base"]),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
@@ -327,8 +396,8 @@ def test_class_with_derived_and_single_property(generate, dir_param):
|
||||
name="bases",
|
||||
keyset=dbscheme.KeySet(["id"]),
|
||||
columns=[
|
||||
dbscheme.Column('id', '@base'),
|
||||
dbscheme.Column('single', '@prop'),
|
||||
dbscheme.Column("id", "@base"),
|
||||
dbscheme.Column("single", "@prop"),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
@@ -345,17 +414,20 @@ def test_class_with_derived_and_single_property(generate, dir_param):
|
||||
|
||||
|
||||
def test_class_with_derived_and_optional_property(generate, dir_param):
|
||||
assert generate([
|
||||
schema.Class(
|
||||
name="Base",
|
||||
derived={"Left", "Right"},
|
||||
pragmas={"group": dir_param.input},
|
||||
properties=[
|
||||
schema.OptionalProperty("opt", "Prop"),
|
||||
]),
|
||||
schema.Class(name="Left", bases=["Base"]),
|
||||
schema.Class(name="Right", bases=["Base"]),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
name="Base",
|
||||
derived={"Left", "Right"},
|
||||
pragmas={"group": dir_param.input},
|
||||
properties=[
|
||||
schema.OptionalProperty("opt", "Prop"),
|
||||
],
|
||||
),
|
||||
schema.Class(name="Left", bases=["Base"]),
|
||||
schema.Class(name="Right", bases=["Base"]),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
@@ -367,8 +439,8 @@ def test_class_with_derived_and_optional_property(generate, dir_param):
|
||||
name="base_opts",
|
||||
keyset=dbscheme.KeySet(["id"]),
|
||||
columns=[
|
||||
dbscheme.Column('id', '@base'),
|
||||
dbscheme.Column('opt', '@prop'),
|
||||
dbscheme.Column("id", "@base"),
|
||||
dbscheme.Column("opt", "@prop"),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
@@ -385,17 +457,20 @@ def test_class_with_derived_and_optional_property(generate, dir_param):
|
||||
|
||||
|
||||
def test_class_with_derived_and_repeated_property(generate, dir_param):
|
||||
assert generate([
|
||||
schema.Class(
|
||||
name="Base",
|
||||
pragmas={"group": dir_param.input},
|
||||
derived={"Left", "Right"},
|
||||
properties=[
|
||||
schema.RepeatedProperty("rep", "Prop"),
|
||||
]),
|
||||
schema.Class(name="Left", bases=["Base"]),
|
||||
schema.Class(name="Right", bases=["Base"]),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
name="Base",
|
||||
pragmas={"group": dir_param.input},
|
||||
derived={"Left", "Right"},
|
||||
properties=[
|
||||
schema.RepeatedProperty("rep", "Prop"),
|
||||
],
|
||||
),
|
||||
schema.Class(name="Left", bases=["Base"]),
|
||||
schema.Class(name="Right", bases=["Base"]),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
@@ -407,9 +482,9 @@ def test_class_with_derived_and_repeated_property(generate, dir_param):
|
||||
name="base_reps",
|
||||
keyset=dbscheme.KeySet(["id", "index"]),
|
||||
columns=[
|
||||
dbscheme.Column('id', '@base'),
|
||||
dbscheme.Column('index', 'int'),
|
||||
dbscheme.Column('rep', '@prop'),
|
||||
dbscheme.Column("id", "@base"),
|
||||
dbscheme.Column("index", "int"),
|
||||
dbscheme.Column("rep", "@prop"),
|
||||
],
|
||||
dir=dir_param.expected,
|
||||
),
|
||||
@@ -426,38 +501,41 @@ def test_class_with_derived_and_repeated_property(generate, dir_param):
|
||||
|
||||
|
||||
def test_null_class(generate):
|
||||
assert generate([
|
||||
schema.Class(
|
||||
name="Base",
|
||||
derived={"W", "X", "Y", "Z", "Null"},
|
||||
),
|
||||
schema.Class(
|
||||
name="W",
|
||||
bases=["Base"],
|
||||
properties=[
|
||||
schema.SingleProperty("w", "W"),
|
||||
schema.SingleProperty("x", "X"),
|
||||
schema.OptionalProperty("y", "Y"),
|
||||
schema.RepeatedProperty("z", "Z"),
|
||||
]
|
||||
),
|
||||
schema.Class(
|
||||
name="X",
|
||||
bases=["Base"],
|
||||
),
|
||||
schema.Class(
|
||||
name="Y",
|
||||
bases=["Base"],
|
||||
),
|
||||
schema.Class(
|
||||
name="Z",
|
||||
bases=["Base"],
|
||||
),
|
||||
schema.Class(
|
||||
name="Null",
|
||||
bases=["Base"],
|
||||
),
|
||||
], null="Null") == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
name="Base",
|
||||
derived={"W", "X", "Y", "Z", "Null"},
|
||||
),
|
||||
schema.Class(
|
||||
name="W",
|
||||
bases=["Base"],
|
||||
properties=[
|
||||
schema.SingleProperty("w", "W"),
|
||||
schema.SingleProperty("x", "X"),
|
||||
schema.OptionalProperty("y", "Y"),
|
||||
schema.RepeatedProperty("z", "Z"),
|
||||
],
|
||||
),
|
||||
schema.Class(
|
||||
name="X",
|
||||
bases=["Base"],
|
||||
),
|
||||
schema.Class(
|
||||
name="Y",
|
||||
bases=["Base"],
|
||||
),
|
||||
schema.Class(
|
||||
name="Z",
|
||||
bases=["Base"],
|
||||
),
|
||||
schema.Class(
|
||||
name="Null",
|
||||
bases=["Base"],
|
||||
),
|
||||
],
|
||||
null="Null",
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
@@ -468,50 +546,50 @@ def test_null_class(generate):
|
||||
dbscheme.Table(
|
||||
name="ws",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@w', binding=True),
|
||||
dbscheme.Column('w', '@w_or_none'),
|
||||
dbscheme.Column('x', '@x_or_none'),
|
||||
dbscheme.Column("id", "@w", binding=True),
|
||||
dbscheme.Column("w", "@w_or_none"),
|
||||
dbscheme.Column("x", "@x_or_none"),
|
||||
],
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="w_ies",
|
||||
keyset=dbscheme.KeySet(["id"]),
|
||||
columns=[
|
||||
dbscheme.Column('id', '@w'),
|
||||
dbscheme.Column('y', '@y_or_none'),
|
||||
dbscheme.Column("id", "@w"),
|
||||
dbscheme.Column("y", "@y_or_none"),
|
||||
],
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="w_zs",
|
||||
keyset=dbscheme.KeySet(["id", "index"]),
|
||||
columns=[
|
||||
dbscheme.Column('id', '@w'),
|
||||
dbscheme.Column('index', 'int'),
|
||||
dbscheme.Column('z', '@z_or_none'),
|
||||
dbscheme.Column("id", "@w"),
|
||||
dbscheme.Column("index", "int"),
|
||||
dbscheme.Column("z", "@z_or_none"),
|
||||
],
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="xes",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@x', binding=True),
|
||||
dbscheme.Column("id", "@x", binding=True),
|
||||
],
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="ys",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@y', binding=True),
|
||||
dbscheme.Column("id", "@y", binding=True),
|
||||
],
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="zs",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@z', binding=True),
|
||||
dbscheme.Column("id", "@z", binding=True),
|
||||
],
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="nulls",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@null', binding=True),
|
||||
dbscheme.Column("id", "@null", binding=True),
|
||||
],
|
||||
),
|
||||
dbscheme.Union(
|
||||
@@ -535,11 +613,15 @@ def test_null_class(generate):
|
||||
|
||||
|
||||
def test_synth_classes_ignored(generate):
|
||||
assert generate([
|
||||
schema.Class(name="A", pragmas={"synth": schema.SynthInfo()}),
|
||||
schema.Class(name="B", pragmas={"synth": schema.SynthInfo(from_class="A")}),
|
||||
schema.Class(name="C", pragmas={"synth": schema.SynthInfo(on_arguments={"x": "A"})}),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(name="A", pragmas={"synth": schema.SynthInfo()}),
|
||||
schema.Class(name="B", pragmas={"synth": schema.SynthInfo(from_class="A")}),
|
||||
schema.Class(
|
||||
name="C", pragmas={"synth": schema.SynthInfo(on_arguments={"x": "A"})}
|
||||
),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[],
|
||||
@@ -547,11 +629,13 @@ def test_synth_classes_ignored(generate):
|
||||
|
||||
|
||||
def test_synth_derived_classes_ignored(generate):
|
||||
assert generate([
|
||||
schema.Class(name="A", derived={"B", "C"}),
|
||||
schema.Class(name="B", bases=["A"], pragmas={"synth": schema.SynthInfo()}),
|
||||
schema.Class(name="C", bases=["A"]),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(name="A", derived={"B", "C"}),
|
||||
schema.Class(name="B", bases=["A"], pragmas={"synth": schema.SynthInfo()}),
|
||||
schema.Class(name="C", bases=["A"]),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
@@ -561,23 +645,28 @@ def test_synth_derived_classes_ignored(generate):
|
||||
columns=[
|
||||
dbscheme.Column("id", "@c", binding=True),
|
||||
],
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_synth_properties_ignored(generate):
|
||||
assert generate([
|
||||
schema.Class(name="A", properties=[
|
||||
schema.SingleProperty("x", "a"),
|
||||
schema.SingleProperty("y", "b", synth=True),
|
||||
schema.SingleProperty("z", "c"),
|
||||
schema.OptionalProperty("foo", "bar", synth=True),
|
||||
schema.RepeatedProperty("baz", "bazz", synth=True),
|
||||
schema.RepeatedOptionalProperty("bazzz", "bazzzz", synth=True),
|
||||
schema.RepeatedUnorderedProperty("bazzzzz", "bazzzzzz", synth=True),
|
||||
]),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
name="A",
|
||||
properties=[
|
||||
schema.SingleProperty("x", "a"),
|
||||
schema.SingleProperty("y", "b", synth=True),
|
||||
schema.SingleProperty("z", "c"),
|
||||
schema.OptionalProperty("foo", "bar", synth=True),
|
||||
schema.RepeatedProperty("baz", "bazz", synth=True),
|
||||
schema.RepeatedOptionalProperty("bazzz", "bazzzz", synth=True),
|
||||
schema.RepeatedUnorderedProperty("bazzzzz", "bazzzzzz", synth=True),
|
||||
],
|
||||
),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
@@ -595,24 +684,44 @@ def test_synth_properties_ignored(generate):
|
||||
|
||||
def test_table_conflict(generate):
|
||||
with pytest.raises(dbschemegen.Error):
|
||||
generate([
|
||||
schema.Class("Foo", properties=[
|
||||
schema.OptionalProperty("bar", "FooBar"),
|
||||
]),
|
||||
schema.Class("FooBar"),
|
||||
])
|
||||
generate(
|
||||
[
|
||||
schema.Class(
|
||||
"Foo",
|
||||
properties=[
|
||||
schema.OptionalProperty("bar", "FooBar"),
|
||||
],
|
||||
),
|
||||
schema.Class("FooBar"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_table_name_overrides(generate):
|
||||
assert generate([
|
||||
schema.Class("Obj", properties=[
|
||||
schema.OptionalProperty("x", "a", pragmas={"ql_db_table_name": "foo"}),
|
||||
schema.RepeatedProperty("y", "b", pragmas={"ql_db_table_name": "bar"}),
|
||||
schema.RepeatedOptionalProperty("z", "c", pragmas={"ql_db_table_name": "baz"}),
|
||||
schema.PredicateProperty("p", pragmas={"ql_db_table_name": "hello"}),
|
||||
schema.RepeatedUnorderedProperty("q", "d", pragmas={"ql_db_table_name": "world"}),
|
||||
]),
|
||||
]) == dbscheme.Scheme(
|
||||
assert generate(
|
||||
[
|
||||
schema.Class(
|
||||
"Obj",
|
||||
properties=[
|
||||
schema.OptionalProperty(
|
||||
"x", "a", pragmas={"ql_db_table_name": "foo"}
|
||||
),
|
||||
schema.RepeatedProperty(
|
||||
"y", "b", pragmas={"ql_db_table_name": "bar"}
|
||||
),
|
||||
schema.RepeatedOptionalProperty(
|
||||
"z", "c", pragmas={"ql_db_table_name": "baz"}
|
||||
),
|
||||
schema.PredicateProperty(
|
||||
"p", pragmas={"ql_db_table_name": "hello"}
|
||||
),
|
||||
schema.RepeatedUnorderedProperty(
|
||||
"q", "d", pragmas={"ql_db_table_name": "world"}
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
) == dbscheme.Scheme(
|
||||
src=schema_file.name,
|
||||
includes=[],
|
||||
declarations=[
|
||||
@@ -666,5 +775,5 @@ def test_table_name_overrides(generate):
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__] + sys.argv[1:]))
|
||||
|
||||
@@ -22,26 +22,42 @@ def test_load_empty(load):
|
||||
|
||||
|
||||
def test_load_one_empty_table(load):
|
||||
assert load("""
|
||||
assert (
|
||||
load(
|
||||
"""
|
||||
test_foos();
|
||||
""") == [
|
||||
dbscheme.Table(name="test_foos", columns=[])
|
||||
]
|
||||
"""
|
||||
)
|
||||
== [dbscheme.Table(name="test_foos", columns=[])]
|
||||
)
|
||||
|
||||
|
||||
def test_load_table_with_keyset(load):
|
||||
assert load("""
|
||||
assert (
|
||||
load(
|
||||
"""
|
||||
#keyset[x, y,z]
|
||||
test_foos();
|
||||
""") == [
|
||||
dbscheme.Table(name="test_foos", columns=[], keyset=dbscheme.KeySet(["x", "y", "z"]))
|
||||
]
|
||||
"""
|
||||
)
|
||||
== [
|
||||
dbscheme.Table(
|
||||
name="test_foos", columns=[], keyset=dbscheme.KeySet(["x", "y", "z"])
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
expected_columns = [
|
||||
("int foo: int ref", dbscheme.Column(schema_name="foo", type="int", binding=False)),
|
||||
(" int bar : int ref", dbscheme.Column(schema_name="bar", type="int", binding=False)),
|
||||
("str baz_: str ref", dbscheme.Column(schema_name="baz", type="str", binding=False)),
|
||||
(
|
||||
" int bar : int ref",
|
||||
dbscheme.Column(schema_name="bar", type="int", binding=False),
|
||||
),
|
||||
(
|
||||
"str baz_: str ref",
|
||||
dbscheme.Column(schema_name="baz", type="str", binding=False),
|
||||
),
|
||||
("int x: @foo ref", dbscheme.Column(schema_name="x", type="@foo", binding=False)),
|
||||
("int y: @foo", dbscheme.Column(schema_name="y", type="@foo", binding=True)),
|
||||
("unique int z: @foo", dbscheme.Column(schema_name="z", type="@foo", binding=True)),
|
||||
@@ -50,42 +66,58 @@ expected_columns = [
|
||||
|
||||
@pytest.mark.parametrize("column,expected", expected_columns)
|
||||
def test_load_table_with_column(load, column, expected):
|
||||
assert load(f"""
|
||||
assert (
|
||||
load(
|
||||
f"""
|
||||
foos(
|
||||
{column}
|
||||
);
|
||||
""") == [
|
||||
dbscheme.Table(name="foos", columns=[deepcopy(expected)])
|
||||
]
|
||||
"""
|
||||
)
|
||||
== [dbscheme.Table(name="foos", columns=[deepcopy(expected)])]
|
||||
)
|
||||
|
||||
|
||||
def test_load_table_with_multiple_columns(load):
|
||||
columns = ",\n".join(c for c, _ in expected_columns)
|
||||
expected = [deepcopy(e) for _, e in expected_columns]
|
||||
assert load(f"""
|
||||
assert (
|
||||
load(
|
||||
f"""
|
||||
foos(
|
||||
{columns}
|
||||
);
|
||||
""") == [
|
||||
dbscheme.Table(name="foos", columns=expected)
|
||||
]
|
||||
"""
|
||||
)
|
||||
== [dbscheme.Table(name="foos", columns=expected)]
|
||||
)
|
||||
|
||||
|
||||
def test_load_table_with_multiple_columns_and_dir(load):
|
||||
columns = ",\n".join(c for c, _ in expected_columns)
|
||||
expected = [deepcopy(e) for _, e in expected_columns]
|
||||
assert load(f"""
|
||||
assert (
|
||||
load(
|
||||
f"""
|
||||
foos( //dir=foo/bar/baz
|
||||
{columns}
|
||||
);
|
||||
""") == [
|
||||
dbscheme.Table(name="foos", columns=expected, dir=pathlib.Path("foo/bar/baz"))
|
||||
]
|
||||
"""
|
||||
)
|
||||
== [
|
||||
dbscheme.Table(
|
||||
name="foos", columns=expected, dir=pathlib.Path("foo/bar/baz")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_load_multiple_table_with_columns(load):
|
||||
tables = [f"table{i}({col});" for i, (col, _) in enumerate(expected_columns)]
|
||||
expected = [dbscheme.Table(name=f"table{i}", columns=[deepcopy(e)]) for i, (_, e) in enumerate(expected_columns)]
|
||||
expected = [
|
||||
dbscheme.Table(name=f"table{i}", columns=[deepcopy(e)])
|
||||
for i, (_, e) in enumerate(expected_columns)
|
||||
]
|
||||
assert load("\n".join(tables)) == expected
|
||||
|
||||
|
||||
@@ -96,28 +128,41 @@ def test_union(load):
|
||||
|
||||
|
||||
def test_table_and_union(load):
|
||||
assert load("""
|
||||
assert (
|
||||
load(
|
||||
"""
|
||||
foos();
|
||||
|
||||
@foo = @bar | @baz | @bla;""") == [
|
||||
dbscheme.Table(name="foos", columns=[]),
|
||||
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
|
||||
]
|
||||
@foo = @bar | @baz | @bla;"""
|
||||
)
|
||||
== [
|
||||
dbscheme.Table(name="foos", columns=[]),
|
||||
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_comments_ignored(load):
|
||||
assert load("""
|
||||
assert (
|
||||
load(
|
||||
"""
|
||||
// fake_table();
|
||||
foos(/* x */unique /*y*/int/*
|
||||
z
|
||||
*/ id/* */: /* * */ @bar/*,
|
||||
int ignored: int ref*/);
|
||||
|
||||
@foo = @bar | @baz | @bla; // | @xxx""") == [
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column(schema_name="id", type="@bar", binding=True)]),
|
||||
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
|
||||
]
|
||||
@foo = @bar | @baz | @bla; // | @xxx"""
|
||||
)
|
||||
== [
|
||||
dbscheme.Table(
|
||||
name="foos",
|
||||
columns=[dbscheme.Column(schema_name="id", type="@bar", binding=True)],
|
||||
),
|
||||
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__] + sys.argv[1:]))
|
||||
|
||||
@@ -34,37 +34,55 @@ def test_property_unordered_getter(name, expected_getter):
|
||||
assert prop.getter == expected_getter
|
||||
|
||||
|
||||
@pytest.mark.parametrize("plural,expected", [
|
||||
(None, False),
|
||||
("", False),
|
||||
("X", True),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"plural,expected",
|
||||
[
|
||||
(None, False),
|
||||
("", False),
|
||||
("X", True),
|
||||
],
|
||||
)
|
||||
def test_property_is_repeated(plural, expected):
|
||||
prop = ql.Property("foo", "Foo", "props", ["result"], plural=plural)
|
||||
assert prop.is_repeated is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("plural,unordered,expected", [
|
||||
(None, False, False),
|
||||
("", False, False),
|
||||
("X", False, True),
|
||||
("X", True, False),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"plural,unordered,expected",
|
||||
[
|
||||
(None, False, False),
|
||||
("", False, False),
|
||||
("X", False, True),
|
||||
("X", True, False),
|
||||
],
|
||||
)
|
||||
def test_property_is_indexed(plural, unordered, expected):
|
||||
prop = ql.Property("foo", "Foo", "props", ["result"], plural=plural, is_unordered=unordered)
|
||||
prop = ql.Property(
|
||||
"foo", "Foo", "props", ["result"], plural=plural, is_unordered=unordered
|
||||
)
|
||||
assert prop.is_indexed is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_optional,is_predicate,plural,expected", [
|
||||
(False, False, None, True),
|
||||
(False, False, "", True),
|
||||
(False, False, "X", False),
|
||||
(True, False, None, False),
|
||||
(False, True, None, False),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"is_optional,is_predicate,plural,expected",
|
||||
[
|
||||
(False, False, None, True),
|
||||
(False, False, "", True),
|
||||
(False, False, "X", False),
|
||||
(True, False, None, False),
|
||||
(False, True, None, False),
|
||||
],
|
||||
)
|
||||
def test_property_is_single(is_optional, is_predicate, plural, expected):
|
||||
prop = ql.Property("foo", "Foo", "props", ["result"], plural=plural,
|
||||
is_predicate=is_predicate, is_optional=is_optional)
|
||||
prop = ql.Property(
|
||||
"foo",
|
||||
"Foo",
|
||||
"props",
|
||||
["result"],
|
||||
plural=plural,
|
||||
is_predicate=is_predicate,
|
||||
is_optional=is_optional,
|
||||
)
|
||||
assert prop.is_single is expected
|
||||
|
||||
|
||||
@@ -85,7 +103,12 @@ def test_property_predicate_getter():
|
||||
|
||||
def test_class_processes_bases():
|
||||
bases = ["B", "Ab", "C", "Aa"]
|
||||
expected = [ql.Base("B"), ql.Base("Ab", prev="B"), ql.Base("C", prev="Ab"), ql.Base("Aa", prev="C")]
|
||||
expected = [
|
||||
ql.Base("B"),
|
||||
ql.Base("Ab", prev="B"),
|
||||
ql.Base("C", prev="Ab"),
|
||||
ql.Base("Aa", prev="C"),
|
||||
]
|
||||
cls = ql.Class("Foo", bases=bases)
|
||||
assert cls.bases == expected
|
||||
|
||||
@@ -110,7 +133,9 @@ def test_non_root_class():
|
||||
assert not cls.root
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prev_child,is_child", [(None, False), ("", True), ("x", True)])
|
||||
@pytest.mark.parametrize(
|
||||
"prev_child,is_child", [(None, False), ("", True), ("x", True)]
|
||||
)
|
||||
def test_is_child(prev_child, is_child):
|
||||
p = ql.Property("Foo", "int", prev_child=prev_child)
|
||||
assert p.is_child is is_child
|
||||
@@ -122,22 +147,27 @@ def test_empty_class_no_children():
|
||||
|
||||
|
||||
def test_class_no_children():
|
||||
cls = ql.Class("Class", properties=[ql.Property("Foo", "int"), ql.Property("Bar", "string")])
|
||||
cls = ql.Class(
|
||||
"Class", properties=[ql.Property("Foo", "int"), ql.Property("Bar", "string")]
|
||||
)
|
||||
assert cls.has_children is False
|
||||
|
||||
|
||||
def test_class_with_children():
|
||||
cls = ql.Class("Class", properties=[ql.Property("Foo", "int"), ql.Property("Child", "x", prev_child=""),
|
||||
ql.Property("Bar", "string")])
|
||||
cls = ql.Class(
|
||||
"Class",
|
||||
properties=[
|
||||
ql.Property("Foo", "int"),
|
||||
ql.Property("Child", "x", prev_child=""),
|
||||
ql.Property("Bar", "string"),
|
||||
],
|
||||
)
|
||||
assert cls.has_children is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("doc,expected",
|
||||
[
|
||||
(["foo", "bar"], True),
|
||||
(["foo", "bar"], True),
|
||||
([], False)
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"doc,expected", [(["foo", "bar"], True), (["foo", "bar"], True), ([], False)]
|
||||
)
|
||||
def test_has_doc(doc, expected):
|
||||
stub = ql.Stub("Class", base_import="foo", import_prefix="bar", doc=doc)
|
||||
assert stub.has_qldoc is expected
|
||||
@@ -150,5 +180,5 @@ def test_synth_accessor_has_first_constructor_param_marked():
|
||||
assert [p.param for p in x.constructorparams] == params
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__] + sys.argv[1:]))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -46,7 +46,10 @@ def write_registry(file, *files_and_hashes):
|
||||
def assert_registry(file, *files_and_hashes):
|
||||
assert_file(file, create_registry(files_and_hashes))
|
||||
files = [file.name, ".gitattributes"] + [f for f, _, _ in files_and_hashes]
|
||||
assert_file(file.parent / ".gitattributes", "\n".join(f"/{f} linguist-generated" for f in files) + "\n")
|
||||
assert_file(
|
||||
file.parent / ".gitattributes",
|
||||
"\n".join(f"/{f} linguist-generated" for f in files) + "\n",
|
||||
)
|
||||
|
||||
|
||||
def hash(text):
|
||||
@@ -56,11 +59,11 @@ def hash(text):
|
||||
|
||||
|
||||
def test_constructor(pystache_renderer_cls, sut):
|
||||
pystache_init, = pystache_renderer_cls.mock_calls
|
||||
assert set(pystache_init.kwargs) == {'search_dirs', 'escape'}
|
||||
assert pystache_init.kwargs['search_dirs'] == str(paths.templates_dir)
|
||||
(pystache_init,) = pystache_renderer_cls.mock_calls
|
||||
assert set(pystache_init.kwargs) == {"search_dirs", "escape"}
|
||||
assert pystache_init.kwargs["search_dirs"] == str(paths.templates_dir)
|
||||
an_object = object()
|
||||
assert pystache_init.kwargs['escape'](an_object) is an_object
|
||||
assert pystache_init.kwargs["escape"](an_object) is an_object
|
||||
|
||||
|
||||
def test_render(pystache_renderer, sut):
|
||||
@@ -218,7 +221,9 @@ def test_managed_render_with_skipping_of_stub_file(pystache_renderer, sut):
|
||||
some_processed_output = "// generated some processed output"
|
||||
registry = paths.root_dir / "a/registry.list"
|
||||
write(stub, some_processed_output)
|
||||
write_registry(registry, ("some/stub.txt", hash(some_output), hash(some_processed_output)))
|
||||
write_registry(
|
||||
registry, ("some/stub.txt", hash(some_output), hash(some_processed_output))
|
||||
)
|
||||
|
||||
pystache_renderer.render_name.side_effect = (some_output,)
|
||||
|
||||
@@ -227,7 +232,9 @@ def test_managed_render_with_skipping_of_stub_file(pystache_renderer, sut):
|
||||
assert renderer.written == set()
|
||||
assert_file(stub, some_processed_output)
|
||||
|
||||
assert_registry(registry, ("some/stub.txt", hash(some_output), hash(some_processed_output)))
|
||||
assert_registry(
|
||||
registry, ("some/stub.txt", hash(some_output), hash(some_processed_output))
|
||||
)
|
||||
assert pystache_renderer.mock_calls == [
|
||||
mock.call.render_name(data.template, data, generator=generator),
|
||||
]
|
||||
@@ -238,13 +245,17 @@ def test_managed_render_with_modified_generated_file(pystache_renderer, sut):
|
||||
some_processed_output = "// some processed output"
|
||||
registry = paths.root_dir / "a/registry.list"
|
||||
write(output, "// something else")
|
||||
write_registry(registry, ("some/output.txt", "whatever", hash(some_processed_output)))
|
||||
write_registry(
|
||||
registry, ("some/output.txt", "whatever", hash(some_processed_output))
|
||||
)
|
||||
|
||||
with pytest.raises(render.Error):
|
||||
sut.manage(generated=(output,), stubs=(), registry=registry)
|
||||
|
||||
|
||||
def test_managed_render_with_modified_stub_file_still_marked_as_generated(pystache_renderer, sut):
|
||||
def test_managed_render_with_modified_stub_file_still_marked_as_generated(
|
||||
pystache_renderer, sut
|
||||
):
|
||||
stub = paths.root_dir / "a/some/stub.txt"
|
||||
some_processed_output = "// generated some processed output"
|
||||
registry = paths.root_dir / "a/registry.list"
|
||||
@@ -255,7 +266,9 @@ def test_managed_render_with_modified_stub_file_still_marked_as_generated(pystac
|
||||
sut.manage(generated=(), stubs=(stub,), registry=registry)
|
||||
|
||||
|
||||
def test_managed_render_with_modified_stub_file_not_marked_as_generated(pystache_renderer, sut):
|
||||
def test_managed_render_with_modified_stub_file_not_marked_as_generated(
|
||||
pystache_renderer, sut
|
||||
):
|
||||
stub = paths.root_dir / "a/some/stub.txt"
|
||||
some_processed_output = "// generated some processed output"
|
||||
registry = paths.root_dir / "a/registry.list"
|
||||
@@ -272,7 +285,9 @@ class MyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def test_managed_render_exception_drops_written_and_inexsistent_from_registry(pystache_renderer, sut):
|
||||
def test_managed_render_exception_drops_written_and_inexsistent_from_registry(
|
||||
pystache_renderer, sut
|
||||
):
|
||||
data = mock.Mock(spec=("template",))
|
||||
text = "some text"
|
||||
pystache_renderer.render_name.side_effect = (text,)
|
||||
@@ -281,11 +296,9 @@ def test_managed_render_exception_drops_written_and_inexsistent_from_registry(py
|
||||
write(output, text)
|
||||
write(paths.root_dir / "a/a")
|
||||
write(paths.root_dir / "a/c")
|
||||
write_registry(registry,
|
||||
"aaa",
|
||||
("some/output.txt", "whatever", hash(text)),
|
||||
"bbb",
|
||||
"ccc")
|
||||
write_registry(
|
||||
registry, "aaa", ("some/output.txt", "whatever", hash(text)), "bbb", "ccc"
|
||||
)
|
||||
|
||||
with pytest.raises(MyError):
|
||||
with sut.manage(generated=(), stubs=(), registry=registry) as renderer:
|
||||
@@ -299,17 +312,14 @@ def test_managed_render_drops_inexsistent_from_registry(pystache_renderer, sut):
|
||||
registry = paths.root_dir / "a/registry.list"
|
||||
write(paths.root_dir / "a/a")
|
||||
write(paths.root_dir / "a/c")
|
||||
write_registry(registry,
|
||||
("a", hash(''), hash('')),
|
||||
"bbb",
|
||||
("c", hash(''), hash('')))
|
||||
write_registry(
|
||||
registry, ("a", hash(""), hash("")), "bbb", ("c", hash(""), hash(""))
|
||||
)
|
||||
|
||||
with sut.manage(generated=(), stubs=(), registry=registry):
|
||||
pass
|
||||
|
||||
assert_registry(registry,
|
||||
("a", hash(''), hash('')),
|
||||
("c", hash(''), hash('')))
|
||||
assert_registry(registry, ("a", hash(""), hash("")), ("c", hash(""), hash("")))
|
||||
|
||||
|
||||
def test_managed_render_exception_does_not_erase(pystache_renderer, sut):
|
||||
@@ -321,7 +331,9 @@ def test_managed_render_exception_does_not_erase(pystache_renderer, sut):
|
||||
write_registry(registry)
|
||||
|
||||
with pytest.raises(MyError):
|
||||
with sut.manage(generated=(output,), stubs=(stub,), registry=registry) as renderer:
|
||||
with sut.manage(
|
||||
generated=(output,), stubs=(stub,), registry=registry
|
||||
) as renderer:
|
||||
raise MyError
|
||||
|
||||
assert output.is_file()
|
||||
@@ -333,14 +345,15 @@ def test_render_with_extensions(pystache_renderer, sut):
|
||||
data.template = "test_template"
|
||||
data.extensions = ["foo", "bar", "baz"]
|
||||
output = pathlib.Path("my", "test", "file")
|
||||
expected_outputs = [pathlib.Path("my", "test", p) for p in ("file.foo", "file.bar", "file.baz")]
|
||||
expected_outputs = [
|
||||
pathlib.Path("my", "test", p) for p in ("file.foo", "file.bar", "file.baz")
|
||||
]
|
||||
rendered = [f"text{i}" for i in range(len(expected_outputs))]
|
||||
pystache_renderer.render_name.side_effect = rendered
|
||||
sut.render(data, output)
|
||||
expected_templates = ["test_template_foo", "test_template_bar", "test_template_baz"]
|
||||
assert pystache_renderer.mock_calls == [
|
||||
mock.call.render_name(t, data, generator=generator)
|
||||
for t in expected_templates
|
||||
mock.call.render_name(t, data, generator=generator) for t in expected_templates
|
||||
]
|
||||
for expected_output, expected_contents in zip(expected_outputs, rendered):
|
||||
assert_file(expected_output, expected_contents)
|
||||
@@ -356,7 +369,9 @@ def test_managed_render_with_force_not_skipping_generated_file(pystache_renderer
|
||||
|
||||
pystache_renderer.render_name.side_effect = (some_output,)
|
||||
|
||||
with sut.manage(generated=(output,), stubs=(), registry=registry, force=True) as renderer:
|
||||
with sut.manage(
|
||||
generated=(output,), stubs=(), registry=registry, force=True
|
||||
) as renderer:
|
||||
renderer.render(data, output)
|
||||
assert renderer.written == {output}
|
||||
assert_file(output, some_output)
|
||||
@@ -374,11 +389,15 @@ def test_managed_render_with_force_not_skipping_stub_file(pystache_renderer, sut
|
||||
some_processed_output = "// generated some processed output"
|
||||
registry = paths.root_dir / "a/registry.list"
|
||||
write(stub, some_processed_output)
|
||||
write_registry(registry, ("some/stub.txt", hash(some_output), hash(some_processed_output)))
|
||||
write_registry(
|
||||
registry, ("some/stub.txt", hash(some_output), hash(some_processed_output))
|
||||
)
|
||||
|
||||
pystache_renderer.render_name.side_effect = (some_output,)
|
||||
|
||||
with sut.manage(generated=(), stubs=(stub,), registry=registry, force=True) as renderer:
|
||||
with sut.manage(
|
||||
generated=(), stubs=(stub,), registry=registry, force=True
|
||||
) as renderer:
|
||||
renderer.render(data, stub)
|
||||
assert renderer.written == {stub}
|
||||
assert_file(stub, some_output)
|
||||
@@ -394,13 +413,17 @@ def test_managed_render_with_force_ignores_modified_generated_file(sut):
|
||||
some_processed_output = "// some processed output"
|
||||
registry = paths.root_dir / "a/registry.list"
|
||||
write(output, "// something else")
|
||||
write_registry(registry, ("some/output.txt", "whatever", hash(some_processed_output)))
|
||||
write_registry(
|
||||
registry, ("some/output.txt", "whatever", hash(some_processed_output))
|
||||
)
|
||||
|
||||
with sut.manage(generated=(output,), stubs=(), registry=registry, force=True):
|
||||
pass
|
||||
|
||||
|
||||
def test_managed_render_with_force_ignores_modified_stub_file_still_marked_as_generated(sut):
|
||||
def test_managed_render_with_force_ignores_modified_stub_file_still_marked_as_generated(
|
||||
sut,
|
||||
):
|
||||
stub = paths.root_dir / "a/some/stub.txt"
|
||||
some_processed_output = "// generated some processed output"
|
||||
registry = paths.root_dir / "a/registry.list"
|
||||
@@ -411,5 +434,5 @@ def test_managed_render_with_force_ignores_modified_stub_file_still_marked_as_ge
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__] + sys.argv[1:]))
|
||||
|
||||
@@ -26,9 +26,9 @@ def test_one_empty_class():
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
'MyClass': schema.Class('MyClass'),
|
||||
"MyClass": schema.Class("MyClass"),
|
||||
}
|
||||
assert data.root_class is data.classes['MyClass']
|
||||
assert data.root_class is data.classes["MyClass"]
|
||||
|
||||
|
||||
def test_two_empty_classes():
|
||||
@@ -41,10 +41,10 @@ def test_two_empty_classes():
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
'MyClass1': schema.Class('MyClass1', derived={'MyClass2'}),
|
||||
'MyClass2': schema.Class('MyClass2', bases=['MyClass1']),
|
||||
"MyClass1": schema.Class("MyClass1", derived={"MyClass2"}),
|
||||
"MyClass2": schema.Class("MyClass2", bases=["MyClass1"]),
|
||||
}
|
||||
assert data.root_class is data.classes['MyClass1']
|
||||
assert data.root_class is data.classes["MyClass1"]
|
||||
|
||||
|
||||
def test_no_external_bases():
|
||||
@@ -52,6 +52,7 @@ def test_no_external_bases():
|
||||
pass
|
||||
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
class MyClass(A):
|
||||
@@ -60,6 +61,7 @@ def test_no_external_bases():
|
||||
|
||||
def test_no_multiple_roots():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
class MyClass1:
|
||||
@@ -85,10 +87,10 @@ def test_empty_classes_diamond():
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', derived={'B', 'C'}),
|
||||
'B': schema.Class('B', bases=['A'], derived={'D'}),
|
||||
'C': schema.Class('C', bases=['A'], derived={'D'}),
|
||||
'D': schema.Class('D', bases=['B', 'C']),
|
||||
"A": schema.Class("A", derived={"B", "C"}),
|
||||
"B": schema.Class("B", bases=["A"], derived={"D"}),
|
||||
"C": schema.Class("C", bases=["A"], derived={"D"}),
|
||||
"D": schema.Class("D", bases=["B", "C"]),
|
||||
}
|
||||
|
||||
|
||||
@@ -101,7 +103,7 @@ def test_group():
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', pragmas={"group": "xxx"}),
|
||||
"A": schema.Class("A", pragmas={"group": "xxx"}),
|
||||
}
|
||||
|
||||
|
||||
@@ -114,7 +116,7 @@ def test_group_is_inherited():
|
||||
class B(A):
|
||||
pass
|
||||
|
||||
@defs.group('xxx')
|
||||
@defs.group("xxx")
|
||||
class C(A):
|
||||
pass
|
||||
|
||||
@@ -122,25 +124,26 @@ def test_group_is_inherited():
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', derived={'B', 'C'}),
|
||||
'B': schema.Class('B', bases=['A'], derived={'D'}),
|
||||
'C': schema.Class('C', bases=['A'], derived={'D'}, pragmas={"group": "xxx"}),
|
||||
'D': schema.Class('D', bases=['B', 'C'], pragmas={"group": "xxx"}),
|
||||
"A": schema.Class("A", derived={"B", "C"}),
|
||||
"B": schema.Class("B", bases=["A"], derived={"D"}),
|
||||
"C": schema.Class("C", bases=["A"], derived={"D"}, pragmas={"group": "xxx"}),
|
||||
"D": schema.Class("D", bases=["B", "C"], pragmas={"group": "xxx"}),
|
||||
}
|
||||
|
||||
|
||||
def test_no_mixed_groups_in_bases():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
class A:
|
||||
pass
|
||||
|
||||
@defs.group('x')
|
||||
@defs.group("x")
|
||||
class B(A):
|
||||
pass
|
||||
|
||||
@defs.group('y')
|
||||
@defs.group("y")
|
||||
class C(A):
|
||||
pass
|
||||
|
||||
@@ -153,6 +156,7 @@ def test_no_mixed_groups_in_bases():
|
||||
|
||||
def test_lowercase_rejected():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
class aLowerCase:
|
||||
@@ -171,14 +175,17 @@ def test_properties():
|
||||
six: defs.set[defs.string]
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', properties=[
|
||||
schema.SingleProperty('one', 'string'),
|
||||
schema.OptionalProperty('two', 'int'),
|
||||
schema.RepeatedProperty('three', 'boolean'),
|
||||
schema.RepeatedOptionalProperty('four', 'string'),
|
||||
schema.PredicateProperty('five'),
|
||||
schema.RepeatedUnorderedProperty('six', 'string'),
|
||||
]),
|
||||
"A": schema.Class(
|
||||
"A",
|
||||
properties=[
|
||||
schema.SingleProperty("one", "string"),
|
||||
schema.OptionalProperty("two", "int"),
|
||||
schema.RepeatedProperty("three", "boolean"),
|
||||
schema.RepeatedOptionalProperty("four", "string"),
|
||||
schema.PredicateProperty("five"),
|
||||
schema.RepeatedUnorderedProperty("six", "string"),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -199,14 +206,18 @@ def test_class_properties():
|
||||
five: defs.set[A]
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', derived={'B'}),
|
||||
'B': schema.Class('B', bases=['A'], properties=[
|
||||
schema.SingleProperty('one', 'A'),
|
||||
schema.OptionalProperty('two', 'A'),
|
||||
schema.RepeatedProperty('three', 'A'),
|
||||
schema.RepeatedOptionalProperty('four', 'A'),
|
||||
schema.RepeatedUnorderedProperty('five', 'A'),
|
||||
]),
|
||||
"A": schema.Class("A", derived={"B"}),
|
||||
"B": schema.Class(
|
||||
"B",
|
||||
bases=["A"],
|
||||
properties=[
|
||||
schema.SingleProperty("one", "A"),
|
||||
schema.OptionalProperty("two", "A"),
|
||||
schema.RepeatedProperty("three", "A"),
|
||||
schema.RepeatedOptionalProperty("four", "A"),
|
||||
schema.RepeatedUnorderedProperty("five", "A"),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -221,20 +232,31 @@ def test_string_reference_class_properties():
|
||||
five: defs.set["A"]
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', properties=[
|
||||
schema.SingleProperty('one', 'A'),
|
||||
schema.OptionalProperty('two', 'A'),
|
||||
schema.RepeatedProperty('three', 'A'),
|
||||
schema.RepeatedOptionalProperty('four', 'A'),
|
||||
schema.RepeatedUnorderedProperty('five', 'A'),
|
||||
]),
|
||||
"A": schema.Class(
|
||||
"A",
|
||||
properties=[
|
||||
schema.SingleProperty("one", "A"),
|
||||
schema.OptionalProperty("two", "A"),
|
||||
schema.RepeatedProperty("three", "A"),
|
||||
schema.RepeatedOptionalProperty("four", "A"),
|
||||
schema.RepeatedUnorderedProperty("five", "A"),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", [lambda t: t, lambda t: defs.optional[t], lambda t: defs.list[t],
|
||||
lambda t: defs.list[defs.optional[t]]])
|
||||
@pytest.mark.parametrize(
|
||||
"spec",
|
||||
[
|
||||
lambda t: t,
|
||||
lambda t: defs.optional[t],
|
||||
lambda t: defs.list[t],
|
||||
lambda t: defs.list[defs.optional[t]],
|
||||
],
|
||||
)
|
||||
def test_string_reference_dangling(spec):
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
class A:
|
||||
@@ -251,18 +273,24 @@ def test_children():
|
||||
four: defs.list[defs.optional["A"]] | defs.child
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', properties=[
|
||||
schema.SingleProperty('one', 'A', is_child=True),
|
||||
schema.OptionalProperty('two', 'A', is_child=True),
|
||||
schema.RepeatedProperty('three', 'A', is_child=True),
|
||||
schema.RepeatedOptionalProperty('four', 'A', is_child=True),
|
||||
]),
|
||||
"A": schema.Class(
|
||||
"A",
|
||||
properties=[
|
||||
schema.SingleProperty("one", "A", is_child=True),
|
||||
schema.OptionalProperty("two", "A", is_child=True),
|
||||
schema.RepeatedProperty("three", "A", is_child=True),
|
||||
schema.RepeatedOptionalProperty("four", "A", is_child=True),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec", [defs.string, defs.int, defs.boolean, defs.predicate, defs.set["A"]])
|
||||
@pytest.mark.parametrize(
|
||||
"spec", [defs.string, defs.int, defs.boolean, defs.predicate, defs.set["A"]]
|
||||
)
|
||||
def test_builtin_predicate_and_set_children_not_allowed(spec):
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
class A:
|
||||
@@ -291,9 +319,12 @@ def test_property_with_pragma(pragma, expected):
|
||||
x: defs.string | pragma
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', properties=[
|
||||
schema.SingleProperty('x', 'string', pragmas=[expected]),
|
||||
]),
|
||||
"A": schema.Class(
|
||||
"A",
|
||||
properties=[
|
||||
schema.SingleProperty("x", "string", pragmas=[expected]),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -308,9 +339,16 @@ def test_property_with_pragmas():
|
||||
x: spec
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', properties=[
|
||||
schema.SingleProperty('x', 'string', pragmas=[expected for _, expected in _property_pragmas]),
|
||||
]),
|
||||
"A": schema.Class(
|
||||
"A",
|
||||
properties=[
|
||||
schema.SingleProperty(
|
||||
"x",
|
||||
"string",
|
||||
pragmas=[expected for _, expected in _property_pragmas],
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -323,7 +361,7 @@ def test_class_with_pragma(pragma, expected):
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', pragmas=[expected]),
|
||||
"A": schema.Class("A", pragmas=[expected]),
|
||||
}
|
||||
|
||||
|
||||
@@ -340,7 +378,7 @@ def test_class_with_pragmas():
|
||||
apply_pragmas(A)
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', pragmas=[e for _, e in _pragmas]),
|
||||
"A": schema.Class("A", pragmas=[e for _, e in _pragmas]),
|
||||
}
|
||||
|
||||
|
||||
@@ -355,8 +393,10 @@ def test_synth_from_class():
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', derived={'B'}, pragmas={"synth": True}),
|
||||
'B': schema.Class('B', bases=['A'], pragmas={"synth": schema.SynthInfo(from_class="A")}),
|
||||
"A": schema.Class("A", derived={"B"}, pragmas={"synth": True}),
|
||||
"B": schema.Class(
|
||||
"B", bases=["A"], pragmas={"synth": schema.SynthInfo(from_class="A")}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -371,13 +411,16 @@ def test_synth_from_class_ref():
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', derived={'B'}, pragmas={"synth": schema.SynthInfo(from_class="B")}),
|
||||
'B': schema.Class('B', bases=['A']),
|
||||
"A": schema.Class(
|
||||
"A", derived={"B"}, pragmas={"synth": schema.SynthInfo(from_class="B")}
|
||||
),
|
||||
"B": schema.Class("B", bases=["A"]),
|
||||
}
|
||||
|
||||
|
||||
def test_synth_from_class_dangling():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
@defs.synth.from_class("X")
|
||||
@@ -396,8 +439,12 @@ def test_synth_class_on():
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', derived={'B'}, pragmas={"synth": True}),
|
||||
'B': schema.Class('B', bases=['A'], pragmas={"synth": schema.SynthInfo(on_arguments={'a': 'A', 'i': 'int'})}),
|
||||
"A": schema.Class("A", derived={"B"}, pragmas={"synth": True}),
|
||||
"B": schema.Class(
|
||||
"B",
|
||||
bases=["A"],
|
||||
pragmas={"synth": schema.SynthInfo(on_arguments={"a": "A", "i": "int"})},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -415,13 +462,18 @@ def test_synth_class_on_ref():
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', derived={'B'}, pragmas={"synth": schema.SynthInfo(on_arguments={'b': 'B', 'i': 'int'})}),
|
||||
'B': schema.Class('B', bases=['A']),
|
||||
"A": schema.Class(
|
||||
"A",
|
||||
derived={"B"},
|
||||
pragmas={"synth": schema.SynthInfo(on_arguments={"b": "B", "i": "int"})},
|
||||
),
|
||||
"B": schema.Class("B", bases=["A"]),
|
||||
}
|
||||
|
||||
|
||||
def test_synth_class_on_dangling():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
@defs.synth.on_arguments(s=defs.string, a="A", i=defs.int)
|
||||
@@ -453,12 +505,25 @@ def test_synth_class_hierarchy():
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
'Root': schema.Class('Root', derived={'Base', 'C'}),
|
||||
'Base': schema.Class('Base', bases=['Root'], derived={'Intermediate', 'B'}, pragmas={"synth": True}),
|
||||
'Intermediate': schema.Class('Intermediate', bases=['Base'], derived={'A'}, pragmas={"synth": True}),
|
||||
'A': schema.Class('A', bases=['Intermediate'], pragmas={"synth": schema.SynthInfo(on_arguments={'a': 'Base', 'i': 'int'})}),
|
||||
'B': schema.Class('B', bases=['Base'], pragmas={"synth": schema.SynthInfo(from_class='Base')}),
|
||||
'C': schema.Class('C', bases=['Root']),
|
||||
"Root": schema.Class("Root", derived={"Base", "C"}),
|
||||
"Base": schema.Class(
|
||||
"Base",
|
||||
bases=["Root"],
|
||||
derived={"Intermediate", "B"},
|
||||
pragmas={"synth": True},
|
||||
),
|
||||
"Intermediate": schema.Class(
|
||||
"Intermediate", bases=["Base"], derived={"A"}, pragmas={"synth": True}
|
||||
),
|
||||
"A": schema.Class(
|
||||
"A",
|
||||
bases=["Intermediate"],
|
||||
pragmas={"synth": schema.SynthInfo(on_arguments={"a": "Base", "i": "int"})},
|
||||
),
|
||||
"B": schema.Class(
|
||||
"B", bases=["Base"], pragmas={"synth": schema.SynthInfo(from_class="Base")}
|
||||
),
|
||||
"C": schema.Class("C", bases=["Root"]),
|
||||
}
|
||||
|
||||
|
||||
@@ -479,9 +544,7 @@ def test_class_docstring():
|
||||
class A:
|
||||
"""Very important class."""
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', doc=["Very important class."])
|
||||
}
|
||||
assert data.classes == {"A": schema.Class("A", doc=["Very important class."])}
|
||||
|
||||
|
||||
def test_property_docstring():
|
||||
@@ -491,7 +554,14 @@ def test_property_docstring():
|
||||
x: int | defs.desc("very important property.")
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', properties=[schema.SingleProperty('x', 'int', description=["very important property."])])
|
||||
"A": schema.Class(
|
||||
"A",
|
||||
properties=[
|
||||
schema.SingleProperty(
|
||||
"x", "int", description=["very important property."]
|
||||
)
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -502,21 +572,27 @@ def test_class_docstring_newline():
|
||||
"""Very important
|
||||
class."""
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', doc=["Very important", "class."])
|
||||
}
|
||||
assert data.classes == {"A": schema.Class("A", doc=["Very important", "class."])}
|
||||
|
||||
|
||||
def test_property_docstring_newline():
|
||||
@load
|
||||
class data:
|
||||
class A:
|
||||
x: int | defs.desc("""very important
|
||||
property.""")
|
||||
x: int | defs.desc(
|
||||
"""very important
|
||||
property."""
|
||||
)
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A',
|
||||
properties=[schema.SingleProperty('x', 'int', description=["very important", "property."])])
|
||||
"A": schema.Class(
|
||||
"A",
|
||||
properties=[
|
||||
schema.SingleProperty(
|
||||
"x", "int", description=["very important", "property."]
|
||||
)
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -530,23 +606,30 @@ def test_class_docstring_stripped():
|
||||
|
||||
"""
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', doc=["Very important class."])
|
||||
}
|
||||
assert data.classes == {"A": schema.Class("A", doc=["Very important class."])}
|
||||
|
||||
|
||||
def test_property_docstring_stripped():
|
||||
@load
|
||||
class data:
|
||||
class A:
|
||||
x: int | defs.desc("""
|
||||
x: int | defs.desc(
|
||||
"""
|
||||
|
||||
very important property.
|
||||
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', properties=[schema.SingleProperty('x', 'int', description=["very important property."])])
|
||||
"A": schema.Class(
|
||||
"A",
|
||||
properties=[
|
||||
schema.SingleProperty(
|
||||
"x", "int", description=["very important property."]
|
||||
)
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -559,7 +642,9 @@ def test_class_docstring_split():
|
||||
As said, very important."""
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', doc=["Very important class.", "", "As said, very important."])
|
||||
"A": schema.Class(
|
||||
"A", doc=["Very important class.", "", "As said, very important."]
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -567,13 +652,27 @@ def test_property_docstring_split():
|
||||
@load
|
||||
class data:
|
||||
class A:
|
||||
x: int | defs.desc("""very important property.
|
||||
x: int | defs.desc(
|
||||
"""very important property.
|
||||
|
||||
Very very important.""")
|
||||
Very very important."""
|
||||
)
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', properties=[
|
||||
schema.SingleProperty('x', 'int', description=["very important property.", "", "Very very important."])])
|
||||
"A": schema.Class(
|
||||
"A",
|
||||
properties=[
|
||||
schema.SingleProperty(
|
||||
"x",
|
||||
"int",
|
||||
description=[
|
||||
"very important property.",
|
||||
"",
|
||||
"Very very important.",
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -587,7 +686,9 @@ def test_class_docstring_indent():
|
||||
"""
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', doc=["Very important class.", " As said, very important."])
|
||||
"A": schema.Class(
|
||||
"A", doc=["Very important class.", " As said, very important."]
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -595,14 +696,24 @@ def test_property_docstring_indent():
|
||||
@load
|
||||
class data:
|
||||
class A:
|
||||
x: int | defs.desc("""
|
||||
x: int | defs.desc(
|
||||
"""
|
||||
very important property.
|
||||
Very very important.
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', properties=[
|
||||
schema.SingleProperty('x', 'int', description=["very important property.", " Very very important."])])
|
||||
"A": schema.Class(
|
||||
"A",
|
||||
properties=[
|
||||
schema.SingleProperty(
|
||||
"x",
|
||||
"int",
|
||||
description=["very important property.", " Very very important."],
|
||||
)
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -613,13 +724,13 @@ def test_property_doc_override():
|
||||
x: int | defs.doc("y")
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', properties=[
|
||||
schema.SingleProperty('x', 'int', doc="y")]),
|
||||
"A": schema.Class("A", properties=[schema.SingleProperty("x", "int", doc="y")]),
|
||||
}
|
||||
|
||||
|
||||
def test_property_doc_override_no_newlines():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
class A:
|
||||
@@ -628,6 +739,7 @@ def test_property_doc_override_no_newlines():
|
||||
|
||||
def test_property_doc_override_no_trailing_dot():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
class A:
|
||||
@@ -642,7 +754,7 @@ def test_class_default_doc_name():
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', pragmas={"ql_default_doc_name": "b"}),
|
||||
"A": schema.Class("A", pragmas={"ql_default_doc_name": "b"}),
|
||||
}
|
||||
|
||||
|
||||
@@ -653,7 +765,12 @@ def test_db_table_name():
|
||||
x: optional[int] | defs.ql.db_table_name("foo")
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', properties=[schema.OptionalProperty("x", "int", pragmas={"ql_db_table_name": "foo"})]),
|
||||
"A": schema.Class(
|
||||
"A",
|
||||
properties=[
|
||||
schema.OptionalProperty("x", "int", pragmas={"ql_db_table_name": "foo"})
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -668,15 +785,16 @@ def test_null_class():
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
'Root': schema.Class('Root', derived={'Null'}),
|
||||
'Null': schema.Class('Null', bases=['Root']),
|
||||
"Root": schema.Class("Root", derived={"Null"}),
|
||||
"Null": schema.Class("Null", bases=["Root"]),
|
||||
}
|
||||
assert data.null == 'Null'
|
||||
assert data.null == "Null"
|
||||
assert data.null_class is data.classes[data.null]
|
||||
|
||||
|
||||
def test_null_class_cannot_be_derived():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
class Root:
|
||||
@@ -692,6 +810,7 @@ def test_null_class_cannot_be_derived():
|
||||
|
||||
def test_null_class_cannot_be_defined_multiple_times():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
class Root:
|
||||
@@ -708,6 +827,7 @@ def test_null_class_cannot_be_defined_multiple_times():
|
||||
|
||||
def test_uppercase_acronyms_are_rejected():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
class Root:
|
||||
@@ -737,10 +857,18 @@ def test_hideable():
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
"Root": schema.Class("Root", derived={"A", "IndirectlyHideable", "NonHideable"}, pragmas=["ql_hideable"]),
|
||||
"Root": schema.Class(
|
||||
"Root",
|
||||
derived={"A", "IndirectlyHideable", "NonHideable"},
|
||||
pragmas=["ql_hideable"],
|
||||
),
|
||||
"A": schema.Class("A", bases=["Root"], derived={"B"}, pragmas=["ql_hideable"]),
|
||||
"IndirectlyHideable": schema.Class("IndirectlyHideable", bases=["Root"], derived={"B"}, pragmas=["ql_hideable"]),
|
||||
"B": schema.Class("B", bases=["A", "IndirectlyHideable"], pragmas=["ql_hideable"]),
|
||||
"IndirectlyHideable": schema.Class(
|
||||
"IndirectlyHideable", bases=["Root"], derived={"B"}, pragmas=["ql_hideable"]
|
||||
),
|
||||
"B": schema.Class(
|
||||
"B", bases=["A", "IndirectlyHideable"], pragmas=["ql_hideable"]
|
||||
),
|
||||
"NonHideable": schema.Class("NonHideable", bases=["Root"]),
|
||||
}
|
||||
|
||||
@@ -771,7 +899,9 @@ def test_test_with():
|
||||
assert data.classes == {
|
||||
"Root": schema.Class("Root", derived=set("ABCD")),
|
||||
"A": schema.Class("A", bases=["Root"]),
|
||||
"B": schema.Class("B", bases=["Root"], pragmas={"qltest_test_with": "A"}, derived={'E'}),
|
||||
"B": schema.Class(
|
||||
"B", bases=["Root"], pragmas={"qltest_test_with": "A"}, derived={"E"}
|
||||
),
|
||||
"C": schema.Class("C", bases=["Root"], pragmas={"qltest_test_with": "D"}),
|
||||
"D": schema.Class("D", bases=["Root"]),
|
||||
"E": schema.Class("E", bases=["B"], pragmas={"qltest_test_with": "A"}),
|
||||
@@ -782,10 +912,10 @@ def test_annotate_docstring():
|
||||
@load
|
||||
class data:
|
||||
class Root:
|
||||
""" old docstring """
|
||||
"""old docstring"""
|
||||
|
||||
class A(Root):
|
||||
""" A docstring """
|
||||
"""A docstring"""
|
||||
|
||||
@defs.annotate(Root)
|
||||
class _:
|
||||
@@ -819,7 +949,15 @@ def test_annotate_decorations():
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
"Root": schema.Class("Root", pragmas=["qltest_skip", "cpp_skip", "ql_hideable", "qltest_collapse_hierarchy"]),
|
||||
"Root": schema.Class(
|
||||
"Root",
|
||||
pragmas=[
|
||||
"qltest_skip",
|
||||
"cpp_skip",
|
||||
"ql_hideable",
|
||||
"qltest_collapse_hierarchy",
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -837,11 +975,16 @@ def test_annotate_fields():
|
||||
z: defs.string
|
||||
|
||||
assert data.classes == {
|
||||
"Root": schema.Class("Root", properties=[
|
||||
schema.SingleProperty("x", "int", doc="foo"),
|
||||
schema.OptionalProperty("y", "Root", pragmas=["ql_internal"], is_child=True),
|
||||
schema.SingleProperty("z", "string"),
|
||||
]),
|
||||
"Root": schema.Class(
|
||||
"Root",
|
||||
properties=[
|
||||
schema.SingleProperty("x", "int", doc="foo"),
|
||||
schema.OptionalProperty(
|
||||
"y", "Root", pragmas=["ql_internal"], is_child=True
|
||||
),
|
||||
schema.SingleProperty("z", "string"),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -860,16 +1003,20 @@ def test_annotate_fields_negations():
|
||||
z: defs._ | ~defs.synth | ~defs.doc
|
||||
|
||||
assert data.classes == {
|
||||
"Root": schema.Class("Root", properties=[
|
||||
schema.SingleProperty("x", "int"),
|
||||
schema.OptionalProperty("y", "Root"),
|
||||
schema.SingleProperty("z", "string"),
|
||||
]),
|
||||
"Root": schema.Class(
|
||||
"Root",
|
||||
properties=[
|
||||
schema.SingleProperty("x", "int"),
|
||||
schema.OptionalProperty("y", "Root"),
|
||||
schema.SingleProperty("z", "string"),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_annotate_non_existing_field():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
class Root:
|
||||
@@ -882,6 +1029,7 @@ def test_annotate_non_existing_field():
|
||||
|
||||
def test_annotate_not_underscore():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
class Root:
|
||||
@@ -916,6 +1064,7 @@ def test_annotate_replace_bases():
|
||||
@defs.annotate(Derived, replace_bases={B: C})
|
||||
class _:
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
"Root": schema.Class("Root", derived={"A", "B"}),
|
||||
"A": schema.Class("A", bases=["Root"], derived={"Derived"}),
|
||||
@@ -946,6 +1095,7 @@ def test_annotate_add_bases():
|
||||
@defs.annotate(Derived, add_bases=(B, C))
|
||||
class _:
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
"Root": schema.Class("Root", derived={"A", "B", "C"}),
|
||||
"A": schema.Class("A", bases=["Root"], derived={"Derived"}),
|
||||
@@ -968,15 +1118,19 @@ def test_annotate_drop_field():
|
||||
y: defs.drop
|
||||
|
||||
assert data.classes == {
|
||||
"Root": schema.Class("Root", properties=[
|
||||
schema.SingleProperty("x", "int"),
|
||||
schema.SingleProperty("z", "boolean"),
|
||||
]),
|
||||
"Root": schema.Class(
|
||||
"Root",
|
||||
properties=[
|
||||
schema.SingleProperty("x", "int"),
|
||||
schema.SingleProperty("z", "boolean"),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_test_with_unknown_string():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
@load
|
||||
class data:
|
||||
class Root:
|
||||
@@ -989,6 +1143,7 @@ def test_test_with_unknown_string():
|
||||
|
||||
def test_test_with_unknown_class():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
class B:
|
||||
pass
|
||||
|
||||
@@ -1004,6 +1159,7 @@ def test_test_with_unknown_class():
|
||||
|
||||
def test_test_with_double():
|
||||
with pytest.raises(schema.Error):
|
||||
|
||||
class B:
|
||||
pass
|
||||
|
||||
@@ -1024,5 +1180,5 @@ def test_test_with_double():
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__] + sys.argv[1:]))
|
||||
|
||||
@@ -17,10 +17,16 @@ def generate_grouped(opts, renderer, dbscheme_input):
|
||||
dirs = {f.parent for f in generated}
|
||||
assert all(isinstance(f, pathlib.Path) for f in generated)
|
||||
assert all(f.name in ("TrapEntries", "TrapTags") for f in generated)
|
||||
assert set(f for f in generated if f.name == "TrapTags") == {output_dir / "TrapTags"}
|
||||
return ({
|
||||
str(d.relative_to(output_dir)): generated[d / "TrapEntries"] for d in dirs
|
||||
}, generated[output_dir / "TrapTags"])
|
||||
assert set(f for f in generated if f.name == "TrapTags") == {
|
||||
output_dir / "TrapTags"
|
||||
}
|
||||
return (
|
||||
{
|
||||
str(d.relative_to(output_dir)): generated[d / "TrapEntries"]
|
||||
for d in dirs
|
||||
},
|
||||
generated[output_dir / "TrapTags"],
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
@@ -65,87 +71,130 @@ def test_empty_tags(generate_tags):
|
||||
|
||||
def test_one_empty_table_rejected(generate_traps):
|
||||
with pytest.raises(AssertionError):
|
||||
generate_traps([
|
||||
dbscheme.Table(name="foos", columns=[]),
|
||||
])
|
||||
generate_traps(
|
||||
[
|
||||
dbscheme.Table(name="foos", columns=[]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_one_table(generate_traps):
|
||||
assert generate_traps([
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
|
||||
]) == [
|
||||
assert generate_traps(
|
||||
[
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
|
||||
]
|
||||
) == [
|
||||
cpp.Trap("foos", name="Foos", fields=[cpp.Field("bla", "int")]),
|
||||
]
|
||||
|
||||
|
||||
def test_one_table(generate_traps):
|
||||
assert generate_traps([
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
|
||||
]) == [
|
||||
assert generate_traps(
|
||||
[
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
|
||||
]
|
||||
) == [
|
||||
cpp.Trap("foos", name="Foos", fields=[cpp.Field("bla", "int")]),
|
||||
]
|
||||
|
||||
|
||||
def test_one_table_with_id(generate_traps):
|
||||
assert generate_traps([
|
||||
dbscheme.Table(name="foos", columns=[
|
||||
dbscheme.Column("bla", "int", binding=True)]),
|
||||
]) == [
|
||||
cpp.Trap("foos", name="Foos", fields=[cpp.Field(
|
||||
"bla", "int")], id=cpp.Field("bla", "int")),
|
||||
assert generate_traps(
|
||||
[
|
||||
dbscheme.Table(
|
||||
name="foos", columns=[dbscheme.Column("bla", "int", binding=True)]
|
||||
),
|
||||
]
|
||||
) == [
|
||||
cpp.Trap(
|
||||
"foos",
|
||||
name="Foos",
|
||||
fields=[cpp.Field("bla", "int")],
|
||||
id=cpp.Field("bla", "int"),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def test_one_table_with_two_binding_first_is_id(generate_traps):
|
||||
assert generate_traps([
|
||||
dbscheme.Table(name="foos", columns=[
|
||||
dbscheme.Column("x", "a", binding=True),
|
||||
dbscheme.Column("y", "b", binding=True),
|
||||
]),
|
||||
]) == [
|
||||
cpp.Trap("foos", name="Foos", fields=[
|
||||
cpp.Field("x", "a"),
|
||||
cpp.Field("y", "b"),
|
||||
], id=cpp.Field("x", "a")),
|
||||
assert generate_traps(
|
||||
[
|
||||
dbscheme.Table(
|
||||
name="foos",
|
||||
columns=[
|
||||
dbscheme.Column("x", "a", binding=True),
|
||||
dbscheme.Column("y", "b", binding=True),
|
||||
],
|
||||
),
|
||||
]
|
||||
) == [
|
||||
cpp.Trap(
|
||||
"foos",
|
||||
name="Foos",
|
||||
fields=[
|
||||
cpp.Field("x", "a"),
|
||||
cpp.Field("y", "b"),
|
||||
],
|
||||
id=cpp.Field("x", "a"),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("column,field", [
|
||||
(dbscheme.Column("x", "string"), cpp.Field("x", "std::string")),
|
||||
(dbscheme.Column("y", "boolean"), cpp.Field("y", "bool")),
|
||||
(dbscheme.Column("z", "@db_type"), cpp.Field("z", "TrapLabel<DbTypeTag>")),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"column,field",
|
||||
[
|
||||
(dbscheme.Column("x", "string"), cpp.Field("x", "std::string")),
|
||||
(dbscheme.Column("y", "boolean"), cpp.Field("y", "bool")),
|
||||
(dbscheme.Column("z", "@db_type"), cpp.Field("z", "TrapLabel<DbTypeTag>")),
|
||||
],
|
||||
)
|
||||
def test_one_table_special_types(generate_traps, column, field):
|
||||
assert generate_traps([
|
||||
dbscheme.Table(name="foos", columns=[column]),
|
||||
]) == [
|
||||
assert generate_traps(
|
||||
[
|
||||
dbscheme.Table(name="foos", columns=[column]),
|
||||
]
|
||||
) == [
|
||||
cpp.Trap("foos", name="Foos", fields=[field]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", ["start_line", "start_column", "end_line", "end_column", "index", "num_whatever"])
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
["start_line", "start_column", "end_line", "end_column", "index", "num_whatever"],
|
||||
)
|
||||
def test_one_table_overridden_unsigned_field(generate_traps, name):
|
||||
assert generate_traps([
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column(name, "bar")]),
|
||||
]) == [
|
||||
assert generate_traps(
|
||||
[
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column(name, "bar")]),
|
||||
]
|
||||
) == [
|
||||
cpp.Trap("foos", name="Foos", fields=[cpp.Field(name, "unsigned")]),
|
||||
]
|
||||
|
||||
|
||||
def test_one_table_overridden_underscore_named_field(generate_traps):
|
||||
assert generate_traps([
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column("whatever_", "bar")]),
|
||||
]) == [
|
||||
assert generate_traps(
|
||||
[
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column("whatever_", "bar")]),
|
||||
]
|
||||
) == [
|
||||
cpp.Trap("foos", name="Foos", fields=[cpp.Field("whatever", "bar")]),
|
||||
]
|
||||
|
||||
|
||||
def test_tables_with_dir(generate_grouped_traps):
|
||||
assert generate_grouped_traps([
|
||||
dbscheme.Table(name="x", columns=[dbscheme.Column("i", "int")]),
|
||||
dbscheme.Table(name="y", columns=[dbscheme.Column("i", "int")], dir=pathlib.Path("foo")),
|
||||
dbscheme.Table(name="z", columns=[dbscheme.Column("i", "int")], dir=pathlib.Path("foo/bar")),
|
||||
]) == {
|
||||
assert generate_grouped_traps(
|
||||
[
|
||||
dbscheme.Table(name="x", columns=[dbscheme.Column("i", "int")]),
|
||||
dbscheme.Table(
|
||||
name="y", columns=[dbscheme.Column("i", "int")], dir=pathlib.Path("foo")
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="z",
|
||||
columns=[dbscheme.Column("i", "int")],
|
||||
dir=pathlib.Path("foo/bar"),
|
||||
),
|
||||
]
|
||||
) == {
|
||||
".": [cpp.Trap("x", name="X", fields=[cpp.Field("i", "int")])],
|
||||
"foo": [cpp.Trap("y", name="Y", fields=[cpp.Field("i", "int")])],
|
||||
"foo/bar": [cpp.Trap("z", name="Z", fields=[cpp.Field("i", "int")])],
|
||||
@@ -153,15 +202,22 @@ def test_tables_with_dir(generate_grouped_traps):
|
||||
|
||||
|
||||
def test_one_table_no_tags(generate_tags):
|
||||
assert generate_tags([
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
|
||||
]) == []
|
||||
assert (
|
||||
generate_tags(
|
||||
[
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
|
||||
]
|
||||
)
|
||||
== []
|
||||
)
|
||||
|
||||
|
||||
def test_one_union_tags(generate_tags):
|
||||
assert generate_tags([
|
||||
dbscheme.Union(lhs="@left_hand_side", rhs=["@b", "@a", "@c"]),
|
||||
]) == [
|
||||
assert generate_tags(
|
||||
[
|
||||
dbscheme.Union(lhs="@left_hand_side", rhs=["@b", "@a", "@c"]),
|
||||
]
|
||||
) == [
|
||||
cpp.Tag(name="LeftHandSide", bases=[], id="@left_hand_side"),
|
||||
cpp.Tag(name="A", bases=["LeftHandSide"], id="@a"),
|
||||
cpp.Tag(name="B", bases=["LeftHandSide"], id="@b"),
|
||||
@@ -170,11 +226,13 @@ def test_one_union_tags(generate_tags):
|
||||
|
||||
|
||||
def test_multiple_union_tags(generate_tags):
|
||||
assert generate_tags([
|
||||
dbscheme.Union(lhs="@d", rhs=["@a"]),
|
||||
dbscheme.Union(lhs="@a", rhs=["@b", "@c"]),
|
||||
dbscheme.Union(lhs="@e", rhs=["@c", "@f"]),
|
||||
]) == [
|
||||
assert generate_tags(
|
||||
[
|
||||
dbscheme.Union(lhs="@d", rhs=["@a"]),
|
||||
dbscheme.Union(lhs="@a", rhs=["@b", "@c"]),
|
||||
dbscheme.Union(lhs="@e", rhs=["@c", "@f"]),
|
||||
]
|
||||
) == [
|
||||
cpp.Tag(name="D", bases=[], id="@d"),
|
||||
cpp.Tag(name="E", bases=[], id="@e"),
|
||||
cpp.Tag(name="A", bases=["D"], id="@a"),
|
||||
@@ -184,5 +242,5 @@ def test_multiple_union_tags(generate_tags):
|
||||
]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__] + sys.argv[1:]))
|
||||
|
||||
@@ -39,8 +39,9 @@ def opts():
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def override_paths(tmp_path):
|
||||
with mock.patch("misc.codegen.lib.paths.root_dir", tmp_path), \
|
||||
mock.patch("misc.codegen.lib.paths.exe_file", tmp_path / "exe"):
|
||||
with mock.patch("misc.codegen.lib.paths.root_dir", tmp_path), mock.patch(
|
||||
"misc.codegen.lib.paths.exe_file", tmp_path / "exe"
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user