Swift: some restructuring of codegen

Loading of the schema and dbscheme has been moved to a separate
`loaders` package for better separation of concerns.
This commit is contained in:
Paolo Tranquilli
2023-02-14 09:51:01 +01:00
parent 781aab3eb7
commit 8e079320f3
19 changed files with 531 additions and 352 deletions

View File

@@ -6,6 +6,6 @@ py_library(
visibility = ["//swift/codegen:__subpackages__"],
deps = [
"//swift/codegen/lib",
requirement("toposort"),
"//swift/codegen/loaders",
],
)

View File

@@ -17,6 +17,7 @@ import typing
import inflection
from swift.codegen.lib import cpp, schema
from swift.codegen.loaders import schemaloader
def _get_type(t: str, add_or_none_except: typing.Optional[str] = None) -> str:
@@ -90,7 +91,7 @@ class Processor:
def generate(opts, renderer):
assert opts.cpp_output
processor = Processor(schema.load_file(opts.schema))
processor = Processor(schemaloader.load_file(opts.schema))
out = opts.cpp_output
for dir, classes in processor.get_classes().items():
renderer.render(cpp.ClassList(classes, opts.schema,

View File

@@ -18,8 +18,8 @@ import typing
import inflection
from swift.codegen.lib import schema
from swift.codegen.loaders import schemaloader
from swift.codegen.lib.dbscheme import *
from typing import Set, List
log = logging.getLogger(__name__)
@@ -123,7 +123,7 @@ def generate(opts, renderer):
input = opts.schema
out = opts.dbscheme
data = schema.load_file(input)
data = schemaloader.load_file(input)
dbscheme = Scheme(src=input.relative_to(opts.swift_dir),
includes=get_includes(data, include_dir=input.parent, swift_dir=opts.swift_dir),

View File

@@ -30,6 +30,7 @@ import itertools
import inflection
from swift.codegen.lib import schema, ql
from swift.codegen.loaders import schemaloader
log = logging.getLogger(__name__)
@@ -297,7 +298,7 @@ def generate(opts, renderer):
stubs = {q for q in stub_out.rglob("*.qll")}
data = schema.load_file(input)
data = schemaloader.load_file(input)
classes = {name: get_ql_class(cls) for name, cls in data.classes.items()}
if not classes:

View File

@@ -18,6 +18,7 @@ import inflection
from toposort import toposort_flatten
from swift.codegen.lib import dbscheme, cpp
from swift.codegen.loaders import dbschemeloader
log = logging.getLogger(__name__)
@@ -73,7 +74,7 @@ def generate(opts, renderer):
out = opts.cpp_output
traps = {pathlib.Path(): []}
for e in dbscheme.iterload(opts.dbscheme):
for e in dbschemeloader.iterload(opts.dbscheme):
if e.is_table:
traps.setdefault(e.dir, []).append(get_trap(e))
elif e.is_union:

View File

@@ -2,11 +2,10 @@ load("@swift_codegen_deps//:requirements.bzl", "requirement")
py_library(
name = "lib",
srcs = glob(["**/*.py"]),
srcs = glob(["*.py"]),
visibility = ["//swift/codegen:__subpackages__"],
deps = [
requirement("pystache"),
requirement("pyyaml"),
requirement("inflection"),
],
)

View File

@@ -105,54 +105,3 @@ class Scheme:
src: str
includes: List[SchemeInclude]
declarations: List[Decl]
class Re:
entity = re.compile(
"(?m)"
r"(?:^#keyset\[(?P<tablekeys>[\w\s,]+)\][\s\n]*)?^(?P<table>\w+)\("
r"(?:\s*//dir=(?P<tabledir>\S*))?(?P<tablebody>[^\)]*)"
r"\);?"
"|"
r"^(?P<union>@\w+)\s*=\s*(?P<unionbody>@\w+(?:\s*\|\s*@\w+)*)\s*;?"
)
field = re.compile(r"(?m)[\w\s]*\s(?P<field>\w+)\s*:\s*(?P<type>@?\w+)(?P<ref>\s+ref)?")
key = re.compile(r"@\w+")
comment = re.compile(r"(?m)(?s)/\*.*?\*/|//(?!dir=)[^\n]*$") # lookahead avoid ignoring metadata like //dir=foo
def get_column(match):
return Column(
schema_name=match["field"].rstrip("_"),
type=match["type"],
binding=not match["ref"],
)
def get_table(match):
keyset = None
if match["tablekeys"]:
keyset = KeySet(k.strip() for k in match["tablekeys"].split(","))
return Table(
name=match["table"],
columns=[get_column(f) for f in Re.field.finditer(match["tablebody"])],
keyset=keyset,
dir=pathlib.PosixPath(match["tabledir"]) if match["tabledir"] else None,
)
def get_union(match):
return Union(
lhs=match["union"],
rhs=(d[0] for d in Re.key.finditer(match["unionbody"])),
)
def iterload(file):
with open(file) as file:
data = Re.comment.sub("", file.read())
for e in Re.entity.finditer(data):
if e["table"]:
yield get_table(e)
elif e["union"]:
yield get_union(e)

View File

@@ -1,15 +1,9 @@
""" schema.yml format representation """
import pathlib
import re
import types
""" schema format representation """
import typing
from dataclasses import dataclass, field
from typing import List, Set, Union, Dict, Optional
from enum import Enum, auto
import functools
import importlib.util
from toposort import toposort_flatten
import inflection
class Error(Exception):
@@ -198,125 +192,3 @@ def split_doc(doc):
while trimmed and not trimmed[0]:
trimmed.pop(0)
return trimmed
@dataclass
class _PropertyNamer(PropertyModifier):
name: str
def modify(self, prop: Property):
prop.name = self.name.rstrip("_")
def _get_class(cls: type) -> Class:
if not isinstance(cls, type):
raise Error(f"Only class definitions allowed in schema, found {cls}")
# we must check that going to dbscheme names and back is preserved
# In particular this will not happen if uppercase acronyms are included in the name
to_underscore_and_back = inflection.camelize(inflection.underscore(cls.__name__), uppercase_first_letter=True)
if cls.__name__ != to_underscore_and_back:
raise Error(f"Class name must be upper camel-case, without capitalized acronyms, found {cls.__name__} "
f"instead of {to_underscore_and_back}")
if len({b._group for b in cls.__bases__ if hasattr(b, "_group")}) > 1:
raise Error(f"Bases with mixed groups for {cls.__name__}")
if any(getattr(b, "_null", False) for b in cls.__bases__):
raise Error(f"Null class cannot be derived")
return Class(name=cls.__name__,
bases=[b.__name__ for b in cls.__bases__ if b is not object],
derived={d.__name__ for d in cls.__subclasses__()},
# getattr to inherit from bases
group=getattr(cls, "_group", ""),
# in the following we don't use `getattr` to avoid inheriting
pragmas=cls.__dict__.get("_pragmas", []),
ipa=cls.__dict__.get("_ipa", None),
properties=[
a | _PropertyNamer(n)
for n, a in cls.__dict__.get("__annotations__", {}).items()
],
doc=split_doc(cls.__doc__),
default_doc_name=cls.__dict__.get("_doc_name"),
)
def _toposort_classes_by_group(classes: typing.Dict[str, Class]) -> typing.Dict[str, Class]:
groups = {}
ret = {}
for name, cls in classes.items():
groups.setdefault(cls.group, []).append(name)
for group, grouped in sorted(groups.items()):
inheritance = {name: classes[name].bases for name in grouped}
for name in toposort_flatten(inheritance):
ret[name] = classes[name]
return ret
def _fill_ipa_information(classes: typing.Dict[str, Class]):
""" Take a dictionary where the `ipa` field is filled for all explicitly synthesized classes
and update it so that all non-final classes that have only synthesized final descendants
get `True` as` value for the `ipa` field
"""
if not classes:
return
is_ipa: typing.Dict[str, bool] = {}
def fill_is_ipa(name: str):
if name not in is_ipa:
cls = classes[name]
for d in cls.derived:
fill_is_ipa(d)
if cls.ipa is not None:
is_ipa[name] = True
elif not cls.derived:
is_ipa[name] = False
else:
is_ipa[name] = all(is_ipa[d] for d in cls.derived)
root = next(iter(classes))
fill_is_ipa(root)
for name, cls in classes.items():
if cls.ipa is None and is_ipa[name]:
cls.ipa = True
def load(m: types.ModuleType) -> Schema:
includes = set()
classes = {}
known = {"int", "string", "boolean"}
known.update(n for n in m.__dict__ if not n.startswith("__"))
import swift.codegen.lib.schema.defs as defs
null = None
for name, data in m.__dict__.items():
if hasattr(defs, name):
continue
if name == "__includes":
includes = set(data)
continue
if name.startswith("__"):
continue
cls = _get_class(data)
if classes and not cls.bases:
raise Error(
f"Only one root class allowed, found second root {name}")
cls.check_types(known)
classes[name] = cls
if getattr(data, "_null", False):
if null is not None:
raise Error(f"Null class {null} already defined, second null class {name} not allowed")
null = name
cls.is_null_class = True
_fill_ipa_information(classes)
return Schema(includes=includes, classes=_toposort_classes_by_group(classes), null=null)
def load_file(path: pathlib.Path) -> Schema:
spec = importlib.util.spec_from_file_location("schema", path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return load(module)

View File

@@ -1 +0,0 @@
from .schema import *

View File

@@ -0,0 +1,11 @@
load("@swift_codegen_deps//:requirements.bzl", "requirement")
py_library(
name = "loaders",
srcs = glob(["*.py"]),
visibility = ["//swift/codegen:__subpackages__"],
deps = [
requirement("toposort"),
requirement("inflection"),
],
)

View File

@@ -0,0 +1,54 @@
import pathlib
import re
from swift.codegen.lib import dbscheme
class _Re:
entity = re.compile(
"(?m)"
r"(?:^#keyset\[(?P<tablekeys>[\w\s,]+)\][\s\n]*)?^(?P<table>\w+)\("
r"(?:\s*//dir=(?P<tabledir>\S*))?(?P<tablebody>[^\)]*)"
r"\);?"
"|"
r"^(?P<union>@\w+)\s*=\s*(?P<unionbody>@\w+(?:\s*\|\s*@\w+)*)\s*;?"
)
field = re.compile(r"(?m)[\w\s]*\s(?P<field>\w+)\s*:\s*(?P<type>@?\w+)(?P<ref>\s+ref)?")
key = re.compile(r"@\w+")
comment = re.compile(r"(?m)(?s)/\*.*?\*/|//(?!dir=)[^\n]*$") # lookahead avoid ignoring metadata like //dir=foo
def _get_column(match):
return dbscheme.Column(
schema_name=match["field"].rstrip("_"),
type=match["type"],
binding=not match["ref"],
)
def _get_table(match):
keyset = None
if match["tablekeys"]:
keyset = dbscheme.KeySet(k.strip() for k in match["tablekeys"].split(","))
return dbscheme.Table(
name=match["table"],
columns=[_get_column(f) for f in _Re.field.finditer(match["tablebody"])],
keyset=keyset,
dir=pathlib.PosixPath(match["tabledir"]) if match["tabledir"] else None,
)
def _get_union(match):
return dbscheme.Union(
lhs=match["union"],
rhs=(d[0] for d in _Re.key.finditer(match["unionbody"])),
)
def iterload(file):
with open(file) as file:
data = _Re.comment.sub("", file.read())
for e in _Re.entity.finditer(data):
if e["table"]:
yield _get_table(e)
elif e["union"]:
yield _get_union(e)

View File

@@ -0,0 +1,133 @@
""" schema loader """
import inflection
import typing
import types
import pathlib
import importlib
from dataclasses import dataclass
from toposort import toposort_flatten
from swift.codegen.lib import schema, schemadefs
@dataclass
class _PropertyNamer(schema.PropertyModifier):
name: str
def modify(self, prop: schema.Property):
prop.name = self.name.rstrip("_")
def _get_class(cls: type) -> schema.Class:
if not isinstance(cls, type):
raise schema.Error(f"Only class definitions allowed in schema, found {cls}")
# we must check that going to dbscheme names and back is preserved
# In particular this will not happen if uppercase acronyms are included in the name
to_underscore_and_back = inflection.camelize(inflection.underscore(cls.__name__), uppercase_first_letter=True)
if cls.__name__ != to_underscore_and_back:
raise schema.Error(f"Class name must be upper camel-case, without capitalized acronyms, found {cls.__name__} "
f"instead of {to_underscore_and_back}")
if len({b._group for b in cls.__bases__ if hasattr(b, "_group")}) > 1:
raise schema.Error(f"Bases with mixed groups for {cls.__name__}")
if any(getattr(b, "_null", False) for b in cls.__bases__):
raise schema.Error(f"Null class cannot be derived")
return schema.Class(name=cls.__name__,
bases=[b.__name__ for b in cls.__bases__ if b is not object],
derived={d.__name__ for d in cls.__subclasses__()},
# getattr to inherit from bases
group=getattr(cls, "_group", ""),
# in the following we don't use `getattr` to avoid inheriting
pragmas=cls.__dict__.get("_pragmas", []),
ipa=cls.__dict__.get("_ipa", None),
properties=[
a | _PropertyNamer(n)
for n, a in cls.__dict__.get("__annotations__", {}).items()
],
doc=schema.split_doc(cls.__doc__),
default_doc_name=cls.__dict__.get("_doc_name"),
)
def _toposort_classes_by_group(classes: typing.Dict[str, schema.Class]) -> typing.Dict[str, schema.Class]:
groups = {}
ret = {}
for name, cls in classes.items():
groups.setdefault(cls.group, []).append(name)
for group, grouped in sorted(groups.items()):
inheritance = {name: classes[name].bases for name in grouped}
for name in toposort_flatten(inheritance):
ret[name] = classes[name]
return ret
def _fill_ipa_information(classes: typing.Dict[str, schema.Class]):
""" Take a dictionary where the `ipa` field is filled for all explicitly synthesized classes
and update it so that all non-final classes that have only synthesized final descendants
get `True` as` value for the `ipa` field
"""
if not classes:
return
is_ipa: typing.Dict[str, bool] = {}
def fill_is_ipa(name: str):
if name not in is_ipa:
cls = classes[name]
for d in cls.derived:
fill_is_ipa(d)
if cls.ipa is not None:
is_ipa[name] = True
elif not cls.derived:
is_ipa[name] = False
else:
is_ipa[name] = all(is_ipa[d] for d in cls.derived)
root = next(iter(classes))
fill_is_ipa(root)
for name, cls in classes.items():
if cls.ipa is None and is_ipa[name]:
cls.ipa = True
def load(m: types.ModuleType) -> schema.Schema:
includes = set()
classes = {}
known = {"int", "string", "boolean"}
known.update(n for n in m.__dict__ if not n.startswith("__"))
import swift.codegen.lib.schemadefs as defs
null = None
for name, data in m.__dict__.items():
if hasattr(defs, name):
continue
if name == "__includes":
includes = set(data)
continue
if name.startswith("__"):
continue
cls = _get_class(data)
if classes and not cls.bases:
raise schema.Error(
f"Only one root class allowed, found second root {name}")
cls.check_types(known)
classes[name] = cls
if getattr(data, "_null", False):
if null is not None:
raise schema.Error(f"Null class {null} already defined, second null class {name} not allowed")
null = name
cls.is_null_class = True
_fill_ipa_information(classes)
return schema.Schema(includes=includes, classes=_toposort_classes_by_group(classes), null=null)
def load_file(path: pathlib.Path) -> schema.Schema:
spec = importlib.util.spec_from_file_location("schema", path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return load(module)

View File

@@ -0,0 +1,149 @@
from typing import Callable as _Callable
from swift.codegen.lib import schema as _schema
import inspect as _inspect
from dataclasses import dataclass as _dataclass
class _ChildModifier(_schema.PropertyModifier):
def modify(self, prop: _schema.Property):
if prop.type is None or prop.type[0].islower():
raise _schema.Error("Non-class properties cannot be children")
prop.is_child = True
@_dataclass
class _DocModifier(_schema.PropertyModifier):
doc: str
def modify(self, prop: _schema.Property):
if "\n" in self.doc or self.doc[-1] == ".":
raise _schema.Error("No newlines or trailing dots are allowed in doc, did you intend to use desc?")
prop.doc = self.doc
@_dataclass
class _DescModifier(_schema.PropertyModifier):
description: str
def modify(self, prop: _schema.Property):
prop.description = _schema.split_doc(self.description)
def include(source: str):
# add to `includes` variable in calling context
_inspect.currentframe().f_back.f_locals.setdefault(
"__includes", []).append(source)
class _Namespace:
""" simple namespacing mechanism """
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
qltest = _Namespace()
ql = _Namespace()
cpp = _Namespace()
synth = _Namespace()
@_dataclass
class _Pragma(_schema.PropertyModifier):
""" A class or property pragma.
For properties, it functions similarly to a `_PropertyModifier` with `|`, adding the pragma.
For schema classes it acts as a python decorator with `@`.
"""
pragma: str
def __post_init__(self):
namespace, _, name = self.pragma.partition('_')
setattr(globals()[namespace], name, self)
def modify(self, prop: _schema.Property):
prop.pragmas.append(self.pragma)
def __call__(self, cls: type) -> type:
""" use this pragma as a decorator on classes """
if "_pragmas" in cls.__dict__: # not using hasattr as we don't want to land on inherited pragmas
cls._pragmas.append(self.pragma)
else:
cls._pragmas = [self.pragma]
return cls
class _Optionalizer(_schema.PropertyModifier):
def modify(self, prop: _schema.Property):
K = _schema.Property.Kind
if prop.kind != K.SINGLE:
raise _schema.Error(
"Optional should only be applied to simple property types")
prop.kind = K.OPTIONAL
class _Listifier(_schema.PropertyModifier):
def modify(self, prop: _schema.Property):
K = _schema.Property.Kind
if prop.kind == K.SINGLE:
prop.kind = K.REPEATED
elif prop.kind == K.OPTIONAL:
prop.kind = K.REPEATED_OPTIONAL
else:
raise _schema.Error(
"Repeated should only be applied to simple or optional property types")
class _TypeModifier:
""" Modifies types using get item notation """
def __init__(self, modifier: _schema.PropertyModifier):
self.modifier = modifier
def __getitem__(self, item):
return item | self.modifier
_ClassDecorator = _Callable[[type], type]
def _annotate(**kwargs) -> _ClassDecorator:
def f(cls: type) -> type:
for k, v in kwargs.items():
setattr(cls, f"_{k}", v)
return cls
return f
boolean = "boolean"
int = "int"
string = "string"
predicate = _schema.predicate_marker
optional = _TypeModifier(_Optionalizer())
list = _TypeModifier(_Listifier())
child = _ChildModifier()
doc = _DocModifier
desc = _DescModifier
use_for_null = _annotate(null=True)
_Pragma("qltest_skip")
_Pragma("qltest_collapse_hierarchy")
_Pragma("qltest_uncollapse_hierarchy")
ql.default_doc_name = lambda doc: _annotate(doc_name=doc)
_Pragma("ql_internal")
_Pragma("cpp_skip")
def group(name: str = "") -> _ClassDecorator:
return _annotate(group=name)
synth.from_class = lambda ref: _annotate(ipa=_schema.IpaInfo(
from_class=_schema.get_type_name(ref)))
synth.on_arguments = lambda **kwargs: _annotate(
ipa=_schema.IpaInfo(on_arguments={k: _schema.get_type_name(t) for k, t in kwargs.items()}))

View File

@@ -48,119 +48,5 @@ def test_union_has_first_case_marked():
assert [c.type for c in u.rhs] == rhs
# load tests
@pytest.fixture
def load(tmp_path):
file = tmp_path / "test.dbscheme"
def ret(yml):
write(file, yml)
return list(dbscheme.iterload(file))
return ret
def test_load_empty(load):
assert load("") == []
def test_load_one_empty_table(load):
assert load("""
test_foos();
""") == [
dbscheme.Table(name="test_foos", columns=[])
]
def test_load_table_with_keyset(load):
assert load("""
#keyset[x, y,z]
test_foos();
""") == [
dbscheme.Table(name="test_foos", columns=[], keyset=dbscheme.KeySet(["x", "y", "z"]))
]
expected_columns = [
("int foo: int ref", dbscheme.Column(schema_name="foo", type="int", binding=False)),
(" int bar : int ref", dbscheme.Column(schema_name="bar", type="int", binding=False)),
("str baz_: str ref", dbscheme.Column(schema_name="baz", type="str", binding=False)),
("int x: @foo ref", dbscheme.Column(schema_name="x", type="@foo", binding=False)),
("int y: @foo", dbscheme.Column(schema_name="y", type="@foo", binding=True)),
("unique int z: @foo", dbscheme.Column(schema_name="z", type="@foo", binding=True)),
]
@pytest.mark.parametrize("column,expected", expected_columns)
def test_load_table_with_column(load, column, expected):
assert load(f"""
foos(
{column}
);
""") == [
dbscheme.Table(name="foos", columns=[deepcopy(expected)])
]
def test_load_table_with_multiple_columns(load):
columns = ",\n".join(c for c, _ in expected_columns)
expected = [deepcopy(e) for _, e in expected_columns]
assert load(f"""
foos(
{columns}
);
""") == [
dbscheme.Table(name="foos", columns=expected)
]
def test_load_table_with_multiple_columns_and_dir(load):
columns = ",\n".join(c for c, _ in expected_columns)
expected = [deepcopy(e) for _, e in expected_columns]
assert load(f"""
foos( //dir=foo/bar/baz
{columns}
);
""") == [
dbscheme.Table(name="foos", columns=expected, dir=pathlib.Path("foo/bar/baz"))
]
def test_load_multiple_table_with_columns(load):
tables = [f"table{i}({col});" for i, (col, _) in enumerate(expected_columns)]
expected = [dbscheme.Table(name=f"table{i}", columns=[deepcopy(e)]) for i, (_, e) in enumerate(expected_columns)]
assert load("\n".join(tables)) == expected
def test_union(load):
assert load("@foo = @bar | @baz | @bla;") == [
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
def test_table_and_union(load):
assert load("""
foos();
@foo = @bar | @baz | @bla;""") == [
dbscheme.Table(name="foos", columns=[]),
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
def test_comments_ignored(load):
assert load("""
// fake_table();
foos(/* x */unique /*y*/int/*
z
*/ id/* */: /* * */ @bar/*,
int ignored: int ref*/);
@foo = @bar | @baz | @bla; // | @xxx""") == [
dbscheme.Table(name="foos", columns=[dbscheme.Column(schema_name="id", type="@bar", binding=True)]),
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -0,0 +1,123 @@
import sys
from copy import deepcopy
from swift.codegen.lib import dbscheme
from swift.codegen.loaders.dbschemeloader import iterload
from swift.codegen.test.utils import *
@pytest.fixture
def load(tmp_path):
file = tmp_path / "test.dbscheme"
def ret(yml):
write(file, yml)
return list(iterload(file))
return ret
def test_load_empty(load):
assert load("") == []
def test_load_one_empty_table(load):
assert load("""
test_foos();
""") == [
dbscheme.Table(name="test_foos", columns=[])
]
def test_load_table_with_keyset(load):
assert load("""
#keyset[x, y,z]
test_foos();
""") == [
dbscheme.Table(name="test_foos", columns=[], keyset=dbscheme.KeySet(["x", "y", "z"]))
]
expected_columns = [
("int foo: int ref", dbscheme.Column(schema_name="foo", type="int", binding=False)),
(" int bar : int ref", dbscheme.Column(schema_name="bar", type="int", binding=False)),
("str baz_: str ref", dbscheme.Column(schema_name="baz", type="str", binding=False)),
("int x: @foo ref", dbscheme.Column(schema_name="x", type="@foo", binding=False)),
("int y: @foo", dbscheme.Column(schema_name="y", type="@foo", binding=True)),
("unique int z: @foo", dbscheme.Column(schema_name="z", type="@foo", binding=True)),
]
@pytest.mark.parametrize("column,expected", expected_columns)
def test_load_table_with_column(load, column, expected):
assert load(f"""
foos(
{column}
);
""") == [
dbscheme.Table(name="foos", columns=[deepcopy(expected)])
]
def test_load_table_with_multiple_columns(load):
columns = ",\n".join(c for c, _ in expected_columns)
expected = [deepcopy(e) for _, e in expected_columns]
assert load(f"""
foos(
{columns}
);
""") == [
dbscheme.Table(name="foos", columns=expected)
]
def test_load_table_with_multiple_columns_and_dir(load):
columns = ",\n".join(c for c, _ in expected_columns)
expected = [deepcopy(e) for _, e in expected_columns]
assert load(f"""
foos( //dir=foo/bar/baz
{columns}
);
""") == [
dbscheme.Table(name="foos", columns=expected, dir=pathlib.Path("foo/bar/baz"))
]
def test_load_multiple_table_with_columns(load):
tables = [f"table{i}({col});" for i, (col, _) in enumerate(expected_columns)]
expected = [dbscheme.Table(name=f"table{i}", columns=[deepcopy(e)]) for i, (_, e) in enumerate(expected_columns)]
assert load("\n".join(tables)) == expected
def test_union(load):
assert load("@foo = @bar | @baz | @bla;") == [
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
def test_table_and_union(load):
assert load("""
foos();
@foo = @bar | @baz | @bla;""") == [
dbscheme.Table(name="foos", columns=[]),
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
def test_comments_ignored(load):
assert load("""
// fake_table();
foos(/* x */unique /*y*/int/*
z
*/ id/* */: /* * */ @bar/*,
int ignored: int ref*/);
@foo = @bar | @baz | @bla; // | @xxx""") == [
dbscheme.Table(name="foos", columns=[dbscheme.Column(schema_name="id", type="@bar", binding=True)]),
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -3,11 +3,12 @@ import sys
import pytest
from swift.codegen.test.utils import *
from swift.codegen.lib.schema import defs
from swift.codegen.lib import schemadefs as defs
from swift.codegen.loaders.schemaloader import load
def test_empty_schema():
@schema.load
@load
class data:
pass
@@ -18,7 +19,7 @@ def test_empty_schema():
def test_one_empty_class():
@schema.load
@load
class data:
class MyClass:
pass
@@ -30,7 +31,7 @@ def test_one_empty_class():
def test_two_empty_classes():
@schema.load
@load
class data:
class MyClass1:
pass
@@ -50,7 +51,7 @@ def test_no_external_bases():
pass
with pytest.raises(schema.Error):
@schema.load
@load
class data:
class MyClass(A):
pass
@@ -58,7 +59,7 @@ def test_no_external_bases():
def test_no_multiple_roots():
with pytest.raises(schema.Error):
@schema.load
@load
class data:
class MyClass1:
pass
@@ -68,7 +69,7 @@ def test_no_multiple_roots():
def test_empty_classes_diamond():
@schema.load
@load
class data:
class A:
pass
@@ -92,7 +93,7 @@ def test_empty_classes_diamond():
#
def test_group():
@schema.load
@load
class data:
@defs.group("xxx")
class A:
@@ -104,7 +105,7 @@ def test_group():
def test_group_is_inherited():
@schema.load
@load
class data:
class A:
pass
@@ -129,7 +130,7 @@ def test_group_is_inherited():
def test_no_mixed_groups_in_bases():
with pytest.raises(schema.Error):
@schema.load
@load
class data:
class A:
pass
@@ -151,14 +152,14 @@ def test_no_mixed_groups_in_bases():
def test_lowercase_rejected():
with pytest.raises(schema.Error):
@schema.load
@load
class data:
class aLowerCase:
pass
def test_properties():
@schema.load
@load
class data:
class A:
one: defs.string
@@ -182,7 +183,7 @@ def test_class_properties():
class A:
pass
@schema.load
@load
class data:
class A:
pass
@@ -205,7 +206,7 @@ def test_class_properties():
def test_string_reference_class_properties():
@schema.load
@load
class data:
class A:
one: "A"
@@ -227,14 +228,14 @@ def test_string_reference_class_properties():
lambda t: defs.list[defs.optional[t]]])
def test_string_reference_dangling(spec):
with pytest.raises(schema.Error):
@schema.load
@load
class data:
class A:
x: spec("B")
def test_children():
@schema.load
@load
class data:
class A:
one: "A" | defs.child
@@ -255,7 +256,7 @@ def test_children():
@pytest.mark.parametrize("spec", [defs.string, defs.int, defs.boolean, defs.predicate])
def test_builtin_and_predicate_children_not_allowed(spec):
with pytest.raises(schema.Error):
@schema.load
@load
class data:
class A:
x: spec | defs.child
@@ -271,7 +272,7 @@ _pragmas = [(defs.qltest.skip, "qltest_skip"),
@pytest.mark.parametrize("pragma,expected", _pragmas)
def test_property_with_pragma(pragma, expected):
@schema.load
@load
class data:
class A:
x: defs.string | pragma
@@ -288,7 +289,7 @@ def test_property_with_pragmas():
for pragma, _ in _pragmas:
spec |= pragma
@schema.load
@load
class data:
class A:
x: spec
@@ -302,7 +303,7 @@ def test_property_with_pragmas():
@pytest.mark.parametrize("pragma,expected", _pragmas)
def test_class_with_pragma(pragma, expected):
@schema.load
@load
class data:
@pragma
class A:
@@ -318,7 +319,7 @@ def test_class_with_pragmas():
for p, _ in _pragmas:
p(cls)
@schema.load
@load
class data:
class A:
pass
@@ -331,7 +332,7 @@ def test_class_with_pragmas():
def test_ipa_from_class():
@schema.load
@load
class data:
class A:
pass
@@ -347,7 +348,7 @@ def test_ipa_from_class():
def test_ipa_from_class_ref():
@schema.load
@load
class data:
@defs.synth.from_class("B")
class A:
@@ -364,7 +365,7 @@ def test_ipa_from_class_ref():
def test_ipa_from_class_dangling():
with pytest.raises(schema.Error):
@schema.load
@load
class data:
@defs.synth.from_class("X")
class A:
@@ -372,7 +373,7 @@ def test_ipa_from_class_dangling():
def test_ipa_class_on():
@schema.load
@load
class data:
class A:
pass
@@ -391,7 +392,7 @@ def test_ipa_class_on_ref():
class A:
pass
@schema.load
@load
class data:
@defs.synth.on_arguments(b="B", i=defs.int)
class A:
@@ -408,7 +409,7 @@ def test_ipa_class_on_ref():
def test_ipa_class_on_dangling():
with pytest.raises(schema.Error):
@schema.load
@load
class data:
@defs.synth.on_arguments(s=defs.string, a="A", i=defs.int)
class B:
@@ -416,7 +417,7 @@ def test_ipa_class_on_dangling():
def test_ipa_class_hierarchy():
@schema.load
@load
class data:
class Root:
pass
@@ -449,7 +450,7 @@ def test_ipa_class_hierarchy():
def test_class_docstring():
@schema.load
@load
class data:
class A:
"""Very important class."""
@@ -460,7 +461,7 @@ def test_class_docstring():
def test_property_docstring():
@schema.load
@load
class data:
class A:
x: int | defs.desc("very important property.")
@@ -471,7 +472,7 @@ def test_property_docstring():
def test_class_docstring_newline():
@schema.load
@load
class data:
class A:
"""Very important
@@ -483,7 +484,7 @@ def test_class_docstring_newline():
def test_property_docstring_newline():
@schema.load
@load
class data:
class A:
x: int | defs.desc("""very important
@@ -496,7 +497,7 @@ def test_property_docstring_newline():
def test_class_docstring_stripped():
@schema.load
@load
class data:
class A:
"""
@@ -511,7 +512,7 @@ def test_class_docstring_stripped():
def test_property_docstring_stripped():
@schema.load
@load
class data:
class A:
x: int | defs.desc("""
@@ -526,7 +527,7 @@ def test_property_docstring_stripped():
def test_class_docstring_split():
@schema.load
@load
class data:
class A:
"""Very important class.
@@ -539,7 +540,7 @@ def test_class_docstring_split():
def test_property_docstring_split():
@schema.load
@load
class data:
class A:
x: int | defs.desc("""very important property.
@@ -553,7 +554,7 @@ def test_property_docstring_split():
def test_class_docstring_indent():
@schema.load
@load
class data:
class A:
"""
@@ -567,7 +568,7 @@ def test_class_docstring_indent():
def test_property_docstring_indent():
@schema.load
@load
class data:
class A:
x: int | defs.desc("""
@@ -582,7 +583,7 @@ def test_property_docstring_indent():
def test_property_doc_override():
@schema.load
@load
class data:
class A:
x: int | defs.doc("y")
@@ -595,7 +596,7 @@ def test_property_doc_override():
def test_property_doc_override_no_newlines():
with pytest.raises(schema.Error):
@schema.load
@load
class data:
class A:
x: int | defs.doc("no multiple\nlines")
@@ -603,14 +604,14 @@ def test_property_doc_override_no_newlines():
def test_property_doc_override_no_trailing_dot():
with pytest.raises(schema.Error):
@schema.load
@load
class data:
class A:
x: int | defs.doc("no dots please.")
def test_class_default_doc_name():
@schema.load
@load
class data:
@defs.ql.default_doc_name("b")
class A:
@@ -622,7 +623,7 @@ def test_class_default_doc_name():
def test_null_class():
@schema.load
@load
class data:
class Root:
pass
@@ -641,7 +642,7 @@ def test_null_class():
def test_null_class_cannot_be_derived():
with pytest.raises(schema.Error):
@schema.load
@load
class data:
class Root:
pass
@@ -656,7 +657,7 @@ def test_null_class_cannot_be_derived():
def test_null_class_cannot_be_defined_multiple_times():
with pytest.raises(schema.Error):
@schema.load
@load
class data:
class Root:
pass
@@ -672,7 +673,7 @@ def test_null_class_cannot_be_defined_multiple_times():
def test_uppercase_acronyms_are_rejected():
with pytest.raises(schema.Error):
@schema.load
@load
class data:
class Root:
pass

View File

@@ -47,7 +47,7 @@ def override_paths(tmp_path):
@pytest.fixture
def input(opts, tmp_path):
opts.schema = tmp_path / schema_file
with mock.patch("swift.codegen.lib.schema.load_file") as load_mock:
with mock.patch("swift.codegen.loaders.schemaloader.load_file") as load_mock:
load_mock.return_value = schema.Schema([])
yield load_mock.return_value
assert load_mock.mock_calls == [
@@ -58,7 +58,7 @@ def input(opts, tmp_path):
@pytest.fixture
def dbscheme_input(opts, tmp_path):
opts.dbscheme = tmp_path / dbscheme_file
with mock.patch("swift.codegen.lib.dbscheme.iterload") as load_mock:
with mock.patch("swift.codegen.loaders.dbschemeloader.iterload") as load_mock:
load_mock.entities = []
load_mock.side_effect = lambda _: load_mock.entities
yield load_mock