mirror of
https://github.com/github/codeql.git
synced 2025-12-16 08:43:11 +01:00
Swift: testing non-trivial dataclass properties
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -20,6 +20,9 @@
|
||||
# python virtual environment folder
|
||||
.venv/
|
||||
|
||||
# binary files created by pytest-cov
|
||||
.coverage
|
||||
|
||||
# It's useful (though not required) to be able to unpack codeql in the ql checkout itself
|
||||
/codeql/
|
||||
|
||||
|
||||
Binary file not shown.
@@ -19,42 +19,42 @@ def dbtype(typename):
|
||||
def cls_to_dbscheme(cls: schema.Class):
|
||||
""" Yield all dbscheme entities needed to model class `cls` """
|
||||
if cls.derived:
|
||||
yield DbUnion(dbtype(cls.name), (dbtype(c) for c in cls.derived))
|
||||
yield Union(dbtype(cls.name), (dbtype(c) for c in cls.derived))
|
||||
# output a table specific to a class only if it is a leaf class or it has 1-to-1 properties
|
||||
# Leaf classes need a table to bind the `@` ids
|
||||
# 1-to-1 properties are added to a class specific table
|
||||
# in other cases, separate tables are used for the properties, and a class specific table is unneeded
|
||||
if not cls.derived or any(f.is_single for f in cls.properties):
|
||||
binding = not cls.derived
|
||||
keyset = DbKeySet(["id"]) if cls.derived else None
|
||||
yield DbTable(
|
||||
keyset = KeySet(["id"]) if cls.derived else None
|
||||
yield Table(
|
||||
keyset=keyset,
|
||||
name=inflection.tableize(cls.name),
|
||||
columns=[
|
||||
DbColumn("id", type=dbtype(cls.name), binding=binding),
|
||||
Column("id", type=dbtype(cls.name), binding=binding),
|
||||
] + [
|
||||
DbColumn(f.name, dbtype(f.type)) for f in cls.properties if f.is_single
|
||||
Column(f.name, dbtype(f.type)) for f in cls.properties if f.is_single
|
||||
]
|
||||
)
|
||||
# use property-specific tables for 1-to-many and 1-to-at-most-1 properties
|
||||
for f in cls.properties:
|
||||
if f.is_optional:
|
||||
yield DbTable(
|
||||
keyset=DbKeySet(["id"]),
|
||||
yield Table(
|
||||
keyset=KeySet(["id"]),
|
||||
name=inflection.tableize(f"{cls.name}_{f.name}"),
|
||||
columns=[
|
||||
DbColumn("id", type=dbtype(cls.name)),
|
||||
DbColumn(f.name, dbtype(f.type)),
|
||||
Column("id", type=dbtype(cls.name)),
|
||||
Column(f.name, dbtype(f.type)),
|
||||
],
|
||||
)
|
||||
elif f.is_repeated:
|
||||
yield DbTable(
|
||||
keyset=DbKeySet(["id", "index"]),
|
||||
yield Table(
|
||||
keyset=KeySet(["id", "index"]),
|
||||
name=inflection.tableize(f"{cls.name}_{f.name}"),
|
||||
columns=[
|
||||
DbColumn("id", type=dbtype(cls.name)),
|
||||
DbColumn("index", type="int"),
|
||||
DbColumn(inflection.singularize(f.name), dbtype(f.type)),
|
||||
Column("id", type=dbtype(cls.name)),
|
||||
Column("index", type="int"),
|
||||
Column(inflection.singularize(f.name), dbtype(f.type)),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -68,7 +68,7 @@ def get_includes(data: schema.Schema, include_dir: pathlib.Path):
|
||||
for inc in data.includes:
|
||||
inc = include_dir / inc
|
||||
with open(inc) as inclusion:
|
||||
includes.append(DbSchemeInclude(src=inc.relative_to(paths.swift_dir), data=inclusion.read()))
|
||||
includes.append(SchemeInclude(src=inc.relative_to(paths.swift_dir), data=inclusion.read()))
|
||||
return includes
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ def generate(opts, renderer):
|
||||
|
||||
data = schema.load(input)
|
||||
|
||||
dbscheme = DbScheme(src=input.relative_to(paths.swift_dir),
|
||||
dbscheme = Scheme(src=input.relative_to(paths.swift_dir),
|
||||
includes=get_includes(data, include_dir=input.parent),
|
||||
declarations=get_declarations(data))
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ dbscheme_keywords = {"case", "boolean", "int", "string", "type"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DbColumn:
|
||||
class Column:
|
||||
schema_name: str
|
||||
type: str
|
||||
binding: bool = False
|
||||
@@ -36,33 +36,33 @@ class DbColumn:
|
||||
|
||||
|
||||
@dataclass
|
||||
class DbKeySetId:
|
||||
class KeySetId:
|
||||
id: str
|
||||
first: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class DbKeySet:
|
||||
ids: List[DbKeySetId]
|
||||
class KeySet:
|
||||
ids: List[KeySetId]
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.ids
|
||||
self.ids = [DbKeySetId(x) for x in self.ids]
|
||||
self.ids = [KeySetId(x) for x in self.ids]
|
||||
self.ids[0].first = True
|
||||
|
||||
|
||||
class DbDecl:
|
||||
class Decl:
|
||||
is_table = False
|
||||
is_union = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class DbTable(DbDecl):
|
||||
class Table(Decl):
|
||||
is_table: ClassVar = True
|
||||
|
||||
name: str
|
||||
columns: List[DbColumn]
|
||||
keyset: DbKeySet = None
|
||||
columns: List[Column]
|
||||
keyset: KeySet = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.columns:
|
||||
@@ -70,35 +70,35 @@ class DbTable(DbDecl):
|
||||
|
||||
|
||||
@dataclass
|
||||
class DbUnionCase:
|
||||
class UnionCase:
|
||||
type: str
|
||||
first: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class DbUnion(DbDecl):
|
||||
class Union(Decl):
|
||||
is_union: ClassVar = True
|
||||
|
||||
lhs: str
|
||||
rhs: List[DbUnionCase]
|
||||
rhs: List[UnionCase]
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.rhs
|
||||
self.rhs = [DbUnionCase(x) for x in self.rhs]
|
||||
self.rhs = [UnionCase(x) for x in self.rhs]
|
||||
self.rhs.sort(key=lambda c: c.type)
|
||||
self.rhs[0].first = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DbSchemeInclude:
|
||||
class SchemeInclude:
|
||||
src: str
|
||||
data: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DbScheme:
|
||||
class Scheme:
|
||||
template: ClassVar = 'dbscheme'
|
||||
|
||||
src: str
|
||||
includes: List[DbSchemeInclude]
|
||||
declarations: List[DbDecl]
|
||||
includes: List[SchemeInclude]
|
||||
declarations: List[Decl]
|
||||
|
||||
@@ -6,20 +6,20 @@ import inflection
|
||||
|
||||
|
||||
@dataclass
|
||||
class QlParam:
|
||||
class Param:
|
||||
param: str
|
||||
type: str = None
|
||||
first: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class QlProperty:
|
||||
class Property:
|
||||
singular: str
|
||||
type: str
|
||||
tablename: str
|
||||
tableparams: List[QlParam]
|
||||
tableparams: List[Param]
|
||||
plural: str = None
|
||||
params: List[QlParam] = field(default_factory=list)
|
||||
params: List[Param] = field(default_factory=list)
|
||||
first: bool = False
|
||||
local_var: str = "x"
|
||||
|
||||
@@ -31,7 +31,7 @@ class QlProperty:
|
||||
assert self.tableparams
|
||||
if self.type_is_class:
|
||||
self.tableparams = [x if x != "result" else self.local_var for x in self.tableparams]
|
||||
self.tableparams = [QlParam(x) for x in self.tableparams]
|
||||
self.tableparams = [Param(x) for x in self.tableparams]
|
||||
self.tableparams[0].first = True
|
||||
|
||||
@property
|
||||
@@ -45,13 +45,13 @@ class QlProperty:
|
||||
|
||||
|
||||
@dataclass
|
||||
class QlClass:
|
||||
class Class:
|
||||
template: ClassVar = 'ql_class'
|
||||
|
||||
name: str
|
||||
bases: List[str] = field(default_factory=list)
|
||||
final: bool = False
|
||||
properties: List[QlProperty] = field(default_factory=list)
|
||||
properties: List[Property] = field(default_factory=list)
|
||||
dir: pathlib.Path = pathlib.Path()
|
||||
imports: List[str] = field(default_factory=list)
|
||||
|
||||
@@ -74,7 +74,7 @@ class QlClass:
|
||||
|
||||
|
||||
@dataclass
|
||||
class QlStub:
|
||||
class Stub:
|
||||
template: ClassVar = 'ql_stub'
|
||||
|
||||
name: str
|
||||
@@ -82,7 +82,7 @@ class QlStub:
|
||||
|
||||
|
||||
@dataclass
|
||||
class QlImportList:
|
||||
class ImportList:
|
||||
template: ClassVar = 'ql_imports'
|
||||
|
||||
imports: List[str] = field(default_factory=list)
|
||||
|
||||
@@ -12,32 +12,32 @@ log = logging.getLogger(__name__)
|
||||
|
||||
def get_ql_property(cls: schema.Class, prop: schema.Property):
|
||||
if prop.is_single:
|
||||
return ql.QlProperty(
|
||||
return ql.Property(
|
||||
singular=inflection.camelize(prop.name),
|
||||
type=prop.type,
|
||||
tablename=inflection.tableize(cls.name),
|
||||
tableparams=["this"] + ["result" if p is prop else "_" for p in cls.properties if p.is_single],
|
||||
)
|
||||
elif prop.is_optional:
|
||||
return ql.QlProperty(
|
||||
return ql.Property(
|
||||
singular=inflection.camelize(prop.name),
|
||||
type=prop.type,
|
||||
tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
|
||||
tableparams=["this", "result"],
|
||||
)
|
||||
elif prop.is_repeated:
|
||||
return ql.QlProperty(
|
||||
return ql.Property(
|
||||
singular=inflection.singularize(inflection.camelize(prop.name)),
|
||||
plural=inflection.pluralize(inflection.camelize(prop.name)),
|
||||
type=prop.type,
|
||||
tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
|
||||
tableparams=["this", "index", "result"],
|
||||
params=[ql.QlParam("index", type="int")],
|
||||
params=[ql.Param("index", type="int")],
|
||||
)
|
||||
|
||||
|
||||
def get_ql_class(cls: schema.Class):
|
||||
return ql.QlClass(
|
||||
return ql.Class(
|
||||
name=cls.name,
|
||||
bases=cls.bases,
|
||||
final=not cls.derived,
|
||||
@@ -51,7 +51,7 @@ def get_import(file):
|
||||
return str(stem).replace("/", ".")
|
||||
|
||||
|
||||
def get_types_used_by(cls: ql.QlClass):
|
||||
def get_types_used_by(cls: ql.Class):
|
||||
for b in cls.bases:
|
||||
yield b
|
||||
for p in cls.properties:
|
||||
@@ -60,7 +60,7 @@ def get_types_used_by(cls: ql.QlClass):
|
||||
yield param.type
|
||||
|
||||
|
||||
def get_classes_used_by(cls: ql.QlClass):
|
||||
def get_classes_used_by(cls: ql.Class):
|
||||
return sorted(set(t for t in get_types_used_by(cls) if t[0].isupper()))
|
||||
|
||||
|
||||
@@ -98,12 +98,12 @@ def generate(opts, renderer):
|
||||
renderer.render(c, qll)
|
||||
stub_file = (stub_out / c.path).with_suffix(".qll")
|
||||
if not stub_file.is_file() or is_generated(stub_file):
|
||||
stub = ql.QlStub(name=c.name, base_import=get_import(qll))
|
||||
stub = ql.Stub(name=c.name, base_import=get_import(qll))
|
||||
renderer.render(stub, stub_file)
|
||||
|
||||
# for example path/to/syntax/generated -> path/to/syntax.qll
|
||||
include_file = stub_out.with_suffix(".qll")
|
||||
all_imports = ql.QlImportList([v for _, v in sorted(imports.items())])
|
||||
all_imports = ql.ImportList([v for _, v in sorted(imports.items())])
|
||||
renderer.render(all_imports, include_file)
|
||||
|
||||
renderer.cleanup(existing)
|
||||
|
||||
52
swift/codegen/test/test_dbscheme.py
Normal file
52
swift/codegen/test/test_dbscheme.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
|
||||
from swift.codegen.lib import dbscheme
|
||||
from swift.codegen.test.utils import *
|
||||
|
||||
|
||||
def test_dbcolumn_name():
|
||||
assert dbscheme.Column("foo", "some_type").name == "foo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("keyword", dbscheme.dbscheme_keywords)
|
||||
def test_dbcolumn_keyword_name(keyword):
|
||||
assert dbscheme.Column(keyword, "some_type").name == keyword + "_"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type,binding,lhstype,rhstype", [
|
||||
("builtin_type", False, "builtin_type", "builtin_type ref"),
|
||||
("builtin_type", True, "builtin_type", "builtin_type ref"),
|
||||
("@at_type", False, "int", "@at_type ref"),
|
||||
("@at_type", True, "unique int", "@at_type"),
|
||||
])
|
||||
def test_dbcolumn_types(type, binding, lhstype, rhstype):
|
||||
col = dbscheme.Column("foo", type, binding)
|
||||
assert col.lhstype == lhstype
|
||||
assert col.rhstype == rhstype
|
||||
|
||||
|
||||
def test_keyset_has_first_id_marked():
|
||||
ids = ["a", "b", "c"]
|
||||
ks = dbscheme.KeySet(ids)
|
||||
assert ks.ids[0].first
|
||||
assert [id.id for id in ks.ids] == ids
|
||||
|
||||
|
||||
def test_table_has_first_column_marked():
|
||||
columns = [dbscheme.Column("a", "x"), dbscheme.Column("b", "y", binding=True), dbscheme.Column("c", "z")]
|
||||
expected = deepcopy(columns)
|
||||
table = dbscheme.Table("foo", columns)
|
||||
expected[0].first = True
|
||||
assert table.columns == expected
|
||||
|
||||
|
||||
def test_union_has_first_case_marked():
|
||||
rhs = ["a", "b", "c"]
|
||||
u = dbscheme.Union(lhs="x", rhs=rhs)
|
||||
assert u.rhs[0].first
|
||||
assert [c.type for c in u.rhs] == rhs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(pytest.main())
|
||||
@@ -1,10 +1,10 @@
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
from swift.codegen import dbschemegen
|
||||
from swift.codegen.lib import dbscheme, paths
|
||||
from swift.codegen.lib import dbscheme
|
||||
from swift.codegen.test.utils import *
|
||||
|
||||
|
||||
def generate(opts, renderer):
|
||||
(out, data), = run_generation(dbschemegen.generate, opts, renderer).items()
|
||||
assert out is opts.dbscheme
|
||||
@@ -12,7 +12,7 @@ def generate(opts, renderer):
|
||||
|
||||
|
||||
def test_empty(opts, input, renderer):
|
||||
assert generate(opts, renderer) == dbscheme.DbScheme(
|
||||
assert generate(opts, renderer) == dbscheme.Scheme(
|
||||
src=schema_file,
|
||||
includes=[],
|
||||
declarations=[],
|
||||
@@ -25,10 +25,10 @@ def test_includes(opts, input, renderer):
|
||||
for i in includes:
|
||||
write(opts.schema.parent / i, i + " data")
|
||||
|
||||
assert generate(opts, renderer) == dbscheme.DbScheme(
|
||||
assert generate(opts, renderer) == dbscheme.Scheme(
|
||||
src=schema_file,
|
||||
includes=[
|
||||
dbscheme.DbSchemeInclude(
|
||||
dbscheme.SchemeInclude(
|
||||
src=schema_dir / i,
|
||||
data=i + " data",
|
||||
) for i in includes
|
||||
@@ -41,14 +41,14 @@ def test_empty_final_class(opts, input, renderer):
|
||||
input.classes = [
|
||||
schema.Class("Object"),
|
||||
]
|
||||
assert generate(opts, renderer) == dbscheme.DbScheme(
|
||||
assert generate(opts, renderer) == dbscheme.Scheme(
|
||||
src=schema_file,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.DbTable(
|
||||
dbscheme.Table(
|
||||
name="objects",
|
||||
columns=[
|
||||
dbscheme.DbColumn('id', '@object', binding=True),
|
||||
dbscheme.Column('id', '@object', binding=True),
|
||||
]
|
||||
)
|
||||
],
|
||||
@@ -62,15 +62,15 @@ def test_final_class_with_single_scalar_field(opts, input, renderer):
|
||||
schema.SingleProperty("foo", "bar"),
|
||||
]),
|
||||
]
|
||||
assert generate(opts, renderer) == dbscheme.DbScheme(
|
||||
assert generate(opts, renderer) == dbscheme.Scheme(
|
||||
src=schema_file,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.DbTable(
|
||||
dbscheme.Table(
|
||||
name="objects",
|
||||
columns=[
|
||||
dbscheme.DbColumn('id', '@object', binding=True),
|
||||
dbscheme.DbColumn('foo', 'bar'),
|
||||
dbscheme.Column('id', '@object', binding=True),
|
||||
dbscheme.Column('foo', 'bar'),
|
||||
]
|
||||
)
|
||||
],
|
||||
@@ -83,15 +83,15 @@ def test_final_class_with_single_class_field(opts, input, renderer):
|
||||
schema.SingleProperty("foo", "Bar"),
|
||||
]),
|
||||
]
|
||||
assert generate(opts, renderer) == dbscheme.DbScheme(
|
||||
assert generate(opts, renderer) == dbscheme.Scheme(
|
||||
src=schema_file,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.DbTable(
|
||||
dbscheme.Table(
|
||||
name="objects",
|
||||
columns=[
|
||||
dbscheme.DbColumn('id', '@object', binding=True),
|
||||
dbscheme.DbColumn('foo', '@bar'),
|
||||
dbscheme.Column('id', '@object', binding=True),
|
||||
dbscheme.Column('foo', '@bar'),
|
||||
]
|
||||
)
|
||||
],
|
||||
@@ -104,22 +104,22 @@ def test_final_class_with_optional_field(opts, input, renderer):
|
||||
schema.OptionalProperty("foo", "bar"),
|
||||
]),
|
||||
]
|
||||
assert generate(opts, renderer) == dbscheme.DbScheme(
|
||||
assert generate(opts, renderer) == dbscheme.Scheme(
|
||||
src=schema_file,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.DbTable(
|
||||
dbscheme.Table(
|
||||
name="objects",
|
||||
columns=[
|
||||
dbscheme.DbColumn('id', '@object', binding=True),
|
||||
dbscheme.Column('id', '@object', binding=True),
|
||||
]
|
||||
),
|
||||
dbscheme.DbTable(
|
||||
dbscheme.Table(
|
||||
name="object_foos",
|
||||
keyset=dbscheme.DbKeySet(["id"]),
|
||||
keyset=dbscheme.KeySet(["id"]),
|
||||
columns=[
|
||||
dbscheme.DbColumn('id', '@object'),
|
||||
dbscheme.DbColumn('foo', 'bar'),
|
||||
dbscheme.Column('id', '@object'),
|
||||
dbscheme.Column('foo', 'bar'),
|
||||
]
|
||||
),
|
||||
],
|
||||
@@ -132,23 +132,23 @@ def test_final_class_with_repeated_field(opts, input, renderer):
|
||||
schema.RepeatedProperty("foo", "bar"),
|
||||
]),
|
||||
]
|
||||
assert generate(opts, renderer) == dbscheme.DbScheme(
|
||||
assert generate(opts, renderer) == dbscheme.Scheme(
|
||||
src=schema_file,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.DbTable(
|
||||
dbscheme.Table(
|
||||
name="objects",
|
||||
columns=[
|
||||
dbscheme.DbColumn('id', '@object', binding=True),
|
||||
dbscheme.Column('id', '@object', binding=True),
|
||||
]
|
||||
),
|
||||
dbscheme.DbTable(
|
||||
dbscheme.Table(
|
||||
name="object_foos",
|
||||
keyset=dbscheme.DbKeySet(["id", "index"]),
|
||||
keyset=dbscheme.KeySet(["id", "index"]),
|
||||
columns=[
|
||||
dbscheme.DbColumn('id', '@object'),
|
||||
dbscheme.DbColumn('index', 'int'),
|
||||
dbscheme.DbColumn('foo', 'bar'),
|
||||
dbscheme.Column('id', '@object'),
|
||||
dbscheme.Column('index', 'int'),
|
||||
dbscheme.Column('foo', 'bar'),
|
||||
]
|
||||
),
|
||||
],
|
||||
@@ -164,33 +164,33 @@ def test_final_class_with_more_fields(opts, input, renderer):
|
||||
schema.RepeatedProperty("four", "w"),
|
||||
]),
|
||||
]
|
||||
assert generate(opts, renderer) == dbscheme.DbScheme(
|
||||
assert generate(opts, renderer) == dbscheme.Scheme(
|
||||
src=schema_file,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.DbTable(
|
||||
dbscheme.Table(
|
||||
name="objects",
|
||||
columns=[
|
||||
dbscheme.DbColumn('id', '@object', binding=True),
|
||||
dbscheme.DbColumn('one', 'x'),
|
||||
dbscheme.DbColumn('two', 'y'),
|
||||
dbscheme.Column('id', '@object', binding=True),
|
||||
dbscheme.Column('one', 'x'),
|
||||
dbscheme.Column('two', 'y'),
|
||||
]
|
||||
),
|
||||
dbscheme.DbTable(
|
||||
dbscheme.Table(
|
||||
name="object_threes",
|
||||
keyset=dbscheme.DbKeySet(["id"]),
|
||||
keyset=dbscheme.KeySet(["id"]),
|
||||
columns=[
|
||||
dbscheme.DbColumn('id', '@object'),
|
||||
dbscheme.DbColumn('three', 'z'),
|
||||
dbscheme.Column('id', '@object'),
|
||||
dbscheme.Column('three', 'z'),
|
||||
]
|
||||
),
|
||||
dbscheme.DbTable(
|
||||
dbscheme.Table(
|
||||
name="object_fours",
|
||||
keyset=dbscheme.DbKeySet(["id", "index"]),
|
||||
keyset=dbscheme.KeySet(["id", "index"]),
|
||||
columns=[
|
||||
dbscheme.DbColumn('id', '@object'),
|
||||
dbscheme.DbColumn('index', 'int'),
|
||||
dbscheme.DbColumn('four', 'w'),
|
||||
dbscheme.Column('id', '@object'),
|
||||
dbscheme.Column('index', 'int'),
|
||||
dbscheme.Column('four', 'w'),
|
||||
]
|
||||
),
|
||||
],
|
||||
@@ -203,11 +203,11 @@ def test_empty_class_with_derived(opts, input, renderer):
|
||||
name="Base",
|
||||
derived={"Left", "Right"}),
|
||||
]
|
||||
assert generate(opts, renderer) == dbscheme.DbScheme(
|
||||
assert generate(opts, renderer) == dbscheme.Scheme(
|
||||
src=schema_file,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.DbUnion(
|
||||
dbscheme.Union(
|
||||
lhs="@base",
|
||||
rhs=["@left", "@right"],
|
||||
),
|
||||
@@ -224,20 +224,20 @@ def test_class_with_derived_and_single_property(opts, input, renderer):
|
||||
schema.SingleProperty("single", "Prop"),
|
||||
]),
|
||||
]
|
||||
assert generate(opts, renderer) == dbscheme.DbScheme(
|
||||
assert generate(opts, renderer) == dbscheme.Scheme(
|
||||
src=schema_file,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.DbUnion(
|
||||
dbscheme.Union(
|
||||
lhs="@base",
|
||||
rhs=["@left", "@right"],
|
||||
),
|
||||
dbscheme.DbTable(
|
||||
dbscheme.Table(
|
||||
name="bases",
|
||||
keyset=dbscheme.DbKeySet(["id"]),
|
||||
keyset=dbscheme.KeySet(["id"]),
|
||||
columns=[
|
||||
dbscheme.DbColumn('id', '@base'),
|
||||
dbscheme.DbColumn('single', '@prop'),
|
||||
dbscheme.Column('id', '@base'),
|
||||
dbscheme.Column('single', '@prop'),
|
||||
]
|
||||
)
|
||||
],
|
||||
@@ -253,20 +253,20 @@ def test_class_with_derived_and_optional_property(opts, input, renderer):
|
||||
schema.OptionalProperty("opt", "Prop"),
|
||||
]),
|
||||
]
|
||||
assert generate(opts, renderer) == dbscheme.DbScheme(
|
||||
assert generate(opts, renderer) == dbscheme.Scheme(
|
||||
src=schema_file,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.DbUnion(
|
||||
dbscheme.Union(
|
||||
lhs="@base",
|
||||
rhs=["@left", "@right"],
|
||||
),
|
||||
dbscheme.DbTable(
|
||||
dbscheme.Table(
|
||||
name="base_opts",
|
||||
keyset=dbscheme.DbKeySet(["id"]),
|
||||
keyset=dbscheme.KeySet(["id"]),
|
||||
columns=[
|
||||
dbscheme.DbColumn('id', '@base'),
|
||||
dbscheme.DbColumn('opt', '@prop'),
|
||||
dbscheme.Column('id', '@base'),
|
||||
dbscheme.Column('opt', '@prop'),
|
||||
]
|
||||
)
|
||||
],
|
||||
@@ -282,47 +282,26 @@ def test_class_with_derived_and_repeated_property(opts, input, renderer):
|
||||
schema.RepeatedProperty("rep", "Prop"),
|
||||
]),
|
||||
]
|
||||
assert generate(opts, renderer) == dbscheme.DbScheme(
|
||||
assert generate(opts, renderer) == dbscheme.Scheme(
|
||||
src=schema_file,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.DbUnion(
|
||||
dbscheme.Union(
|
||||
lhs="@base",
|
||||
rhs=["@left", "@right"],
|
||||
),
|
||||
dbscheme.DbTable(
|
||||
dbscheme.Table(
|
||||
name="base_reps",
|
||||
keyset=dbscheme.DbKeySet(["id", "index"]),
|
||||
keyset=dbscheme.KeySet(["id", "index"]),
|
||||
columns=[
|
||||
dbscheme.DbColumn('id', '@base'),
|
||||
dbscheme.DbColumn('index', 'int'),
|
||||
dbscheme.DbColumn('rep', '@prop'),
|
||||
dbscheme.Column('id', '@base'),
|
||||
dbscheme.Column('index', 'int'),
|
||||
dbscheme.Column('rep', '@prop'),
|
||||
]
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_dbcolumn_name():
|
||||
assert dbscheme.DbColumn("foo", "some_type").name == "foo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("keyword", dbscheme.dbscheme_keywords)
|
||||
def test_dbcolumn_keyword_name(keyword):
|
||||
assert dbscheme.DbColumn(keyword, "some_type").name == keyword + "_"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type,binding,lhstype,rhstype", [
|
||||
("builtin_type", False, "builtin_type", "builtin_type ref"),
|
||||
("builtin_type", True, "builtin_type", "builtin_type ref"),
|
||||
("@at_type", False, "int", "@at_type ref"),
|
||||
("@at_type", True, "unique int", "@at_type"),
|
||||
])
|
||||
def test_dbcolumn_types(type, binding, lhstype, rhstype):
|
||||
col = dbscheme.DbColumn("foo", type, binding)
|
||||
assert col.lhstype == lhstype
|
||||
assert col.rhstype == rhstype
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(pytest.main())
|
||||
|
||||
97
swift/codegen/test/test_ql.py
Normal file
97
swift/codegen/test/test_ql.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
|
||||
from swift.codegen.lib import ql
|
||||
from swift.codegen.test.utils import *
|
||||
|
||||
|
||||
def test_property_has_first_param_marked():
|
||||
params = [ql.Param("a", "x"), ql.Param("b", "y"), ql.Param("c", "z")]
|
||||
expected = deepcopy(params)
|
||||
expected[0].first = True
|
||||
prop = ql.Property("Prop", "foo", "props", ["this"], params=params)
|
||||
assert prop.params == expected
|
||||
|
||||
|
||||
def test_property_has_first_table_param_marked():
|
||||
tableparams = ["a", "b", "c"]
|
||||
prop = ql.Property("Prop", "foo", "props", tableparams)
|
||||
assert prop.tableparams[0].first
|
||||
assert [p.param for p in prop.tableparams] == tableparams
|
||||
assert all(p.type is None for p in prop.tableparams)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("params,expected_local_var", [
|
||||
(["a", "b", "c"], "x"),
|
||||
(["a", "x", "c"], "x_"),
|
||||
(["a", "x", "x_", "c"], "x__"),
|
||||
(["a", "x", "x_", "x__"], "x___"),
|
||||
])
|
||||
def test_property_local_var_avoids_params_collision(params, expected_local_var):
|
||||
prop = ql.Property("Prop", "foo", "props", ["this"], params=[ql.Param(p) for p in params])
|
||||
assert prop.local_var == expected_local_var
|
||||
|
||||
|
||||
def test_property_not_a_class():
|
||||
tableparams = ["x", "result", "y"]
|
||||
prop = ql.Property("Prop", "foo", "props", tableparams)
|
||||
assert not prop.type_is_class
|
||||
assert [p.param for p in prop.tableparams] == tableparams
|
||||
|
||||
|
||||
def test_property_is_a_class():
|
||||
tableparams = ["x", "result", "y"]
|
||||
prop = ql.Property("Prop", "Foo", "props", tableparams)
|
||||
assert prop.type_is_class
|
||||
assert [p.param for p in prop.tableparams] == ["x", prop.local_var, "y"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name,expected_article", [
|
||||
("Argument", "An"),
|
||||
("Element", "An"),
|
||||
("Integer", "An"),
|
||||
("Operator", "An"),
|
||||
("Unit", "A"),
|
||||
("Whatever", "A"),
|
||||
])
|
||||
def test_property_indefinite_article(name, expected_article):
|
||||
prop = ql.Property(name, "Foo", "props", ["x"], plural="X")
|
||||
assert prop.indefinite_article == expected_article
|
||||
|
||||
|
||||
def test_property_no_plural_no_indefinite_article():
|
||||
prop = ql.Property("Prop", "Foo", "props", ["x"])
|
||||
assert prop.indefinite_article is None
|
||||
|
||||
|
||||
def test_class_sorts_bases():
|
||||
bases = ["B", "Ab", "C", "Aa"]
|
||||
expected = ["Aa", "Ab", "B", "C"]
|
||||
cls = ql.Class("Foo", bases=bases)
|
||||
assert cls.bases == expected
|
||||
|
||||
|
||||
def test_class_has_first_property_marked():
|
||||
props = [
|
||||
ql.Property(f"Prop{x}", f"Foo{x}", f"props{x}", [f"{x}"]) for x in range(4)
|
||||
]
|
||||
expected = deepcopy(props)
|
||||
expected[0].first = True
|
||||
cls = ql.Class("Class", properties=props)
|
||||
assert cls.properties == expected
|
||||
|
||||
|
||||
def test_class_db_id():
|
||||
cls = ql.Class("ThisIsMyClass")
|
||||
assert cls.db_id == "@this_is_my_class"
|
||||
|
||||
def test_root_class():
|
||||
cls = ql.Class("Class")
|
||||
assert cls.root
|
||||
|
||||
def test_non_root_class():
|
||||
cls = ql.Class("Class", bases=["A"])
|
||||
assert not cls.root
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(pytest.main())
|
||||
@@ -12,12 +12,13 @@ def run_mock():
|
||||
yield ret
|
||||
|
||||
|
||||
# these are lambdas so that they will use patched paths when called
|
||||
stub_path = lambda: paths.swift_dir / "ql/lib/stub/path"
|
||||
ql_output_path = lambda: paths.swift_dir / "ql/lib/other/path"
|
||||
import_file = lambda: stub_path().with_suffix(".qll")
|
||||
stub_import_prefix = "stub.path."
|
||||
gen_import_prefix = "other.path."
|
||||
index_param = ql.QlParam("index", "int")
|
||||
index_param = ql.Param("index", "int")
|
||||
|
||||
|
||||
def generate(opts, renderer, written=None):
|
||||
@@ -29,7 +30,7 @@ def generate(opts, renderer, written=None):
|
||||
|
||||
def test_empty(opts, input, renderer):
|
||||
assert generate(opts, renderer) == {
|
||||
import_file(): ql.QlImportList()
|
||||
import_file(): ql.ImportList()
|
||||
}
|
||||
|
||||
|
||||
@@ -38,9 +39,9 @@ def test_one_empty_class(opts, input, renderer):
|
||||
schema.Class("A")
|
||||
]
|
||||
assert generate(opts, renderer) == {
|
||||
import_file(): ql.QlImportList([stub_import_prefix + "A"]),
|
||||
stub_path() / "A.qll": ql.QlStub(name="A", base_import=gen_import_prefix + "A"),
|
||||
ql_output_path() / "A.qll": ql.QlClass(name="A", final=True),
|
||||
import_file(): ql.ImportList([stub_import_prefix + "A"]),
|
||||
stub_path() / "A.qll": ql.Stub(name="A", base_import=gen_import_prefix + "A"),
|
||||
ql_output_path() / "A.qll": ql.Class(name="A", final=True),
|
||||
}
|
||||
|
||||
|
||||
@@ -52,16 +53,16 @@ def test_hierarchy(opts, input, renderer):
|
||||
schema.Class("A", derived={"B", "C"}),
|
||||
]
|
||||
assert generate(opts, renderer) == {
|
||||
import_file(): ql.QlImportList([stub_import_prefix + cls for cls in "ABCD"]),
|
||||
stub_path() / "A.qll": ql.QlStub(name="A", base_import=gen_import_prefix + "A"),
|
||||
stub_path() / "B.qll": ql.QlStub(name="B", base_import=gen_import_prefix + "B"),
|
||||
stub_path() / "C.qll": ql.QlStub(name="C", base_import=gen_import_prefix + "C"),
|
||||
stub_path() / "D.qll": ql.QlStub(name="D", base_import=gen_import_prefix + "D"),
|
||||
ql_output_path() / "A.qll": ql.QlClass(name="A"),
|
||||
ql_output_path() / "B.qll": ql.QlClass(name="B", bases=["A"], imports=[stub_import_prefix + "A"]),
|
||||
ql_output_path() / "C.qll": ql.QlClass(name="C", bases=["A"], imports=[stub_import_prefix + "A"]),
|
||||
ql_output_path() / "D.qll": ql.QlClass(name="D", final=True, bases=["B", "C"],
|
||||
imports=[stub_import_prefix + cls for cls in "BC"]),
|
||||
import_file(): ql.ImportList([stub_import_prefix + cls for cls in "ABCD"]),
|
||||
stub_path() / "A.qll": ql.Stub(name="A", base_import=gen_import_prefix + "A"),
|
||||
stub_path() / "B.qll": ql.Stub(name="B", base_import=gen_import_prefix + "B"),
|
||||
stub_path() / "C.qll": ql.Stub(name="C", base_import=gen_import_prefix + "C"),
|
||||
stub_path() / "D.qll": ql.Stub(name="D", base_import=gen_import_prefix + "D"),
|
||||
ql_output_path() / "A.qll": ql.Class(name="A"),
|
||||
ql_output_path() / "B.qll": ql.Class(name="B", bases=["A"], imports=[stub_import_prefix + "A"]),
|
||||
ql_output_path() / "C.qll": ql.Class(name="C", bases=["A"], imports=[stub_import_prefix + "A"]),
|
||||
ql_output_path() / "D.qll": ql.Class(name="D", final=True, bases=["B", "C"],
|
||||
imports=[stub_import_prefix + cls for cls in "BC"]),
|
||||
|
||||
}
|
||||
|
||||
@@ -71,10 +72,10 @@ def test_single_property(opts, input, renderer):
|
||||
schema.Class("MyObject", properties=[schema.SingleProperty("foo", "bar")]),
|
||||
]
|
||||
assert generate(opts, renderer) == {
|
||||
import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]),
|
||||
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
|
||||
ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[
|
||||
ql.QlProperty(singular="Foo", type="bar", tablename="my_objects", tableparams=["this", "result"]),
|
||||
import_file(): ql.ImportList([stub_import_prefix + "MyObject"]),
|
||||
stub_path() / "MyObject.qll": ql.Stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
|
||||
ql_output_path() / "MyObject.qll": ql.Class(name="MyObject", final=True, properties=[
|
||||
ql.Property(singular="Foo", type="bar", tablename="my_objects", tableparams=["this", "result"]),
|
||||
])
|
||||
}
|
||||
|
||||
@@ -88,12 +89,12 @@ def test_single_properties(opts, input, renderer):
|
||||
]),
|
||||
]
|
||||
assert generate(opts, renderer) == {
|
||||
import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]),
|
||||
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
|
||||
ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[
|
||||
ql.QlProperty(singular="One", type="x", tablename="my_objects", tableparams=["this", "result", "_", "_"]),
|
||||
ql.QlProperty(singular="Two", type="y", tablename="my_objects", tableparams=["this", "_", "result", "_"]),
|
||||
ql.QlProperty(singular="Three", type="z", tablename="my_objects", tableparams=["this", "_", "_", "result"]),
|
||||
import_file(): ql.ImportList([stub_import_prefix + "MyObject"]),
|
||||
stub_path() / "MyObject.qll": ql.Stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
|
||||
ql_output_path() / "MyObject.qll": ql.Class(name="MyObject", final=True, properties=[
|
||||
ql.Property(singular="One", type="x", tablename="my_objects", tableparams=["this", "result", "_", "_"]),
|
||||
ql.Property(singular="Two", type="y", tablename="my_objects", tableparams=["this", "_", "result", "_"]),
|
||||
ql.Property(singular="Three", type="z", tablename="my_objects", tableparams=["this", "_", "_", "result"]),
|
||||
])
|
||||
}
|
||||
|
||||
@@ -103,10 +104,10 @@ def test_optional_property(opts, input, renderer):
|
||||
schema.Class("MyObject", properties=[schema.OptionalProperty("foo", "bar")]),
|
||||
]
|
||||
assert generate(opts, renderer) == {
|
||||
import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]),
|
||||
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
|
||||
ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[
|
||||
ql.QlProperty(singular="Foo", type="bar", tablename="my_object_foos", tableparams=["this", "result"]),
|
||||
import_file(): ql.ImportList([stub_import_prefix + "MyObject"]),
|
||||
stub_path() / "MyObject.qll": ql.Stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
|
||||
ql_output_path() / "MyObject.qll": ql.Class(name="MyObject", final=True, properties=[
|
||||
ql.Property(singular="Foo", type="bar", tablename="my_object_foos", tableparams=["this", "result"]),
|
||||
])
|
||||
}
|
||||
|
||||
@@ -116,11 +117,11 @@ def test_repeated_property(opts, input, renderer):
|
||||
schema.Class("MyObject", properties=[schema.RepeatedProperty("foo", "bar")]),
|
||||
]
|
||||
assert generate(opts, renderer) == {
|
||||
import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]),
|
||||
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
|
||||
ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[
|
||||
ql.QlProperty(singular="Foo", plural="Foos", type="bar", tablename="my_object_foos", params=[index_param],
|
||||
tableparams=["this", "index", "result"]),
|
||||
import_file(): ql.ImportList([stub_import_prefix + "MyObject"]),
|
||||
stub_path() / "MyObject.qll": ql.Stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
|
||||
ql_output_path() / "MyObject.qll": ql.Class(name="MyObject", final=True, properties=[
|
||||
ql.Property(singular="Foo", plural="Foos", type="bar", tablename="my_object_foos", params=[index_param],
|
||||
tableparams=["this", "index", "result"]),
|
||||
])
|
||||
}
|
||||
|
||||
@@ -131,15 +132,15 @@ def test_single_class_property(opts, input, renderer):
|
||||
schema.Class("Bar"),
|
||||
]
|
||||
assert generate(opts, renderer) == {
|
||||
import_file(): ql.QlImportList([stub_import_prefix + cls for cls in ("Bar", "MyObject")]),
|
||||
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
|
||||
stub_path() / "Bar.qll": ql.QlStub(name="Bar", base_import=gen_import_prefix + "Bar"),
|
||||
ql_output_path() / "MyObject.qll": ql.QlClass(
|
||||
import_file(): ql.ImportList([stub_import_prefix + cls for cls in ("Bar", "MyObject")]),
|
||||
stub_path() / "MyObject.qll": ql.Stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
|
||||
stub_path() / "Bar.qll": ql.Stub(name="Bar", base_import=gen_import_prefix + "Bar"),
|
||||
ql_output_path() / "MyObject.qll": ql.Class(
|
||||
name="MyObject", final=True, imports=[stub_import_prefix + "Bar"], properties=[
|
||||
ql.QlProperty(singular="Foo", type="Bar", tablename="my_objects", tableparams=["this", "result"]),
|
||||
ql.Property(singular="Foo", type="Bar", tablename="my_objects", tableparams=["this", "result"]),
|
||||
],
|
||||
),
|
||||
ql_output_path() / "Bar.qll": ql.QlClass(name="Bar", final=True)
|
||||
ql_output_path() / "Bar.qll": ql.Class(name="Bar", final=True)
|
||||
}
|
||||
|
||||
|
||||
@@ -150,15 +151,15 @@ def test_class_dir(opts, input, renderer):
|
||||
schema.Class("B", bases={"A"}),
|
||||
]
|
||||
assert generate(opts, renderer) == {
|
||||
import_file(): ql.QlImportList([
|
||||
import_file(): ql.ImportList([
|
||||
stub_import_prefix + "another.rel.path.A",
|
||||
stub_import_prefix + "B",
|
||||
]),
|
||||
stub_path() / dir / "A.qll": ql.QlStub(name="A", base_import=gen_import_prefix + "another.rel.path.A"),
|
||||
stub_path() / "B.qll": ql.QlStub(name="B", base_import=gen_import_prefix + "B"),
|
||||
ql_output_path() / dir / "A.qll": ql.QlClass(name="A", dir=dir),
|
||||
ql_output_path() / "B.qll": ql.QlClass(name="B", final=True, bases=["A"],
|
||||
imports=[stub_import_prefix + "another.rel.path.A"])
|
||||
stub_path() / dir / "A.qll": ql.Stub(name="A", base_import=gen_import_prefix + "another.rel.path.A"),
|
||||
stub_path() / "B.qll": ql.Stub(name="B", base_import=gen_import_prefix + "B"),
|
||||
ql_output_path() / dir / "A.qll": ql.Class(name="A", dir=dir),
|
||||
ql_output_path() / "B.qll": ql.Class(name="B", final=True, bases=["A"],
|
||||
imports=[stub_import_prefix + "another.rel.path.A"])
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user