Swift: testing non-trivial dataclass properties

This commit is contained in:
Paolo Tranquilli
2022-04-27 10:12:28 +02:00
parent 7f0476049f
commit 0100c7171d
10 changed files with 316 additions and 184 deletions

3
.gitignore vendored
View File

@@ -20,6 +20,9 @@
# python virtual environment folder # python virtual environment folder
.venv/ .venv/
# binary files created by pytest-cov
.coverage
# It's useful (though not required) to be able to unpack codeql in the ql checkout itself # It's useful (though not required) to be able to unpack codeql in the ql checkout itself
/codeql/ /codeql/

Binary file not shown.

View File

@@ -19,42 +19,42 @@ def dbtype(typename):
def cls_to_dbscheme(cls: schema.Class): def cls_to_dbscheme(cls: schema.Class):
""" Yield all dbscheme entities needed to model class `cls` """ """ Yield all dbscheme entities needed to model class `cls` """
if cls.derived: if cls.derived:
yield DbUnion(dbtype(cls.name), (dbtype(c) for c in cls.derived)) yield Union(dbtype(cls.name), (dbtype(c) for c in cls.derived))
# output a table specific to a class only if it is a leaf class or it has 1-to-1 properties # output a table specific to a class only if it is a leaf class or it has 1-to-1 properties
# Leaf classes need a table to bind the `@` ids # Leaf classes need a table to bind the `@` ids
# 1-to-1 properties are added to a class specific table # 1-to-1 properties are added to a class specific table
# in other cases, separate tables are used for the properties, and a class specific table is unneeded # in other cases, separate tables are used for the properties, and a class specific table is unneeded
if not cls.derived or any(f.is_single for f in cls.properties): if not cls.derived or any(f.is_single for f in cls.properties):
binding = not cls.derived binding = not cls.derived
keyset = DbKeySet(["id"]) if cls.derived else None keyset = KeySet(["id"]) if cls.derived else None
yield DbTable( yield Table(
keyset=keyset, keyset=keyset,
name=inflection.tableize(cls.name), name=inflection.tableize(cls.name),
columns=[ columns=[
DbColumn("id", type=dbtype(cls.name), binding=binding), Column("id", type=dbtype(cls.name), binding=binding),
] + [ ] + [
DbColumn(f.name, dbtype(f.type)) for f in cls.properties if f.is_single Column(f.name, dbtype(f.type)) for f in cls.properties if f.is_single
] ]
) )
# use property-specific tables for 1-to-many and 1-to-at-most-1 properties # use property-specific tables for 1-to-many and 1-to-at-most-1 properties
for f in cls.properties: for f in cls.properties:
if f.is_optional: if f.is_optional:
yield DbTable( yield Table(
keyset=DbKeySet(["id"]), keyset=KeySet(["id"]),
name=inflection.tableize(f"{cls.name}_{f.name}"), name=inflection.tableize(f"{cls.name}_{f.name}"),
columns=[ columns=[
DbColumn("id", type=dbtype(cls.name)), Column("id", type=dbtype(cls.name)),
DbColumn(f.name, dbtype(f.type)), Column(f.name, dbtype(f.type)),
], ],
) )
elif f.is_repeated: elif f.is_repeated:
yield DbTable( yield Table(
keyset=DbKeySet(["id", "index"]), keyset=KeySet(["id", "index"]),
name=inflection.tableize(f"{cls.name}_{f.name}"), name=inflection.tableize(f"{cls.name}_{f.name}"),
columns=[ columns=[
DbColumn("id", type=dbtype(cls.name)), Column("id", type=dbtype(cls.name)),
DbColumn("index", type="int"), Column("index", type="int"),
DbColumn(inflection.singularize(f.name), dbtype(f.type)), Column(inflection.singularize(f.name), dbtype(f.type)),
] ]
) )
@@ -68,7 +68,7 @@ def get_includes(data: schema.Schema, include_dir: pathlib.Path):
for inc in data.includes: for inc in data.includes:
inc = include_dir / inc inc = include_dir / inc
with open(inc) as inclusion: with open(inc) as inclusion:
includes.append(DbSchemeInclude(src=inc.relative_to(paths.swift_dir), data=inclusion.read())) includes.append(SchemeInclude(src=inc.relative_to(paths.swift_dir), data=inclusion.read()))
return includes return includes
@@ -78,7 +78,7 @@ def generate(opts, renderer):
data = schema.load(input) data = schema.load(input)
dbscheme = DbScheme(src=input.relative_to(paths.swift_dir), dbscheme = Scheme(src=input.relative_to(paths.swift_dir),
includes=get_includes(data, include_dir=input.parent), includes=get_includes(data, include_dir=input.parent),
declarations=get_declarations(data)) declarations=get_declarations(data))

View File

