Move swift/codegen to misc/codegen

This commit is contained in:
Paolo Tranquilli
2023-02-24 13:44:29 +01:00
parent 6d192cdcc1
commit cdd4e8021b
65 changed files with 116 additions and 86 deletions

View File

@@ -1,2 +0,0 @@
[pep8]
max_line_length = 120

View File

@@ -1,19 +1,15 @@
load("@swift_codegen_deps//:requirements.bzl", "requirement")
load("@bazel_skylib//rules:native_binary.bzl", "native_binary")
py_binary(
native_binary(
name = "codegen",
srcs = ["codegen.py"],
out = "codegen",
src = "//misc/codegen",
data = [
"//swift:schema",
"//swift:codegen_conf",
"//swift/codegen/templates:cpp",
"//swift/codegen/templates:trap",
],
args = [
"--configuration-file=$(location //swift:codegen_conf)",
],
visibility = ["//swift:__subpackages__"],
deps = [
"//swift/codegen/generators",
],
)

View File

@@ -1,44 +0,0 @@
# Code generation suite
This directory contains the code generation suite used by the Swift extractor and the QL library. This suite will use
the abstract class specification of [`schema.yml`](schema.yml) to generate:
* [the `dbscheme` file](../ql/lib/swift.dbscheme) (see [`dbschemegen.py`](generators/dbschemegen.py))
* [the QL generated code](../ql/lib/codeql/swift/generated) and when
appropriate [the corresponding stubs](../ql/lib/codeql/swift/elements) (see [`qlgen.py`](generators/qlgen.py))
* C++ tags and trap entries (see [`trapgen.py`](generators/trapgen.py))
* C++ structured classes (see [`cppgen.py`](generators/cppgen.py))
## Usage
By default `bazel run //swift/codegen` will update all checked-in generated files (`dbscheme` and QL sources). You can
append `--` followed by other options to tweak the behaviour, which is mainly intended for debugging.
See `bazel run //swift/codegen -- --help` for a list of all options. In particular `--generate` can be used with a comma
separated list to select what to generate (choosing among `dbscheme`, `ql`, `trap` and `cpp`).
C++ code is generated during build (see [`swift/extractor/trap/BUILD.bazel`](../extractor/trap/BUILD.bazel)). After a
build you can browse the generated code in `bazel-bin/swift/extractor/trap/generated`.
For debugging you can also run `./codegen.py` directly. You must then ensure dependencies are installed, which you can
with the command
```bash
pip3 install -r ./requirements.txt
```
## Implementation notes
The suite uses [mustache templating](https://mustache.github.io/) for generation. Templates are
in [the `templates` directory](templates), prefixed with the generation target they are used for.
Rather than passing dictionaries to the templating engine, python dataclasses are used as defined
in [the `lib` directory](lib). For each of the four generation targets the entry point for the implementation is
specified as the `generate` function in the modules within [the `generators` directory](generators).
Finally, [`codegen.py`](codegen.py) is the driver script gluing everything together and specifying the command line
options.
Unit tests are in [the `test` directory](test) and can be run via `bazel test //swift/codegen/test`.
For more details about each specific generation target, please refer to the module docstrings
in [the `generators` directory](generators).

View File

@@ -1,115 +0,0 @@
#!/usr/bin/env python3
""" Driver script to run all code generation """
import argparse
import logging
import os
import sys
import pathlib
import typing
import shlex
if 'BUILD_WORKSPACE_DIRECTORY' not in os.environ:
# we are not running with `bazel run`, set up module search path
_repo_root = pathlib.Path(__file__).resolve().parents[2]
sys.path.append(str(_repo_root))
from swift.codegen.lib import render, paths
from swift.codegen.generators import generate
def _parse_args() -> argparse.Namespace:
dirs = [pathlib.Path().resolve()]
dirs.extend(dirs[0].parents)
for dir in dirs:
conf = dir / "codegen.conf"
if conf.exists():
break
else:
conf = None
p = argparse.ArgumentParser(description="Code generation suite")
p.add_argument("--generate", type=lambda x: x.split(","),
help="specify what targets to generate as a comma separated list, choosing among dbscheme, ql, trap "
"and cpp")
p.add_argument("--verbose", "-v", action="store_true", help="print more information")
p.add_argument("--quiet", "-q", action="store_true", help="only print errors")
p.add_argument("--configuration-file", "-c", type=_abspath, default=conf,
help="A configuration file to load options from. By default, the first codegen.conf file found by "
"going up directories from the current location. If present all paths provided in options are "
"considered relative to its directory")
p.add_argument("--root-dir", type=_abspath,
help="the directory that should be regarded as the root of the language pack codebase. Used to "
"compute QL imports and in some comments and as root for relative paths provided as options. "
"If not provided it defaults to the directory of the configuration file, if any")
path_arguments = [
p.add_argument("--schema", default="schema.py",
help="input schema file (default %(default)s)"),
p.add_argument("--dbscheme",
help="output file for dbscheme generation, input file for trap generation"),
p.add_argument("--ql-output",
help="output directory for generated QL files"),
p.add_argument("--ql-stub-output",
help="output directory for QL stub/customization files. Defines also the "
"generated qll file importing every class file"),
p.add_argument("--ql-test-output",
help="output directory for QL generated extractor test files"),
p.add_argument("--cpp-output",
help="output directory for generated C++ files, required if trap or cpp is provided to "
"--generate"),
p.add_argument("--generated-registry",
help="registry file containing information about checked-in generated code"),
]
p.add_argument("--script-name",
help="script name to put in header comments of generated files. By default, the path of this "
"script relative to the root directory")
p.add_argument("--trap-library",
help="path to the trap library from an include directory, required if generating C++ trap bindings"),
p.add_argument("--ql-format", action="store_true", default=True,
help="use codeql to autoformat QL files (which is the default)")
p.add_argument("--no-ql-format", action="store_false", dest="ql_format", help="do not format QL files")
p.add_argument("--codeql-binary", default="codeql", help="command to use for QL formatting (default %(default)s)")
p.add_argument("--force", "-f", action="store_true",
help="generate all files without skipping unchanged files and overwriting modified ones")
p.add_argument("--use-current-directory", action="store_true",
help="do not consider paths as relative to --root-dir or the configuration directory")
opts = p.parse_args()
if opts.configuration_file is not None:
with open(opts.configuration_file) as config:
defaults = p.parse_args(shlex.split(config.read(), comments=True))
for flag, value in opts._get_kwargs():
if value is None:
setattr(opts, flag, getattr(defaults, flag))
if opts.root_dir is None:
opts.root_dir = opts.configuration_file.parent
if not opts.generate:
p.error("Nothing to do, specify --generate")
# absolutize all paths
for arg in path_arguments:
path = getattr(opts, arg.dest)
if path is not None:
setattr(opts, arg.dest, _abspath(path) if opts.use_current_directory else (opts.root_dir / path))
if not opts.script_name:
opts.script_name = paths.exe_file.relative_to(opts.root_dir)
return opts
def _abspath(x: str) -> typing.Optional[pathlib.Path]:
return pathlib.Path(x).resolve() if x else None
def run():
opts = _parse_args()
if opts.verbose:
log_level = logging.DEBUG
elif opts.quiet:
log_level = logging.ERROR
else:
log_level = logging.INFO
logging.basicConfig(format="{levelname} {message}", style='{', level=log_level)
for target in opts.generate:
generate(target, opts, render.Renderer(opts.script_name, opts.root_dir))
if __name__ == "__main__":
run()

View File

@@ -1,11 +0,0 @@
load("@swift_codegen_deps//:requirements.bzl", "requirement")
py_library(
name = "generators",
srcs = glob(["*.py"]),
visibility = ["//swift/codegen:__subpackages__"],
deps = [
"//swift/codegen/lib",
"//swift/codegen/loaders",
],
)

View File

@@ -1,6 +0,0 @@
from . import dbschemegen, qlgen, trapgen, cppgen
def generate(target, opts, renderer):
module = globals()[f"{target}gen"]
module.generate(opts, renderer)

View File

@@ -1,99 +0,0 @@
"""
C++ trap class generation
`generate(opts, renderer)` will generate `TrapClasses.h` out of a `yml` schema file.
Each class in the schema gets a corresponding `struct` in `TrapClasses.h`, where:
* inheritance is preserved
* each property will be a corresponding field in the `struct` (with repeated properties mapping to `std::vector` and
optional ones to `std::optional`)
* final classes get a streaming operator that serializes the whole class into the corresponding trap emissions (using
`TrapEntries.h` from `trapgen`).
"""
import functools
import typing
import inflection
from swift.codegen.lib import cpp, schema
from swift.codegen.loaders import schemaloader
def _get_type(t: str, add_or_none_except: typing.Optional[str] = None) -> str:
if t is None:
# this is a predicate
return "bool"
if t == "string":
return "std::string"
if t == "boolean":
return "bool"
if t[0].isupper():
if add_or_none_except is not None and t != add_or_none_except:
suffix = "OrNone"
else:
suffix = ""
return f"TrapLabel<{t}{suffix}Tag>"
return t
def _get_field(cls: schema.Class, p: schema.Property, add_or_none_except: typing.Optional[str] = None) -> cpp.Field:
trap_name = None
if not p.is_single:
trap_name = inflection.camelize(f"{cls.name}_{p.name}")
if not p.is_predicate:
trap_name = inflection.pluralize(trap_name)
args = dict(
field_name=p.name + ("_" if p.name in cpp.cpp_keywords else ""),
base_type=_get_type(p.type, add_or_none_except),
is_optional=p.is_optional,
is_repeated=p.is_repeated,
is_predicate=p.is_predicate,
trap_name=trap_name,
)
args.update(cpp.get_field_override(p.name))
return cpp.Field(**args)
class Processor:
def __init__(self, data: schema.Schema):
self._classmap = data.classes
if data.null:
root_type = next(iter(data.classes))
self._add_or_none_except = root_type
else:
self._add_or_none_except = None
@functools.lru_cache(maxsize=None)
def _get_class(self, name: str) -> cpp.Class:
cls = self._classmap[name]
trap_name = None
if not cls.derived or any(p.is_single for p in cls.properties):
trap_name = inflection.pluralize(cls.name)
return cpp.Class(
name=name,
bases=[self._get_class(b) for b in cls.bases],
fields=[
_get_field(cls, p, self._add_or_none_except)
for p in cls.properties if "cpp_skip" not in p.pragmas
],
final=not cls.derived,
trap_name=trap_name,
)
def get_classes(self):
ret = {'': []}
for k, cls in self._classmap.items():
if not cls.ipa:
ret.setdefault(cls.group, []).append(self._get_class(cls.name))
return ret
def generate(opts, renderer):
assert opts.cpp_output
processor = Processor(schemaloader.load_file(opts.schema))
out = opts.cpp_output
for dir, classes in processor.get_classes().items():
renderer.render(cpp.ClassList(classes, opts.schema,
include_parent=bool(dir),
trap_library=opts.trap_library), out / dir / "TrapClasses")

View File

@@ -1,132 +0,0 @@
"""
dbscheme file generation
`generate(opts, renderer)` will generate a `dbscheme` file out of a `yml` schema file.
Each final class in the schema file will get a corresponding defining DB table with the id and single properties as
columns.
Moreover:
* single properties in non-final classes will also trigger generation of a table with an id reference and all single
properties as columns
* each optional property will trigger generation of a table with an id reference and the property value as columns
* each repeated property will trigger generation of a table with an id reference, an `int` index and the property value
as columns
The type hierarchy will be translated to corresponding `union` declarations.
"""
import typing
import inflection
from swift.codegen.lib import schema
from swift.codegen.loaders import schemaloader
from swift.codegen.lib.dbscheme import *
log = logging.getLogger(__name__)
def dbtype(typename: str, add_or_none_except: typing.Optional[str] = None) -> str:
""" translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes.
For class types, appends an underscore followed by `null` if provided
"""
if typename[0].isupper():
underscored = inflection.underscore(typename)
if add_or_none_except is not None and typename != add_or_none_except:
suffix = "_or_none"
else:
suffix = ""
return f"@{underscored}{suffix}"
return typename
def cls_to_dbscheme(cls: schema.Class, lookup: typing.Dict[str, schema.Class], add_or_none_except: typing.Optional[str] = None):
""" Yield all dbscheme entities needed to model class `cls` """
if cls.ipa:
return
if cls.derived:
yield Union(dbtype(cls.name), (dbtype(c) for c in cls.derived if not lookup[c].ipa))
dir = pathlib.Path(cls.group) if cls.group else None
# output a table specific to a class only if it is a leaf class or it has 1-to-1 properties
# Leaf classes need a table to bind the `@` ids
# 1-to-1 properties are added to a class specific table
# in other cases, separate tables are used for the properties, and a class specific table is unneeded
if not cls.derived or any(f.is_single for f in cls.properties):
binding = not cls.derived
keyset = KeySet(["id"]) if cls.derived else None
yield Table(
keyset=keyset,
name=inflection.tableize(cls.name),
columns=[
Column("id", type=dbtype(cls.name), binding=binding),
] + [
Column(f.name, dbtype(f.type, add_or_none_except)) for f in cls.properties if f.is_single
],
dir=dir,
)
# use property-specific tables for 1-to-many and 1-to-at-most-1 properties
for f in cls.properties:
if f.is_repeated:
yield Table(
keyset=KeySet(["id", "index"]),
name=inflection.tableize(f"{cls.name}_{f.name}"),
columns=[
Column("id", type=dbtype(cls.name)),
Column("index", type="int"),
Column(inflection.singularize(f.name), dbtype(f.type, add_or_none_except)),
],
dir=dir,
)
elif f.is_optional:
yield Table(
keyset=KeySet(["id"]),
name=inflection.tableize(f"{cls.name}_{f.name}"),
columns=[
Column("id", type=dbtype(cls.name)),
Column(f.name, dbtype(f.type, add_or_none_except)),
],
dir=dir,
)
elif f.is_predicate:
yield Table(
keyset=KeySet(["id"]),
name=inflection.underscore(f"{cls.name}_{f.name}"),
columns=[
Column("id", type=dbtype(cls.name)),
],
dir=dir,
)
def get_declarations(data: schema.Schema):
add_or_none_except = data.root_class.name if data.null else None
declarations = [d for cls in data.classes.values() for d in cls_to_dbscheme(cls, data.classes, add_or_none_except)]
if data.null:
property_classes = {
prop.type for cls in data.classes.values() for prop in cls.properties
if cls.name != data.null and prop.type and prop.type[0].isupper()
}
declarations += [
Union(dbtype(t, data.null), [dbtype(t), dbtype(data.null)]) for t in sorted(property_classes)
]
return declarations
def get_includes(data: schema.Schema, include_dir: pathlib.Path, root_dir: pathlib.Path):
includes = []
for inc in data.includes:
inc = include_dir / inc
with open(inc) as inclusion:
includes.append(SchemeInclude(src=inc.relative_to(root_dir), data=inclusion.read()))
return includes
def generate(opts, renderer):
input = opts.schema
out = opts.dbscheme
data = schemaloader.load_file(input)
dbscheme = Scheme(src=input.name,
includes=get_includes(data, include_dir=input.parent, root_dir=input.parent),
declarations=get_declarations(data))
renderer.render(dbscheme, out)

View File

@@ -1,427 +0,0 @@
"""
QL code generation
`generate(opts, renderer)` will generate in the library directory:
* generated/Raw.qll with thin class wrappers around DB types
* generated/Synth.qll with the base algebraic datatypes for AST entities
* generated/<group>/<Class>.qll with generated properties for each class
* if not already modified, a elements/<group>/<Class>.qll stub to customize the above classes
* elements.qll importing all the above stubs
* if not already modified, a elements/<group>/<Class>Constructor.qll stub to customize the algebraic datatype
characteristic predicate
* generated/SynthConstructors.qll importing all the above constructor stubs
* generated/PureSynthConstructors.qll importing constructor stubs for pure synthesized types (that is, not
corresponding to raw types)
Moreover in the test directory for each <Class> in <group> it will generate beneath the
extractor-tests/generated/<group>/<Class> directory either
* a `MISSING_SOURCE.txt` explanation file if no source is present, or
* one `<Class>.ql` test query for all single properties and on `<Class>_<property>.ql` test query for each optional or
repeated property
"""
# TODO this should probably be split in different generators now: ql, qltest, maybe qlsynth
import logging
import pathlib
import re
import subprocess
import typing
import itertools
import inflection
from swift.codegen.lib import schema, ql
from swift.codegen.loaders import schemaloader
log = logging.getLogger(__name__)
class Error(Exception):
def __str__(self):
return self.args[0]
class FormatError(Error):
pass
class RootElementHasChildren(Error):
pass
class NoClasses(Error):
pass
abbreviations = {
"expr": "expression",
"arg": "argument",
"stmt": "statement",
"decl": "declaration",
"repr": "representation",
"param": "parameter",
"int": "integer",
"var": "variable",
"ref": "reference",
}
abbreviations.update({f"{k}s": f"{v}s" for k, v in abbreviations.items()})
_abbreviations_re = re.compile("|".join(fr"\b{abbr}\b" for abbr in abbreviations))
def _humanize(s: str) -> str:
ret = inflection.humanize(s)
ret = ret[0].lower() + ret[1:]
ret = _abbreviations_re.sub(lambda m: abbreviations[m[0]], ret)
return ret
_format_re = re.compile(r"\{(\w+)\}")
def _get_doc(cls: schema.Class, prop: schema.Property, plural=None):
if prop.doc:
if plural is None:
# for consistency, ignore format in non repeated properties
return _format_re.sub(lambda m: m[1], prop.doc)
format = prop.doc
nouns = [m[1] for m in _format_re.finditer(prop.doc)]
if not nouns:
noun, _, rest = prop.doc.partition(" ")
format = f"{{{noun}}} {rest}"
nouns = [noun]
transform = inflection.pluralize if plural else inflection.singularize
return format.format(**{noun: transform(noun) for noun in nouns})
prop_name = _humanize(prop.name)
class_name = cls.default_doc_name or _humanize(inflection.underscore(cls.name))
if prop.is_predicate:
return f"this {class_name} {prop_name}"
if plural is not None:
prop_name = inflection.pluralize(prop_name) if plural else inflection.singularize(prop_name)
return f"{prop_name} of this {class_name}"
def get_ql_property(cls: schema.Class, prop: schema.Property, prev_child: str = "") -> ql.Property:
args = dict(
type=prop.type if not prop.is_predicate else "predicate",
qltest_skip="qltest_skip" in prop.pragmas,
prev_child=prev_child if prop.is_child else None,
is_optional=prop.is_optional,
is_predicate=prop.is_predicate,
description=prop.description
)
if prop.is_single:
args.update(
singular=inflection.camelize(prop.name),
tablename=inflection.tableize(cls.name),
tableparams=["this"] + ["result" if p is prop else "_" for p in cls.properties if p.is_single],
doc=_get_doc(cls, prop),
)
elif prop.is_repeated:
args.update(
singular=inflection.singularize(inflection.camelize(prop.name)),
plural=inflection.pluralize(inflection.camelize(prop.name)),
tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
tableparams=["this", "index", "result"],
doc=_get_doc(cls, prop, plural=False),
doc_plural=_get_doc(cls, prop, plural=True),
)
elif prop.is_optional:
args.update(
singular=inflection.camelize(prop.name),
tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
tableparams=["this", "result"],
doc=_get_doc(cls, prop),
)
elif prop.is_predicate:
args.update(
singular=inflection.camelize(prop.name, uppercase_first_letter=False),
tablename=inflection.underscore(f"{cls.name}_{prop.name}"),
tableparams=["this"],
doc=_get_doc(cls, prop),
)
else:
raise ValueError(f"unknown property kind for {prop.name} from {cls.name}")
return ql.Property(**args)
def get_ql_class(cls: schema.Class) -> ql.Class:
pragmas = {k: True for k in cls.pragmas if k.startswith("ql")}
prev_child = ""
properties = []
for p in cls.properties:
prop = get_ql_property(cls, p, prev_child)
if prop.is_child:
prev_child = prop.singular
properties.append(prop)
return ql.Class(
name=cls.name,
bases=cls.bases,
final=not cls.derived,
properties=properties,
dir=pathlib.Path(cls.group or ""),
ipa=bool(cls.ipa),
doc=cls.doc,
**pragmas,
)
def _to_db_type(x: str) -> str:
if x[0].isupper():
return "Raw::" + x
return x
_final_db_class_lookup = {}
def get_ql_ipa_class_db(name: str) -> ql.Synth.FinalClassDb:
return _final_db_class_lookup.setdefault(name, ql.Synth.FinalClassDb(name=name,
params=[
ql.Synth.Param("id", _to_db_type(name))]))
def get_ql_ipa_class(cls: schema.Class):
if cls.derived:
return ql.Synth.NonFinalClass(name=cls.name, derived=sorted(cls.derived),
root=not cls.bases)
if cls.ipa and cls.ipa.from_class is not None:
source = cls.ipa.from_class
get_ql_ipa_class_db(source).subtract_type(cls.name)
return ql.Synth.FinalClassDerivedIpa(name=cls.name,
params=[ql.Synth.Param("id", _to_db_type(source))])
if cls.ipa and cls.ipa.on_arguments is not None:
return ql.Synth.FinalClassFreshIpa(name=cls.name,
params=[ql.Synth.Param(k, _to_db_type(v))
for k, v in cls.ipa.on_arguments.items()])
return get_ql_ipa_class_db(cls.name)
def get_import(file: pathlib.Path, root_dir: pathlib.Path):
stem = file.relative_to(root_dir / "ql/lib").with_suffix("")
return str(stem).replace("/", ".")
def get_types_used_by(cls: ql.Class) -> typing.Iterable[str]:
for b in cls.bases:
yield b.base
for p in cls.properties:
yield p.type
def get_classes_used_by(cls: ql.Class) -> typing.List[str]:
return sorted(set(t for t in get_types_used_by(cls) if t[0].isupper() and t != cls.name))
def format(codeql, files):
ql_files = [str(f) for f in files if f.suffix in (".qll", ".ql")]
if not ql_files:
return
format_cmd = [codeql, "query", "format", "--in-place", "--"] + ql_files
res = subprocess.run(format_cmd, stderr=subprocess.PIPE, text=True)
if res.returncode:
for line in res.stderr.splitlines():
log.error(line.strip())
raise FormatError("QL format failed")
for line in res.stderr.splitlines():
log.debug(line.strip())
def _get_path(cls: schema.Class) -> pathlib.Path:
return pathlib.Path(cls.group or "", cls.name).with_suffix(".qll")
def _get_all_properties(cls: schema.Class, lookup: typing.Dict[str, schema.Class],
already_seen: typing.Optional[typing.Set[int]] = None) -> \
typing.Iterable[typing.Tuple[schema.Class, schema.Property]]:
# deduplicate using ids
if already_seen is None:
already_seen = set()
for b in sorted(cls.bases):
base = lookup[b]
for item in _get_all_properties(base, lookup, already_seen):
yield item
for p in cls.properties:
if id(p) not in already_seen:
already_seen.add(id(p))
yield cls, p
def _get_all_properties_to_be_tested(cls: schema.Class, lookup: typing.Dict[str, schema.Class]) -> \
typing.Iterable[ql.PropertyForTest]:
for c, p in _get_all_properties(cls, lookup):
if not ("qltest_skip" in c.pragmas or "qltest_skip" in p.pragmas):
# TODO here operations are duplicated, but should be better if we split ql and qltest generation
p = get_ql_property(c, p)
yield ql.PropertyForTest(p.getter, is_total=p.is_single or p.is_predicate,
type=p.type if not p.is_predicate else None, is_repeated=p.is_repeated)
if p.is_repeated and not p.is_optional:
yield ql.PropertyForTest(f"getNumberOf{p.plural}", type="int")
elif p.is_optional and not p.is_repeated:
yield ql.PropertyForTest(f"has{p.singular}")
def _partition_iter(x, pred):
x1, x2 = itertools.tee(x)
return filter(pred, x1), itertools.filterfalse(pred, x2)
def _partition(l, pred):
""" partitions a list according to boolean predicate """
return map(list, _partition_iter(l, pred))
def _is_in_qltest_collapsed_hierarchy(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
return "qltest_collapse_hierarchy" in cls.pragmas or _is_under_qltest_collapsed_hierarchy(cls, lookup)
def _is_under_qltest_collapsed_hierarchy(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
return "qltest_uncollapse_hierarchy" not in cls.pragmas and any(
_is_in_qltest_collapsed_hierarchy(lookup[b], lookup) for b in cls.bases)
def _should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
return "qltest_skip" in cls.pragmas or not (
cls.final or "qltest_collapse_hierarchy" in cls.pragmas) or _is_under_qltest_collapsed_hierarchy(
cls, lookup)
def _get_stub(cls: schema.Class, base_import: str, generated_import_prefix: str) -> ql.Stub:
if isinstance(cls.ipa, schema.IpaInfo):
if cls.ipa.from_class is not None:
accessors = [
ql.IpaUnderlyingAccessor(
argument="Entity",
type=_to_db_type(cls.ipa.from_class),
constructorparams=["result"]
)
]
elif cls.ipa.on_arguments is not None:
accessors = [
ql.IpaUnderlyingAccessor(
argument=inflection.camelize(arg),
type=_to_db_type(type),
constructorparams=["result" if a == arg else "_" for a in cls.ipa.on_arguments]
) for arg, type in cls.ipa.on_arguments.items()
]
else:
accessors = []
return ql.Stub(name=cls.name, base_import=base_import, import_prefix=generated_import_prefix, ipa_accessors=accessors)
def generate(opts, renderer):
input = opts.schema
out = opts.ql_output
stub_out = opts.ql_stub_output
test_out = opts.ql_test_output
missing_test_source_filename = "MISSING_SOURCE.txt"
include_file = stub_out.with_suffix(".qll")
generated = {q for q in out.rglob("*.qll")}
generated.add(include_file)
generated.update(q for q in test_out.rglob("*.ql"))
generated.update(q for q in test_out.rglob(missing_test_source_filename))
stubs = {q for q in stub_out.rglob("*.qll")}
data = schemaloader.load_file(input)
classes = {name: get_ql_class(cls) for name, cls in data.classes.items()}
if not classes:
raise NoClasses
root = next(iter(classes.values()))
if root.has_children:
raise RootElementHasChildren(root)
imports = {}
generated_import_prefix = get_import(out, opts.root_dir)
with renderer.manage(generated=generated, stubs=stubs, registry=opts.generated_registry,
force=opts.force) as renderer:
db_classes = [cls for cls in classes.values() if not cls.ipa]
renderer.render(ql.DbClasses(db_classes), out / "Raw.qll")
classes_by_dir_and_name = sorted(classes.values(), key=lambda cls: (cls.dir, cls.name))
for c in classes_by_dir_and_name:
imports[c.name] = get_import(stub_out / c.path, opts.root_dir)
for c in classes.values():
qll = out / c.path.with_suffix(".qll")
c.imports = [imports[t] for t in get_classes_used_by(c)]
c.import_prefix = generated_import_prefix
renderer.render(c, qll)
for c in data.classes.values():
path = _get_path(c)
stub_file = stub_out / path
if not renderer.is_customized_stub(stub_file):
base_import = get_import(out / path, opts.root_dir)
renderer.render(_get_stub(c, base_import, generated_import_prefix), stub_file)
# for example path/to/elements -> path/to/elements.qll
renderer.render(ql.ImportList([i for name, i in imports.items() if not classes[name].ql_internal]),
include_file)
elements_module = get_import(include_file, opts.root_dir)
renderer.render(
ql.GetParentImplementation(
classes=list(classes.values()),
imports=[elements_module] + [i for name, i in imports.items() if classes[name].ql_internal],
),
out / 'ParentChild.qll')
for c in data.classes.values():
if _should_skip_qltest(c, data.classes):
continue
test_dir = test_out / c.group / c.name
test_dir.mkdir(parents=True, exist_ok=True)
if all(f.suffix in (".txt", ".ql", ".actual", ".expected") for f in test_dir.glob("*.*")):
log.warning(f"no test source in {test_dir.relative_to(test_out)}")
renderer.render(ql.MissingTestInstructions(),
test_dir / missing_test_source_filename)
continue
total_props, partial_props = _partition(_get_all_properties_to_be_tested(c, data.classes),
lambda p: p.is_total)
renderer.render(ql.ClassTester(class_name=c.name,
properties=total_props,
elements_module=elements_module,
# in case of collapsed hierarchies we want to see the actual QL class in results
show_ql_class="qltest_collapse_hierarchy" in c.pragmas),
test_dir / f"{c.name}.ql")
for p in partial_props:
renderer.render(ql.PropertyTester(class_name=c.name,
elements_module=elements_module,
property=p), test_dir / f"{c.name}_{p.getter}.ql")
final_ipa_types = []
non_final_ipa_types = []
constructor_imports = []
ipa_constructor_imports = []
stubs = {}
for cls in sorted(data.classes.values(), key=lambda cls: (cls.group, cls.name)):
ipa_type = get_ql_ipa_class(cls)
if ipa_type.is_final:
final_ipa_types.append(ipa_type)
if ipa_type.has_params:
stub_file = stub_out / cls.group / f"{cls.name}Constructor.qll"
if not renderer.is_customized_stub(stub_file):
# stub rendering must be postponed as we might not have yet all subtracted ipa types in `ipa_type`
stubs[stub_file] = ql.Synth.ConstructorStub(ipa_type, import_prefix=generated_import_prefix)
constructor_import = get_import(stub_file, opts.root_dir)
constructor_imports.append(constructor_import)
if ipa_type.is_ipa:
ipa_constructor_imports.append(constructor_import)
else:
non_final_ipa_types.append(ipa_type)
for stub_file, data in stubs.items():
renderer.render(data, stub_file)
renderer.render(ql.Synth.Types(root.name, generated_import_prefix,
final_ipa_types, non_final_ipa_types), out / "Synth.qll")
renderer.render(ql.ImportList(constructor_imports), out / "SynthConstructors.qll")
renderer.render(ql.ImportList(ipa_constructor_imports), out / "PureSynthConstructors.qll")
if opts.ql_format:
format(opts.codeql_binary, renderer.written)

View File

@@ -1,98 +0,0 @@
"""
C++ trap entry generation
`generate(opts, renderer)` will generate `TrapTags.h` (for types of labels) and `TrapEntries.h` (for trap emission) out
of a dbscheme file.
Each table in the `dbscheme` gets a corresponding `struct` defined in `TrapEntries.h` with a field for each column and
an appropriate streaming operator for the trap emission.
Unions in the `dbscheme` are used to populate a hierarchy of tags (empty structs) in `TrapTags.h` that is used to
enforce a type system on trap labels (see `TrapLabel.h`).
"""
import logging
import pathlib
import inflection
from toposort import toposort_flatten
from swift.codegen.lib import dbscheme, cpp
from swift.codegen.loaders import dbschemeloader
log = logging.getLogger(__name__)
def get_tag_name(s):
assert s.startswith("@")
return inflection.camelize(s[1:])
def get_cpp_type(schema_type: str):
if schema_type.startswith("@"):
tag = get_tag_name(schema_type)
return f"TrapLabel<{tag}Tag>"
if schema_type == "string":
return "std::string"
if schema_type == "boolean":
return "bool"
return schema_type
def get_field(c: dbscheme.Column):
args = {
"field_name": c.schema_name,
"base_type": c.type,
}
args.update(cpp.get_field_override(c.schema_name))
args["base_type"] = get_cpp_type(args["base_type"])
return cpp.Field(**args)
def get_binding_column(t: dbscheme.Table):
try:
return next(c for c in t.columns if c.binding)
except StopIteration:
return None
def get_trap(t: dbscheme.Table):
id = get_binding_column(t)
if id:
id = get_field(id)
return cpp.Trap(
table_name=t.name,
name=inflection.camelize(t.name),
fields=[get_field(c) for c in t.columns],
id=id,
)
def generate(opts, renderer):
assert opts.cpp_output
tag_graph = {}
out = opts.cpp_output
trap_library = opts.trap_library
traps = {pathlib.Path(): []}
for e in dbschemeloader.iterload(opts.dbscheme):
if e.is_table:
traps.setdefault(e.dir, []).append(get_trap(e))
elif e.is_union:
tag_graph.setdefault(e.lhs, set())
for d in e.rhs:
tag_graph.setdefault(d.type, set()).add(e.lhs)
for dir, entries in traps.items():
dir = dir or pathlib.Path()
relative_gen_dir = pathlib.Path(*[".." for _ in dir.parents])
renderer.render(cpp.TrapList(entries, opts.dbscheme, trap_library, relative_gen_dir), out / dir / "TrapEntries")
tags = []
for tag in toposort_flatten(tag_graph):
tags.append(cpp.Tag(
name=get_tag_name(tag),
bases=[get_tag_name(b) for b in sorted(tag_graph[tag])],
id=tag,
))
renderer.render(cpp.TagList(tags, opts.dbscheme), out / "TrapTags")

View File

@@ -1,11 +0,0 @@
load("@swift_codegen_deps//:requirements.bzl", "requirement")
py_library(
name = "lib",
srcs = glob(["*.py"]),
visibility = ["//swift/codegen:__subpackages__"],
deps = [
requirement("pystache"),
requirement("inflection"),
],
)

View File

@@ -1,163 +0,0 @@
import pathlib
import re
from dataclasses import dataclass, field
from typing import List, ClassVar
# taken from https://en.cppreference.com/w/cpp/keyword
cpp_keywords = {"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", "atomic_commit", "atomic_noexcept",
"auto", "bitand", "bitor", "bool", "break", "case", "catch", "char", "char8_t", "char16_t", "char32_t",
"class", "compl", "concept", "const", "consteval", "constexpr", "constinit", "const_cast", "continue",
"co_await", "co_return", "co_yield", "decltype", "default", "delete", "do", "double", "dynamic_cast",
"else", "enum", "explicit", "export", "extern", "false", "float", "for", "friend", "goto", "if",
"inline", "int", "long", "mutable", "namespace", "new", "noexcept", "not", "not_eq", "nullptr",
"operator", "or", "or_eq", "private", "protected", "public", "reflexpr", "register", "reinterpret_cast",
"requires", "return", "short", "signed", "sizeof", "static", "static_assert", "static_cast", "struct",
"switch", "synchronized", "template", "this", "thread_local", "throw", "true", "try", "typedef",
"typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile", "wchar_t", "while",
"xor", "xor_eq"}
_field_overrides = [
(re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"), {"base_type": "unsigned"}),
(re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}),
]
def get_field_override(field: str):
for r, o in _field_overrides:
m = r.fullmatch(field)
if m:
return o(m) if callable(o) else o
return {}
@dataclass
class Field:
field_name: str
base_type: str
is_optional: bool = False
is_repeated: bool = False
is_predicate: bool = False
trap_name: str = None
first: bool = False
def __post_init__(self):
if self.field_name in cpp_keywords:
self.field_name += "_"
@property
def type(self) -> str:
type = self.base_type
if self.is_optional:
type = f"std::optional<{type}>"
if self.is_repeated:
type = f"std::vector<{type}>"
return type
# using @property breaks pystache internals here
def get_streamer(self):
if self.type == "std::string":
return lambda x: f"trapQuoted({x})"
elif self.type == "bool":
return lambda x: f'({x} ? "true" : "false")'
else:
return lambda x: x
@property
def is_single(self):
return not (self.is_optional or self.is_repeated or self.is_predicate)
@property
def is_label(self):
return self.base_type.startswith("TrapLabel<")
@dataclass
class Trap:
table_name: str
name: str
fields: List[Field]
id: Field = None
def __post_init__(self):
assert self.fields
self.fields[0].first = True
@dataclass
class TagBase:
base: str
first: bool = False
@dataclass
class Tag:
name: str
bases: List[TagBase]
id: str
def __post_init__(self):
if self.bases:
self.bases = [TagBase(b) for b in self.bases]
self.bases[0].first = True
@property
def has_bases(self):
return bool(self.bases)
@dataclass
class TrapList:
template: ClassVar = 'trap_traps'
extensions = ["h", "cpp"]
traps: List[Trap]
source: str
trap_library_dir: pathlib.Path
gen_dir: pathlib.Path
@dataclass
class TagList:
template: ClassVar = 'trap_tags'
extensions = ["h"]
tags: List[Tag]
source: str
@dataclass
class ClassBase:
ref: 'Class'
first: bool = False
@dataclass
class Class:
name: str
bases: List[ClassBase] = field(default_factory=list)
final: bool = False
fields: List[Field] = field(default_factory=list)
trap_name: str = None
def __post_init__(self):
self.bases = [ClassBase(c) for c in sorted(self.bases, key=lambda cls: cls.name)]
if self.bases:
self.bases[0].first = True
@property
def has_bases(self):
return bool(self.bases)
@property
def single_fields(self):
return [f for f in self.fields if f.is_single]
@dataclass
class ClassList:
template: ClassVar = "cpp_classes"
extensions: ClassVar = ["h", "cpp"]
classes: List[Class]
source: str
trap_library: str
include_parent: bool = False

View File

@@ -1,107 +0,0 @@
""" dbscheme format representation """
import logging
import pathlib
import re
from dataclasses import dataclass
from typing import ClassVar, List
log = logging.getLogger(__name__)
dbscheme_keywords = {"case", "boolean", "int", "string", "type"}
@dataclass
class Column:
schema_name: str
type: str
binding: bool = False
first: bool = False
@property
def name(self):
if self.schema_name in dbscheme_keywords:
return self.schema_name + "_"
return self.schema_name
@property
def lhstype(self):
if self.type[0] == "@":
return "unique int" if self.binding else "int"
return self.type
@property
def rhstype(self):
if self.type[0] == "@" and self.binding:
return self.type
return self.type + " ref"
@dataclass
class KeySetId:
id: str
first: bool = False
@dataclass
class KeySet:
ids: List[KeySetId]
def __post_init__(self):
assert self.ids
self.ids = [KeySetId(x) for x in self.ids]
self.ids[0].first = True
class Decl:
is_table = False
is_union = False
@dataclass
class Table(Decl):
is_table: ClassVar = True
name: str
columns: List[Column]
keyset: KeySet = None
dir: pathlib.Path = None
def __post_init__(self):
if self.columns:
self.columns[0].first = True
@dataclass
class UnionCase:
type: str
first: bool = False
@dataclass
class Union(Decl):
is_union: ClassVar = True
lhs: str
rhs: List[UnionCase]
def __post_init__(self):
assert self.rhs
self.rhs = [UnionCase(x) for x in self.rhs]
self.rhs.sort(key=lambda c: c.type)
self.rhs[0].first = True
@dataclass
class SchemeInclude:
src: str
data: str
@dataclass
class Scheme:
template: ClassVar = 'dbscheme'
src: str
includes: List[SchemeInclude]
declarations: List[Decl]

View File

@@ -1,18 +0,0 @@
""" module providing useful filesystem paths """
import pathlib
import sys
import os
try:
workspace_dir = pathlib.Path(os.environ['BUILD_WORKSPACE_DIRECTORY']).resolve() # <- means we are using bazel run
root_dir = workspace_dir / 'swift'
except KeyError:
_this_file = pathlib.Path(__file__).resolve()
root_dir = _this_file.parents[2]
workspace_dir = root_dir.parent
lib_dir = root_dir / 'codegen' / 'lib'
templates_dir = root_dir / 'codegen' / 'templates'
exe_file = pathlib.Path(sys.argv[0]).resolve()

View File

@@ -1,312 +0,0 @@
"""
QL files generation
`generate(opts, renderer)` will generate QL classes and manage stub files out of a `yml` schema file.
Each class (for example, `Foo`) in the schema triggers:
* generation of a `FooBase` class implementation translating all properties into appropriate getters
* if not created or already customized, generation of a stub file which defines `Foo` as extending `FooBase`. This can
be used to add hand-written code to `Foo`, which requires removal of the `// generated` header comment in that file.
All generated base classes actually import these customizations when referencing other classes.
Generated files that do not correspond any more to any class in the schema are deleted. Customized stubs are however
left behind and must be dealt with by hand.
"""
import pathlib
from dataclasses import dataclass, field
import itertools
from typing import List, ClassVar, Union, Optional
import inflection
@dataclass
class Param:
param: str
first: bool = False
@dataclass
class Property:
singular: str
type: Optional[str] = None
tablename: Optional[str] = None
tableparams: List[Param] = field(default_factory=list)
plural: Optional[str] = None
first: bool = False
is_optional: bool = False
is_predicate: bool = False
prev_child: Optional[str] = None
qltest_skip: bool = False
description: List[str] = field(default_factory=list)
doc: Optional[str] = None
doc_plural: Optional[str] = None
def __post_init__(self):
if self.tableparams:
self.tableparams = [Param(x) for x in self.tableparams]
self.tableparams[0].first = True
@property
def getter(self):
return f"get{self.singular}" if not self.is_predicate else self.singular
@property
def indefinite_getter(self):
if self.plural:
article = "An" if self.singular[0] in "AEIO" else "A"
return f"get{article}{self.singular}"
@property
def type_is_class(self):
return bool(self.type) and self.type[0].isupper()
@property
def is_repeated(self):
return bool(self.plural)
@property
def is_single(self):
return not (self.is_optional or self.is_repeated or self.is_predicate)
@property
def is_child(self):
return self.prev_child is not None
@property
def has_description(self) -> bool:
return bool(self.description)
@dataclass
class Base:
base: str
prev: str = ""
def __str__(self):
return self.base
@dataclass
class Class:
template: ClassVar = 'ql_class'
name: str
bases: List[Base] = field(default_factory=list)
final: bool = False
properties: List[Property] = field(default_factory=list)
dir: pathlib.Path = pathlib.Path()
imports: List[str] = field(default_factory=list)
import_prefix: Optional[str] = None
qltest_skip: bool = False
qltest_collapse_hierarchy: bool = False
qltest_uncollapse_hierarchy: bool = False
ql_internal: bool = False
ipa: bool = False
doc: List[str] = field(default_factory=list)
def __post_init__(self):
self.bases = [Base(str(b), str(prev)) for b, prev in zip(self.bases, itertools.chain([""], self.bases))]
if self.properties:
self.properties[0].first = True
@property
def root(self) -> bool:
return not self.bases
@property
def path(self) -> pathlib.Path:
return self.dir / self.name
@property
def db_id(self) -> str:
return "@" + inflection.underscore(self.name)
@property
def has_children(self) -> bool:
return any(p.is_child for p in self.properties)
@property
def last_base(self) -> str:
return self.bases[-1].base if self.bases else ""
@property
def has_doc(self) -> bool:
return bool(self.doc) or self.ql_internal
@dataclass
class IpaUnderlyingAccessor:
argument: str
type: str
constructorparams: List[Param]
def __post_init__(self):
if self.constructorparams:
self.constructorparams = [Param(x) for x in self.constructorparams]
self.constructorparams[0].first = True
@dataclass
class Stub:
template: ClassVar = 'ql_stub'
name: str
base_import: str
import_prefix: str
ipa_accessors: List[IpaUnderlyingAccessor] = field(default_factory=list)
@property
def has_ipa_accessors(self) -> bool:
return bool(self.ipa_accessors)
@dataclass
class DbClasses:
template: ClassVar = 'ql_db'
classes: List[Class] = field(default_factory=list)
@dataclass
class ImportList:
template: ClassVar = 'ql_imports'
imports: List[str] = field(default_factory=list)
@dataclass
class GetParentImplementation:
template: ClassVar = 'ql_parent'
classes: List[Class] = field(default_factory=list)
imports: List[str] = field(default_factory=list)
@dataclass
class PropertyForTest:
getter: str
is_total: bool = True
type: Optional[str] = None
is_repeated: bool = False
@dataclass
class TesterBase:
class_name: str
elements_module: str
@dataclass
class ClassTester(TesterBase):
template: ClassVar = 'ql_test_class'
properties: List[PropertyForTest] = field(default_factory=list)
show_ql_class: bool = False
@dataclass
class PropertyTester(TesterBase):
template: ClassVar = 'ql_test_property'
property: PropertyForTest
@dataclass
class MissingTestInstructions:
template: ClassVar = 'ql_test_missing'
class Synth:
@dataclass
class Class:
is_final: ClassVar = False
name: str
first: bool = False
@dataclass
class Param:
param: str
type: str
first: bool = False
@dataclass
class FinalClass(Class):
is_final: ClassVar = True
is_derived_ipa: ClassVar = False
is_fresh_ipa: ClassVar = False
is_db: ClassVar = False
params: List["Synth.Param"] = field(default_factory=list)
def __post_init__(self):
if self.params:
self.params[0].first = True
@property
def is_ipa(self):
return self.is_fresh_ipa or self.is_derived_ipa
@property
def has_params(self) -> bool:
return bool(self.params)
@dataclass
class FinalClassIpa(FinalClass):
pass
@dataclass
class FinalClassDerivedIpa(FinalClassIpa):
is_derived_ipa: ClassVar = True
@dataclass
class FinalClassFreshIpa(FinalClassIpa):
is_fresh_ipa: ClassVar = True
@dataclass
class FinalClassDb(FinalClass):
is_db: ClassVar = True
subtracted_ipa_types: List["Synth.Class"] = field(default_factory=list)
def subtract_type(self, type: str):
self.subtracted_ipa_types.append(Synth.Class(type, first=not self.subtracted_ipa_types))
@property
def has_subtracted_ipa_types(self) -> bool:
return bool(self.subtracted_ipa_types)
@property
def db_id(self) -> str:
return "@" + inflection.underscore(self.name)
@dataclass
class NonFinalClass(Class):
derived: List["Synth.Class"] = field(default_factory=list)
root: bool = False
def __post_init__(self):
self.derived = [Synth.Class(c) for c in self.derived]
if self.derived:
self.derived[0].first = True
@dataclass
class Types:
template: ClassVar = "ql_ipa_types"
root: str
import_prefix: str
final_classes: List["Synth.FinalClass"] = field(default_factory=list)
non_final_classes: List["Synth.NonFinalClass"] = field(default_factory=list)
def __post_init__(self):
if self.final_classes:
self.final_classes[0].first = True
@dataclass
class ConstructorStub:
template: ClassVar = "ql_ipa_constructor_stub"
cls: "Synth.FinalClass"
import_prefix: str

View File

@@ -1,198 +0,0 @@
""" template renderer module, wrapping around `pystache.Renderer`
`pystache` is a python mustache engine, and mustache is a template language. More information on
https://mustache.github.io/
"""
import logging
import pathlib
import typing
import hashlib
from dataclasses import dataclass
import pystache
from . import paths
log = logging.getLogger(__name__)
class Error(Exception):
pass
class Renderer:
""" Template renderer using mustache templates in the `templates` directory """
def __init__(self, generator: pathlib.Path, root_dir: pathlib.Path):
self._r = pystache.Renderer(search_dirs=str(paths.templates_dir), escape=lambda u: u)
self._root_dir = root_dir
self._generator = generator
def _get_path(self, file: pathlib.Path):
return file.relative_to(self._root_dir)
def render(self, data: object, output: pathlib.Path):
""" Render `data` to `output`.
`data` must have a `template` attribute denoting which template to use from the template directory.
Optionally, `data` can also have an `extensions` attribute denoting list of file extensions: they will all be
appended to the template name with an underscore and be generated in turn.
"""
mnemonic = type(data).__name__
output.parent.mkdir(parents=True, exist_ok=True)
extensions = getattr(data, "extensions", [None])
for ext in extensions:
output_filename = output
template = data.template
if ext:
output_filename = output_filename.with_suffix(f".{ext}")
template += f"_{ext}"
contents = self._r.render_name(template, data, generator=self._generator)
self._do_write(mnemonic, contents, output_filename)
def _do_write(self, mnemonic: str, contents: str, output: pathlib.Path):
with open(output, "w") as out:
out.write(contents)
log.debug(f"{mnemonic}: generated {output.name}")
def manage(self, generated: typing.Iterable[pathlib.Path], stubs: typing.Iterable[pathlib.Path],
registry: pathlib.Path, force: bool = False) -> "RenderManager":
return RenderManager(self._generator, self._root_dir, generated, stubs, registry, force)
class RenderManager(Renderer):
""" A context manager allowing to manage checked in generated files and their cleanup, able
to skip unneeded writes.
This is done by using and updating a checked in list of generated files that assigns two
hashes to each file:
* one is the hash of the mustache rendered contents, that can be used to quickly check whether a
write is needed
* the other is the hash of the actual file after code generation has finished. This will be
different from the above because of post-processing like QL formatting. This hash is used
to detect invalid modification of generated files"""
written: typing.Set[pathlib.Path]
@dataclass
class Hashes:
"""
pre contains the hash of a file as rendered, post is the hash after
postprocessing (for example QL formatting)
"""
pre: str
post: typing.Optional[str] = None
def __init__(self, generator: pathlib.Path, root_dir: pathlib.Path, generated: typing.Iterable[pathlib.Path],
stubs: typing.Iterable[pathlib.Path],
registry: pathlib.Path, force: bool = False):
super().__init__(generator, root_dir)
self._registry_path = registry
self._force = force
self._hashes = {}
self.written = set()
self._existing = set()
self._skipped = set()
self._load_registry()
self._process_generated(generated)
self._process_stubs(stubs)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val is None:
for f in self._existing - self._skipped - self.written:
f.unlink(missing_ok=True)
log.info(f"removed {f.name}")
for f in self.written:
self._hashes[self._get_path(f)].post = self._hash_file(f)
else:
# if an error was encountered, drop already written files from the registry
# so that they get the chance to be regenerated again during the next run
for f in self.written:
self._hashes.pop(self._get_path(f), None)
# clean up the registry from files that do not exist any more
for f in list(self._hashes):
if not (self._root_dir / f).exists():
self._hashes.pop(f)
self._dump_registry()
def _do_write(self, mnemonic: str, contents: str, output: pathlib.Path):
hash = self._hash_string(contents)
rel_output = self._get_path(output)
if rel_output in self._hashes and self._hashes[rel_output].pre == hash:
self._skipped.add(output)
log.debug(f"{mnemonic}: skipped {output.name}")
else:
self.written.add(output)
super()._do_write(mnemonic, contents, output)
self._hashes[rel_output] = self.Hashes(pre=hash)
def _process_generated(self, generated: typing.Iterable[pathlib.Path]):
for f in generated:
self._existing.add(f)
rel_path = self._get_path(f)
if self._force:
pass
elif rel_path not in self._hashes:
log.warning(f"{rel_path} marked as generated but absent from the registry")
elif self._hashes[rel_path].post != self._hash_file(f):
raise Error(f"{rel_path} is generated but was modified, please revert the file "
"or pass --force to overwrite")
def _process_stubs(self, stubs: typing.Iterable[pathlib.Path]):
for f in stubs:
rel_path = self._get_path(f)
if self.is_customized_stub(f):
self._hashes.pop(rel_path, None)
continue
self._existing.add(f)
if self._force:
pass
elif rel_path not in self._hashes:
log.warning(f"{rel_path} marked as stub but absent from the registry")
elif self._hashes[rel_path].post != self._hash_file(f):
raise Error(f"{rel_path} is a stub marked as generated, but it was modified, "
"please remove the `// generated` header, revert the file or pass --force to overwrite it")
@staticmethod
def is_customized_stub(file: pathlib.Path) -> bool:
if not file.is_file():
return False
with open(file) as contents:
for line in contents:
return not line.startswith("// generated")
# no lines
return True
@staticmethod
def _hash_file(filename: pathlib.Path) -> str:
with open(filename) as inp:
return RenderManager._hash_string(inp.read())
@staticmethod
def _hash_string(data: str) -> str:
h = hashlib.sha256()
h.update(data.encode())
return h.hexdigest()
def _load_registry(self):
if self._force:
return
try:
with open(self._registry_path) as reg:
for line in reg:
filename, prehash, posthash = line.split()
self._hashes[pathlib.Path(filename)] = self.Hashes(prehash, posthash)
except FileNotFoundError:
pass
def _dump_registry(self):
self._registry_path.parent.mkdir(parents=True, exist_ok=True)
with open(self._registry_path, 'w') as out:
for f, hashes in sorted(self._hashes.items()):
print(f, hashes.pre, hashes.post, file=out)

View File

@@ -1,194 +0,0 @@
""" schema format representation """
import typing
from dataclasses import dataclass, field
from typing import List, Set, Union, Dict, Optional
from enum import Enum, auto
import functools
class Error(Exception):
def __str__(self):
return self.args[0]
def _check_type(t: Optional[str], known: typing.Iterable[str]):
if t is not None and t not in known:
raise Error(f"Unknown type {t}")
@dataclass
class Property:
class Kind(Enum):
SINGLE = auto()
REPEATED = auto()
OPTIONAL = auto()
REPEATED_OPTIONAL = auto()
PREDICATE = auto()
kind: Kind
name: Optional[str] = None
type: Optional[str] = None
is_child: bool = False
pragmas: List[str] = field(default_factory=list)
doc: Optional[str] = None
description: List[str] = field(default_factory=list)
@property
def is_single(self) -> bool:
return self.kind == self.Kind.SINGLE
@property
def is_optional(self) -> bool:
return self.kind in (self.Kind.OPTIONAL, self.Kind.REPEATED_OPTIONAL)
@property
def is_repeated(self) -> bool:
return self.kind in (self.Kind.REPEATED, self.Kind.REPEATED_OPTIONAL)
@property
def is_predicate(self) -> bool:
return self.kind == self.Kind.PREDICATE
@property
def has_class_type(self) -> bool:
return bool(self.type) and self.type[0].isupper()
@property
def has_builtin_type(self) -> bool:
return bool(self.type) and self.type[0].islower()
SingleProperty = functools.partial(Property, Property.Kind.SINGLE)
OptionalProperty = functools.partial(Property, Property.Kind.OPTIONAL)
RepeatedProperty = functools.partial(Property, Property.Kind.REPEATED)
RepeatedOptionalProperty = functools.partial(
Property, Property.Kind.REPEATED_OPTIONAL)
PredicateProperty = functools.partial(Property, Property.Kind.PREDICATE)
@dataclass
class IpaInfo:
from_class: Optional[str] = None
on_arguments: Optional[Dict[str, str]] = None
@dataclass
class Class:
name: str
bases: List[str] = field(default_factory=list)
derived: Set[str] = field(default_factory=set)
properties: List[Property] = field(default_factory=list)
group: str = ""
pragmas: List[str] = field(default_factory=list)
ipa: Optional[Union[IpaInfo, bool]] = None
"""^^^ filled with `True` for non-final classes with only synthesized final descendants """
doc: List[str] = field(default_factory=list)
default_doc_name: Optional[str] = None
@property
def final(self):
return not self.derived
def check_types(self, known: typing.Iterable[str]):
for b in self.bases:
_check_type(b, known)
for d in self.derived:
_check_type(d, known)
for p in self.properties:
_check_type(p.type, known)
if self.ipa is not None:
_check_type(self.ipa.from_class, known)
if self.ipa.on_arguments is not None:
for t in self.ipa.on_arguments.values():
_check_type(t, known)
@dataclass
class Schema:
classes: Dict[str, Class] = field(default_factory=dict)
includes: Set[str] = field(default_factory=set)
null: Optional[str] = None
@property
def root_class(self):
# always the first in the dictionary
return next(iter(self.classes.values()))
@property
def null_class(self):
return self.classes[self.null] if self.null else None
predicate_marker = object()
TypeRef = Union[type, str]
@functools.singledispatch
def get_type_name(arg: TypeRef) -> str:
raise Error(f"Not a schema type or string ({arg})")
@get_type_name.register
def _(arg: type):
return arg.__name__
@get_type_name.register
def _(arg: str):
return arg
@functools.singledispatch
def _make_property(arg: object) -> Property:
if arg is predicate_marker:
return PredicateProperty()
raise Error(f"Illegal property specifier {arg}")
@_make_property.register(str)
@_make_property.register(type)
def _(arg: TypeRef):
return SingleProperty(type=get_type_name(arg))
@_make_property.register
def _(arg: Property):
return arg
class PropertyModifier:
""" Modifier of `Property` objects.
Being on the right of `|` it will trigger construction of a `Property` from
the left operand.
"""
def __ror__(self, other: object) -> Property:
ret = _make_property(other)
self.modify(ret)
return ret
def modify(self, prop: Property):
raise NotImplementedError
def split_doc(doc):
# implementation inspired from https://peps.python.org/pep-0257/
if not doc:
return []
lines = doc.splitlines()
# Determine minimum indentation (first line doesn't count):
strippedlines = (line.lstrip() for line in lines[1:])
indents = [len(line) - len(stripped) for line, stripped in zip(lines[1:], strippedlines) if stripped]
# Remove indentation (first line is special):
trimmed = [lines[0].strip()]
if indents:
indent = min(indents)
trimmed.extend(line[indent:].rstrip() for line in lines[1:])
# Strip off trailing and leading blank lines:
while trimmed and not trimmed[-1]:
trimmed.pop()
while trimmed and not trimmed[0]:
trimmed.pop(0)
return trimmed

View File

@@ -1,149 +0,0 @@
from typing import Callable as _Callable
from swift.codegen.lib import schema as _schema
import inspect as _inspect
from dataclasses import dataclass as _dataclass
class _ChildModifier(_schema.PropertyModifier):
def modify(self, prop: _schema.Property):
if prop.type is None or prop.type[0].islower():
raise _schema.Error("Non-class properties cannot be children")
prop.is_child = True
@_dataclass
class _DocModifier(_schema.PropertyModifier):
doc: str
def modify(self, prop: _schema.Property):
if "\n" in self.doc or self.doc[-1] == ".":
raise _schema.Error("No newlines or trailing dots are allowed in doc, did you intend to use desc?")
prop.doc = self.doc
@_dataclass
class _DescModifier(_schema.PropertyModifier):
description: str
def modify(self, prop: _schema.Property):
prop.description = _schema.split_doc(self.description)
def include(source: str):
# add to `includes` variable in calling context
_inspect.currentframe().f_back.f_locals.setdefault(
"__includes", []).append(source)
class _Namespace:
""" simple namespacing mechanism """
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
qltest = _Namespace()
ql = _Namespace()
cpp = _Namespace()
synth = _Namespace()
@_dataclass
class _Pragma(_schema.PropertyModifier):
""" A class or property pragma.
For properties, it functions similarly to a `_PropertyModifier` with `|`, adding the pragma.
For schema classes it acts as a python decorator with `@`.
"""
pragma: str
def __post_init__(self):
namespace, _, name = self.pragma.partition('_')
setattr(globals()[namespace], name, self)
def modify(self, prop: _schema.Property):
prop.pragmas.append(self.pragma)
def __call__(self, cls: type) -> type:
""" use this pragma as a decorator on classes """
if "_pragmas" in cls.__dict__: # not using hasattr as we don't want to land on inherited pragmas
cls._pragmas.append(self.pragma)
else:
cls._pragmas = [self.pragma]
return cls
class _Optionalizer(_schema.PropertyModifier):
def modify(self, prop: _schema.Property):
K = _schema.Property.Kind
if prop.kind != K.SINGLE:
raise _schema.Error(
"Optional should only be applied to simple property types")
prop.kind = K.OPTIONAL
class _Listifier(_schema.PropertyModifier):
def modify(self, prop: _schema.Property):
K = _schema.Property.Kind
if prop.kind == K.SINGLE:
prop.kind = K.REPEATED
elif prop.kind == K.OPTIONAL:
prop.kind = K.REPEATED_OPTIONAL
else:
raise _schema.Error(
"Repeated should only be applied to simple or optional property types")
class _TypeModifier:
""" Modifies types using get item notation """
def __init__(self, modifier: _schema.PropertyModifier):
self.modifier = modifier
def __getitem__(self, item):
return item | self.modifier
_ClassDecorator = _Callable[[type], type]
def _annotate(**kwargs) -> _ClassDecorator:
def f(cls: type) -> type:
for k, v in kwargs.items():
setattr(cls, f"_{k}", v)
return cls
return f
boolean = "boolean"
int = "int"
string = "string"
predicate = _schema.predicate_marker
optional = _TypeModifier(_Optionalizer())
list = _TypeModifier(_Listifier())
child = _ChildModifier()
doc = _DocModifier
desc = _DescModifier
use_for_null = _annotate(null=True)
_Pragma("qltest_skip")
_Pragma("qltest_collapse_hierarchy")
_Pragma("qltest_uncollapse_hierarchy")
ql.default_doc_name = lambda doc: _annotate(doc_name=doc)
_Pragma("ql_internal")
_Pragma("cpp_skip")
def group(name: str = "") -> _ClassDecorator:
return _annotate(group=name)
synth.from_class = lambda ref: _annotate(ipa=_schema.IpaInfo(
from_class=_schema.get_type_name(ref)))
synth.on_arguments = lambda **kwargs: _annotate(
ipa=_schema.IpaInfo(on_arguments={k: _schema.get_type_name(t) for k, t in kwargs.items()}))

View File

@@ -1,11 +0,0 @@
load("@swift_codegen_deps//:requirements.bzl", "requirement")
py_library(
name = "loaders",
srcs = glob(["*.py"]),
visibility = ["//swift/codegen:__subpackages__"],
deps = [
requirement("toposort"),
requirement("inflection"),
],
)

View File

@@ -1,54 +0,0 @@
import pathlib
import re
from swift.codegen.lib import dbscheme
class _Re:
entity = re.compile(
"(?m)"
r"(?:^#keyset\[(?P<tablekeys>[\w\s,]+)\][\s\n]*)?^(?P<table>\w+)\("
r"(?:\s*//dir=(?P<tabledir>\S*))?(?P<tablebody>[^\)]*)"
r"\);?"
"|"
r"^(?P<union>@\w+)\s*=\s*(?P<unionbody>@\w+(?:\s*\|\s*@\w+)*)\s*;?"
)
field = re.compile(r"(?m)[\w\s]*\s(?P<field>\w+)\s*:\s*(?P<type>@?\w+)(?P<ref>\s+ref)?")
key = re.compile(r"@\w+")
comment = re.compile(r"(?m)(?s)/\*.*?\*/|//(?!dir=)[^\n]*$") # lookahead avoid ignoring metadata like //dir=foo
def _get_column(match):
return dbscheme.Column(
schema_name=match["field"].rstrip("_"),
type=match["type"],
binding=not match["ref"],
)
def _get_table(match):
keyset = None
if match["tablekeys"]:
keyset = dbscheme.KeySet(k.strip() for k in match["tablekeys"].split(","))
return dbscheme.Table(
name=match["table"],
columns=[_get_column(f) for f in _Re.field.finditer(match["tablebody"])],
keyset=keyset,
dir=pathlib.PosixPath(match["tabledir"]) if match["tabledir"] else None,
)
def _get_union(match):
return dbscheme.Union(
lhs=match["union"],
rhs=(d[0] for d in _Re.key.finditer(match["unionbody"])),
)
def iterload(file):
with open(file) as file:
data = _Re.comment.sub("", file.read())
for e in _Re.entity.finditer(data):
if e["table"]:
yield _get_table(e)
elif e["union"]:
yield _get_union(e)

View File

@@ -1,133 +0,0 @@
""" schema loader """
import inflection
import typing
import types
import pathlib
import importlib.util
from dataclasses import dataclass
from toposort import toposort_flatten
from swift.codegen.lib import schema, schemadefs
@dataclass
class _PropertyNamer(schema.PropertyModifier):
name: str
def modify(self, prop: schema.Property):
prop.name = self.name.rstrip("_")
def _get_class(cls: type) -> schema.Class:
if not isinstance(cls, type):
raise schema.Error(f"Only class definitions allowed in schema, found {cls}")
# we must check that going to dbscheme names and back is preserved
# In particular this will not happen if uppercase acronyms are included in the name
to_underscore_and_back = inflection.camelize(inflection.underscore(cls.__name__), uppercase_first_letter=True)
if cls.__name__ != to_underscore_and_back:
raise schema.Error(f"Class name must be upper camel-case, without capitalized acronyms, found {cls.__name__} "
f"instead of {to_underscore_and_back}")
if len({b._group for b in cls.__bases__ if hasattr(b, "_group")}) > 1:
raise schema.Error(f"Bases with mixed groups for {cls.__name__}")
if any(getattr(b, "_null", False) for b in cls.__bases__):
raise schema.Error(f"Null class cannot be derived")
return schema.Class(name=cls.__name__,
bases=[b.__name__ for b in cls.__bases__ if b is not object],
derived={d.__name__ for d in cls.__subclasses__()},
# getattr to inherit from bases
group=getattr(cls, "_group", ""),
# in the following we don't use `getattr` to avoid inheriting
pragmas=cls.__dict__.get("_pragmas", []),
ipa=cls.__dict__.get("_ipa", None),
properties=[
a | _PropertyNamer(n)
for n, a in cls.__dict__.get("__annotations__", {}).items()
],
doc=schema.split_doc(cls.__doc__),
default_doc_name=cls.__dict__.get("_doc_name"),
)
def _toposort_classes_by_group(classes: typing.Dict[str, schema.Class]) -> typing.Dict[str, schema.Class]:
groups = {}
ret = {}
for name, cls in classes.items():
groups.setdefault(cls.group, []).append(name)
for group, grouped in sorted(groups.items()):
inheritance = {name: classes[name].bases for name in grouped}
for name in toposort_flatten(inheritance):
ret[name] = classes[name]
return ret
def _fill_ipa_information(classes: typing.Dict[str, schema.Class]):
""" Take a dictionary where the `ipa` field is filled for all explicitly synthesized classes
and update it so that all non-final classes that have only synthesized final descendants
get `True` as` value for the `ipa` field
"""
if not classes:
return
is_ipa: typing.Dict[str, bool] = {}
def fill_is_ipa(name: str):
if name not in is_ipa:
cls = classes[name]
for d in cls.derived:
fill_is_ipa(d)
if cls.ipa is not None:
is_ipa[name] = True
elif not cls.derived:
is_ipa[name] = False
else:
is_ipa[name] = all(is_ipa[d] for d in cls.derived)
root = next(iter(classes))
fill_is_ipa(root)
for name, cls in classes.items():
if cls.ipa is None and is_ipa[name]:
cls.ipa = True
def load(m: types.ModuleType) -> schema.Schema:
includes = set()
classes = {}
known = {"int", "string", "boolean"}
known.update(n for n in m.__dict__ if not n.startswith("__"))
import swift.codegen.lib.schemadefs as defs
null = None
for name, data in m.__dict__.items():
if hasattr(defs, name):
continue
if name == "__includes":
includes = set(data)
continue
if name.startswith("__"):
continue
cls = _get_class(data)
if classes and not cls.bases:
raise schema.Error(
f"Only one root class allowed, found second root {name}")
cls.check_types(known)
classes[name] = cls
if getattr(data, "_null", False):
if null is not None:
raise schema.Error(f"Null class {null} already defined, second null class {name} not allowed")
null = name
cls.is_null_class = True
_fill_ipa_information(classes)
return schema.Schema(includes=includes, classes=_toposort_classes_by_group(classes), null=null)
def load_file(path: pathlib.Path) -> schema.Schema:
spec = importlib.util.spec_from_file_location("schema", path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return load(module)

View File

@@ -1,5 +0,0 @@
inflection
pystache
pytest
pyyaml
toposort

View File

@@ -1,149 +0,0 @@
from typing import Callable as _Callable
from swift.codegen.lib import schema as _schema
import inspect as _inspect
from dataclasses import dataclass as _dataclass
class _ChildModifier(_schema.PropertyModifier):
def modify(self, prop: _schema.Property):
if prop.type is None or prop.type[0].islower():
raise _schema.Error("Non-class properties cannot be children")
prop.is_child = True
@_dataclass
class _DocModifier(_schema.PropertyModifier):
doc: str
def modify(self, prop: _schema.Property):
if "\n" in self.doc or self.doc[-1] == ".":
raise _schema.Error("No newlines or trailing dots are allowed in doc, did you intend to use desc?")
prop.doc = self.doc
@_dataclass
class _DescModifier(_schema.PropertyModifier):
description: str
def modify(self, prop: _schema.Property):
prop.description = _schema.split_doc(self.description)
def include(source: str):
# add to `includes` variable in calling context
_inspect.currentframe().f_back.f_locals.setdefault(
"__includes", []).append(source)
class _Namespace:
""" simple namespacing mechanism """
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
qltest = _Namespace()
ql = _Namespace()
cpp = _Namespace()
synth = _Namespace()
@_dataclass
class _Pragma(_schema.PropertyModifier):
""" A class or property pragma.
For properties, it functions similarly to a `_PropertyModifier` with `|`, adding the pragma.
For schema classes it acts as a python decorator with `@`.
"""
pragma: str
def __post_init__(self):
namespace, _, name = self.pragma.partition('_')
setattr(globals()[namespace], name, self)
def modify(self, prop: _schema.Property):
prop.pragmas.append(self.pragma)
def __call__(self, cls: type) -> type:
""" use this pragma as a decorator on classes """
if "_pragmas" in cls.__dict__: # not using hasattr as we don't want to land on inherited pragmas
cls._pragmas.append(self.pragma)
else:
cls._pragmas = [self.pragma]
return cls
class _Optionalizer(_schema.PropertyModifier):
def modify(self, prop: _schema.Property):
K = _schema.Property.Kind
if prop.kind != K.SINGLE:
raise _schema.Error(
"Optional should only be applied to simple property types")
prop.kind = K.OPTIONAL
class _Listifier(_schema.PropertyModifier):
def modify(self, prop: _schema.Property):
K = _schema.Property.Kind
if prop.kind == K.SINGLE:
prop.kind = K.REPEATED
elif prop.kind == K.OPTIONAL:
prop.kind = K.REPEATED_OPTIONAL
else:
raise _schema.Error(
"Repeated should only be applied to simple or optional property types")
class _TypeModifier:
""" Modifies types using get item notation """
def __init__(self, modifier: _schema.PropertyModifier):
self.modifier = modifier
def __getitem__(self, item):
return item | self.modifier
_ClassDecorator = _Callable[[type], type]
def _annotate(**kwargs) -> _ClassDecorator:
def f(cls: type) -> type:
for k, v in kwargs.items():
setattr(cls, f"_{k}", v)
return cls
return f
boolean = "boolean"
int = "int"
string = "string"
predicate = _schema.predicate_marker
optional = _TypeModifier(_Optionalizer())
list = _TypeModifier(_Listifier())
child = _ChildModifier()
doc = _DocModifier
desc = _DescModifier
use_for_null = _annotate(null=True)
_Pragma("qltest_skip")
_Pragma("qltest_collapse_hierarchy")
_Pragma("qltest_uncollapse_hierarchy")
ql.default_doc_name = lambda doc: _annotate(doc_name=doc)
_Pragma("ql_internal")
_Pragma("cpp_skip")
def group(name: str = "") -> _ClassDecorator:
return _annotate(group=name)
synth.from_class = lambda ref: _annotate(ipa=_schema.IpaInfo(
from_class=_schema.get_type_name(ref)))
synth.on_arguments = lambda **kwargs: _annotate(
ipa=_schema.IpaInfo(on_arguments={k: _schema.get_type_name(t) for k, t in kwargs.items()}))

View File

@@ -1,11 +0,0 @@
package(default_visibility = ["//swift:__subpackages__"])
filegroup(
name = "trap",
srcs = glob(["trap_*.mustache"]),
)
filegroup(
name = "cpp",
srcs = glob(["cpp_*.mustache"]),
)

View File

@@ -1,37 +0,0 @@
// generated by {{generator}} from {{source}}
// clang-format off
#include "./TrapClasses.h"
namespace codeql {
{{#classes}}
void {{name}}::emit({{^final}}TrapLabel<{{name}}Tag> id, {{/final}}std::ostream& out) const {
{{#trap_name}}
out << {{.}}Trap{id{{#single_fields}}, {{field_name}}{{/single_fields}}} << '\n';
{{/trap_name}}
{{#bases}}
{{ref.name}}::emit(id, out);
{{/bases}}
{{#fields}}
{{#is_predicate}}
if ({{field_name}}) out << {{trap_name}}Trap{id} << '\n';
{{/is_predicate}}
{{#is_optional}}
{{^is_repeated}}
if ({{field_name}}) out << {{trap_name}}Trap{id, *{{field_name}}} << '\n';
{{/is_repeated}}
{{/is_optional}}
{{#is_repeated}}
for (auto i = 0u; i < {{field_name}}.size(); ++i) {
{{^is_optional}}
out << {{trap_name}}Trap{id, i, {{field_name}}[i]} << '\n';
{{/is_optional}}
{{#is_optional}}
if ({{field_name}}[i]) out << {{trap_name}}Trap{id, i, *{{field_name}}[i]} << '\n';
{{/is_optional}}
}
{{/is_repeated}}
{{/fields}}
}
{{/classes}}
}

View File

@@ -1,82 +0,0 @@
// generated by {{generator}} from {{source}}
// clang-format off
#pragma once
#include <iostream>
#include <optional>
#include <vector>
#include "{{trap_library}}/TrapLabel.h"
#include "{{trap_library}}/TrapTagTraits.h"
#include "./TrapEntries.h"
{{#include_parent}}
#include "../TrapClasses.h"
{{/include_parent}}
namespace codeql {
{{#classes}}
struct {{name}}{{#has_bases}} : {{#bases}}{{^first}}, {{/first}}{{ref.name}}{{/bases}}{{/has_bases}} {
static constexpr const char* NAME = "{{name}}";
{{#final}}
explicit {{name}}(TrapLabel<{{name}}Tag> id) : id{id} {}
TrapLabel<{{name}}Tag> id{};
{{/final}}
{{#fields}}
{{type}} {{field_name}}{};
{{/fields}}
{{#final}}
friend std::ostream& operator<<(std::ostream& out, const {{name}}& x) {
x.emit(out);
return out;
}
{{/final}}
{{^final}}
protected:
{{/final}}
template <typename F>
void forEachLabel(F f) {
{{#final}}
f("id", -1, id);
{{/final}}
{{#bases}}
{{ref.name}}::forEachLabel(f);
{{/bases}}
{{#fields}}
{{#is_label}}
{{#is_repeated}}
for (auto i = 0u; i < {{field_name}}.size(); ++i) {
{{#is_optional}}
if ({{field_name}}[i]) f("{{field_name}}", i, *{{field_name}}[i]);
{{/is_optional}}
{{^is_optional}}
f("{{field_name}}", i, {{field_name}}[i]);
{{/is_optional}}
}
{{/is_repeated}}
{{^is_repeated}}
{{#is_optional}}
if ({{field_name}}) f("{{field_name}}", -1, *{{field_name}});
{{/is_optional}}
{{^is_optional}}
f("{{field_name}}", -1, {{field_name}});
{{/is_optional}}
{{/is_repeated}}
{{/is_label}}
{{/fields}}
}
protected:
void emit({{^final}}TrapLabel<{{name}}Tag> id, {{/final}}std::ostream& out) const;
};
template <>
struct detail::ToTrapClassFunctor<{{name}}Tag> {
using type = {{name}};
};
{{/classes}}
}

View File

@@ -1,25 +0,0 @@
// generated by {{generator}}
{{#includes}}
// from {{src}}
{{data}}
{{/includes}}
// from {{src}}
{{#declarations}}
{{#is_union}}
{{lhs}} =
{{#rhs}}
{{#first}} {{/first}}{{^first}}| {{/first}}{{type}}
{{/rhs}};
{{/is_union}}
{{#is_table}}
{{#keyset}}
#keyset[{{#ids}}{{^first}}, {{/first}}{{id}}{{/ids}}]
{{/keyset}}
{{name}}({{#dir}} //dir={{.}}{{/dir}}{{#columns}}{{^first}},{{/first}}
{{lhstype}} {{name}}: {{rhstype}}{{/columns}}
);
{{/is_table}}
{{/declarations}}

View File

@@ -1,138 +0,0 @@
// generated by {{generator}}
private import {{import_prefix}}.Synth
private import {{import_prefix}}.Raw
{{#imports}}
import {{.}}
{{/imports}}
module Generated {
{{#has_doc}}
/**
{{#ql_internal}}
* INTERNAL: Do not use.
{{/ql_internal}}
{{#doc}}
* {{.}}
{{/doc}}
*/
{{/has_doc}}
class {{name}} extends Synth::T{{name}}{{#bases}}, {{.}}{{/bases}} {
{{#root}}
/**
* Gets the string representation of this element.
*/
string toString() { none() } // overridden by subclasses
/**
* Gets the name of a primary CodeQL class to which this element belongs.
*
* This is the most precise syntactic category to which they belong; for
* example, `CallExpr` is a primary class, but `ApplyExpr` is not.
*
* There might be some corner cases when this returns multiple classes, or none.
*/
string getAPrimaryQlClass() { none() } // overridden by subclasses
/**
* Gets a comma-separated list of the names of the primary CodeQL classes to which this element belongs.
*/
final string getPrimaryQlClasses() { result = concat(this.getAPrimaryQlClass(), ",") }
/**
* Gets the most immediate element that should substitute this element in the explicit AST, if any.
* Classes can override this to indicate this node should be in the "hidden" AST, mostly reserved
* for conversions and syntactic sugar nodes like parentheses.
*/
{{name}} getResolveStep() { none() } // overridden by subclasses
/**
* Gets the element that should substitute this element in the explicit AST, applying `getResolveStep`
* transitively.
*/
final {{name}} resolve() {
not exists(getResolveStep()) and result = this
or
result = getResolveStep().resolve()
}
{{/root}}
{{#final}}
override string getAPrimaryQlClass() { result = "{{name}}" }
{{/final}}
{{#properties}}
{{#type_is_class}}
/**
* {{>ql_property_doc}} *
* This includes nodes from the "hidden" AST. It can be overridden in subclasses to change the
* behavior of both the `Immediate` and non-`Immediate` versions.
*/
{{type}} getImmediate{{singular}}({{#is_repeated}}int index{{/is_repeated}}) {
{{^ipa}}
result = Synth::convert{{type}}FromRaw(Synth::convert{{name}}ToRaw(this){{^root}}.(Raw::{{name}}){{/root}}.{{getter}}({{#is_repeated}}index{{/is_repeated}}))
{{/ipa}}
{{#ipa}}
none()
{{/ipa}}
}
/**
* {{>ql_property_doc}} *
{{#has_description}}
{{#description}}
* {{.}}
{{/description}}
{{/has_description}}
*/
final {{type}} {{getter}}({{#is_repeated}}int index{{/is_repeated}}) {
result = getImmediate{{singular}}({{#is_repeated}}index{{/is_repeated}}).resolve()
}
{{/type_is_class}}
{{^type_is_class}}
/**
* {{>ql_property_doc}} *
{{#has_description}}
{{#description}}
* {{.}}
{{/description}}
{{/has_description}}
*/
{{type}} {{getter}}({{#is_repeated}}int index{{/is_repeated}}) {
{{^ipa}}
{{^is_predicate}}result = {{/is_predicate}}Synth::convert{{name}}ToRaw(this){{^root}}.(Raw::{{name}}){{/root}}.{{getter}}({{#is_repeated}}index{{/is_repeated}})
{{/ipa}}
{{#ipa}}
none()
{{/ipa}}
}
{{/type_is_class}}
{{#is_optional}}
/**
* Holds if `{{getter}}({{#is_repeated}}index{{/is_repeated}})` exists.
*/
final predicate has{{singular}}({{#is_repeated}}int index{{/is_repeated}}) {
exists({{getter}}({{#is_repeated}}index{{/is_repeated}}))
}
{{/is_optional}}
{{#is_repeated}}
/**
* Gets any of the {{doc_plural}}.
*/
final {{type}} {{indefinite_getter}}() {
result = {{getter}}(_)
}
{{^is_optional}}
/**
* Gets the number of {{doc_plural}}.
*/
final int getNumberOf{{plural}}() {
result = count(int i | exists({{getter}}(i)))
}
{{/is_optional}}
{{/is_repeated}}
{{/properties}}
}
}

View File

@@ -1,15 +0,0 @@
module Raw {
{{#classes}}
class {{name}} extends {{db_id}}{{#bases}}, {{.}}{{/bases}} {
{{#root}}string toString() { none() }{{/root}}
{{#final}}override string toString() { result = "{{name}}" }{{/final}}
{{#properties}}
{{type}} {{getter}}({{#is_repeated}}int index{{/is_repeated}}) {
{{tablename}}({{#tableparams}}{{^first}}, {{/first}}{{param}}{{/tableparams}})
}
{{/properties}}
}
{{/classes}}
}

View File

@@ -1,4 +0,0 @@
// generated by {{generator}}
{{#imports}}
import {{.}}
{{/imports}}

View File

@@ -1,19 +0,0 @@
// generated by {{generator}}, remove this comment if you wish to edit this file
private import {{import_prefix}}.Raw
{{#cls}}
{{#is_db}}
{{#has_subtracted_ipa_types}}
private import {{import_prefix}}.PureSynthConstructors
{{/has_subtracted_ipa_types}}
{{/is_db}}
predicate construct{{name}}({{#params}}{{^first}}, {{/first}}{{type}} {{param}}{{/params}}) {
{{#is_db}}
{{#subtracted_ipa_types}}{{^first}} and {{/first}}not construct{{name}}(id){{/subtracted_ipa_types}}
{{^subtracted_ipa_types}}any(){{/subtracted_ipa_types}}
{{/is_db}}
{{^is_db}}
none()
{{/is_db}}
}
{{/cls}}

View File

@@ -1,62 +0,0 @@
private import {{import_prefix}}.SynthConstructors
private import {{import_prefix}}.Raw
cached module Synth {
cached newtype T{{root}} =
{{#final_classes}}
{{^first}}
or
{{/first}}
T{{name}}({{#params}}{{^first}}, {{/first}}{{type}} {{param}}{{/params}}){{#has_params}} { construct{{name}}({{#params}}{{^first}}, {{/first}}{{param}}{{/params}}) }{{/has_params}}
{{/final_classes}}
{{#non_final_classes}}
{{^root}}
class T{{name}} = {{#derived}}{{^first}} or {{/first}}T{{name}}{{/derived}};
{{/root}}
{{/non_final_classes}}
{{#final_classes}}
cached T{{name}} convert{{name}}FromRaw(Raw::Element e) {
{{^is_fresh_ipa}}
result = T{{name}}(e)
{{/is_fresh_ipa}}
{{#is_fresh_ipa}}
none()
{{/is_fresh_ipa}}
}
{{/final_classes}}
{{#non_final_classes}}
cached T{{name}} convert{{name}}FromRaw(Raw::Element e) {
{{#derived}}
{{^first}}
or
{{/first}}
result = convert{{name}}FromRaw(e)
{{/derived}}
}
{{/non_final_classes}}
{{#final_classes}}
cached Raw::Element convert{{name}}ToRaw(T{{name}} e) {
{{^is_fresh_ipa}}
e = T{{name}}(result)
{{/is_fresh_ipa}}
{{#is_fresh_ipa}}
none()
{{/is_fresh_ipa}}
}
{{/final_classes}}
{{#non_final_classes}}
cached Raw::Element convert{{name}}ToRaw(T{{name}} e) {
{{#derived}}
{{^first}}
or
{{/first}}
result = convert{{name}}ToRaw(e)
{{/derived}}
}
{{/non_final_classes}}
}

View File

@@ -1,91 +0,0 @@
// generated by {{generator}}
{{#imports}}
import {{.}}
{{/imports}}
private module Impl {
{{#classes}}
private Element getImmediateChildOf{{name}}({{name}} e, int index, string partialPredicateCall) {
{{! avoid unused argument warnings on root element, assuming the root element has no children }}
{{#root}}none(){{/root}}
{{^root}}
{{! b is the base offset 0, for ease of generation }}
{{! b<base> is constructed to be strictly greater than the indexes required for children coming from <base> }}
{{! n is the base offset for direct children, equal to the last base offset from above }}
{{! n<child> is constructed to be strictly greater than the indexes for <child> children }}
exists(int b{{#bases}}, int b{{.}}{{/bases}}, int n{{#properties}}{{#is_child}}, int n{{singular}}{{/is_child}}{{/properties}} |
b = 0
{{#bases}}
and
b{{.}} = b{{prev}} + 1 + max(int i | i = -1 or exists(getImmediateChildOf{{.}}(e, i, _)) | i)
{{/bases}}
and
n = b{{last_base}}
{{#properties}}
{{#is_child}}
{{! n<child> is defined on top of the previous definition }}
{{! for single and optional properties it adds 1 (regardless of whether the optional property exists) }}
{{! for repeated it adds 1 + the maximum index (which works for repeated optional as well) }}
and
n{{singular}} = n{{prev_child}} + 1{{#is_repeated}}+ max(int i | i = -1 or exists(e.getImmediate{{singular}}(i)) | i){{/is_repeated}}
{{/is_child}}
{{/properties}} and (
none()
{{#bases}}
or
result = getImmediateChildOf{{.}}(e, index - b{{prev}}, partialPredicateCall)
{{/bases}}
{{#properties}}
{{#is_child}}
or
{{#is_repeated}}
result = e.getImmediate{{singular}}(index - n{{prev_child}}) and partialPredicateCall = "{{singular}}(" + (index - n{{prev_child}}).toString() + ")"
{{/is_repeated}}
{{^is_repeated}}
index = n{{prev_child}} and result = e.getImmediate{{singular}}() and partialPredicateCall = "{{singular}}()"
{{/is_repeated}}
{{/is_child}}
{{/properties}}
))
{{/root}}
}
{{/classes}}
cached
Element getImmediateChild(Element e, int index, string partialAccessor) {
// why does this look more complicated than it should?
// * none() simplifies generation, as we can append `or ...` without a special case for the first item
none()
{{#classes}}
{{#final}}
or
result = getImmediateChildOf{{name}}(e, index, partialAccessor)
{{/final}}
{{/classes}}
}
}
/**
* Gets the "immediate" parent of `e`. "Immediate" means not taking into account node resolution: for example
* if `e` has conversions, `getImmediateParent(e)` will give the innermost conversion in the hidden AST.
*/
Element getImmediateParent(Element e) {
// `unique` is used here to tell the optimizer that there is in fact only one result
// this is tested by the `library-tests/parent/no_double_parents.ql` test
result = unique(Element x | e = Impl::getImmediateChild(x, _, _) | x)
}
/**
* Gets the immediate child indexed at `index`. Indexes are not guaranteed to be contiguous, but are guaranteed to be distinct. `accessor` is bound the member predicate call resulting in the given child.
*/
Element getImmediateChildAndAccessor(Element e, int index, string accessor) {
exists(string partialAccessor | result = Impl::getImmediateChild(e, index, partialAccessor) and accessor = "getImmediate" + partialAccessor)
}
/**
* Gets the child indexed at `index`. Indexes are not guaranteed to be contiguous, but are guaranteed to be distinct. `accessor` is bound the member predicate call resulting in the given child.
*/
Element getChildAndAccessor(Element e, int index, string accessor) {
exists(string partialAccessor | result = Impl::getImmediateChild(e, index, partialAccessor).resolve() and accessor = "get" + partialAccessor)
}

View File

@@ -1,6 +0,0 @@
{{^is_predicate}}
Gets the {{#is_repeated}}`index`th {{/is_repeated}}{{doc}}{{#is_repeated}} (0-based){{/is_repeated}}{{#is_optional}}, if it exists{{/is_optional}}.
{{/is_predicate}}
{{#is_predicate}}
Holds if {{doc}}.
{{/is_predicate}}

View File

@@ -1,18 +0,0 @@
// generated by {{generator}}, remove this comment if you wish to edit this file
private import {{base_import}}
{{#has_ipa_accessors}}
private import {{import_prefix}}.Raw
private import {{import_prefix}}.Synth
{{/has_ipa_accessors}}
{{#ql_internal}}
/**
* INTERNAL: Do not use.
*/
{{/ql_internal}}
class {{name}} extends Generated::{{name}} {
{{#ipa_accessors}}
private
cached {{type}} getUnderlying{{argument}}() { this = Synth::T{{name}}({{#constructorparams}}{{^first}},{{/first}}{{param}}{{/constructorparams}})}
{{/ipa_accessors}}
}

View File

@@ -1,16 +0,0 @@
// generated by {{generator}}
import {{elements_module}}
import TestUtils
from {{class_name}} x{{#properties}}, {{#type}}{{.}}{{/type}}{{^type}}string{{/type}} {{getter}}{{/properties}}
where toBeTested(x) and not x.isUnknown()
{{#properties}}
{{#type}}
and {{getter}} = x.{{getter}}()
{{/type}}
{{^type}}
and if x.{{getter}}() then {{getter}} = "yes" else {{getter}} = "no"
{{/type}}
{{/properties}}
select x{{#show_ql_class}}, x.getPrimaryQlClasses(){{/show_ql_class}}{{#properties}}, "{{getter}}:", {{getter}}{{/properties}}

View File

@@ -1,4 +0,0 @@
// generated by {{generator}}
After a source file is added in this directory and {{generator}} is run again, test queries
will appear and this file will be deleted

View File

@@ -1,10 +0,0 @@
// generated by {{generator}}
import {{elements_module}}
import TestUtils
{{#property}}
from {{class_name}} x{{#is_repeated}}, int index{{/is_repeated}}
where toBeTested(x) and not x.isUnknown()
select x, {{#is_repeated}}index, {{/is_repeated}}x.{{getter}}({{#is_repeated}}index{{/is_repeated}})
{{/property}}

View File

@@ -1,13 +0,0 @@
// generated by {{generator}} from {{source}}
// clang-format off
#pragma once
namespace codeql {
{{#tags}}
// {{id}}
struct {{name}}Tag {{#has_bases}}: {{#bases}}{{^first}}, {{/first}}{{base}}Tag{{/bases}} {{/has_bases}}{
static constexpr const char* prefix = "{{name}}";
};
{{/tags}}
}

View File

@@ -1,15 +0,0 @@
// generated by {{generator}} from {{source}}
// clang-format off
#include "./TrapEntries.h"
namespace codeql {
{{#traps}}
// {{table_name}}
std::ostream &operator<<(std::ostream &out, const {{name}}Trap &e) {
out << "{{table_name}}("{{#fields}}{{^first}} << ", "{{/first}}
<< {{#get_streamer}}e.{{field_name}}{{/get_streamer}}{{/fields}} << ")";
return out;
}
{{/traps}}
}

View File

@@ -1,45 +0,0 @@
// generated by {{generator}} from {{source}}
// clang-format off
#pragma once
#include <iostream>
#include <string>
#include "{{trap_library_dir}}/TrapLabel.h"
#include "{{trap_library_dir}}/TrapTagTraits.h"
#include "{{gen_dir}}/TrapTags.h"
namespace codeql {
{{#traps}}
// {{table_name}}
struct {{name}}Trap {
static constexpr const char* NAME = "{{name}}Trap";
{{#fields}}
{{type}} {{field_name}}{};
{{/fields}}
template <typename F>
void forEachLabel(F f) {
{{#fields}}
{{#is_label}}
f("{{field_name}}", -1, {{field_name}});
{{/is_label}}
{{/fields}}
}
};
std::ostream &operator<<(std::ostream &out, const {{name}}Trap &e);
{{#id}}
namespace detail {
template<>
struct ToBindingTrapFunctor<{{type}}> {
using type = {{name}}Trap;
};
}
{{/id}}
{{/traps}}
}

View File

@@ -1,28 +0,0 @@
load("@swift_codegen_deps//:requirements.bzl", "requirement")
py_library(
name = "utils",
testonly = True,
srcs = ["utils.py"],
deps = [
"//swift/codegen/lib",
requirement("pytest"),
],
)
[
py_test(
name = src[:-len(".py")],
size = "small",
srcs = [src],
deps = [
":utils",
"//swift/codegen/generators",
],
)
for src in glob(["test_*.py"])
]
test_suite(
name = "test",
)

View File

@@ -1,117 +0,0 @@
import sys
from copy import deepcopy
import pytest
from swift.codegen.lib import cpp
@pytest.mark.parametrize("keyword", cpp.cpp_keywords)
def test_field_keyword_name(keyword):
f = cpp.Field(keyword, "int")
assert f.field_name == keyword + "_"
def test_field_name():
f = cpp.Field("foo", "int")
assert f.field_name == "foo"
@pytest.mark.parametrize("type,expected", [
("std::string", "trapQuoted(value)"),
("bool", '(value ? "true" : "false")'),
("something_else", "value"),
])
def test_field_get_streamer(type, expected):
f = cpp.Field("name", type)
assert f.get_streamer()("value") == expected
@pytest.mark.parametrize("is_optional,is_repeated,is_predicate,expected", [
(False, False, False, True),
(True, False, False, False),
(False, True, False, False),
(True, True, False, False),
(False, False, True, False),
])
def test_field_is_single(is_optional, is_repeated, is_predicate, expected):
f = cpp.Field("name", "type", is_optional=is_optional, is_repeated=is_repeated, is_predicate=is_predicate)
assert f.is_single is expected
@pytest.mark.parametrize("is_optional,is_repeated,expected", [
(False, False, "bar"),
(True, False, "std::optional<bar>"),
(False, True, "std::vector<bar>"),
(True, True, "std::vector<std::optional<bar>>"),
])
def test_field_modal_types(is_optional, is_repeated, expected):
f = cpp.Field("name", "bar", is_optional=is_optional, is_repeated=is_repeated)
assert f.type == expected
def test_trap_has_first_field_marked():
fields = [
cpp.Field("a", "x"),
cpp.Field("b", "y"),
cpp.Field("c", "z"),
]
expected = deepcopy(fields)
expected[0].first = True
t = cpp.Trap("table_name", "name", fields)
assert t.fields == expected
def test_tag_has_first_base_marked():
bases = ["a", "b", "c"]
expected = [cpp.TagBase("a", first=True), cpp.TagBase("b"), cpp.TagBase("c")]
t = cpp.Tag("name", bases, "id")
assert t.bases == expected
@pytest.mark.parametrize("bases,expected", [
([], False),
(["a"], True),
(["a", "b"], True)
])
def test_tag_has_bases(bases, expected):
t = cpp.Tag("name", bases, "id")
assert t.has_bases is expected
def test_class_has_first_base_marked():
bases = [
cpp.Class("a"),
cpp.Class("b"),
cpp.Class("c"),
]
expected = [cpp.ClassBase(c) for c in bases]
expected[0].first = True
c = cpp.Class("foo", bases=bases)
assert c.bases == expected
@pytest.mark.parametrize("bases,expected", [
([], False),
(["a"], True),
(["a", "b"], True)
])
def test_class_has_bases(bases, expected):
t = cpp.Class("name", [cpp.Class(b) for b in bases])
assert t.has_bases is expected
def test_class_single_fields():
fields = [
cpp.Field("a", "A"),
cpp.Field("b", "B", is_optional=True),
cpp.Field("c", "C"),
cpp.Field("d", "D", is_repeated=True),
cpp.Field("e", "E"),
]
c = cpp.Class("foo", fields=fields)
assert c.single_fields == fields[::2]
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -1,206 +0,0 @@
import sys
from swift.codegen.generators import cppgen
from swift.codegen.lib import cpp
from swift.codegen.test.utils import *
output_dir = pathlib.Path("path", "to", "output")
@pytest.fixture
def generate_grouped(opts, renderer, input):
opts.cpp_output = output_dir
def ret(classes):
input.classes = {cls.name: cls for cls in classes}
generated = run_generation(cppgen.generate, opts, renderer)
for f, g in generated.items():
assert isinstance(g, cpp.ClassList), f
assert g.include_parent is (f.parent != output_dir)
assert f.name == "TrapClasses", f
return {str(f.parent.relative_to(output_dir)): g.classes for f, g in generated.items()}
return ret
@pytest.fixture
def generate(generate_grouped):
def ret(classes):
generated = generate_grouped(classes)
assert set(generated) == {"."}
return generated["."]
return ret
def test_empty(generate):
assert generate([]) == []
def test_empty_class(generate):
assert generate([
schema.Class(name="MyClass"),
]) == [
cpp.Class(name="MyClass", final=True, trap_name="MyClasses")
]
def test_two_class_hierarchy(generate):
base = cpp.Class(name="A")
assert generate([
schema.Class(name="A", derived={"B"}),
schema.Class(name="B", bases=["A"]),
]) == [
base,
cpp.Class(name="B", bases=[base], final=True, trap_name="Bs"),
]
@pytest.mark.parametrize("type,expected", [
("a", "a"),
("string", "std::string"),
("boolean", "bool"),
("MyClass", "TrapLabel<MyClassTag>"),
])
@pytest.mark.parametrize("property_cls,optional,repeated,trap_name", [
(schema.SingleProperty, False, False, None),
(schema.OptionalProperty, True, False, "MyClassProps"),
(schema.RepeatedProperty, False, True, "MyClassProps"),
(schema.RepeatedOptionalProperty, True, True, "MyClassProps"),
])
def test_class_with_field(generate, type, expected, property_cls, optional, repeated, trap_name):
assert generate([
schema.Class(name="MyClass", properties=[property_cls("prop", type)]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field("prop", expected, is_optional=optional,
is_repeated=repeated, trap_name=trap_name)],
trap_name="MyClasses",
final=True)
]
def test_class_field_with_null(generate, input):
input.null = "Null"
a = cpp.Class(name="A")
assert generate([
schema.Class(name="A", derived={"B"}),
schema.Class(name="B", bases=["A"], properties=[
schema.SingleProperty("x", "A"),
schema.SingleProperty("y", "B"),
])
]) == [
a,
cpp.Class(name="B", bases=[a], final=True, trap_name="Bs",
fields=[
cpp.Field("x", "TrapLabel<ATag>"),
cpp.Field("y", "TrapLabel<BOrNoneTag>"),
]),
]
def test_class_with_predicate(generate):
assert generate([
schema.Class(name="MyClass", properties=[
schema.PredicateProperty("prop")]),
]) == [
cpp.Class(name="MyClass",
fields=[
cpp.Field("prop", "bool", trap_name="MyClassProp", is_predicate=True)],
trap_name="MyClasses",
final=True)
]
@pytest.mark.parametrize("name",
["start_line", "start_column", "end_line", "end_column", "index", "num_whatever", "width"])
def test_class_with_overridden_unsigned_field(generate, name):
assert generate([
schema.Class(name="MyClass", properties=[
schema.SingleProperty(name, "bar")]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field(name, "unsigned")],
trap_name="MyClasses",
final=True)
]
def test_class_with_overridden_underscore_field(generate):
assert generate([
schema.Class(name="MyClass", properties=[
schema.SingleProperty("something_", "bar")]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field("something", "bar")],
trap_name="MyClasses",
final=True)
]
@pytest.mark.parametrize("name", cpp.cpp_keywords)
def test_class_with_keyword_field(generate, name):
assert generate([
schema.Class(name="MyClass", properties=[
schema.SingleProperty(name, "bar")]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field(name + "_", "bar")],
trap_name="MyClasses",
final=True)
]
def test_classes_with_dirs(generate_grouped):
cbase = cpp.Class(name="CBase")
assert generate_grouped([
schema.Class(name="A"),
schema.Class(name="B", group="foo"),
schema.Class(name="CBase", derived={"C"}, group="bar"),
schema.Class(name="C", bases=["CBase"], group="bar"),
schema.Class(name="D", group="foo/bar/baz"),
]) == {
".": [cpp.Class(name="A", trap_name="As", final=True)],
"foo": [cpp.Class(name="B", trap_name="Bs", final=True)],
"bar": [cbase, cpp.Class(name="C", bases=[cbase], trap_name="Cs", final=True)],
"foo/bar/baz": [cpp.Class(name="D", trap_name="Ds", final=True)],
}
def test_cpp_skip_pragma(generate):
assert generate([
schema.Class(name="A", properties=[
schema.SingleProperty("x", "foo"),
schema.SingleProperty("y", "bar", pragmas=["x", "cpp_skip", "y"]),
])
]) == [
cpp.Class(name="A", final=True, trap_name="As", fields=[
cpp.Field("x", "foo"),
]),
]
def test_ipa_classes_ignored(generate):
assert generate([
schema.Class(
name="W",
ipa=schema.IpaInfo(),
),
schema.Class(
name="X",
ipa=schema.IpaInfo(from_class="A"),
),
schema.Class(
name="Y",
ipa=schema.IpaInfo(on_arguments={"a": "A", "b": "int"}),
),
schema.Class(
name="Z",
),
]) == [
cpp.Class(name="Z", final=True, trap_name="Zs"),
]
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -1,52 +0,0 @@
import sys
from copy import deepcopy
from swift.codegen.lib import dbscheme
from swift.codegen.test.utils import *
def test_dbcolumn_name():
assert dbscheme.Column("foo", "some_type").name == "foo"
@pytest.mark.parametrize("keyword", dbscheme.dbscheme_keywords)
def test_dbcolumn_keyword_name(keyword):
assert dbscheme.Column(keyword, "some_type").name == keyword + "_"
@pytest.mark.parametrize("type,binding,lhstype,rhstype", [
("builtin_type", False, "builtin_type", "builtin_type ref"),
("builtin_type", True, "builtin_type", "builtin_type ref"),
("@at_type", False, "int", "@at_type ref"),
("@at_type", True, "unique int", "@at_type"),
])
def test_dbcolumn_types(type, binding, lhstype, rhstype):
col = dbscheme.Column("foo", type, binding)
assert col.lhstype == lhstype
assert col.rhstype == rhstype
def test_keyset_has_first_id_marked():
ids = ["a", "b", "c"]
ks = dbscheme.KeySet(ids)
assert ks.ids[0].first
assert [id.id for id in ks.ids] == ids
def test_table_has_first_column_marked():
columns = [dbscheme.Column("a", "x"), dbscheme.Column("b", "y", binding=True), dbscheme.Column("c", "z")]
expected = deepcopy(columns)
table = dbscheme.Table("foo", columns)
expected[0].first = True
assert table.columns == expected
def test_union_has_first_case_marked():
rhs = ["a", "b", "c"]
u = dbscheme.Union(lhs="x", rhs=rhs)
assert u.rhs[0].first
assert [c.type for c in u.rhs] == rhs
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -1,544 +0,0 @@
import collections
import sys
from swift.codegen.generators import dbschemegen
from swift.codegen.lib import dbscheme
from swift.codegen.test.utils import *
InputExpectedPair = collections.namedtuple("InputExpectedPair", ("input", "expected"))
@pytest.fixture(params=[
InputExpectedPair(None, None),
InputExpectedPair("foodir", pathlib.Path("foodir")),
])
def dir_param(request):
return request.param
@pytest.fixture
def generate(opts, input, renderer):
def func(classes, null=None):
input.classes = {cls.name: cls for cls in classes}
input.null = null
(out, data), = run_generation(dbschemegen.generate, opts, renderer).items()
assert out is opts.dbscheme
return data
return func
def test_empty(generate):
assert generate([]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[],
)
def test_includes(input, opts, generate):
includes = ["foo", "bar"]
input.includes = includes
for i in includes:
write(opts.schema.parent / i, i + " data")
assert generate([]) == dbscheme.Scheme(
src=schema_file.name,
includes=[
dbscheme.SchemeInclude(
src=pathlib.Path(i),
data=i + " data",
) for i in includes
],
declarations=[],
)
def test_empty_final_class(generate, dir_param):
assert generate([
schema.Class("Object", group=dir_param.input),
]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
],
dir=dir_param.expected,
)
],
)
def test_final_class_with_single_scalar_field(generate, dir_param):
assert generate([
schema.Class("Object", group=dir_param.input, properties=[
schema.SingleProperty("foo", "bar"),
]),
]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
dbscheme.Column('foo', 'bar'),
], dir=dir_param.expected,
)
],
)
def test_final_class_with_single_class_field(generate, dir_param):
assert generate([
schema.Class("Object", group=dir_param.input, properties=[
schema.SingleProperty("foo", "Bar"),
]),
]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
dbscheme.Column('foo', '@bar'),
], dir=dir_param.expected,
)
],
)
def test_final_class_with_optional_field(generate, dir_param):
assert generate([
schema.Class("Object", group=dir_param.input, properties=[
schema.OptionalProperty("foo", "bar"),
]),
]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
], dir=dir_param.expected,
),
dbscheme.Table(
name="object_foos",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@object'),
dbscheme.Column('foo', 'bar'),
], dir=dir_param.expected,
),
],
)
@pytest.mark.parametrize("property_cls", [schema.RepeatedProperty, schema.RepeatedOptionalProperty])
def test_final_class_with_repeated_field(generate, property_cls, dir_param):
assert generate([
schema.Class("Object", group=dir_param.input, properties=[
property_cls("foo", "bar"),
]),
]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
], dir=dir_param.expected,
),
dbscheme.Table(
name="object_foos",
keyset=dbscheme.KeySet(["id", "index"]),
columns=[
dbscheme.Column('id', '@object'),
dbscheme.Column('index', 'int'),
dbscheme.Column('foo', 'bar'),
], dir=dir_param.expected,
),
],
)
def test_final_class_with_predicate_field(generate, dir_param):
assert generate([
schema.Class("Object", group=dir_param.input, properties=[
schema.PredicateProperty("foo"),
]),
]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
], dir=dir_param.expected,
),
dbscheme.Table(
name="object_foo",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@object'),
], dir=dir_param.expected,
),
],
)
def test_final_class_with_more_fields(generate, dir_param):
assert generate([
schema.Class("Object", group=dir_param.input, properties=[
schema.SingleProperty("one", "x"),
schema.SingleProperty("two", "y"),
schema.OptionalProperty("three", "z"),
schema.RepeatedProperty("four", "u"),
schema.RepeatedOptionalProperty("five", "v"),
schema.PredicateProperty("six"),
]),
]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
dbscheme.Column('one', 'x'),
dbscheme.Column('two', 'y'),
], dir=dir_param.expected,
),
dbscheme.Table(
name="object_threes",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@object'),
dbscheme.Column('three', 'z'),
], dir=dir_param.expected,
),
dbscheme.Table(
name="object_fours",
keyset=dbscheme.KeySet(["id", "index"]),
columns=[
dbscheme.Column('id', '@object'),
dbscheme.Column('index', 'int'),
dbscheme.Column('four', 'u'),
], dir=dir_param.expected,
),
dbscheme.Table(
name="object_fives",
keyset=dbscheme.KeySet(["id", "index"]),
columns=[
dbscheme.Column('id', '@object'),
dbscheme.Column('index', 'int'),
dbscheme.Column('five', 'v'),
], dir=dir_param.expected,
),
dbscheme.Table(
name="object_six",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@object'),
], dir=dir_param.expected,
),
],
)
def test_empty_class_with_derived(generate):
assert generate([
schema.Class(name="Base", derived={"Left", "Right"}),
schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]),
]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Union(
lhs="@base",
rhs=["@left", "@right"],
),
dbscheme.Table(
name="lefts",
columns=[dbscheme.Column("id", "@left", binding=True)],
),
dbscheme.Table(
name="rights",
columns=[dbscheme.Column("id", "@right", binding=True)],
),
],
)
def test_class_with_derived_and_single_property(generate, dir_param):
assert generate([
schema.Class(
name="Base",
derived={"Left", "Right"},
group=dir_param.input,
properties=[
schema.SingleProperty("single", "Prop"),
]),
schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]),
]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Union(
lhs="@base",
rhs=["@left", "@right"],
),
dbscheme.Table(
name="bases",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@base'),
dbscheme.Column('single', '@prop'),
],
dir=dir_param.expected,
),
dbscheme.Table(
name="lefts",
columns=[dbscheme.Column("id", "@left", binding=True)],
),
dbscheme.Table(
name="rights",
columns=[dbscheme.Column("id", "@right", binding=True)],
),
],
)
def test_class_with_derived_and_optional_property(generate, dir_param):
assert generate([
schema.Class(
name="Base",
derived={"Left", "Right"},
group=dir_param.input,
properties=[
schema.OptionalProperty("opt", "Prop"),
]),
schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]),
]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Union(
lhs="@base",
rhs=["@left", "@right"],
),
dbscheme.Table(
name="base_opts",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@base'),
dbscheme.Column('opt', '@prop'),
],
dir=dir_param.expected,
),
dbscheme.Table(
name="lefts",
columns=[dbscheme.Column("id", "@left", binding=True)],
),
dbscheme.Table(
name="rights",
columns=[dbscheme.Column("id", "@right", binding=True)],
),
],
)
def test_class_with_derived_and_repeated_property(generate, dir_param):
assert generate([
schema.Class(
name="Base",
group=dir_param.input,
derived={"Left", "Right"},
properties=[
schema.RepeatedProperty("rep", "Prop"),
]),
schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]),
]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Union(
lhs="@base",
rhs=["@left", "@right"],
),
dbscheme.Table(
name="base_reps",
keyset=dbscheme.KeySet(["id", "index"]),
columns=[
dbscheme.Column('id', '@base'),
dbscheme.Column('index', 'int'),
dbscheme.Column('rep', '@prop'),
],
dir=dir_param.expected,
),
dbscheme.Table(
name="lefts",
columns=[dbscheme.Column("id", "@left", binding=True)],
),
dbscheme.Table(
name="rights",
columns=[dbscheme.Column("id", "@right", binding=True)],
),
],
)
def test_null_class(generate):
assert generate([
schema.Class(
name="Base",
derived={"W", "X", "Y", "Z", "Null"},
),
schema.Class(
name="W",
bases=["Base"],
properties=[
schema.SingleProperty("w", "W"),
schema.SingleProperty("x", "X"),
schema.OptionalProperty("y", "Y"),
schema.RepeatedProperty("z", "Z"),
]
),
schema.Class(
name="X",
bases=["Base"],
),
schema.Class(
name="Y",
bases=["Base"],
),
schema.Class(
name="Z",
bases=["Base"],
),
schema.Class(
name="Null",
bases=["Base"],
),
], null="Null") == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Union(
lhs="@base",
rhs=["@null", "@w", "@x", "@y", "@z"],
),
dbscheme.Table(
name="ws",
columns=[
dbscheme.Column('id', '@w', binding=True),
dbscheme.Column('w', '@w_or_none'),
dbscheme.Column('x', '@x_or_none'),
],
),
dbscheme.Table(
name="w_ies",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@w'),
dbscheme.Column('y', '@y_or_none'),
],
),
dbscheme.Table(
name="w_zs",
keyset=dbscheme.KeySet(["id", "index"]),
columns=[
dbscheme.Column('id', '@w'),
dbscheme.Column('index', 'int'),
dbscheme.Column('z', '@z_or_none'),
],
),
dbscheme.Table(
name="xes",
columns=[
dbscheme.Column('id', '@x', binding=True),
],
),
dbscheme.Table(
name="ys",
columns=[
dbscheme.Column('id', '@y', binding=True),
],
),
dbscheme.Table(
name="zs",
columns=[
dbscheme.Column('id', '@z', binding=True),
],
),
dbscheme.Table(
name="nulls",
columns=[
dbscheme.Column('id', '@null', binding=True),
],
),
dbscheme.Union(
lhs="@w_or_none",
rhs=["@w", "@null"],
),
dbscheme.Union(
lhs="@x_or_none",
rhs=["@x", "@null"],
),
dbscheme.Union(
lhs="@y_or_none",
rhs=["@y", "@null"],
),
dbscheme.Union(
lhs="@z_or_none",
rhs=["@z", "@null"],
),
],
)
def test_ipa_classes_ignored(generate):
assert generate([
schema.Class(name="A", ipa=schema.IpaInfo()),
schema.Class(name="B", ipa=schema.IpaInfo(from_class="A")),
schema.Class(name="C", ipa=schema.IpaInfo(on_arguments={"x": "A"})),
]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[],
)
def test_ipa_derived_classes_ignored(generate):
assert generate([
schema.Class(name="A", derived={"B", "C"}),
schema.Class(name="B", bases=["A"], ipa=schema.IpaInfo()),
schema.Class(name="C", bases=["A"]),
]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
declarations=[
dbscheme.Union("@a", ["@c"]),
dbscheme.Table(
name="cs",
columns=[
dbscheme.Column("id", "@c", binding=True),
],
)
],
)
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -1,123 +0,0 @@
import sys
from copy import deepcopy
from swift.codegen.lib import dbscheme
from swift.codegen.loaders.dbschemeloader import iterload
from swift.codegen.test.utils import *
@pytest.fixture
def load(tmp_path):
file = tmp_path / "test.dbscheme"
def ret(yml):
write(file, yml)
return list(iterload(file))
return ret
def test_load_empty(load):
assert load("") == []
def test_load_one_empty_table(load):
assert load("""
test_foos();
""") == [
dbscheme.Table(name="test_foos", columns=[])
]
def test_load_table_with_keyset(load):
assert load("""
#keyset[x, y,z]
test_foos();
""") == [
dbscheme.Table(name="test_foos", columns=[], keyset=dbscheme.KeySet(["x", "y", "z"]))
]
expected_columns = [
("int foo: int ref", dbscheme.Column(schema_name="foo", type="int", binding=False)),
(" int bar : int ref", dbscheme.Column(schema_name="bar", type="int", binding=False)),
("str baz_: str ref", dbscheme.Column(schema_name="baz", type="str", binding=False)),
("int x: @foo ref", dbscheme.Column(schema_name="x", type="@foo", binding=False)),
("int y: @foo", dbscheme.Column(schema_name="y", type="@foo", binding=True)),
("unique int z: @foo", dbscheme.Column(schema_name="z", type="@foo", binding=True)),
]
@pytest.mark.parametrize("column,expected", expected_columns)
def test_load_table_with_column(load, column, expected):
assert load(f"""
foos(
{column}
);
""") == [
dbscheme.Table(name="foos", columns=[deepcopy(expected)])
]
def test_load_table_with_multiple_columns(load):
columns = ",\n".join(c for c, _ in expected_columns)
expected = [deepcopy(e) for _, e in expected_columns]
assert load(f"""
foos(
{columns}
);
""") == [
dbscheme.Table(name="foos", columns=expected)
]
def test_load_table_with_multiple_columns_and_dir(load):
columns = ",\n".join(c for c, _ in expected_columns)
expected = [deepcopy(e) for _, e in expected_columns]
assert load(f"""
foos( //dir=foo/bar/baz
{columns}
);
""") == [
dbscheme.Table(name="foos", columns=expected, dir=pathlib.Path("foo/bar/baz"))
]
def test_load_multiple_table_with_columns(load):
tables = [f"table{i}({col});" for i, (col, _) in enumerate(expected_columns)]
expected = [dbscheme.Table(name=f"table{i}", columns=[deepcopy(e)]) for i, (_, e) in enumerate(expected_columns)]
assert load("\n".join(tables)) == expected
def test_union(load):
assert load("@foo = @bar | @baz | @bla;") == [
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
def test_table_and_union(load):
assert load("""
foos();
@foo = @bar | @baz | @bla;""") == [
dbscheme.Table(name="foos", columns=[]),
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
def test_comments_ignored(load):
assert load("""
// fake_table();
foos(/* x */unique /*y*/int/*
z
*/ id/* */: /* * */ @bar/*,
int ignored: int ref*/);
@foo = @bar | @baz | @bla; // | @xxx""") == [
dbscheme.Table(name="foos", columns=[dbscheme.Column(schema_name="id", type="@bar", binding=True)]),
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -1,160 +0,0 @@
import sys
from copy import deepcopy
from swift.codegen.lib import ql
from swift.codegen.test.utils import *
def test_property_has_first_table_param_marked():
tableparams = ["a", "b", "c"]
prop = ql.Property("Prop", "foo", "props", tableparams)
assert prop.tableparams[0].first
assert [p.param for p in prop.tableparams] == tableparams
@pytest.mark.parametrize("type,expected", [
("Foo", True),
("Bar", True),
("foo", False),
("bar", False),
(None, False),
])
def test_property_is_a_class(type, expected):
tableparams = ["a", "result", "b"]
expected_tableparams = ["a", "result" if expected else "result", "b"]
prop = ql.Property("Prop", type, tableparams=tableparams)
assert prop.type_is_class is expected
assert [p.param for p in prop.tableparams] == expected_tableparams
@pytest.mark.parametrize("name,expected_getter", [
("Argument", "getAnArgument"),
("Element", "getAnElement"),
("Integer", "getAnInteger"),
("Operator", "getAnOperator"),
("Unit", "getAUnit"),
("Whatever", "getAWhatever"),
])
def test_property_indefinite_article(name, expected_getter):
prop = ql.Property(name, plural="X")
assert prop.indefinite_getter == expected_getter
@pytest.mark.parametrize("plural,expected", [
(None, False),
("", False),
("X", True),
])
def test_property_is_repeated(plural, expected):
prop = ql.Property("foo", "Foo", "props", ["result"], plural=plural)
assert prop.is_repeated is expected
@pytest.mark.parametrize("is_optional,is_predicate,plural,expected", [
(False, False, None, True),
(False, False, "", True),
(False, False, "X", False),
(True, False, None, False),
(False, True, None, False),
])
def test_property_is_single(is_optional, is_predicate, plural, expected):
prop = ql.Property("foo", "Foo", "props", ["result"], plural=plural,
is_predicate=is_predicate, is_optional=is_optional)
assert prop.is_single is expected
def test_property_no_plural_no_indefinite_getter():
prop = ql.Property("Prop", "Foo", "props", ["result"])
assert prop.indefinite_getter is None
def test_property_getter():
prop = ql.Property("Prop", "Foo")
assert prop.getter == "getProp"
def test_property_predicate_getter():
prop = ql.Property("prop", is_predicate=True)
assert prop.getter == "prop"
def test_class_processes_bases():
bases = ["B", "Ab", "C", "Aa"]
expected = [ql.Base("B"), ql.Base("Ab", prev="B"), ql.Base("C", prev="Ab"), ql.Base("Aa", prev="C")]
cls = ql.Class("Foo", bases=bases)
assert cls.bases == expected
def test_class_has_first_property_marked():
props = [
ql.Property(f"Prop{x}", f"Foo{x}", f"props{x}", [f"{x}"]) for x in range(4)
]
expected = deepcopy(props)
expected[0].first = True
cls = ql.Class("Class", properties=props)
assert cls.properties == expected
def test_root_class():
cls = ql.Class("Class")
assert cls.root
def test_non_root_class():
cls = ql.Class("Class", bases=["A"])
assert not cls.root
@pytest.mark.parametrize("prev_child,is_child", [(None, False), ("", True), ("x", True)])
def test_is_child(prev_child, is_child):
p = ql.Property("Foo", "int", prev_child=prev_child)
assert p.is_child is is_child
def test_empty_class_no_children():
cls = ql.Class("Class", properties=[])
assert cls.has_children is False
def test_class_no_children():
cls = ql.Class("Class", properties=[ql.Property("Foo", "int"), ql.Property("Bar", "string")])
assert cls.has_children is False
def test_class_with_children():
cls = ql.Class("Class", properties=[ql.Property("Foo", "int"), ql.Property("Child", "x", prev_child=""),
ql.Property("Bar", "string")])
assert cls.has_children is True
@pytest.mark.parametrize("doc,ql_internal,expected",
[
(["foo", "bar"], False, True),
(["foo", "bar"], True, True),
([], False, False),
([], True, True),
])
def test_has_doc(doc, ql_internal, expected):
cls = ql.Class("Class", doc=doc, ql_internal=ql_internal)
assert cls.has_doc is expected
def test_property_with_description():
prop = ql.Property("X", "int", description=["foo", "bar"])
assert prop.has_description is True
def test_class_without_description():
prop = ql.Property("X", "int")
assert prop.has_description is False
def test_ipa_accessor_has_first_constructor_param_marked():
params = ["a", "b", "c"]
x = ql.IpaUnderlyingAccessor("foo", "bar", params)
assert x.constructorparams[0].first
assert [p.param for p in x.constructorparams] == params
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -1,865 +0,0 @@
import pathlib
import subprocess
import sys
import pytest
from swift.codegen.generators import qlgen
from swift.codegen.lib import ql
from swift.codegen.test.utils import *
@pytest.fixture(autouse=True)
def run_mock():
with mock.patch("subprocess.run") as ret:
ret.return_value.returncode = 0
yield ret
# these are lambdas so that they will use patched paths when called
def stub_path(): return paths.root_dir / "ql/lib/stub/path"
def ql_output_path(): return paths.root_dir / "ql/lib/other/path"
def ql_test_output_path(): return paths.root_dir / "ql/test/path"
def generated_registry_path(): return paths.root_dir / "registry.list"
def import_file(): return stub_path().with_suffix(".qll")
def children_file(): return ql_output_path() / "ParentChild.qll"
stub_import = "stub.path"
stub_import_prefix = stub_import + "."
root_import = stub_import_prefix + "Element"
gen_import = "other.path"
gen_import_prefix = gen_import + "."
@pytest.fixture
def qlgen_opts(opts):
opts.ql_stub_output = stub_path()
opts.ql_output = ql_output_path()
opts.ql_test_output = ql_test_output_path()
opts.generated_registry = generated_registry_path()
opts.ql_format = True
opts.root_dir = paths.root_dir
opts.force = False
return opts
@pytest.fixture
def generate(input, qlgen_opts, renderer, render_manager):
render_manager.written = []
def func(classes):
input.classes = {cls.name: cls for cls in classes}
return run_managed_generation(qlgen.generate, qlgen_opts, renderer, render_manager)
return func
@pytest.fixture
def generate_import_list(generate):
def func(classes):
ret = generate(classes)
assert import_file() in ret
return ret[import_file()]
return func
@pytest.fixture
def generate_children_implementations(generate):
def func(classes):
ret = generate(classes)
assert children_file() in ret
return ret[children_file()]
return func
def _filter_generated_classes(ret, output_test_files=False):
files = {x for x in ret}
print(files)
files.remove(import_file())
files.remove(children_file())
stub_files = set()
base_files = set()
test_files = set()
for f in files:
try:
stub_files.add(f.relative_to(stub_path()))
print(f)
except ValueError:
try:
base_files.add(f.relative_to(ql_output_path()))
except ValueError:
try:
test_files.add(f.relative_to(ql_test_output_path()))
except ValueError:
assert False, f"{f} is in wrong directory"
if output_test_files:
return {
str(f): ret[ql_test_output_path() / f]
for f in test_files
}
base_files -= {pathlib.Path(f"{name}.qll") for name in
("Raw", "Synth", "SynthConstructors", "PureSynthConstructors")}
assert base_files <= stub_files
return {
str(f): (ret[stub_path() / f], ret[ql_output_path() / f])
for f in base_files
}
@pytest.fixture
def generate_classes(generate):
def func(classes):
return _filter_generated_classes(generate(classes))
return func
@pytest.fixture
def generate_tests(generate):
def func(classes):
return _filter_generated_classes(generate(classes), output_test_files=True)
return func
def a_ql_class(**kwargs):
return ql.Class(**kwargs, import_prefix=gen_import)
def a_ql_stub(**kwargs):
return ql.Stub(**kwargs, import_prefix=gen_import)
def test_one_empty_class(generate_classes):
assert generate_classes([
schema.Class("A")
]) == {
"A.qll": (a_ql_stub(name="A", base_import=gen_import_prefix + "A"),
a_ql_class(name="A", final=True)),
}
def test_hierarchy(generate_classes):
assert generate_classes([
schema.Class("D", bases=["B", "C"]),
schema.Class("C", bases=["A"], derived={"D"}),
schema.Class("B", bases=["A"], derived={"D"}),
schema.Class("A", derived={"B", "C"}),
]) == {
"A.qll": (a_ql_stub(name="A", base_import=gen_import_prefix + "A"),
a_ql_class(name="A")),
"B.qll": (a_ql_stub(name="B", base_import=gen_import_prefix + "B"),
a_ql_class(name="B", bases=["A"], imports=[stub_import_prefix + "A"])),
"C.qll": (a_ql_stub(name="C", base_import=gen_import_prefix + "C"),
a_ql_class(name="C", bases=["A"], imports=[stub_import_prefix + "A"])),
"D.qll": (a_ql_stub(name="D", base_import=gen_import_prefix + "D"),
a_ql_class(name="D", final=True, bases=["B", "C"],
imports=[stub_import_prefix + cls for cls in "BC"])),
}
def test_hierarchy_imports(generate_import_list):
assert generate_import_list([
schema.Class("D", bases=["B", "C"]),
schema.Class("C", bases=["A"], derived={"D"}),
schema.Class("B", bases=["A"], derived={"D"}),
schema.Class("A", derived={"B", "C"}),
]) == ql.ImportList([stub_import_prefix + cls for cls in "ABCD"])
def test_internal_not_in_import_list(generate_import_list):
assert generate_import_list([
schema.Class("D", bases=["B", "C"]),
schema.Class("C", bases=["A"], derived={"D"}, pragmas=["ql_internal"]),
schema.Class("B", bases=["A"], derived={"D"}),
schema.Class("A", derived={"B", "C"}, pragmas=["ql_internal"]),
]) == ql.ImportList([stub_import_prefix + cls for cls in "BD"])
def test_hierarchy_children(generate_children_implementations):
assert generate_children_implementations([
schema.Class("A", derived={"B", "C"}, pragmas=["ql_internal"]),
schema.Class("B", bases=["A"], derived={"D"}),
schema.Class("C", bases=["A"], derived={"D"}, pragmas=["ql_internal"]),
schema.Class("D", bases=["B", "C"]),
]) == ql.GetParentImplementation(
classes=[a_ql_class(name="A", ql_internal=True),
a_ql_class(name="B", bases=["A"], imports=[
stub_import_prefix + "A"]),
a_ql_class(name="C", bases=["A"], imports=[
stub_import_prefix + "A"], ql_internal=True),
a_ql_class(name="D", final=True, bases=["B", "C"],
imports=[stub_import_prefix + cls for cls in "BC"]),
],
imports=[stub_import] + [stub_import_prefix + cls for cls in "AC"],
)
def test_single_property(generate_classes):
assert generate_classes([
schema.Class("MyObject", properties=[
schema.SingleProperty("foo", "bar")]),
]) == {
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(name="MyObject", final=True,
properties=[
ql.Property(singular="Foo", type="bar", tablename="my_objects",
tableparams=["this", "result"], doc="foo of this my object"),
])),
}
def test_children(generate_classes):
assert generate_classes([
schema.Class("FakeRoot"),
schema.Class("MyObject", properties=[
schema.SingleProperty("a", "int"),
schema.SingleProperty("child_1", "int", is_child=True),
schema.RepeatedProperty("bs", "int"),
schema.RepeatedProperty("children", "int", is_child=True),
schema.OptionalProperty("c", "int"),
schema.OptionalProperty("child_3", "int", is_child=True),
schema.RepeatedOptionalProperty("d", "int"),
schema.RepeatedOptionalProperty("child_4", "int", is_child=True),
]),
]) == {
"FakeRoot.qll": (a_ql_stub(name="FakeRoot", base_import=gen_import_prefix + "FakeRoot"),
a_ql_class(name="FakeRoot", final=True)),
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(name="MyObject", final=True,
properties=[
ql.Property(singular="A", type="int", tablename="my_objects",
tableparams=["this", "result", "_"],
doc="a of this my object"),
ql.Property(singular="Child1", type="int", tablename="my_objects",
tableparams=["this", "_", "result"], prev_child="",
doc="child 1 of this my object"),
ql.Property(singular="B", plural="Bs", type="int",
tablename="my_object_bs",
tableparams=["this", "index", "result"],
doc="b of this my object",
doc_plural="bs of this my object"),
ql.Property(singular="Child", plural="Children", type="int",
tablename="my_object_children",
tableparams=["this", "index", "result"], prev_child="Child1",
doc="child of this my object",
doc_plural="children of this my object"),
ql.Property(singular="C", type="int", tablename="my_object_cs",
tableparams=["this", "result"], is_optional=True,
doc="c of this my object"),
ql.Property(singular="Child3", type="int",
tablename="my_object_child_3s",
tableparams=["this", "result"], is_optional=True,
prev_child="Child", doc="child 3 of this my object"),
ql.Property(singular="D", plural="Ds", type="int",
tablename="my_object_ds",
tableparams=["this", "index", "result"], is_optional=True,
doc="d of this my object",
doc_plural="ds of this my object"),
ql.Property(singular="Child4", plural="Child4s", type="int",
tablename="my_object_child_4s",
tableparams=["this", "index", "result"], is_optional=True,
prev_child="Child3", doc="child 4 of this my object",
doc_plural="child 4s of this my object"),
])),
}
def test_single_properties(generate_classes):
assert generate_classes([
schema.Class("MyObject", properties=[
schema.SingleProperty("one", "x"),
schema.SingleProperty("two", "y"),
schema.SingleProperty("three", "z"),
]),
]) == {
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(name="MyObject", final=True,
properties=[
ql.Property(singular="One", type="x", tablename="my_objects",
tableparams=["this", "result", "_", "_"],
doc="one of this my object"),
ql.Property(singular="Two", type="y", tablename="my_objects",
tableparams=["this", "_", "result", "_"],
doc="two of this my object"),
ql.Property(singular="Three", type="z", tablename="my_objects",
tableparams=["this", "_", "_", "result"],
doc="three of this my object"),
])),
}
@pytest.mark.parametrize("is_child,prev_child", [(False, None), (True, "")])
def test_optional_property(generate_classes, is_child, prev_child):
assert generate_classes([
schema.Class("FakeRoot"),
schema.Class("MyObject", properties=[
schema.OptionalProperty("foo", "bar", is_child=is_child)]),
]) == {
"FakeRoot.qll": (a_ql_stub(name="FakeRoot", base_import=gen_import_prefix + "FakeRoot"),
a_ql_class(name="FakeRoot", final=True)),
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(name="MyObject", final=True, properties=[
ql.Property(singular="Foo", type="bar", tablename="my_object_foos",
tableparams=["this", "result"],
is_optional=True, prev_child=prev_child, doc="foo of this my object"),
])),
}
@pytest.mark.parametrize("is_child,prev_child", [(False, None), (True, "")])
def test_repeated_property(generate_classes, is_child, prev_child):
assert generate_classes([
schema.Class("FakeRoot"),
schema.Class("MyObject", properties=[
schema.RepeatedProperty("foo", "bar", is_child=is_child)]),
]) == {
"FakeRoot.qll": (a_ql_stub(name="FakeRoot", base_import=gen_import_prefix + "FakeRoot"),
a_ql_class(name="FakeRoot", final=True)),
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(name="MyObject", final=True, properties=[
ql.Property(singular="Foo", plural="Foos", type="bar", tablename="my_object_foos",
tableparams=["this", "index", "result"], prev_child=prev_child,
doc="foo of this my object", doc_plural="foos of this my object"),
])),
}
@pytest.mark.parametrize("is_child,prev_child", [(False, None), (True, "")])
def test_repeated_optional_property(generate_classes, is_child, prev_child):
assert generate_classes([
schema.Class("FakeRoot"),
schema.Class("MyObject", properties=[
schema.RepeatedOptionalProperty("foo", "bar", is_child=is_child)]),
]) == {
"FakeRoot.qll": (a_ql_stub(name="FakeRoot", base_import=gen_import_prefix + "FakeRoot"),
a_ql_class(name="FakeRoot", final=True)),
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(name="MyObject", final=True, properties=[
ql.Property(singular="Foo", plural="Foos", type="bar", tablename="my_object_foos",
tableparams=["this", "index", "result"], is_optional=True,
prev_child=prev_child, doc="foo of this my object",
doc_plural="foos of this my object"),
])),
}
def test_predicate_property(generate_classes):
assert generate_classes([
schema.Class("MyObject", properties=[
schema.PredicateProperty("is_foo")]),
]) == {
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(name="MyObject", final=True, properties=[
ql.Property(singular="isFoo", type="predicate", tablename="my_object_is_foo",
tableparams=["this"], is_predicate=True, doc="this my object is foo"),
])),
}
@pytest.mark.parametrize("is_child,prev_child", [(False, None), (True, "")])
def test_single_class_property(generate_classes, is_child, prev_child):
assert generate_classes([
schema.Class("Bar"),
schema.Class("MyObject", properties=[
schema.SingleProperty("foo", "Bar", is_child=is_child)]),
]) == {
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(
name="MyObject", final=True, imports=[stub_import_prefix + "Bar"], properties=[
ql.Property(singular="Foo", type="Bar", tablename="my_objects",
tableparams=[
"this", "result"],
prev_child=prev_child, doc="foo of this my object"),
],
)),
"Bar.qll": (a_ql_stub(name="Bar", base_import=gen_import_prefix + "Bar"),
a_ql_class(name="Bar", final=True)),
}
def test_class_with_doc(generate_classes):
doc = ["Very important class.", "Very."]
assert generate_classes([
schema.Class("A", doc=doc),
]) == {
"A.qll": (a_ql_stub(name="A", base_import=gen_import_prefix + "A"),
a_ql_class(name="A", final=True, doc=doc)),
}
def test_class_dir(generate_classes):
dir = "another/rel/path"
assert generate_classes([
schema.Class("A", derived={"B"}, group=dir),
schema.Class("B", bases=["A"]),
]) == {
f"{dir}/A.qll": (a_ql_stub(name="A", base_import=gen_import_prefix + "another.rel.path.A"),
a_ql_class(name="A", dir=pathlib.Path(dir))),
"B.qll": (a_ql_stub(name="B", base_import=gen_import_prefix + "B"),
a_ql_class(name="B", final=True, bases=["A"],
imports=[stub_import_prefix + "another.rel.path.A"])),
}
def test_root_element_cannot_have_children(generate_classes):
with pytest.raises(qlgen.RootElementHasChildren):
generate_classes([
schema.Class('A', properties=[schema.SingleProperty("x", is_child=True)])
])
def test_class_dir_imports(generate_import_list):
dir = "another/rel/path"
assert generate_import_list([
schema.Class("A", derived={"B"}, group=dir),
schema.Class("B", bases=["A"]),
]) == ql.ImportList([
stub_import_prefix + "B",
stub_import_prefix + "another.rel.path.A",
])
def test_format(opts, generate, render_manager, run_mock):
opts.codeql_binary = "my_fake_codeql"
run_mock.return_value.stderr = "some\nlines\n"
render_manager.written = [
pathlib.Path("x", "foo.ql"),
pathlib.Path("bar.qll"),
pathlib.Path("y", "baz.txt"),
]
generate([schema.Class('A')])
assert run_mock.mock_calls == [
mock.call(["my_fake_codeql", "query", "format", "--in-place", "--", "x/foo.ql", "bar.qll"],
stderr=subprocess.PIPE, text=True),
]
def test_format_error(opts, generate, render_manager, run_mock):
opts.codeql_binary = "my_fake_codeql"
run_mock.return_value.stderr = "some\nlines\n"
run_mock.return_value.returncode = 1
render_manager.written = [
pathlib.Path("x", "foo.ql"),
pathlib.Path("bar.qll"),
pathlib.Path("y", "baz.txt"),
]
with pytest.raises(qlgen.FormatError):
generate([schema.Class('A')])
@pytest.mark.parametrize("force", [False, True])
def test_manage_parameters(opts, generate, renderer, force):
opts.force = force
ql_a = opts.ql_output / "A.qll"
ql_b = opts.ql_output / "B.qll"
stub_a = opts.ql_stub_output / "A.qll"
stub_b = opts.ql_stub_output / "B.qll"
test_a = opts.ql_test_output / "A.ql"
test_b = opts.ql_test_output / "MISSING_SOURCE.txt"
test_c = opts.ql_test_output / "B.txt"
write(ql_a)
write(ql_b)
write(stub_a)
write(stub_b)
write(test_a)
write(test_b)
write(test_c)
generate([schema.Class('A')])
assert renderer.mock_calls == [
mock.call.manage(generated={ql_a, ql_b, test_a, test_b, import_file()}, stubs={stub_a, stub_b},
registry=opts.generated_registry, force=force)
]
def test_modified_stub_skipped(qlgen_opts, generate, render_manager):
stub = qlgen_opts.ql_stub_output / "A.qll"
render_manager.is_customized_stub.side_effect = lambda f: f == stub
assert stub not in generate([schema.Class('A')])
def test_test_missing_source(generate_tests):
generate_tests([
schema.Class("A"),
]) == {
"A/MISSING_SOURCE.txt": ql.MissingTestInstructions(),
}
def a_ql_class_tester(**kwargs):
return ql.ClassTester(**kwargs, elements_module=stub_import)
def a_ql_property_tester(**kwargs):
return ql.PropertyTester(**kwargs, elements_module=stub_import)
def test_test_source_present(opts, generate_tests):
write(opts.ql_test_output / "A" / "test.swift")
assert generate_tests([
schema.Class("A"),
]) == {
"A/A.ql": a_ql_class_tester(class_name="A"),
}
def test_test_source_present_with_dir(opts, generate_tests):
write(opts.ql_test_output / "foo" / "A" / "test.swift")
assert generate_tests([
schema.Class("A", group="foo"),
]) == {
"foo/A/A.ql": a_ql_class_tester(class_name="A"),
}
def test_test_total_properties(opts, generate_tests):
write(opts.ql_test_output / "B" / "test.swift")
assert generate_tests([
schema.Class("A", derived={"B"}, properties=[
schema.SingleProperty("x", "string"),
]),
schema.Class("B", bases=["A"], properties=[
schema.PredicateProperty("y", "int"),
]),
]) == {
"B/B.ql": a_ql_class_tester(class_name="B", properties=[
ql.PropertyForTest(getter="getX", type="string"),
ql.PropertyForTest(getter="y"),
])
}
def test_test_partial_properties(opts, generate_tests):
write(opts.ql_test_output / "B" / "test.swift")
assert generate_tests([
schema.Class("A", derived={"B", "C"}, properties=[
schema.OptionalProperty("x", "string"),
]),
schema.Class("B", bases=["A"], properties=[
schema.RepeatedProperty("y", "bool"),
schema.RepeatedOptionalProperty("z", "int"),
]),
]) == {
"B/B.ql": a_ql_class_tester(class_name="B", properties=[
ql.PropertyForTest(getter="hasX"),
ql.PropertyForTest(getter="getNumberOfYs", type="int"),
]),
"B/B_getX.ql": a_ql_property_tester(class_name="B",
property=ql.PropertyForTest(getter="getX", is_total=False,
type="string")),
"B/B_getY.ql": a_ql_property_tester(class_name="B",
property=ql.PropertyForTest(getter="getY", is_total=False,
is_repeated=True,
type="bool")),
"B/B_getZ.ql": a_ql_property_tester(class_name="B",
property=ql.PropertyForTest(getter="getZ", is_total=False,
is_repeated=True,
type="int")),
}
def test_test_properties_deduplicated(opts, generate_tests):
write(opts.ql_test_output / "Final" / "test.swift")
assert generate_tests([
schema.Class("Base", derived={"A", "B"}, properties=[
schema.SingleProperty("x", "string"),
schema.RepeatedProperty("y", "bool"),
]),
schema.Class("A", bases=["Base"], derived={"Final"}),
schema.Class("B", bases=["Base"], derived={"Final"}),
schema.Class("Final", bases=["A", "B"]),
]) == {
"Final/Final.ql": a_ql_class_tester(class_name="Final", properties=[
ql.PropertyForTest(getter="getX", type="string"),
ql.PropertyForTest(getter="getNumberOfYs", type="int"),
]),
"Final/Final_getY.ql": a_ql_property_tester(class_name="Final",
property=ql.PropertyForTest(getter="getY", is_total=False,
is_repeated=True,
type="bool")),
}
def test_test_properties_skipped(opts, generate_tests):
write(opts.ql_test_output / "Derived" / "test.swift")
assert generate_tests([
schema.Class("Base", derived={"Derived"}, properties=[
schema.SingleProperty("x", "string", pragmas=["qltest_skip", "foo"]),
schema.RepeatedProperty("y", "int", pragmas=["bar", "qltest_skip"]),
]),
schema.Class("Derived", bases=["Base"], properties=[
schema.PredicateProperty("a", pragmas=["qltest_skip"]),
schema.OptionalProperty(
"b", "int", pragmas=["bar", "qltest_skip", "baz"]),
]),
]) == {
"Derived/Derived.ql": a_ql_class_tester(class_name="Derived"),
}
def test_test_base_class_skipped(opts, generate_tests):
write(opts.ql_test_output / "Derived" / "test.swift")
assert generate_tests([
schema.Class("Base", derived={"Derived"}, pragmas=["qltest_skip", "foo"], properties=[
schema.SingleProperty("x", "string"),
schema.RepeatedProperty("y", "int"),
]),
schema.Class("Derived", bases=["Base"]),
]) == {
"Derived/Derived.ql": a_ql_class_tester(class_name="Derived"),
}
def test_test_final_class_skipped(opts, generate_tests):
write(opts.ql_test_output / "Derived" / "test.swift")
assert generate_tests([
schema.Class("Base", derived={"Derived"}),
schema.Class("Derived", bases=["Base"], pragmas=["qltest_skip", "foo"], properties=[
schema.SingleProperty("x", "string"),
schema.RepeatedProperty("y", "int"),
]),
]) == {}
def test_test_class_hierarchy_collapse(opts, generate_tests):
write(opts.ql_test_output / "Base" / "test.swift")
assert generate_tests([
schema.Class("Base", derived={"D1", "D2"}, pragmas=["foo", "qltest_collapse_hierarchy"]),
schema.Class("D1", bases=["Base"], properties=[schema.SingleProperty("x", "string")]),
schema.Class("D2", bases=["Base"], derived={"D3"}, properties=[schema.SingleProperty("y", "string")]),
schema.Class("D3", bases=["D2"], properties=[schema.SingleProperty("z", "string")]),
]) == {
"Base/Base.ql": a_ql_class_tester(class_name="Base", show_ql_class=True),
}
def test_test_class_hierarchy_uncollapse(opts, generate_tests):
for d in ("Base", "D3", "D4"):
write(opts.ql_test_output / d / "test.swift")
assert generate_tests([
schema.Class("Base", derived={"D1", "D2"}, pragmas=["foo", "qltest_collapse_hierarchy"]),
schema.Class("D1", bases=["Base"], properties=[schema.SingleProperty("x", "string")]),
schema.Class("D2", bases=["Base"], derived={"D3", "D4"}, pragmas=["qltest_uncollapse_hierarchy", "bar"]),
schema.Class("D3", bases=["D2"]),
schema.Class("D4", bases=["D2"]),
]) == {
"Base/Base.ql": a_ql_class_tester(class_name="Base", show_ql_class=True),
"D3/D3.ql": a_ql_class_tester(class_name="D3"),
"D4/D4.ql": a_ql_class_tester(class_name="D4"),
}
def test_test_class_hierarchy_uncollapse_at_final(opts, generate_tests):
for d in ("Base", "D3"):
write(opts.ql_test_output / d / "test.swift")
assert generate_tests([
schema.Class("Base", derived={"D1", "D2"}, pragmas=["foo", "qltest_collapse_hierarchy"]),
schema.Class("D1", bases=["Base"], properties=[schema.SingleProperty("x", "string")]),
schema.Class("D2", bases=["Base"], derived={"D3"}),
schema.Class("D3", bases=["D2"], pragmas=["qltest_uncollapse_hierarchy", "bar"]),
]) == {
"Base/Base.ql": a_ql_class_tester(class_name="Base", show_ql_class=True),
"D3/D3.ql": a_ql_class_tester(class_name="D3"),
}
def test_property_description(generate_classes):
description = ["Lorem", "Ipsum"]
assert generate_classes([
schema.Class("MyObject", properties=[
schema.SingleProperty("foo", "bar", description=description),
]),
]) == {
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(name="MyObject", final=True,
properties=[
ql.Property(singular="Foo", type="bar", tablename="my_objects",
tableparams=["this", "result"],
doc="foo of this my object",
description=description),
])),
}
def test_property_doc_override(generate_classes):
assert generate_classes([
schema.Class("MyObject", properties=[
schema.SingleProperty("foo", "bar", doc="baz")]),
]) == {
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(name="MyObject", final=True,
properties=[
ql.Property(singular="Foo", type="bar", tablename="my_objects",
tableparams=["this", "result"], doc="baz"),
])),
}
def test_repeated_property_doc_override(generate_classes):
assert generate_classes([
schema.Class("MyObject", properties=[
schema.RepeatedProperty("x", "int", doc="children of this"),
schema.RepeatedOptionalProperty("y", "int", doc="child of this")]),
]) == {
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(name="MyObject", final=True,
properties=[
ql.Property(singular="X", plural="Xes", type="int",
tablename="my_object_xes",
tableparams=["this", "index", "result"],
doc="child of this", doc_plural="children of this"),
ql.Property(singular="Y", plural="Ys", type="int",
tablename="my_object_ies", is_optional=True,
tableparams=["this", "index", "result"],
doc="child of this", doc_plural="children of this"),
])),
}
@pytest.mark.parametrize("abbr,expected", list(qlgen.abbreviations.items()))
def test_property_doc_abbreviations(generate_classes, abbr, expected):
expected_doc = f"foo {expected} bar of this object"
assert generate_classes([
schema.Class("Object", properties=[
schema.SingleProperty(f"foo_{abbr}_bar", "baz")]),
]) == {
"Object.qll": (a_ql_stub(name="Object", base_import=gen_import_prefix + "Object"),
a_ql_class(name="Object", final=True,
properties=[
ql.Property(singular=f"Foo{abbr.capitalize()}Bar", type="baz",
tablename="objects",
tableparams=["this", "result"], doc=expected_doc),
])),
}
@pytest.mark.parametrize("abbr,expected", list(qlgen.abbreviations.items()))
def test_property_doc_abbreviations_ignored_if_within_word(generate_classes, abbr, expected):
expected_doc = f"foo {abbr}acadabra bar of this object"
assert generate_classes([
schema.Class("Object", properties=[
schema.SingleProperty(f"foo_{abbr}acadabra_bar", "baz")]),
]) == {
"Object.qll": (a_ql_stub(name="Object", base_import=gen_import_prefix + "Object"),
a_ql_class(name="Object", final=True,
properties=[
ql.Property(singular=f"Foo{abbr.capitalize()}acadabraBar", type="baz",
tablename="objects",
tableparams=["this", "result"], doc=expected_doc),
])),
}
def test_repeated_property_doc_override_with_format(generate_classes):
assert generate_classes([
schema.Class("MyObject", properties=[
schema.RepeatedProperty("x", "int", doc="special {children} of this"),
schema.RepeatedOptionalProperty("y", "int", doc="special {child} of this")]),
]) == {
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(name="MyObject", final=True,
properties=[
ql.Property(singular="X", plural="Xes", type="int",
tablename="my_object_xes",
tableparams=["this", "index", "result"],
doc="special child of this",
doc_plural="special children of this"),
ql.Property(singular="Y", plural="Ys", type="int",
tablename="my_object_ies", is_optional=True,
tableparams=["this", "index", "result"],
doc="special child of this",
doc_plural="special children of this"),
])),
}
def test_repeated_property_doc_override_with_multiple_formats(generate_classes):
assert generate_classes([
schema.Class("MyObject", properties=[
schema.RepeatedProperty("x", "int", doc="{cat} or {dog}"),
schema.RepeatedOptionalProperty("y", "int", doc="{cats} or {dogs}")]),
]) == {
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(name="MyObject", final=True,
properties=[
ql.Property(singular="X", plural="Xes", type="int",
tablename="my_object_xes",
tableparams=["this", "index", "result"],
doc="cat or dog", doc_plural="cats or dogs"),
ql.Property(singular="Y", plural="Ys", type="int",
tablename="my_object_ies", is_optional=True,
tableparams=["this", "index", "result"],
doc="cat or dog", doc_plural="cats or dogs"),
])),
}
def test_property_doc_override_with_format(generate_classes):
assert generate_classes([
schema.Class("MyObject", properties=[
schema.SingleProperty("foo", "bar", doc="special {baz} of this")]),
]) == {
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(name="MyObject", final=True,
properties=[
ql.Property(singular="Foo", type="bar", tablename="my_objects",
tableparams=["this", "result"], doc="special baz of this"),
])),
}
def test_property_on_class_with_default_doc_name(generate_classes):
assert generate_classes([
schema.Class("MyObject", properties=[
schema.SingleProperty("foo", "bar")],
default_doc_name="baz"),
]) == {
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
a_ql_class(name="MyObject", final=True,
properties=[
ql.Property(singular="Foo", type="bar", tablename="my_objects",
tableparams=["this", "result"], doc="foo of this baz"),
])),
}
def test_stub_on_class_with_ipa_from_class(generate_classes):
assert generate_classes([
schema.Class("MyObject", ipa=schema.IpaInfo(from_class="A")),
]) == {
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject", ipa_accessors=[
ql.IpaUnderlyingAccessor(argument="Entity", type="Raw::A", constructorparams=["result"]),
]),
a_ql_class(name="MyObject", final=True, ipa=True)),
}
def test_stub_on_class_with_ipa_on_arguments(generate_classes):
assert generate_classes([
schema.Class("MyObject", ipa=schema.IpaInfo(on_arguments={"base": "A", "index": "int", "label": "string"})),
]) == {
"MyObject.qll": (a_ql_stub(name="MyObject", base_import=gen_import_prefix + "MyObject", ipa_accessors=[
ql.IpaUnderlyingAccessor(argument="Base", type="Raw::A", constructorparams=["result", "_", "_"]),
ql.IpaUnderlyingAccessor(argument="Index", type="int", constructorparams=["_", "result", "_"]),
ql.IpaUnderlyingAccessor(argument="Label", type="string", constructorparams=["_", "_", "result"]),
]),
a_ql_class(name="MyObject", final=True, ipa=True)),
}
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -1,355 +0,0 @@
import sys
import pytest
from swift.codegen.test.utils import *
import hashlib
generator = "foo"
@pytest.fixture
def pystache_renderer_cls():
with mock.patch("pystache.Renderer") as ret:
yield ret
@pytest.fixture
def pystache_renderer(pystache_renderer_cls):
ret = mock.Mock()
pystache_renderer_cls.return_value = ret
return ret
@pytest.fixture
def sut(pystache_renderer):
return render.Renderer(generator, paths.root_dir)
def assert_file(file, text):
with open(file) as inp:
assert inp.read() == text
def hash(text):
h = hashlib.sha256()
h.update(text.encode())
return h.hexdigest()
def test_constructor(pystache_renderer_cls, sut):
pystache_init, = pystache_renderer_cls.mock_calls
assert set(pystache_init.kwargs) == {'search_dirs', 'escape'}
assert pystache_init.kwargs['search_dirs'] == str(paths.templates_dir)
an_object = object()
assert pystache_init.kwargs['escape'](an_object) is an_object
def test_render(pystache_renderer, sut):
data = mock.Mock(spec=("template",))
text = "some text"
pystache_renderer.render_name.side_effect = (text,)
output = paths.root_dir / "some/output.txt"
sut.render(data, output)
assert_file(output, text)
assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=generator),
]
def test_managed_render(pystache_renderer, sut):
data = mock.Mock(spec=("template",))
text = "some text"
pystache_renderer.render_name.side_effect = (text,)
output = paths.root_dir / "some/output.txt"
registry = paths.root_dir / "a/registry.list"
write(registry)
with sut.manage(generated=(), stubs=(), registry=registry) as renderer:
renderer.render(data, output)
assert renderer.written == {output}
assert_file(output, text)
assert_file(registry, f"some/output.txt {hash(text)} {hash(text)}\n")
assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=generator),
]
def test_managed_render_with_no_registry(pystache_renderer, sut):
data = mock.Mock(spec=("template",))
text = "some text"
pystache_renderer.render_name.side_effect = (text,)
output = paths.root_dir / "some/output.txt"
registry = paths.root_dir / "a/registry.list"
with sut.manage(generated=(), stubs=(), registry=registry) as renderer:
renderer.render(data, output)
assert renderer.written == {output}
assert_file(output, text)
assert_file(registry, f"some/output.txt {hash(text)} {hash(text)}\n")
assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=generator),
]
def test_managed_render_with_post_processing(pystache_renderer, sut):
data = mock.Mock(spec=("template",))
text = "some text"
postprocessed_text = "some other text"
pystache_renderer.render_name.side_effect = (text,)
output = paths.root_dir / "some/output.txt"
registry = paths.root_dir / "a/registry.list"
write(registry)
with sut.manage(generated=(), stubs=(), registry=registry) as renderer:
renderer.render(data, output)
assert renderer.written == {output}
assert_file(output, text)
write(output, postprocessed_text)
assert_file(registry, f"some/output.txt {hash(text)} {hash(postprocessed_text)}\n")
assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=generator),
]
def test_managed_render_with_erasing(pystache_renderer, sut):
output = paths.root_dir / "some/output.txt"
stub = paths.root_dir / "some/stub.txt"
registry = paths.root_dir / "a/registry.list"
write(output)
write(stub, "// generated bla bla")
write(registry)
with sut.manage(generated=(output,), stubs=(stub,), registry=registry) as renderer:
pass
assert not output.is_file()
assert not stub.is_file()
assert_file(registry, "")
assert pystache_renderer.mock_calls == []
def test_managed_render_with_skipping_of_generated_file(pystache_renderer, sut):
data = mock.Mock(spec=("template",))
output = paths.root_dir / "some/output.txt"
some_output = "some output"
registry = paths.root_dir / "a/registry.list"
write(output, some_output)
write(registry, f"some/output.txt {hash(some_output)} {hash(some_output)}\n")
pystache_renderer.render_name.side_effect = (some_output,)
with sut.manage(generated=(output,), stubs=(), registry=registry) as renderer:
renderer.render(data, output)
assert renderer.written == set()
assert_file(output, some_output)
assert_file(registry, f"some/output.txt {hash(some_output)} {hash(some_output)}\n")
assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=generator),
]
def test_managed_render_with_skipping_of_stub_file(pystache_renderer, sut):
data = mock.Mock(spec=("template",))
stub = paths.root_dir / "some/stub.txt"
some_output = "// generated some output"
some_processed_output = "// generated some processed output"
registry = paths.root_dir / "a/registry.list"
write(stub, some_processed_output)
write(registry, f"some/stub.txt {hash(some_output)} {hash(some_processed_output)}\n")
pystache_renderer.render_name.side_effect = (some_output,)
with sut.manage(generated=(), stubs=(stub,), registry=registry) as renderer:
renderer.render(data, stub)
assert renderer.written == set()
assert_file(stub, some_processed_output)
assert_file(registry, f"some/stub.txt {hash(some_output)} {hash(some_processed_output)}\n")
assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=generator),
]
def test_managed_render_with_modified_generated_file(pystache_renderer, sut):
output = paths.root_dir / "some/output.txt"
some_processed_output = "// some processed output"
registry = paths.root_dir / "a/registry.list"
write(output, "// something else")
write(registry, f"some/output.txt whatever {hash(some_processed_output)}\n")
with pytest.raises(render.Error):
sut.manage(generated=(output,), stubs=(), registry=registry)
def test_managed_render_with_modified_stub_file_still_marked_as_generated(pystache_renderer, sut):
stub = paths.root_dir / "some/stub.txt"
some_processed_output = "// generated some processed output"
registry = paths.root_dir / "a/registry.list"
write(stub, "// generated something else")
write(registry, f"some/stub.txt whatever {hash(some_processed_output)}\n")
with pytest.raises(render.Error):
sut.manage(generated=(), stubs=(stub,), registry=registry)
def test_managed_render_with_modified_stub_file_not_marked_as_generated(pystache_renderer, sut):
stub = paths.root_dir / "some/stub.txt"
some_processed_output = "// generated some processed output"
registry = paths.root_dir / "a/registry.list"
write(stub, "// no more generated")
write(registry, f"some/stub.txt whatever {hash(some_processed_output)}\n")
with sut.manage(generated=(), stubs=(stub,), registry=registry) as renderer:
pass
assert_file(registry, "")
class MyError(Exception):
pass
def test_managed_render_exception_drops_written_and_inexsistent_from_registry(pystache_renderer, sut):
data = mock.Mock(spec=("template",))
text = "some text"
pystache_renderer.render_name.side_effect = (text,)
output = paths.root_dir / "some/output.txt"
registry = paths.root_dir / "x/registry.list"
write(output, text)
write(paths.root_dir / "a")
write(paths.root_dir / "c")
write(registry, "a a a\n"
f"some/output.txt whatever {hash(text)}\n"
"b b b\n"
"c c c")
with pytest.raises(MyError):
with sut.manage(generated=(), stubs=(), registry=registry) as renderer:
renderer.render(data, output)
raise MyError
assert_file(registry, "a a a\nc c c\n")
def test_managed_render_drops_inexsistent_from_registry(pystache_renderer, sut):
registry = paths.root_dir / "x/registry.list"
write(paths.root_dir / "a")
write(paths.root_dir / "c")
write(registry, f"a {hash('')} {hash('')}\n"
"b b b\n"
f"c {hash('')} {hash('')}")
with sut.manage(generated=(), stubs=(), registry=registry):
pass
assert_file(registry, f"a {hash('')} {hash('')}\nc {hash('')} {hash('')}\n")
def test_managed_render_exception_does_not_erase(pystache_renderer, sut):
output = paths.root_dir / "some/output.txt"
stub = paths.root_dir / "some/stub.txt"
registry = paths.root_dir / "a/registry.list"
write(output)
write(stub, "// generated bla bla")
write(registry)
with pytest.raises(MyError):
with sut.manage(generated=(output,), stubs=(stub,), registry=registry) as renderer:
raise MyError
assert output.is_file()
assert stub.is_file()
def test_render_with_extensions(pystache_renderer, sut):
data = mock.Mock(spec=("template", "extensions"))
data.template = "test_template"
data.extensions = ["foo", "bar", "baz"]
output = pathlib.Path("my", "test", "file")
expected_outputs = [pathlib.Path("my", "test", p) for p in ("file.foo", "file.bar", "file.baz")]
rendered = [f"text{i}" for i in range(len(expected_outputs))]
pystache_renderer.render_name.side_effect = rendered
sut.render(data, output)
expected_templates = ["test_template_foo", "test_template_bar", "test_template_baz"]
assert pystache_renderer.mock_calls == [
mock.call.render_name(t, data, generator=generator)
for t in expected_templates
]
for expected_output, expected_contents in zip(expected_outputs, rendered):
assert_file(expected_output, expected_contents)
def test_managed_render_with_force_not_skipping_generated_file(pystache_renderer, sut):
data = mock.Mock(spec=("template",))
output = paths.root_dir / "some/output.txt"
some_output = "some output"
registry = paths.root_dir / "a/registry.list"
write(output, some_output)
write(registry, f"some/output.txt {hash(some_output)} {hash(some_output)}\n")
pystache_renderer.render_name.side_effect = (some_output,)
with sut.manage(generated=(output,), stubs=(), registry=registry, force=True) as renderer:
renderer.render(data, output)
assert renderer.written == {output}
assert_file(output, some_output)
assert_file(registry, f"some/output.txt {hash(some_output)} {hash(some_output)}\n")
assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=generator),
]
def test_managed_render_with_force_not_skipping_stub_file(pystache_renderer, sut):
data = mock.Mock(spec=("template",))
stub = paths.root_dir / "some/stub.txt"
some_output = "// generated some output"
some_processed_output = "// generated some processed output"
registry = paths.root_dir / "a/registry.list"
write(stub, some_processed_output)
write(registry, f"some/stub.txt {hash(some_output)} {hash(some_processed_output)}\n")
pystache_renderer.render_name.side_effect = (some_output,)
with sut.manage(generated=(), stubs=(stub,), registry=registry, force=True) as renderer:
renderer.render(data, stub)
assert renderer.written == {stub}
assert_file(stub, some_output)
assert_file(registry, f"some/stub.txt {hash(some_output)} {hash(some_output)}\n")
assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=generator),
]
def test_managed_render_with_force_ignores_modified_generated_file(sut):
output = paths.root_dir / "some/output.txt"
some_processed_output = "// some processed output"
registry = paths.root_dir / "a/registry.list"
write(output, "// something else")
write(registry, f"some/output.txt whatever {hash(some_processed_output)}\n")
with sut.manage(generated=(output,), stubs=(), registry=registry, force=True):
pass
def test_managed_render_with_force_ignores_modified_stub_file_still_marked_as_generated(sut):
stub = paths.root_dir / "some/stub.txt"
some_processed_output = "// generated some processed output"
registry = paths.root_dir / "a/registry.list"
write(stub, "// generated something else")
write(registry, f"some/stub.txt whatever {hash(some_processed_output)}\n")
with sut.manage(generated=(), stubs=(stub,), registry=registry, force=True):
pass
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -1,686 +0,0 @@
import sys
import pytest
from swift.codegen.test.utils import *
from swift.codegen.lib import schemadefs as defs
from swift.codegen.loaders.schemaloader import load
def test_empty_schema():
@load
class data:
pass
assert data.classes == {}
assert data.includes == set()
assert data.null is None
assert data.null_class is None
def test_one_empty_class():
@load
class data:
class MyClass:
pass
assert data.classes == {
'MyClass': schema.Class('MyClass'),
}
assert data.root_class is data.classes['MyClass']
def test_two_empty_classes():
@load
class data:
class MyClass1:
pass
class MyClass2(MyClass1):
pass
assert data.classes == {
'MyClass1': schema.Class('MyClass1', derived={'MyClass2'}),
'MyClass2': schema.Class('MyClass2', bases=['MyClass1']),
}
assert data.root_class is data.classes['MyClass1']
def test_no_external_bases():
class A:
pass
with pytest.raises(schema.Error):
@load
class data:
class MyClass(A):
pass
def test_no_multiple_roots():
with pytest.raises(schema.Error):
@load
class data:
class MyClass1:
pass
class MyClass2:
pass
def test_empty_classes_diamond():
@load
class data:
class A:
pass
class B(A):
pass
class C(A):
pass
class D(B, C):
pass
assert data.classes == {
'A': schema.Class('A', derived={'B', 'C'}),
'B': schema.Class('B', bases=['A'], derived={'D'}),
'C': schema.Class('C', bases=['A'], derived={'D'}),
'D': schema.Class('D', bases=['B', 'C']),
}
#
def test_group():
@load
class data:
@defs.group("xxx")
class A:
pass
assert data.classes == {
'A': schema.Class('A', group="xxx"),
}
def test_group_is_inherited():
@load
class data:
class A:
pass
class B(A):
pass
@defs.group('xxx')
class C(A):
pass
class D(B, C):
pass
assert data.classes == {
'A': schema.Class('A', derived={'B', 'C'}),
'B': schema.Class('B', bases=['A'], derived={'D'}),
'C': schema.Class('C', bases=['A'], derived={'D'}, group='xxx'),
'D': schema.Class('D', bases=['B', 'C'], group='xxx'),
}
def test_no_mixed_groups_in_bases():
with pytest.raises(schema.Error):
@load
class data:
class A:
pass
@defs.group('x')
class B(A):
pass
@defs.group('y')
class C(A):
pass
class D(B, C):
pass
#
def test_lowercase_rejected():
with pytest.raises(schema.Error):
@load
class data:
class aLowerCase:
pass
def test_properties():
@load
class data:
class A:
one: defs.string
two: defs.optional[defs.int]
three: defs.list[defs.boolean]
four: defs.list[defs.optional[defs.string]]
five: defs.predicate
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('one', 'string'),
schema.OptionalProperty('two', 'int'),
schema.RepeatedProperty('three', 'boolean'),
schema.RepeatedOptionalProperty('four', 'string'),
schema.PredicateProperty('five'),
]),
}
def test_class_properties():
class A:
pass
@load
class data:
class A:
pass
class B(A):
one: A
two: defs.optional[A]
three: defs.list[A]
four: defs.list[defs.optional[A]]
assert data.classes == {
'A': schema.Class('A', derived={'B'}),
'B': schema.Class('B', bases=['A'], properties=[
schema.SingleProperty('one', 'A'),
schema.OptionalProperty('two', 'A'),
schema.RepeatedProperty('three', 'A'),
schema.RepeatedOptionalProperty('four', 'A'),
]),
}
def test_string_reference_class_properties():
@load
class data:
class A:
one: "A"
two: defs.optional["A"]
three: defs.list["A"]
four: defs.list[defs.optional["A"]]
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('one', 'A'),
schema.OptionalProperty('two', 'A'),
schema.RepeatedProperty('three', 'A'),
schema.RepeatedOptionalProperty('four', 'A'),
]),
}
@pytest.mark.parametrize("spec", [lambda t: t, lambda t: defs.optional[t], lambda t: defs.list[t],
lambda t: defs.list[defs.optional[t]]])
def test_string_reference_dangling(spec):
with pytest.raises(schema.Error):
@load
class data:
class A:
x: spec("B")
def test_children():
@load
class data:
class A:
one: "A" | defs.child
two: defs.optional["A"] | defs.child
three: defs.list["A"] | defs.child
four: defs.list[defs.optional["A"]] | defs.child
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('one', 'A', is_child=True),
schema.OptionalProperty('two', 'A', is_child=True),
schema.RepeatedProperty('three', 'A', is_child=True),
schema.RepeatedOptionalProperty('four', 'A', is_child=True),
]),
}
@pytest.mark.parametrize("spec", [defs.string, defs.int, defs.boolean, defs.predicate])
def test_builtin_and_predicate_children_not_allowed(spec):
with pytest.raises(schema.Error):
@load
class data:
class A:
x: spec | defs.child
_pragmas = [(defs.qltest.skip, "qltest_skip"),
(defs.qltest.collapse_hierarchy, "qltest_collapse_hierarchy"),
(defs.qltest.uncollapse_hierarchy, "qltest_uncollapse_hierarchy"),
(defs.cpp.skip, "cpp_skip"),
(defs.ql.internal, "ql_internal"),
]
@pytest.mark.parametrize("pragma,expected", _pragmas)
def test_property_with_pragma(pragma, expected):
@load
class data:
class A:
x: defs.string | pragma
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('x', 'string', pragmas=[expected]),
]),
}
def test_property_with_pragmas():
spec = defs.string
for pragma, _ in _pragmas:
spec |= pragma
@load
class data:
class A:
x: spec
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('x', 'string', pragmas=[expected for _, expected in _pragmas]),
]),
}
@pytest.mark.parametrize("pragma,expected", _pragmas)
def test_class_with_pragma(pragma, expected):
@load
class data:
@pragma
class A:
pass
assert data.classes == {
'A': schema.Class('A', pragmas=[expected]),
}
def test_class_with_pragmas():
def apply_pragmas(cls):
for p, _ in _pragmas:
p(cls)
@load
class data:
class A:
pass
apply_pragmas(A)
assert data.classes == {
'A': schema.Class('A', pragmas=[e for _, e in _pragmas]),
}
def test_ipa_from_class():
@load
class data:
class A:
pass
@defs.synth.from_class(A)
class B(A):
pass
assert data.classes == {
'A': schema.Class('A', derived={'B'}, ipa=True),
'B': schema.Class('B', bases=['A'], ipa=schema.IpaInfo(from_class="A")),
}
def test_ipa_from_class_ref():
@load
class data:
@defs.synth.from_class("B")
class A:
pass
class B(A):
pass
assert data.classes == {
'A': schema.Class('A', derived={'B'}, ipa=schema.IpaInfo(from_class="B")),
'B': schema.Class('B', bases=['A']),
}
def test_ipa_from_class_dangling():
with pytest.raises(schema.Error):
@load
class data:
@defs.synth.from_class("X")
class A:
pass
def test_ipa_class_on():
@load
class data:
class A:
pass
@defs.synth.on_arguments(a=A, i=defs.int)
class B(A):
pass
assert data.classes == {
'A': schema.Class('A', derived={'B'}, ipa=True),
'B': schema.Class('B', bases=['A'], ipa=schema.IpaInfo(on_arguments={'a': 'A', 'i': 'int'})),
}
def test_ipa_class_on_ref():
class A:
pass
@load
class data:
@defs.synth.on_arguments(b="B", i=defs.int)
class A:
pass
class B(A):
pass
assert data.classes == {
'A': schema.Class('A', derived={'B'}, ipa=schema.IpaInfo(on_arguments={'b': 'B', 'i': 'int'})),
'B': schema.Class('B', bases=['A']),
}
def test_ipa_class_on_dangling():
with pytest.raises(schema.Error):
@load
class data:
@defs.synth.on_arguments(s=defs.string, a="A", i=defs.int)
class B:
pass
def test_ipa_class_hierarchy():
@load
class data:
class Root:
pass
class Base(Root):
pass
class Intermediate(Base):
pass
@defs.synth.on_arguments(a=Base, i=defs.int)
class A(Intermediate):
pass
@defs.synth.from_class(Base)
class B(Base):
pass
class C(Root):
pass
assert data.classes == {
'Root': schema.Class('Root', derived={'Base', 'C'}),
'Base': schema.Class('Base', bases=['Root'], derived={'Intermediate', 'B'}, ipa=True),
'Intermediate': schema.Class('Intermediate', bases=['Base'], derived={'A'}, ipa=True),
'A': schema.Class('A', bases=['Intermediate'], ipa=schema.IpaInfo(on_arguments={'a': 'Base', 'i': 'int'})),
'B': schema.Class('B', bases=['Base'], ipa=schema.IpaInfo(from_class='Base')),
'C': schema.Class('C', bases=['Root']),
}
def test_class_docstring():
@load
class data:
class A:
"""Very important class."""
assert data.classes == {
'A': schema.Class('A', doc=["Very important class."])
}
def test_property_docstring():
@load
class data:
class A:
x: int | defs.desc("very important property.")
assert data.classes == {
'A': schema.Class('A', properties=[schema.SingleProperty('x', 'int', description=["very important property."])])
}
def test_class_docstring_newline():
@load
class data:
class A:
"""Very important
class."""
assert data.classes == {
'A': schema.Class('A', doc=["Very important", "class."])
}
def test_property_docstring_newline():
@load
class data:
class A:
x: int | defs.desc("""very important
property.""")
assert data.classes == {
'A': schema.Class('A',
properties=[schema.SingleProperty('x', 'int', description=["very important", "property."])])
}
def test_class_docstring_stripped():
@load
class data:
class A:
"""
Very important class.
"""
assert data.classes == {
'A': schema.Class('A', doc=["Very important class."])
}
def test_property_docstring_stripped():
@load
class data:
class A:
x: int | defs.desc("""
very important property.
""")
assert data.classes == {
'A': schema.Class('A', properties=[schema.SingleProperty('x', 'int', description=["very important property."])])
}
def test_class_docstring_split():
@load
class data:
class A:
"""Very important class.
As said, very important."""
assert data.classes == {
'A': schema.Class('A', doc=["Very important class.", "", "As said, very important."])
}
def test_property_docstring_split():
@load
class data:
class A:
x: int | defs.desc("""very important property.
Very very important.""")
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('x', 'int', description=["very important property.", "", "Very very important."])])
}
def test_class_docstring_indent():
@load
class data:
class A:
"""
Very important class.
As said, very important.
"""
assert data.classes == {
'A': schema.Class('A', doc=["Very important class.", " As said, very important."])
}
def test_property_docstring_indent():
@load
class data:
class A:
x: int | defs.desc("""
very important property.
Very very important.
""")
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('x', 'int', description=["very important property.", " Very very important."])])
}
def test_property_doc_override():
@load
class data:
class A:
x: int | defs.doc("y")
assert data.classes == {
'A': schema.Class('A', properties=[
schema.SingleProperty('x', 'int', doc="y")]),
}
def test_property_doc_override_no_newlines():
with pytest.raises(schema.Error):
@load
class data:
class A:
x: int | defs.doc("no multiple\nlines")
def test_property_doc_override_no_trailing_dot():
with pytest.raises(schema.Error):
@load
class data:
class A:
x: int | defs.doc("no dots please.")
def test_class_default_doc_name():
@load
class data:
@defs.ql.default_doc_name("b")
class A:
pass
assert data.classes == {
'A': schema.Class('A', default_doc_name="b"),
}
def test_null_class():
@load
class data:
class Root:
pass
@defs.use_for_null
class Null(Root):
pass
assert data.classes == {
'Root': schema.Class('Root', derived={'Null'}),
'Null': schema.Class('Null', bases=['Root']),
}
assert data.null == 'Null'
assert data.null_class is data.classes[data.null]
def test_null_class_cannot_be_derived():
with pytest.raises(schema.Error):
@load
class data:
class Root:
pass
@defs.use_for_null
class Null(Root):
pass
class Impossible(Null):
pass
def test_null_class_cannot_be_defined_multiple_times():
with pytest.raises(schema.Error):
@load
class data:
class Root:
pass
@defs.use_for_null
class Null1(Root):
pass
@defs.use_for_null
class Null2(Root):
pass
def test_uppercase_acronyms_are_rejected():
with pytest.raises(schema.Error):
@load
class data:
class Root:
pass
class ROTFLNode(Root):
pass
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -1,188 +0,0 @@
import sys
from swift.codegen.generators import trapgen
from swift.codegen.lib import cpp, dbscheme
from swift.codegen.test.utils import *
output_dir = pathlib.Path("path", "to", "output")
@pytest.fixture
def generate_grouped(opts, renderer, dbscheme_input):
opts.cpp_output = output_dir
def ret(entities):
dbscheme_input.entities = entities
generated = run_generation(trapgen.generate, opts, renderer)
dirs = {f.parent for f in generated}
assert all(isinstance(f, pathlib.Path) for f in generated)
assert all(f.name in ("TrapEntries", "TrapTags") for f in generated)
assert set(f for f in generated if f.name == "TrapTags") == {output_dir / "TrapTags"}
return ({
str(d.relative_to(output_dir)): generated[d / "TrapEntries"] for d in dirs
}, generated[output_dir / "TrapTags"])
return ret
@pytest.fixture
def generate_grouped_traps(generate_grouped):
def ret(entities):
generated, _ = generate_grouped(entities)
assert all(isinstance(g, cpp.TrapList) for g in generated.values())
return {d: traps.traps for d, traps in generated.items()}
return ret
@pytest.fixture
def generate_traps(generate_grouped_traps):
def ret(entities):
generated = generate_grouped_traps(entities)
assert set(generated) == {"."}
return generated["."]
return ret
@pytest.fixture
def generate_tags(generate_grouped):
def ret(entities):
_, tags = generate_grouped(entities)
assert isinstance(tags, cpp.TagList)
return tags.tags
return ret
def test_empty_traps(generate_traps):
assert generate_traps([]) == []
def test_empty_tags(generate_tags):
assert generate_tags([]) == []
def test_one_empty_table_rejected(generate_traps):
with pytest.raises(AssertionError):
generate_traps([
dbscheme.Table(name="foos", columns=[]),
])
def test_one_table(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
]) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field("bla", "int")]),
]
def test_one_table(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
]) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field("bla", "int")]),
]
def test_one_table_with_id(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[
dbscheme.Column("bla", "int", binding=True)]),
]) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field(
"bla", "int")], id=cpp.Field("bla", "int")),
]
def test_one_table_with_two_binding_first_is_id(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[
dbscheme.Column("x", "a", binding=True),
dbscheme.Column("y", "b", binding=True),
]),
]) == [
cpp.Trap("foos", name="Foos", fields=[
cpp.Field("x", "a"),
cpp.Field("y", "b"),
], id=cpp.Field("x", "a")),
]
@pytest.mark.parametrize("column,field", [
(dbscheme.Column("x", "string"), cpp.Field("x", "std::string")),
(dbscheme.Column("y", "boolean"), cpp.Field("y", "bool")),
(dbscheme.Column("z", "@db_type"), cpp.Field("z", "TrapLabel<DbTypeTag>")),
])
def test_one_table_special_types(generate_traps, column, field):
assert generate_traps([
dbscheme.Table(name="foos", columns=[column]),
]) == [
cpp.Trap("foos", name="Foos", fields=[field]),
]
@pytest.mark.parametrize("name", ["start_line", "start_column", "end_line", "end_column", "index", "num_whatever"])
def test_one_table_overridden_unsigned_field(generate_traps, name):
assert generate_traps([
dbscheme.Table(name="foos", columns=[dbscheme.Column(name, "bar")]),
]) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field(name, "unsigned")]),
]
def test_one_table_overridden_underscore_named_field(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[dbscheme.Column("whatever_", "bar")]),
]) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field("whatever", "bar")]),
]
def test_tables_with_dir(generate_grouped_traps):
assert generate_grouped_traps([
dbscheme.Table(name="x", columns=[dbscheme.Column("i", "int")]),
dbscheme.Table(name="y", columns=[dbscheme.Column("i", "int")], dir=pathlib.Path("foo")),
dbscheme.Table(name="z", columns=[dbscheme.Column("i", "int")], dir=pathlib.Path("foo/bar")),
]) == {
".": [cpp.Trap("x", name="X", fields=[cpp.Field("i", "int")])],
"foo": [cpp.Trap("y", name="Y", fields=[cpp.Field("i", "int")])],
"foo/bar": [cpp.Trap("z", name="Z", fields=[cpp.Field("i", "int")])],
}
def test_one_table_no_tags(generate_tags):
assert generate_tags([
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
]) == []
def test_one_union_tags(generate_tags):
assert generate_tags([
dbscheme.Union(lhs="@left_hand_side", rhs=["@b", "@a", "@c"]),
]) == [
cpp.Tag(name="LeftHandSide", bases=[], id="@left_hand_side"),
cpp.Tag(name="A", bases=["LeftHandSide"], id="@a"),
cpp.Tag(name="B", bases=["LeftHandSide"], id="@b"),
cpp.Tag(name="C", bases=["LeftHandSide"], id="@c"),
]
def test_multiple_union_tags(generate_tags):
assert generate_tags([
dbscheme.Union(lhs="@d", rhs=["@a"]),
dbscheme.Union(lhs="@a", rhs=["@b", "@c"]),
dbscheme.Union(lhs="@e", rhs=["@c", "@f"]),
]) == [
cpp.Tag(name="D", bases=[], id="@d"),
cpp.Tag(name="E", bases=[], id="@e"),
cpp.Tag(name="A", bases=["D"], id="@a"),
cpp.Tag(name="F", bases=["E"], id="@f"),
cpp.Tag(name="B", bases=["A"], id="@b"),
cpp.Tag(name="C", bases=["A", "E"], id="@c"),
]
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -1,84 +0,0 @@
import pathlib
from unittest import mock
import pytest
from swift.codegen.lib import render, schema, paths
schema_dir = pathlib.Path("a", "dir")
schema_file = schema_dir / "schema.py"
dbscheme_file = pathlib.Path("another", "dir", "test.dbscheme")
def write(out, contents=""):
out.parent.mkdir(parents=True, exist_ok=True)
with open(out, "w") as out:
out.write(contents)
@pytest.fixture
def renderer():
return mock.Mock(spec=render.Renderer)
@pytest.fixture
def render_manager(renderer):
ret = mock.Mock(spec=render.RenderManager)
ret.__enter__ = mock.Mock(return_value=ret)
ret.__exit__ = mock.Mock(return_value=None)
ret.is_customized_stub.return_value = False
return ret
@pytest.fixture
def opts():
ret = mock.MagicMock()
ret.root_dir = paths.root_dir
return ret
@pytest.fixture(autouse=True)
def override_paths(tmp_path):
with mock.patch("swift.codegen.lib.paths.root_dir", tmp_path), \
mock.patch("swift.codegen.lib.paths.exe_file", tmp_path / "exe"):
yield
@pytest.fixture
def input(opts, tmp_path):
opts.schema = tmp_path / schema_file
with mock.patch("swift.codegen.loaders.schemaloader.load_file") as load_mock:
load_mock.return_value = schema.Schema([])
yield load_mock.return_value
assert load_mock.mock_calls == [
mock.call(opts.schema),
], load_mock.mock_calls
@pytest.fixture
def dbscheme_input(opts, tmp_path):
opts.dbscheme = tmp_path / dbscheme_file
with mock.patch("swift.codegen.loaders.dbschemeloader.iterload") as load_mock:
load_mock.entities = []
load_mock.side_effect = lambda _: load_mock.entities
yield load_mock
assert load_mock.mock_calls == [
mock.call(opts.dbscheme),
], load_mock.mock_calls
def run_generation(generate, opts, renderer):
output = {}
renderer.render.side_effect = lambda data, out: output.__setitem__(out, data)
generate(opts, renderer)
return output
def run_managed_generation(generate, opts, renderer, render_manager):
output = {}
renderer.manage.side_effect = (render_manager,)
render_manager.render.side_effect = lambda data, out: output.__setitem__(out, data)
generate(opts, renderer)
return output