Swift: add possibility to specify null class

This commit is contained in:
Paolo Tranquilli
2022-11-04 08:07:04 +01:00
parent e00585ca24
commit 2aa528852e
10 changed files with 281 additions and 32 deletions

View File

@@ -12,15 +12,14 @@ Each class in the schema gets a corresponding `struct` in `TrapClasses.h`, where
"""
import functools
import pathlib
from typing import Dict
import typing
import inflection
from swift.codegen.lib import cpp, schema
def _get_type(t: str) -> str:
def _get_type(t: str, add_or_none_except: typing.Optional[str] = None) -> str:
if t is None:
# this is a predicate
return "bool"
@@ -29,11 +28,15 @@ def _get_type(t: str) -> str:
if t == "boolean":
return "bool"
if t[0].isupper():
return f"TrapLabel<{t}Tag>"
if add_or_none_except is not None and t != add_or_none_except:
suffix = "OrNone"
else:
suffix = ""
return f"TrapLabel<{t}{suffix}Tag>"
return t
def _get_field(cls: schema.Class, p: schema.Property) -> cpp.Field:
def _get_field(cls: schema.Class, p: schema.Property, add_or_none_except: typing.Optional[str] = None) -> cpp.Field:
trap_name = None
if not p.is_single:
trap_name = inflection.camelize(f"{cls.name}_{p.name}")
@@ -41,7 +44,7 @@ def _get_field(cls: schema.Class, p: schema.Property) -> cpp.Field:
trap_name = inflection.pluralize(trap_name)
args = dict(
field_name=p.name + ("_" if p.name in cpp.cpp_keywords else ""),
type=_get_type(p.type),
base_type=_get_type(p.type, add_or_none_except),
is_optional=p.is_optional,
is_repeated=p.is_repeated,
is_predicate=p.is_predicate,
@@ -52,8 +55,13 @@ def _get_field(cls: schema.Class, p: schema.Property) -> cpp.Field:
class Processor:
def __init__(self, data: Dict[str, schema.Class]):
self._classmap = data
def __init__(self, data: schema.Schema):
self._classmap = data.classes
if data.null:
root_type = next(iter(data.classes))
self._add_or_none_except = root_type
else:
self._add_or_none_except = None
@functools.lru_cache(maxsize=None)
def _get_class(self, name: str) -> cpp.Class:
@@ -64,7 +72,10 @@ class Processor:
return cpp.Class(
name=name,
bases=[self._get_class(b) for b in cls.bases],
fields=[_get_field(cls, p) for p in cls.properties if "cpp_skip" not in p.pragmas],
fields=[
_get_field(cls, p, self._add_or_none_except)
for p in cls.properties if "cpp_skip" not in p.pragmas
],
final=not cls.derived,
trap_name=trap_name,
)
@@ -78,8 +89,8 @@ class Processor:
def generate(opts, renderer):
assert opts.cpp_output
processor = Processor(schema.load_file(opts.schema).classes)
processor = Processor(schema.load_file(opts.schema))
out = opts.cpp_output
for dir, classes in processor.get_classes().items():
renderer.render(cpp.ClassList(classes, opts.schema,
include_parent=bool(dir)), out / dir / "TrapClasses")
include_parent=bool(dir)), out / dir / "TrapClasses")

View File

@@ -13,6 +13,7 @@ Moreover:
as columns
The type hierarchy will be translated to corresponding `union` declarations.
"""
import typing
import inflection
@@ -23,14 +24,21 @@ from typing import Set, List
log = logging.getLogger(__name__)
def dbtype(typename):
""" translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes """
def dbtype(typename: str, add_or_none_except: typing.Optional[str] = None) -> str:
""" translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes.
For class types, appends an underscore followed by `null` if provided
"""
if typename[0].isupper():
return "@" + inflection.underscore(typename)
underscored = inflection.underscore(typename)
if add_or_none_except is not None and typename != add_or_none_except:
suffix = "_or_none"
else:
suffix = ""
return f"@{underscored}{suffix}"
return typename
def cls_to_dbscheme(cls: schema.Class):
def cls_to_dbscheme(cls: schema.Class, add_or_none_except: typing.Optional[str] = None):
""" Yield all dbscheme entities needed to model class `cls` """
if cls.derived:
yield Union(dbtype(cls.name), (dbtype(c) for c in cls.derived))
@@ -48,7 +56,7 @@ def cls_to_dbscheme(cls: schema.Class):
columns=[
Column("id", type=dbtype(cls.name), binding=binding),
] + [
Column(f.name, dbtype(f.type)) for f in cls.properties if f.is_single
Column(f.name, dbtype(f.type, add_or_none_except)) for f in cls.properties if f.is_single
],
dir=dir,
)
@@ -61,7 +69,7 @@ def cls_to_dbscheme(cls: schema.Class):
columns=[
Column("id", type=dbtype(cls.name)),
Column("index", type="int"),
Column(inflection.singularize(f.name), dbtype(f.type)),
Column(inflection.singularize(f.name), dbtype(f.type, add_or_none_except)),
],
dir=dir,
)
@@ -71,7 +79,7 @@ def cls_to_dbscheme(cls: schema.Class):
name=inflection.tableize(f"{cls.name}_{f.name}"),
columns=[
Column("id", type=dbtype(cls.name)),
Column(f.name, dbtype(f.type)),
Column(f.name, dbtype(f.type, add_or_none_except)),
],
dir=dir,
)
@@ -87,7 +95,17 @@ def cls_to_dbscheme(cls: schema.Class):
def get_declarations(data: schema.Schema):
return [d for cls in data.classes.values() for d in cls_to_dbscheme(cls)]
add_or_none_except = data.root_class.name if data.null else None
declarations = [d for cls in data.classes.values() for d in cls_to_dbscheme(cls, add_or_none_except)]
if data.null:
property_classes = {
prop.type for cls in data.classes.values() for prop in cls.properties
if cls.name != data.null and prop.type and prop.type[0].isupper()
}
declarations += [
Union(dbtype(t, data.null), [dbtype(t), dbtype(data.null)]) for t in sorted(property_classes)
]
return declarations
def get_includes(data: schema.Schema, include_dir: pathlib.Path, swift_dir: pathlib.Path):