@@ -10,7 +10,7 @@ dbscheme_keywords = {"case", "boolean", "int", "string", "type"}
@dataclass @dataclass
class DbColumn: class Column:
schema_name: str schema_name: str
type: str type: str
binding: bool = False binding: bool = False
@@ -36,33 +36,33 @@ class DbColumn:
@dataclass @dataclass
class DbKeySetId: class KeySetId:
id: str id: str
first: bool = False first: bool = False
@dataclass @dataclass
class DbKeySet: class KeySet:
ids: List[DbKeySetId] ids: List[KeySetId]
def __post_init__(self): def __post_init__(self):
assert self.ids assert self.ids
self.ids = [DbKeySetId(x) for x in self.ids] self.ids = [KeySetId(x) for x in self.ids]
self.ids[0].first = True self.ids[0].first = True
class DbDecl: class Decl:
is_table = False is_table = False
is_union = False is_union = False
@dataclass @dataclass
class DbTable(DbDecl): class Table(Decl):
is_table: ClassVar = True is_table: ClassVar = True
name: str name: str
columns: List[DbColumn] columns: List[Column]
keyset: DbKeySet = None keyset: KeySet = None
def __post_init__(self): def __post_init__(self):
if self.columns: if self.columns:
@@ -70,35 +70,35 @@ class DbTable(DbDecl):
@dataclass @dataclass
class DbUnionCase: class UnionCase:
type: str type: str
first: bool = False first: bool = False
@dataclass @dataclass
class DbUnion(DbDecl): class Union(Decl):
is_union: ClassVar = True is_union: ClassVar = True
lhs: str lhs: str
rhs: List[DbUnionCase] rhs: List[UnionCase]
def __post_init__(self): def __post_init__(self):
assert self.rhs assert self.rhs
self.rhs = [DbUnionCase(x) for x in self.rhs] self.rhs = [UnionCase(x) for x in self.rhs]
self.rhs.sort(key=lambda c: c.type) self.rhs.sort(key=lambda c: c.type)
self.rhs[0].first = True self.rhs[0].first = True
@dataclass @dataclass
class DbSchemeInclude: class SchemeInclude:
src: str src: str
data: str data: str
@dataclass @dataclass
class DbScheme: class Scheme:
template: ClassVar = 'dbscheme' template: ClassVar = 'dbscheme'
src: str src: str
includes: List[DbSchemeInclude] includes: List[SchemeInclude]
declarations: List[DbDecl] declarations: List[Decl]

View File

@@ -6,20 +6,20 @@ import inflection
@dataclass @dataclass
class QlParam: class Param:
param: str param: str
type: str = None type: str = None
first: bool = False first: bool = False
@dataclass @dataclass
class QlProperty: class Property:
singular: str singular: str
type: str type: str
tablename: str tablename: str
tableparams: List[QlParam] tableparams: List[Param]
plural: str = None plural: str = None
params: List[QlParam] = field(default_factory=list) params: List[Param] = field(default_factory=list)
first: bool = False first: bool = False
local_var: str = "x" local_var: str = "x"
@@ -31,7 +31,7 @@ class QlProperty:
assert self.tableparams assert self.tableparams
if self.type_is_class: if self.type_is_class:
self.tableparams = [x if x != "result" else self.local_var for x in self.tableparams] self.tableparams = [x if x != "result" else self.local_var for x in self.tableparams]
self.tableparams = [QlParam(x) for x in self.tableparams] self.tableparams = [Param(x) for x in self.tableparams]
self.tableparams[0].first = True self.tableparams[0].first = True
@property @property
@@ -45,13 +45,13 @@ class QlProperty:
@dataclass @dataclass
class QlClass: class Class:
template: ClassVar = 'ql_class' template: ClassVar = 'ql_class'
name: str name: str
bases: List[str] = field(default_factory=list) bases: List[str] = field(default_factory=list)
final: bool = False final: bool = False
properties: List[QlProperty] = field(default_factory=list) properties: List[Property] = field(default_factory=list)
dir: pathlib.Path = pathlib.Path() dir: pathlib.Path = pathlib.Path()
imports: List[str] = field(default_factory=list) imports: List[str] = field(default_factory=list)
@@ -74,7 +74,7 @@ class QlClass:
@dataclass @dataclass
class QlStub: class Stub:
template: ClassVar = 'ql_stub' template: ClassVar = 'ql_stub'
name: str name: str
@@ -82,7 +82,7 @@ class QlStub:
@dataclass @dataclass
class QlImportList: class ImportList:
template: ClassVar = 'ql_imports' template: ClassVar = 'ql_imports'
imports: List[str] = field(default_factory=list) imports: List[str] = field(default_factory=list)

View File

