Merge pull request #9094 from redsun82/swift-codegen-predicate-properties

Swift codegen: add predicate properties
This commit is contained in:
Mathias Vorreiter Pedersen
2022-05-09 17:17:10 +01:00
committed by GitHub
17 changed files with 179 additions and 60 deletions

View File

@@ -8,6 +8,9 @@ from swift.codegen.lib import cpp, generator, schema
def _get_type(t: str, trap_affix: str) -> str:
if t is None:
# this is a predicate
return "bool"
if t == "string":
return "std::string"
if t == "boolean":
@@ -20,12 +23,15 @@ def _get_type(t: str, trap_affix: str) -> str:
def _get_field(cls: schema.Class, p: schema.Property, trap_affix: str) -> cpp.Field:
trap_name = None
if not p.is_single:
trap_name = inflection.pluralize(inflection.camelize(f"{cls.name}_{p.name}"))
trap_name = inflection.camelize(f"{cls.name}_{p.name}")
if not p.is_predicate:
trap_name = inflection.pluralize(trap_name)
args = dict(
name=p.name + ("_" if p.name in cpp.cpp_keywords else ""),
type=_get_type(p.type, trap_affix),
is_optional=p.is_optional,
is_repeated=p.is_repeated,
is_predicate=p.is_predicate,
trap_name=trap_name,
)
args.update(cpp.get_field_override(p.name))

View File

@@ -57,6 +57,15 @@ def cls_to_dbscheme(cls: schema.Class):
Column(f.name, dbtype(f.type)),
],
)
elif f.is_predicate:
yield Table(
keyset=KeySet(["id"]),
name=inflection.underscore(f"{cls.name}_{f.name}"),
columns=[
Column("id", type=dbtype(cls.name)),
],
)
def get_declarations(data: schema.Schema):

View File

@@ -35,6 +35,7 @@ class Field:
type: str
is_optional: bool = False
is_repeated: bool = False
is_predicate: bool = False
trap_name: str = None
first: bool = False
@@ -61,7 +62,7 @@ class Field:
@property
def is_single(self):
return not (self.is_optional or self.is_repeated)
return not (self.is_optional or self.is_repeated or self.is_predicate)

View File

@@ -14,29 +14,35 @@ class Param:
@dataclass
class Property:
singular: str
type: str
tablename: str
tableparams: List[Param]
type: str = None
tablename: str = None
tableparams: List[Param] = field(default_factory=list)
plural: str = None
first: bool = False
local_var: str = "x"
is_optional: bool = False
is_predicate: bool = False
def __post_init__(self):
assert self.tableparams
if self.type_is_class:
self.tableparams = [x if x != "result" else self.local_var for x in self.tableparams]
self.tableparams = [Param(x) for x in self.tableparams]
self.tableparams[0].first = True
if self.tableparams:
if self.type_is_class:
self.tableparams = [x if x != "result" else self.local_var for x in self.tableparams]
self.tableparams = [Param(x) for x in self.tableparams]
self.tableparams[0].first = True
@property
def indefinite_article(self):
def getter(self):
return f"get{self.singular}" if not self.is_predicate else self.singular
@property
def indefinite_getter(self):
if self.plural:
return "An" if self.singular[0] in "AEIO" else "A"
article = "An" if self.singular[0] in "AEIO" else "A"
return f"get{article}{self.singular}"
@property
def type_is_class(self):
return self.type[0].isupper()
return bool(self.type) and self.type[0].isupper()
@property
def is_repeated(self):

View File

@@ -15,9 +15,10 @@ class Property:
is_single: ClassVar = False
is_optional: ClassVar = False
is_repeated: ClassVar = False
is_predicate: ClassVar = False
name: str
type: str
type: str = None
@dataclass
@@ -41,6 +42,11 @@ class RepeatedOptionalProperty(Property):
is_repeated: ClassVar = True
@dataclass
class PredicateProperty(Property):
is_predicate: ClassVar = True
@dataclass
class Class:
name: str
@@ -58,17 +64,15 @@ class Schema:
def _parse_property(name, type):
if type.endswith("?*"):
cls = RepeatedOptionalProperty
type = type[:-2]
return RepeatedOptionalProperty(name, type[:-2])
elif type.endswith("*"):
cls = RepeatedProperty
type = type[:-1]
return RepeatedProperty(name, type[:-1])
elif type.endswith("?"):
cls = OptionalProperty
type = type[:-1]
return OptionalProperty(name, type[:-1])
elif type == "predicate":
return PredicateProperty(name)
else:
cls = SingleProperty
return cls(name, type)
return SingleProperty(name, type)
class _DirSelector:

View File