View File

@@ -147,7 +147,7 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, prev_child: str =
return ql.Property(**args)
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
def get_ql_class(cls: schema.Class):
pragmas = {k: True for k in cls.pragmas if k.startswith("ql")}
prev_child = ""
properties = []
@@ -314,7 +314,7 @@ def generate(opts, renderer):
data = schema.load_file(input)
classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items()}
classes = {name: get_ql_class(cls) for name, cls in data.classes.items()}
if not classes:
raise NoClasses
root = next(iter(classes.values()))

View File

@@ -41,10 +41,10 @@ def get_cpp_type(schema_type: str):
def get_field(c: dbscheme.Column):
args = {
"field_name": c.schema_name,
"type": c.type,
"base_type": c.type,
}
args.update(cpp.get_field_override(c.schema_name))
args["type"] = get_cpp_type(args["type"])
args["base_type"] = get_cpp_type(args["base_type"])
return cpp.Field(**args)

View File

@@ -16,7 +16,7 @@ cpp_keywords = {"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", "
"xor", "xor_eq"}
_field_overrides = [
(re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"), {"type": "unsigned"}),
(re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"), {"base_type": "unsigned"}),
(re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}),
]
@@ -32,7 +32,7 @@ def get_field_override(field: str):
@dataclass
class Field:
field_name: str
type: str
base_type: str
is_optional: bool = False
is_repeated: bool = False
is_predicate: bool = False
@@ -40,13 +40,18 @@ class Field:
first: bool = False
def __post_init__(self):
if self.is_optional:
self.type = f"std::optional<{self.type}>"
if self.is_repeated:
self.type = f"std::vector<{self.type}>"
if self.field_name in cpp_keywords:
self.field_name += "_"
@property
def type(self) -> str:
type = self.base_type
if self.is_optional:
type = f"std::optional<{type}>"
if self.is_repeated:
type = f"std::vector<{type}>"
return type
# using @property breaks pystache internals here
def get_streamer(self):
if self.type == "std::string":
@@ -60,6 +65,10 @@ class Field:
def is_single(self):
return not (self.is_optional or self.is_repeated or self.is_predicate)
@property
def is_label(self):
return self.base_type.startswith("TrapLabel<")
@dataclass
class Trap:

View File

@@ -115,6 +115,8 @@ child = _ChildModifier()
doc = _DocModifier
desc = _DescModifier
use_for_null = _annotate(null=True)
qltest = _Namespace(
skip=_Pragma("qltest_skip"),
collapse_hierarchy=_Pragma("qltest_collapse_hierarchy"),

View File

@@ -55,6 +55,14 @@ class Property:
def is_predicate(self) -> bool:
return self.kind == self.Kind.PREDICATE
@property
def has_class_type(self) -> bool:
return bool(self.type) and self.type[0].isupper()
@property
def has_builtin_type(self) -> bool:
return bool(self.type) and self.type[0].islower()
SingleProperty = functools.partial(Property, Property.Kind.SINGLE)
OptionalProperty = functools.partial(Property, Property.Kind.OPTIONAL)
@@ -104,6 +112,16 @@ class Class:
class Schema:
classes: Dict[str, Class] = field(default_factory=dict)
includes: Set[str] = field(default_factory=set)
null: Optional[str] = None
@property
def root_class(self):
# always the first in the dictionary
return next(iter(self.classes.values()))
@property
def null_class(self):
return self.classes[self.null] if self.null else None
predicate_marker = object()
@@ -195,6 +213,8 @@ def _get_class(cls: type) -> Class:
raise Error(f"Class name must be capitalized, found {cls.__name__}")
if len({b._group for b in cls.__bases__ if hasattr(b, "_group")}) > 1:
raise Error(f"Bases with mixed groups for {cls.__name__}")
if any(getattr(b, "_null", False) for b in cls.__bases__):
raise Error(f"Null class cannot be derived")
return Class(name=cls.__name__,
bases=[b.__name__ for b in cls.__bases__ if b is not object],
derived={d.__name__ for d in cls.__subclasses__()},
@@ -233,6 +253,7 @@ def load(m: types.ModuleType) -> Schema:
known = {"int", "string", "boolean"}
known.update(n for n in m.__dict__ if not n.startswith("__"))
import swift.codegen.lib.schema.defs as defs
null = None
for name, data in m.__dict__.items():
if hasattr(defs, name):
continue
@@ -247,8 +268,13 @@ def load(m: types.ModuleType) -> Schema:
f"Only one root class allowed, found second root {name}")
cls.check_types(known)
classes[name] = cls
if getattr(data, "_null", False):
if null is not None:
raise Error(f"Null class {null} already defined, second null class {name} not allowed")
null = name
cls.is_null_class = True
return Schema(includes=includes, classes=_toposort_classes_by_group(classes))
return Schema(includes=includes, classes=_toposort_classes_by_group(classes), null=null)
def load_file(path: pathlib.Path) -> Schema:

View File

@@ -80,6 +80,25 @@ def test_class_with_field(generate, type, expected, property_cls, optional, repe
]
def test_class_field_with_null(generate, input):
input.null = "Null"
a = cpp.Class(name="A")
assert generate([
schema.Class(name="A", derived={"B"}),
schema.Class(name="B", bases=["A"], properties=[
schema.SingleProperty("x", "A"),
schema.SingleProperty("y", "B"),
])
]) == [
a,
cpp.Class(name="B", bases=[a], final=True, trap_name="Bs",
fields=[
cpp.Field("x", "TrapLabel<ATag>"),
cpp.Field("y", "TrapLabel<BOrNoneTag>"),
]),
]
def test_class_with_predicate(generate):
assert generate([
schema.Class(name="MyClass", properties=[

View File

@@ -18,8 +18,9 @@ def dir_param(request):
@pytest.fixture
def generate(opts, input, renderer):
def func(classes):
def func(classes, null=None):
input.classes = {cls.name: cls for cls in classes}
input.null = null
(out, data), = run_generation(dbschemegen.generate, opts, renderer).items()
assert out is opts.dbscheme
return data
@@ -359,5 +360,114 @@ def test_class_with_derived_and_repeated_property(generate, dir_param):
)
def test_null_class(generate):
assert generate([
schema.Class(
name="Base",
derived={"W", "X", "Y", "Z", "Null"},
),
schema.Class(
name="W",
bases=["Base"],
properties=[
schema.SingleProperty("w", "W"),
schema.SingleProperty("x", "X"),
schema.OptionalProperty("y", "Y"),
schema.RepeatedProperty("z", "Z"),
]
),
schema.Class(
name="X",
bases=["Base"],
),
schema.Class(
name="Y",
bases=["Base"],
),
schema.Class(
name="Z",
bases=["Base"],
),
schema.Class(
name="Null",
bases=["Base"],
),
], null="Null") == dbscheme.Scheme(
src=schema_file,
includes=[],
declarations=[
dbscheme.Union(
lhs="@base",
rhs=["@null", "@w", "@x", "@y", "@z"],
),
dbscheme.Table(
name="ws",
columns=[
dbscheme.Column('id', '@w', binding=True),
dbscheme.Column('w', '@w_or_none'),
dbscheme.Column('x', '@x_or_none'),
],
),
dbscheme.Table(
name="w_ies",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@w'),
dbscheme.Column('y', '@y_or_none'),
],
),
dbscheme.Table(
name="w_zs",
keyset=dbscheme.KeySet(["id", "index"]),
columns=[
dbscheme.Column('id', '@w'),
dbscheme.Column('index', 'int'),
dbscheme.Column('z', '@z_or_none'),
],
),
dbscheme.Table(
name="xes",
columns=[
dbscheme.Column('id', '@x', binding=True),
],
),
dbscheme.Table(
name="ys",
columns=[
dbscheme.Column('id', '@y', binding=True),
],
),
dbscheme.Table(
name="zs",
columns=[
dbscheme.Column('id', '@z', binding=True),
],
),
dbscheme.Table(
name="nulls",
columns=[
dbscheme.Column('id', '@null', binding=True),
],
),
dbscheme.Union(
lhs="@w_or_none",
rhs=["@w", "@null"],
),
dbscheme.Union(
lhs="@x_or_none",
rhs=["@x", "@null"],
),
dbscheme.Union(
lhs="@y_or_none",
rhs=["@y", "@null"],
),
dbscheme.Union(
lhs="@z_or_none",
rhs=["@z", "@null"],
),
],
)
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -13,6 +13,8 @@ def test_empty_schema():
assert data.classes == {}
assert data.includes == set()
assert data.null is None
assert data.null_class is None
def test_one_empty_class():
@@ -24,6 +26,7 @@ def test_one_empty_class():
assert data.classes == {
'MyClass': schema.Class('MyClass'),
}
assert data.root_class is data.classes['MyClass']
def test_two_empty_classes():
@@ -39,6 +42,7 @@ def test_two_empty_classes():
'MyClass1': schema.Class('MyClass1', derived={'MyClass2'}),
'MyClass2': schema.Class('MyClass2', bases=['MyClass1']),
}
assert data.root_class is data.classes['MyClass1']
def test_no_external_bases():
@@ -452,7 +456,8 @@ def test_property_docstring_newline():
property.""")
assert data.classes == {
'A': schema.Class('A', properties=[schema.SingleProperty('x', 'int', description=["very important", "property."])])
'A': schema.Class('A',
properties=[schema.SingleProperty('x', 'int', description=["very important", "property."])])
}
@@ -566,5 +571,54 @@ def test_class_default_doc_name():
}
def test_null_class():
@schema.load
class data:
class Root:
pass
@defs.use_for_null
class Null(Root):
pass
assert data.classes == {
'Root': schema.Class('Root', derived={'Null'}),
'Null': schema.Class('Null', bases=['Root']),
}
assert data.null == 'Null'
assert data.null_class is data.classes[data.null]
def test_null_class_cannot_be_derived():
with pytest.raises(schema.Error):
@schema.load
class data:
class Root:
pass
@defs.use_for_null
class Null(Root):
pass
class Impossible(Null):
pass
def test_null_class_cannot_be_defined_multiple_times():
with pytest.raises(schema.Error):
@schema.load
class data:
class Root:
pass
@defs.use_for_null
class Null1(Root):
pass
@defs.use_for_null
class Null2(Root):
pass
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))