@@ -12,32 +12,32 @@ log = logging.getLogger(__name__)
def get_ql_property(cls: schema.Class, prop: schema.Property): def get_ql_property(cls: schema.Class, prop: schema.Property):
if prop.is_single: if prop.is_single:
return ql.QlProperty( return ql.Property(
singular=inflection.camelize(prop.name), singular=inflection.camelize(prop.name),
type=prop.type, type=prop.type,
tablename=inflection.tableize(cls.name), tablename=inflection.tableize(cls.name),
tableparams=["this"] + ["result" if p is prop else "_" for p in cls.properties if p.is_single], tableparams=["this"] + ["result" if p is prop else "_" for p in cls.properties if p.is_single],
) )
elif prop.is_optional: elif prop.is_optional:
return ql.QlProperty( return ql.Property(
singular=inflection.camelize(prop.name), singular=inflection.camelize(prop.name),
type=prop.type, type=prop.type,
tablename=inflection.tableize(f"{cls.name}_{prop.name}"), tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
tableparams=["this", "result"], tableparams=["this", "result"],
) )
elif prop.is_repeated: elif prop.is_repeated:
return ql.QlProperty( return ql.Property(
singular=inflection.singularize(inflection.camelize(prop.name)), singular=inflection.singularize(inflection.camelize(prop.name)),
plural=inflection.pluralize(inflection.camelize(prop.name)), plural=inflection.pluralize(inflection.camelize(prop.name)),
type=prop.type, type=prop.type,
tablename=inflection.tableize(f"{cls.name}_{prop.name}"), tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
tableparams=["this", "index", "result"], tableparams=["this", "index", "result"],
params=[ql.QlParam("index", type="int")], params=[ql.Param("index", type="int")],
) )
def get_ql_class(cls: schema.Class): def get_ql_class(cls: schema.Class):
return ql.QlClass( return ql.Class(
name=cls.name, name=cls.name,
bases=cls.bases, bases=cls.bases,
final=not cls.derived, final=not cls.derived,
@@ -51,7 +51,7 @@ def get_import(file):
return str(stem).replace("/", ".") return str(stem).replace("/", ".")
def get_types_used_by(cls: ql.QlClass): def get_types_used_by(cls: ql.Class):
for b in cls.bases: for b in cls.bases:
yield b yield b
for p in cls.properties: for p in cls.properties:
@@ -60,7 +60,7 @@ def get_types_used_by(cls: ql.QlClass):
yield param.type yield param.type
def get_classes_used_by(cls: ql.QlClass): def get_classes_used_by(cls: ql.Class):
return sorted(set(t for t in get_types_used_by(cls) if t[0].isupper())) return sorted(set(t for t in get_types_used_by(cls) if t[0].isupper()))
@@ -98,12 +98,12 @@ def generate(opts, renderer):
renderer.render(c, qll) renderer.render(c, qll)
stub_file = (stub_out / c.path).with_suffix(".qll") stub_file = (stub_out / c.path).with_suffix(".qll")
if not stub_file.is_file() or is_generated(stub_file): if not stub_file.is_file() or is_generated(stub_file):
stub = ql.QlStub(name=c.name, base_import=get_import(qll)) stub = ql.Stub(name=c.name, base_import=get_import(qll))
renderer.render(stub, stub_file) renderer.render(stub, stub_file)
# for example path/to/syntax/generated -> path/to/syntax.qll # for example path/to/syntax/generated -> path/to/syntax.qll
include_file = stub_out.with_suffix(".qll") include_file = stub_out.with_suffix(".qll")
all_imports = ql.QlImportList([v for _, v in sorted(imports.items())]) all_imports = ql.ImportList([v for _, v in sorted(imports.items())])
renderer.render(all_imports, include_file) renderer.render(all_imports, include_file)
renderer.cleanup(existing) renderer.cleanup(existing)

View File

@@ -0,0 +1,52 @@
import sys
from copy import deepcopy
from swift.codegen.lib import dbscheme
from swift.codegen.test.utils import *
def test_dbcolumn_name():
assert dbscheme.Column("foo", "some_type").name == "foo"
@pytest.mark.parametrize("keyword", dbscheme.dbscheme_keywords)
def test_dbcolumn_keyword_name(keyword):
assert dbscheme.Column(keyword, "some_type").name == keyword + "_"
@pytest.mark.parametrize("type,binding,lhstype,rhstype", [
("builtin_type", False, "builtin_type", "builtin_type ref"),
("builtin_type", True, "builtin_type", "builtin_type ref"),
("@at_type", False, "int", "@at_type ref"),
("@at_type", True, "unique int", "@at_type"),
])
def test_dbcolumn_types(type, binding, lhstype, rhstype):
col = dbscheme.Column("foo", type, binding)
assert col.lhstype == lhstype
assert col.rhstype == rhstype
def test_keyset_has_first_id_marked():
ids = ["a", "b", "c"]
ks = dbscheme.KeySet(ids)
assert ks.ids[0].first
assert [id.id for id in ks.ids] == ids
def test_table_has_first_column_marked():
columns = [dbscheme.Column("a", "x"), dbscheme.Column("b", "y", binding=True), dbscheme.Column("c", "z")]
expected = deepcopy(columns)
table = dbscheme.Table("foo", columns)
expected[0].first = True
assert table.columns == expected
def test_union_has_first_case_marked():
rhs = ["a", "b", "c"]
u = dbscheme.Union(lhs="x", rhs=rhs)
assert u.rhs[0].first
assert [c.type for c in u.rhs] == rhs
if __name__ == '__main__':
sys.exit(pytest.main())

View File

@@ -1,10 +1,10 @@
import pathlib
import sys import sys
from swift.codegen import dbschemegen from swift.codegen import dbschemegen
from swift.codegen.lib import dbscheme, paths from swift.codegen.lib import dbscheme
from swift.codegen.test.utils import * from swift.codegen.test.utils import *
def generate(opts, renderer): def generate(opts, renderer):
(out, data), = run_generation(dbschemegen.generate, opts, renderer).items() (out, data), = run_generation(dbschemegen.generate, opts, renderer).items()
assert out is opts.dbscheme assert out is opts.dbscheme
@@ -12,7 +12,7 @@ def generate(opts, renderer):
def test_empty(opts, input, renderer): def test_empty(opts, input, renderer):
assert generate(opts, renderer) == dbscheme.DbScheme( assert generate(opts, renderer) == dbscheme.Scheme(
src=schema_file, src=schema_file,
includes=[], includes=[],
declarations=[], declarations=[],
@@ -25,10 +25,10 @@ def test_includes(opts, input, renderer):
for i in includes: for i in includes:
write(opts.schema.parent / i, i + " data") write(opts.schema.parent / i, i + " data")
assert generate(opts, renderer) == dbscheme.DbScheme( assert generate(opts, renderer) == dbscheme.Scheme(
src=schema_file, src=schema_file,
includes=[ includes=[
dbscheme.DbSchemeInclude( dbscheme.SchemeInclude(
src=schema_dir / i, src=schema_dir / i,
data=i + " data", data=i + " data",
) for i in includes ) for i in includes
@@ -41,14 +41,14 @@ def test_empty_final_class(opts, input, renderer):
input.classes = [ input.classes = [
schema.Class("Object"), schema.Class("Object"),
] ]
assert generate(opts, renderer) == dbscheme.DbScheme( assert generate(opts, renderer) == dbscheme.Scheme(
src=schema_file, src=schema_file,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.DbTable( dbscheme.Table(
name="objects", name="objects",
columns=[ columns=[
dbscheme.DbColumn('id', '@object', binding=True), dbscheme.Column('id', '@object', binding=True),
] ]
) )
], ],
@@ -62,15 +62,15 @@ def test_final_class_with_single_scalar_field(opts, input, renderer):
schema.SingleProperty("foo", "bar"), schema.SingleProperty("foo", "bar"),
]), ]),
] ]
assert generate(opts, renderer) == dbscheme.DbScheme( assert generate(opts, renderer) == dbscheme.Scheme(
src=schema_file, src=schema_file,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.DbTable( dbscheme.Table(
name="objects", name="objects",
columns=[ columns=[
dbscheme.DbColumn('id', '@object', binding=True), dbscheme.Column('id', '@object', binding=True),
dbscheme.DbColumn('foo', 'bar'), dbscheme.Column('foo', 'bar'),
] ]
) )
], ],
@@ -83,15 +83,15 @@ def test_final_class_with_single_class_field(opts, input, renderer):
schema.SingleProperty("foo", "Bar"), schema.SingleProperty("foo", "Bar"),
]), ]),
] ]
assert generate(opts, renderer) == dbscheme.DbScheme( assert generate(opts, renderer) == dbscheme.Scheme(
src=schema_file, src=schema_file,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.DbTable( dbscheme.Table(
name="objects", name="objects",
columns=[ columns=[
dbscheme.DbColumn('id', '@object', binding=True), dbscheme.Column('id', '@object', binding=True),
dbscheme.DbColumn('foo', '@bar'), dbscheme.Column('foo', '@bar'),
] ]
) )
], ],
@@ -104,22 +104,22 @@ def test_final_class_with_optional_field(opts, input, renderer):
schema.OptionalProperty("foo", "bar"), schema.OptionalProperty("foo", "bar"),
]), ]),
] ]
assert generate(opts, renderer) == dbscheme.DbScheme( assert generate(opts, renderer) == dbscheme.Scheme(
src=schema_file, src=schema_file,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.DbTable( dbscheme.Table(
name="objects", name="objects",
columns=[ columns=[
dbscheme.DbColumn('id', '@object', binding=True), dbscheme.Column('id', '@object', binding=True),
] ]
), ),
dbscheme.DbTable( dbscheme.Table(
name="object_foos", name="object_foos",
keyset=dbscheme.DbKeySet(["id"]), keyset=dbscheme.KeySet(["id"]),
columns=[ columns=[
dbscheme.DbColumn('id', '@object'), dbscheme.Column('id', '@object'),
dbscheme.DbColumn('foo', 'bar'), dbscheme.Column('foo', 'bar'),
] ]
), ),
], ],
@@ -132,23 +132,23 @@ def test_final_class_with_repeated_field(opts, input, renderer):
schema.RepeatedProperty("foo", "bar"), schema.RepeatedProperty("foo", "bar"),
]), ]),
] ]
assert generate(opts, renderer) == dbscheme.DbScheme( assert generate(opts, renderer) == dbscheme.Scheme(
src=schema_file, src=schema_file,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.DbTable( dbscheme.Table(
name="objects", name="objects",
columns=[ columns=[
dbscheme.DbColumn('id', '@object', binding=True), dbscheme.Column('id', '@object', binding=True),
] ]
), ),
dbscheme.DbTable( dbscheme.Table(
name="object_foos", name="object_foos",
keyset=dbscheme.DbKeySet(["id", "index"]), keyset=dbscheme.KeySet(["id", "index"]),
columns=[ columns=[
dbscheme.DbColumn('id', '@object'), dbscheme.Column('id', '@object'),
dbscheme.DbColumn('index', 'int'), dbscheme.Column('index', 'int'),
dbscheme.DbColumn('foo', 'bar'), dbscheme.Column('foo', 'bar'),
] ]
), ),
], ],
@@ -164,33 +164,33 @@ def test_final_class_with_more_fields(opts, input, renderer):
schema.RepeatedProperty("four", "w"), schema.RepeatedProperty("four", "w"),
]), ]),
] ]
assert generate(opts, renderer) == dbscheme.DbScheme( assert generate(opts, renderer) == dbscheme.Scheme(
src=schema_file, src=schema_file,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.DbTable( dbscheme.Table(
name="objects", name="objects",
columns=[ columns=[
dbscheme.DbColumn('id', '@object', binding=True), dbscheme.Column('id', '@object', binding=True),
dbscheme.DbColumn('one', 'x'), dbscheme.Column('one', 'x'),
dbscheme.DbColumn('two', 'y'), dbscheme.Column('two', 'y'),
] ]
), ),
dbscheme.DbTable( dbscheme.Table(
name="object_threes", name="object_threes",
keyset=dbscheme.DbKeySet(["id"]), keyset=dbscheme.KeySet(["id"]),
columns=[ columns=[
dbscheme.DbColumn('id', '@object'), dbscheme.Column('id', '@object'),
dbscheme.DbColumn('three', 'z'), dbscheme.Column('three', 'z'),
] ]
), ),
dbscheme.DbTable( dbscheme.Table(
name="object_fours", name="object_fours",
keyset=dbscheme.DbKeySet(["id", "index"]), keyset=dbscheme.KeySet(["id", "index"]),
columns=[ columns=[
dbscheme.DbColumn('id', '@object'), dbscheme.Column('id', '@object'),
dbscheme.DbColumn('index', 'int'), dbscheme.Column('index', 'int'),
dbscheme.DbColumn('four', 'w'), dbscheme.Column('four', 'w'),
] ]
), ),
], ],
@@ -203,11 +203,11 @@ def test_empty_class_with_derived(opts, input, renderer):
name="Base", name="Base",
derived={"Left", "Right"}), derived={"Left", "Right"}),
] ]
assert generate(opts, renderer) == dbscheme.DbScheme( assert generate(opts, renderer) == dbscheme.Scheme(
src=schema_file, src=schema_file,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.DbUnion( dbscheme.Union(
lhs="@base", lhs="@base",
rhs=["@left", "@right"], rhs=["@left", "@right"],
), ),
@@ -224,20 +224,20 @@ def test_class_with_derived_and_single_property(opts, input, renderer):
schema.SingleProperty("single", "Prop"), schema.SingleProperty("single", "Prop"),
]), ]),
] ]
assert generate(opts, renderer) == dbscheme.DbScheme( assert generate(opts, renderer) == dbscheme.Scheme(
src=schema_file, src=schema_file,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.DbUnion( dbscheme.Union(
lhs="@base", lhs="@base",
rhs=["@left", "@right"], rhs=["@left", "@right"],
), ),
dbscheme.DbTable( dbscheme.Table(
name="bases", name="bases",
keyset=dbscheme.DbKeySet(["id"]), keyset=dbscheme.KeySet(["id"]),
columns=[ columns=[
dbscheme.DbColumn('id', '@base'), dbscheme.Column('id', '@base'),
dbscheme.DbColumn('single', '@prop'), dbscheme.Column('single', '@prop'),
] ]
) )
], ],
@@ -253,20 +253,20 @@ def test_class_with_derived_and_optional_property(opts, input, renderer):
schema.OptionalProperty("opt", "Prop"), schema.OptionalProperty("opt", "Prop"),
]), ]),
] ]
assert generate(opts, renderer) == dbscheme.DbScheme( assert generate(opts, renderer) == dbscheme.Scheme(
src=schema_file, src=schema_file,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.DbUnion( dbscheme.Union(
lhs="@base", lhs="@base",
rhs=["@left", "@right"], rhs=["@left", "@right"],
), ),
dbscheme.DbTable( dbscheme.Table(
name="base_opts", name="base_opts",
keyset=dbscheme.DbKeySet(["id"]), keyset=dbscheme.KeySet(["id"]),
columns=[ columns=[
dbscheme.DbColumn('id', '@base'), dbscheme.Column('id', '@base'),
dbscheme.DbColumn('opt', '@prop'), dbscheme.Column('opt', '@prop'),
] ]
) )
], ],
@@ -282,47 +282,26 @@ def test_class_with_derived_and_repeated_property(opts, input, renderer):
schema.RepeatedProperty("rep", "Prop"), schema.RepeatedProperty("rep", "Prop"),
]), ]),
] ]
assert generate(opts, renderer) == dbscheme.DbScheme( assert generate(opts, renderer) == dbscheme.Scheme(
src=schema_file, src=schema_file,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.DbUnion( dbscheme.Union(
lhs="@base", lhs="@base",
rhs=["@left", "@right"], rhs=["@left", "@right"],
), ),
dbscheme.DbTable( dbscheme.Table(
name="base_reps", name="base_reps",
keyset=dbscheme.DbKeySet(["id", "index"]), keyset=dbscheme.KeySet(["id", "index"]),
columns=[ columns=[
dbscheme.DbColumn('id', '@base'), dbscheme.Column('id', '@base'),
dbscheme.DbColumn('index', 'int'), dbscheme.Column('index', 'int'),
dbscheme.DbColumn('rep', '@prop'), dbscheme.Column('rep', '@prop'),
] ]
) )
], ],
) )
def test_dbcolumn_name():
assert dbscheme.DbColumn("foo", "some_type").name == "foo"
@pytest.mark.parametrize("keyword", dbscheme.dbscheme_keywords)
def test_dbcolumn_keyword_name(keyword):
assert dbscheme.DbColumn(keyword, "some_type").name == keyword + "_"
@pytest.mark.parametrize("type,binding,lhstype,rhstype", [
("builtin_type", False, "builtin_type", "builtin_type ref"),
("builtin_type", True, "builtin_type", "builtin_type ref"),
("@at_type", False, "int", "@at_type ref"),
("@at_type", True, "unique int", "@at_type"),
])
def test_dbcolumn_types(type, binding, lhstype, rhstype):
col = dbscheme.DbColumn("foo", type, binding)
assert col.lhstype == lhstype
assert col.rhstype == rhstype
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(pytest.main()) sys.exit(pytest.main())