@@ -36,6 +36,14 @@ def get_ql_property(cls: schema.Class, prop: schema.Property):
tableparams=["this", "result"],
is_optional=True,
)
elif prop.is_predicate:
return ql.Property(
singular=inflection.camelize(prop.name, uppercase_first_letter=False),
type="predicate",
tablename=inflection.underscore(f"{cls.name}_{prop.name}"),
tableparams=["this"],
is_predicate=True,
)
def get_ql_class(cls: schema.Class):

View File

@@ -59,6 +59,7 @@ AnyFunctionType:
result: Type
param_types: Type*
param_labels: string*
is_throwing: predicate
AnyGenericType:
_extends: Type

View File

@@ -33,6 +33,9 @@ struct {{name}}{{#final}} : Binding<{{name}}Tag>{{#bases}}, {{ref.name}}{{/bases
{{ref.name}}::emit(id, out);
{{/bases}}
{{#fields}}
{{#is_predicate}}
if ({{name}}) out << {{trap_name}}{{trap_affix}}{id} << '\n';
{{/is_predicate}}
{{#is_optional}}
{{^is_repeated}}
if ({{name}}) out << {{trap_name}}{{trap_affix}}{id, *{{name}}} << '\n';

View File

@@ -21,7 +21,7 @@ class {{name}}Base extends {{db_id}}{{#bases}}, {{.}}{{/bases}} {
{{/final}}
{{#properties}}
{{type}} get{{singular}}({{#is_repeated}}int index{{/is_repeated}}) {
{{type}} {{getter}}({{#is_repeated}}int index{{/is_repeated}}) {
{{#type_is_class}}
exists({{type}} {{local_var}} |
{{tablename}}({{#tableparams}}{{^first}}, {{/first}}{{param}}{{/tableparams}})
@@ -34,13 +34,13 @@ class {{name}}Base extends {{db_id}}{{#bases}}, {{.}}{{/bases}} {
}
{{#is_repeated}}
{{type}} get{{indefinite_article}}{{singular}}() {
result = get{{singular}}(_)
{{type}} {{indefinite_getter}}() {
result = {{getter}}(_)
}
{{^is_optional}}
int getNumberOf{{plural}}() {
result = count(get{{indefinite_article}}{{singular}}())
result = count({{indefinite_getter}}())
}
{{/is_optional}}
{{/is_repeated}}

View File

@@ -27,14 +27,15 @@ def test_field_get_streamer(type, expected):
assert f.get_streamer()("value") == expected
@pytest.mark.parametrize("is_optional,is_repeated,expected", [
(False, False, True),
(True, False, False),
(False, True, False),
(True, True, False),
@pytest.mark.parametrize("is_optional,is_repeated,is_predicate,expected", [
(False, False, False, True),
(True, False, False, False),
(False, True, False, False),
(True, True, False, False),
(False, False, True, False),
])
def test_field_is_single(is_optional, is_repeated, expected):
f = cpp.Field("name", "type", is_optional=is_optional, is_repeated=is_repeated)
def test_field_is_single(is_optional, is_repeated, is_predicate, expected):
f = cpp.Field("name", "type", is_optional=is_optional, is_repeated=is_repeated, is_predicate=is_predicate)
assert f.is_single is expected

View File

@@ -92,6 +92,17 @@ def test_class_with_field(generate, type, expected, property_cls, optional, repe
]
def test_class_with_predicate(generate):
assert generate([
schema.Class(name="MyClass", properties=[schema.PredicateProperty("prop")]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field("prop", "bool", trap_name="MyClassProp", is_predicate=True)],
trap_name="MyClasses",
final=True)
]
@pytest.mark.parametrize("name",
["start_line", "start_column", "end_line", "end_column", "index", "num_whatever", "width"])
def test_class_with_overridden_unsigned_field(generate, name):

View File

@@ -156,6 +156,33 @@ def test_final_class_with_repeated_field(opts, input, renderer, property_cls):
)
def test_final_class_with_predicate_field(opts, input, renderer):
input.classes = [
schema.Class("Object", properties=[
schema.PredicateProperty("foo"),
]),
]
assert generate(opts, renderer) == dbscheme.Scheme(
src=schema_file,
includes=[],
declarations=[
dbscheme.Table(
name="objects",
columns=[
dbscheme.Column('id', '@object', binding=True),
]
),
dbscheme.Table(
name="object_foo",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@object'),
]
),
],
)
def test_final_class_with_more_fields(opts, input, renderer):
input.classes = [
schema.Class("Object", properties=[
@@ -164,6 +191,7 @@ def test_final_class_with_more_fields(opts, input, renderer):
schema.OptionalProperty("three", "z"),
schema.RepeatedProperty("four", "u"),
schema.RepeatedOptionalProperty("five", "v"),
schema.PredicateProperty("six"),
]),
]
assert generate(opts, renderer) == dbscheme.Scheme(
@@ -204,6 +232,13 @@ def test_final_class_with_more_fields(opts, input, renderer):
dbscheme.Column('five', 'v'),
]
),
dbscheme.Table(
name="object_six",
keyset=dbscheme.KeySet(["id"]),
columns=[
dbscheme.Column('id', '@object'),
]
),
],
)

View File

@@ -12,31 +12,32 @@ def test_property_has_first_table_param_marked():
assert [p.param for p in prop.tableparams] == tableparams
def test_property_not_a_class():
tableparams = ["x", "result", "y"]
prop = ql.Property("Prop", "foo", "props", tableparams)
assert not prop.type_is_class
assert [p.param for p in prop.tableparams] == tableparams
def test_property_is_a_class():
tableparams = ["x", "result", "y"]
prop = ql.Property("Prop", "Foo", "props", tableparams)
assert prop.type_is_class
assert [p.param for p in prop.tableparams] == ["x", prop.local_var, "y"]
@pytest.mark.parametrize("name,expected_article", [
("Argument", "An"),
("Element", "An"),
("Integer", "An"),
("Operator", "An"),
("Unit", "A"),
("Whatever", "A"),
@pytest.mark.parametrize("type,expected", [
("Foo", True),
("Bar", True),
("foo", False),
("bar", False),
(None, False),
])
def test_property_indefinite_article(name, expected_article):
prop = ql.Property(name, "Foo", "props", ["x"], plural="X")
assert prop.indefinite_article == expected_article
def test_property_is_a_class(type, expected):
tableparams = ["a", "result", "b"]
expected_tableparams = ["a", "x" if expected else "result", "b"]
prop = ql.Property("Prop", type, tableparams=tableparams)
assert prop.type_is_class is expected
assert [p.param for p in prop.tableparams] == expected_tableparams
@pytest.mark.parametrize("name,expected_getter", [
("Argument", "getAnArgument"),
("Element", "getAnElement"),
("Integer", "getAnInteger"),
("Operator", "getAnOperator"),
("Unit", "getAUnit"),
("Whatever", "getAWhatever"),
])
def test_property_indefinite_article(name, expected_getter):
prop = ql.Property(name, plural="X")
assert prop.indefinite_getter == expected_getter
@pytest.mark.parametrize("plural,expected", [
@@ -49,9 +50,19 @@ def test_property_is_plural(plural, expected):
assert prop.is_repeated is expected
def test_property_no_plural_no_indefinite_article():
def test_property_no_plural_no_indefinite_getter():
prop = ql.Property("Prop", "Foo", "props", ["x"])
assert prop.indefinite_article is None
assert prop.indefinite_getter is None
def test_property_getter():
prop = ql.Property("Prop", "Foo")
assert prop.getter == "getProp"
def test_property_predicate_getter():
prop = ql.Property("prop", is_predicate=True)
assert prop.getter == "prop"
def test_class_sorts_bases():

View File

@@ -2,7 +2,7 @@ import subprocess
import sys
from swift.codegen import qlgen
from swift.codegen.lib import ql, paths
from swift.codegen.lib import ql
from swift.codegen.test.utils import *
@@ -141,6 +141,20 @@ def test_repeated_optional_property(opts, input, renderer):
}
def test_predicate_property(opts, input, renderer):
input.classes = [
schema.Class("MyObject", properties=[schema.PredicateProperty("is_foo")]),
]
assert generate(opts, renderer) == {
import_file(): ql.ImportList([stub_import_prefix + "MyObject"]),
stub_path() / "MyObject.qll": ql.Stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
ql_output_path() / "MyObject.qll": ql.Class(name="MyObject", final=True, properties=[
ql.Property(singular="isFoo", type="predicate", tablename="my_object_is_foo", tableparams=["this"],
is_predicate=True),
])
}
def test_single_class_property(opts, input, renderer):
input.classes = [
schema.Class("MyObject", properties=[schema.SingleProperty("foo", "Bar")]),

View File

@@ -141,6 +141,7 @@ A:
two: int?
three: bool*
four: x?*
five: predicate
""")
assert ret.classes == [
schema.Class(root_name, derived={'A'}),
@@ -149,6 +150,7 @@ A:
schema.OptionalProperty('two', 'int'),
schema.RepeatedProperty('three', 'bool'),
schema.RepeatedOptionalProperty('four', 'x'),
schema.PredicateProperty('five'),
]),
]

View File

@@ -25,4 +25,6 @@ class AnyFunctionTypeBase extends @any_function_type, Type {
string getAParamLabel() { result = getParamLabel(_) }
int getNumberOfParamLabels() { result = count(getAParamLabel()) }
predicate isThrowing() { any_function_type_is_throwing(this) }
}

View File

@@ -167,6 +167,11 @@ any_function_type_param_labels(
string param_label: string ref
);
#keyset[id]
any_function_type_is_throwing(
int id: @any_function_type ref
);
@any_generic_type =
@nominal_or_bound_generic_nominal_type
| @unbound_generic_type