View File

@@ -0,0 +1,97 @@
import sys
from copy import deepcopy
from swift.codegen.lib import ql
from swift.codegen.test.utils import *
def test_property_has_first_param_marked():
params = [ql.Param("a", "x"), ql.Param("b", "y"), ql.Param("c", "z")]
expected = deepcopy(params)
expected[0].first = True
prop = ql.Property("Prop", "foo", "props", ["this"], params=params)
assert prop.params == expected
def test_property_has_first_table_param_marked():
tableparams = ["a", "b", "c"]
prop = ql.Property("Prop", "foo", "props", tableparams)
assert prop.tableparams[0].first
assert [p.param for p in prop.tableparams] == tableparams
assert all(p.type is None for p in prop.tableparams)
@pytest.mark.parametrize("params,expected_local_var", [
(["a", "b", "c"], "x"),
(["a", "x", "c"], "x_"),
(["a", "x", "x_", "c"], "x__"),
(["a", "x", "x_", "x__"], "x___"),
])
def test_property_local_var_avoids_params_collision(params, expected_local_var):
prop = ql.Property("Prop", "foo", "props", ["this"], params=[ql.Param(p) for p in params])
assert prop.local_var == expected_local_var
def test_property_not_a_class():
tableparams = ["x", "result", "y"]
prop = ql.Property("Prop", "foo", "props", tableparams)
assert not prop.type_is_class
assert [p.param for p in prop.tableparams] == tableparams
def test_property_is_a_class():
tableparams = ["x", "result", "y"]
prop = ql.Property("Prop", "Foo", "props", tableparams)
assert prop.type_is_class
assert [p.param for p in prop.tableparams] == ["x", prop.local_var, "y"]
@pytest.mark.parametrize("name,expected_article", [
("Argument", "An"),
("Element", "An"),
("Integer", "An"),
("Operator", "An"),
("Unit", "A"),
("Whatever", "A"),
])
def test_property_indefinite_article(name, expected_article):
prop = ql.Property(name, "Foo", "props", ["x"], plural="X")
assert prop.indefinite_article == expected_article
def test_property_no_plural_no_indefinite_article():
prop = ql.Property("Prop", "Foo", "props", ["x"])
assert prop.indefinite_article is None
def test_class_sorts_bases():
bases = ["B", "Ab", "C", "Aa"]
expected = ["Aa", "Ab", "B", "C"]
cls = ql.Class("Foo", bases=bases)
assert cls.bases == expected
def test_class_has_first_property_marked():
props = [
ql.Property(f"Prop{x}", f"Foo{x}", f"props{x}", [f"{x}"]) for x in range(4)
]
expected = deepcopy(props)
expected[0].first = True
cls = ql.Class("Class", properties=props)
assert cls.properties == expected
def test_class_db_id():
cls = ql.Class("ThisIsMyClass")
assert cls.db_id == "@this_is_my_class"
def test_root_class():
cls = ql.Class("Class")
assert cls.root
def test_non_root_class():
cls = ql.Class("Class", bases=["A"])
assert not cls.root
if __name__ == '__main__':
sys.exit(pytest.main())

View File

@@ -12,12 +12,13 @@ def run_mock():
yield ret yield ret
# these are lambdas so that they will use patched paths when called
stub_path = lambda: paths.swift_dir / "ql/lib/stub/path" stub_path = lambda: paths.swift_dir / "ql/lib/stub/path"
ql_output_path = lambda: paths.swift_dir / "ql/lib/other/path" ql_output_path = lambda: paths.swift_dir / "ql/lib/other/path"
import_file = lambda: stub_path().with_suffix(".qll") import_file = lambda: stub_path().with_suffix(".qll")
stub_import_prefix = "stub.path." stub_import_prefix = "stub.path."
gen_import_prefix = "other.path." gen_import_prefix = "other.path."
index_param = ql.QlParam("index", "int") index_param = ql.Param("index", "int")
def generate(opts, renderer, written=None): def generate(opts, renderer, written=None):
@@ -29,7 +30,7 @@ def generate(opts, renderer, written=None):
def test_empty(opts, input, renderer): def test_empty(opts, input, renderer):
assert generate(opts, renderer) == { assert generate(opts, renderer) == {
import_file(): ql.QlImportList() import_file(): ql.ImportList()
} }
@@ -38,9 +39,9 @@ def test_one_empty_class(opts, input, renderer):
schema.Class("A") schema.Class("A")
] ]
assert generate(opts, renderer) == { assert generate(opts, renderer) == {
import_file(): ql.QlImportList([stub_import_prefix + "A"]), import_file(): ql.ImportList([stub_import_prefix + "A"]),
stub_path() / "A.qll": ql.QlStub(name="A", base_import=gen_import_prefix + "A"), stub_path() / "A.qll": ql.Stub(name="A", base_import=gen_import_prefix + "A"),
ql_output_path() / "A.qll": ql.QlClass(name="A", final=True), ql_output_path() / "A.qll": ql.Class(name="A", final=True),
} }
@@ -52,16 +53,16 @@ def test_hierarchy(opts, input, renderer):
schema.Class("A", derived={"B", "C"}), schema.Class("A", derived={"B", "C"}),
] ]
assert generate(opts, renderer) == { assert generate(opts, renderer) == {
import_file(): ql.QlImportList([stub_import_prefix + cls for cls in "ABCD"]), import_file(): ql.ImportList([stub_import_prefix + cls for cls in "ABCD"]),
stub_path() / "A.qll": ql.QlStub(name="A", base_import=gen_import_prefix + "A"), stub_path() / "A.qll": ql.Stub(name="A", base_import=gen_import_prefix + "A"),
stub_path() / "B.qll": ql.QlStub(name="B", base_import=gen_import_prefix + "B"), stub_path() / "B.qll": ql.Stub(name="B", base_import=gen_import_prefix + "B"),
stub_path() / "C.qll": ql.QlStub(name="C", base_import=gen_import_prefix + "C"), stub_path() / "C.qll": ql.Stub(name="C", base_import=gen_import_prefix + "C"),
stub_path() / "D.qll": ql.QlStub(name="D", base_import=gen_import_prefix + "D"), stub_path() / "D.qll": ql.Stub(name="D", base_import=gen_import_prefix + "D"),
ql_output_path() / "A.qll": ql.QlClass(name="A"), ql_output_path() / "A.qll": ql.Class(name="A"),
ql_output_path() / "B.qll": ql.QlClass(name="B", bases=["A"], imports=[stub_import_prefix + "A"]), ql_output_path() / "B.qll": ql.Class(name="B", bases=["A"], imports=[stub_import_prefix + "A"]),
ql_output_path() / "C.qll": ql.QlClass(name="C", bases=["A"], imports=[stub_import_prefix + "A"]), ql_output_path() / "C.qll": ql.Class(name="C", bases=["A"], imports=[stub_import_prefix + "A"]),
ql_output_path() / "D.qll": ql.QlClass(name="D", final=True, bases=["B", "C"], ql_output_path() / "D.qll": ql.Class(name="D", final=True, bases=["B", "C"],
imports=[stub_import_prefix + cls for cls in "BC"]), imports=[stub_import_prefix + cls for cls in "BC"]),
} }
@@ -71,10 +72,10 @@ def test_single_property(opts, input, renderer):
schema.Class("MyObject", properties=[schema.SingleProperty("foo", "bar")]), schema.Class("MyObject", properties=[schema.SingleProperty("foo", "bar")]),
] ]
assert generate(opts, renderer) == { assert generate(opts, renderer) == {
import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]), import_file(): ql.ImportList([stub_import_prefix + "MyObject"]),
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"), stub_path() / "MyObject.qll": ql.Stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[ ql_output_path() / "MyObject.qll": ql.Class(name="MyObject", final=True, properties=[
ql.QlProperty(singular="Foo", type="bar", tablename="my_objects", tableparams=["this", "result"]), ql.Property(singular="Foo", type="bar", tablename="my_objects", tableparams=["this", "result"]),
]) ])
} }
@@ -88,12 +89,12 @@ def test_single_properties(opts, input, renderer):
]), ]),
] ]
assert generate(opts, renderer) == { assert generate(opts, renderer) == {
import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]), import_file(): ql.ImportList([stub_import_prefix + "MyObject"]),
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"), stub_path() / "MyObject.qll": ql.Stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[ ql_output_path() / "MyObject.qll": ql.Class(name="MyObject", final=True, properties=[
ql.QlProperty(singular="One", type="x", tablename="my_objects", tableparams=["this", "result", "_", "_"]), ql.Property(singular="One", type="x", tablename="my_objects", tableparams=["this", "result", "_", "_"]),
ql.QlProperty(singular="Two", type="y", tablename="my_objects", tableparams=["this", "_", "result", "_"]), ql.Property(singular="Two", type="y", tablename="my_objects", tableparams=["this", "_", "result", "_"]),
ql.QlProperty(singular="Three", type="z", tablename="my_objects", tableparams=["this", "_", "_", "result"]), ql.Property(singular="Three", type="z", tablename="my_objects", tableparams=["this", "_", "_", "result"]),
]) ])
} }
@@ -103,10 +104,10 @@ def test_optional_property(opts, input, renderer):
schema.Class("MyObject", properties=[schema.OptionalProperty("foo", "bar")]), schema.Class("MyObject", properties=[schema.OptionalProperty("foo", "bar")]),
] ]
assert generate(opts, renderer) == { assert generate(opts, renderer) == {
import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]), import_file(): ql.ImportList([stub_import_prefix + "MyObject"]),
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"), stub_path() / "MyObject.qll": ql.Stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[ ql_output_path() / "MyObject.qll": ql.Class(name="MyObject", final=True, properties=[
ql.QlProperty(singular="Foo", type="bar", tablename="my_object_foos", tableparams=["this", "result"]), ql.Property(singular="Foo", type="bar", tablename="my_object_foos", tableparams=["this", "result"]),
]) ])
} }
@@ -116,11 +117,11 @@ def test_repeated_property(opts, input, renderer):
schema.Class("MyObject", properties=[schema.RepeatedProperty("foo", "bar")]), schema.Class("MyObject", properties=[schema.RepeatedProperty("foo", "bar")]),
] ]
assert generate(opts, renderer) == { assert generate(opts, renderer) == {
import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]), import_file(): ql.ImportList([stub_import_prefix + "MyObject"]),
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"), stub_path() / "MyObject.qll": ql.Stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[ ql_output_path() / "MyObject.qll": ql.Class(name="MyObject", final=True, properties=[
ql.QlProperty(singular="Foo", plural="Foos", type="bar", tablename="my_object_foos", params=[index_param], ql.Property(singular="Foo", plural="Foos", type="bar", tablename="my_object_foos", params=[index_param],
tableparams=["this", "index", "result"]), tableparams=["this", "index", "result"]),
]) ])
} }
@@ -131,15 +132,15 @@ def test_single_class_property(opts, input, renderer):
schema.Class("Bar"), schema.Class("Bar"),
] ]
assert generate(opts, renderer) == { assert generate(opts, renderer) == {
import_file(): ql.QlImportList([stub_import_prefix + cls for cls in ("Bar", "MyObject")]), import_file(): ql.ImportList([stub_import_prefix + cls for cls in ("Bar", "MyObject")]),
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"), stub_path() / "MyObject.qll": ql.Stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
stub_path() / "Bar.qll": ql.QlStub(name="Bar", base_import=gen_import_prefix + "Bar"), stub_path() / "Bar.qll": ql.Stub(name="Bar", base_import=gen_import_prefix + "Bar"),
ql_output_path() / "MyObject.qll": ql.QlClass( ql_output_path() / "MyObject.qll": ql.Class(
name="MyObject", final=True, imports=[stub_import_prefix + "Bar"], properties=[ name="MyObject", final=True, imports=[stub_import_prefix + "Bar"], properties=[
ql.QlProperty(singular="Foo", type="Bar", tablename="my_objects", tableparams=["this", "result"]), ql.Property(singular="Foo", type="Bar", tablename="my_objects", tableparams=["this", "result"]),
], ],
), ),
ql_output_path() / "Bar.qll": ql.QlClass(name="Bar", final=True) ql_output_path() / "Bar.qll": ql.Class(name="Bar", final=True)
} }
@@ -150,15 +151,15 @@ def test_class_dir(opts, input, renderer):
schema.Class("B", bases={"A"}), schema.Class("B", bases={"A"}),
] ]
assert generate(opts, renderer) == { assert generate(opts, renderer) == {
import_file(): ql.QlImportList([ import_file(): ql.ImportList([
stub_import_prefix + "another.rel.path.A", stub_import_prefix + "another.rel.path.A",
stub_import_prefix + "B", stub_import_prefix + "B",
]), ]),
stub_path() / dir / "A.qll": ql.QlStub(name="A", base_import=gen_import_prefix + "another.rel.path.A"), stub_path() / dir / "A.qll": ql.Stub(name="A", base_import=gen_import_prefix + "another.rel.path.A"),
stub_path() / "B.qll": ql.QlStub(name="B", base_import=gen_import_prefix + "B"), stub_path() / "B.qll": ql.Stub(name="B", base_import=gen_import_prefix + "B"),
ql_output_path() / dir / "A.qll": ql.QlClass(name="A", dir=dir), ql_output_path() / dir / "A.qll": ql.Class(name="A", dir=dir),
ql_output_path() / "B.qll": ql.QlClass(name="B", final=True, bases=["A"], ql_output_path() / "B.qll": ql.Class(name="B", final=True, bases=["A"],
imports=[stub_import_prefix + "another.rel.path.A"]) imports=[stub_import_prefix + "another.rel.path.A"])
} }