mirror of
https://github.com/github/codeql.git
synced 2025-12-17 01:03:14 +01:00
Swift: add possibility to specify null class
This commit is contained in:
@@ -12,15 +12,14 @@ Each class in the schema gets a corresponding `struct` in `TrapClasses.h`, where
|
||||
"""
|
||||
|
||||
import functools
|
||||
import pathlib
|
||||
from typing import Dict
|
||||
import typing
|
||||
|
||||
import inflection
|
||||
|
||||
from swift.codegen.lib import cpp, schema
|
||||
|
||||
|
||||
def _get_type(t: str) -> str:
|
||||
def _get_type(t: str, add_or_none_except: typing.Optional[str] = None) -> str:
|
||||
if t is None:
|
||||
# this is a predicate
|
||||
return "bool"
|
||||
@@ -29,11 +28,15 @@ def _get_type(t: str) -> str:
|
||||
if t == "boolean":
|
||||
return "bool"
|
||||
if t[0].isupper():
|
||||
return f"TrapLabel<{t}Tag>"
|
||||
if add_or_none_except is not None and t != add_or_none_except:
|
||||
suffix = "OrNone"
|
||||
else:
|
||||
suffix = ""
|
||||
return f"TrapLabel<{t}{suffix}Tag>"
|
||||
return t
|
||||
|
||||
|
||||
def _get_field(cls: schema.Class, p: schema.Property) -> cpp.Field:
|
||||
def _get_field(cls: schema.Class, p: schema.Property, add_or_none_except: typing.Optional[str] = None) -> cpp.Field:
|
||||
trap_name = None
|
||||
if not p.is_single:
|
||||
trap_name = inflection.camelize(f"{cls.name}_{p.name}")
|
||||
@@ -41,7 +44,7 @@ def _get_field(cls: schema.Class, p: schema.Property) -> cpp.Field:
|
||||
trap_name = inflection.pluralize(trap_name)
|
||||
args = dict(
|
||||
field_name=p.name + ("_" if p.name in cpp.cpp_keywords else ""),
|
||||
type=_get_type(p.type),
|
||||
base_type=_get_type(p.type, add_or_none_except),
|
||||
is_optional=p.is_optional,
|
||||
is_repeated=p.is_repeated,
|
||||
is_predicate=p.is_predicate,
|
||||
@@ -52,8 +55,13 @@ def _get_field(cls: schema.Class, p: schema.Property) -> cpp.Field:
|
||||
|
||||
|
||||
class Processor:
|
||||
def __init__(self, data: Dict[str, schema.Class]):
|
||||
self._classmap = data
|
||||
def __init__(self, data: schema.Schema):
|
||||
self._classmap = data.classes
|
||||
if data.null:
|
||||
root_type = next(iter(data.classes))
|
||||
self._add_or_none_except = root_type
|
||||
else:
|
||||
self._add_or_none_except = None
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _get_class(self, name: str) -> cpp.Class:
|
||||
@@ -64,7 +72,10 @@ class Processor:
|
||||
return cpp.Class(
|
||||
name=name,
|
||||
bases=[self._get_class(b) for b in cls.bases],
|
||||
fields=[_get_field(cls, p) for p in cls.properties if "cpp_skip" not in p.pragmas],
|
||||
fields=[
|
||||
_get_field(cls, p, self._add_or_none_except)
|
||||
for p in cls.properties if "cpp_skip" not in p.pragmas
|
||||
],
|
||||
final=not cls.derived,
|
||||
trap_name=trap_name,
|
||||
)
|
||||
@@ -78,8 +89,8 @@ class Processor:
|
||||
|
||||
def generate(opts, renderer):
|
||||
assert opts.cpp_output
|
||||
processor = Processor(schema.load_file(opts.schema).classes)
|
||||
processor = Processor(schema.load_file(opts.schema))
|
||||
out = opts.cpp_output
|
||||
for dir, classes in processor.get_classes().items():
|
||||
renderer.render(cpp.ClassList(classes, opts.schema,
|
||||
include_parent=bool(dir)), out / dir / "TrapClasses")
|
||||
include_parent=bool(dir)), out / dir / "TrapClasses")
|
||||
|
||||
@@ -13,6 +13,7 @@ Moreover:
|
||||
as columns
|
||||
The type hierarchy will be translated to corresponding `union` declarations.
|
||||
"""
|
||||
import typing
|
||||
|
||||
import inflection
|
||||
|
||||
@@ -23,14 +24,21 @@ from typing import Set, List
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def dbtype(typename):
|
||||
""" translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes """
|
||||
def dbtype(typename: str, add_or_none_except: typing.Optional[str] = None) -> str:
|
||||
""" translate a type to a dbscheme counterpart, using `@lower_underscore` format for classes.
|
||||
For class types, appends an underscore followed by `null` if provided
|
||||
"""
|
||||
if typename[0].isupper():
|
||||
return "@" + inflection.underscore(typename)
|
||||
underscored = inflection.underscore(typename)
|
||||
if add_or_none_except is not None and typename != add_or_none_except:
|
||||
suffix = "_or_none"
|
||||
else:
|
||||
suffix = ""
|
||||
return f"@{underscored}{suffix}"
|
||||
return typename
|
||||
|
||||
|
||||
def cls_to_dbscheme(cls: schema.Class):
|
||||
def cls_to_dbscheme(cls: schema.Class, add_or_none_except: typing.Optional[str] = None):
|
||||
""" Yield all dbscheme entities needed to model class `cls` """
|
||||
if cls.derived:
|
||||
yield Union(dbtype(cls.name), (dbtype(c) for c in cls.derived))
|
||||
@@ -48,7 +56,7 @@ def cls_to_dbscheme(cls: schema.Class):
|
||||
columns=[
|
||||
Column("id", type=dbtype(cls.name), binding=binding),
|
||||
] + [
|
||||
Column(f.name, dbtype(f.type)) for f in cls.properties if f.is_single
|
||||
Column(f.name, dbtype(f.type, add_or_none_except)) for f in cls.properties if f.is_single
|
||||
],
|
||||
dir=dir,
|
||||
)
|
||||
@@ -61,7 +69,7 @@ def cls_to_dbscheme(cls: schema.Class):
|
||||
columns=[
|
||||
Column("id", type=dbtype(cls.name)),
|
||||
Column("index", type="int"),
|
||||
Column(inflection.singularize(f.name), dbtype(f.type)),
|
||||
Column(inflection.singularize(f.name), dbtype(f.type, add_or_none_except)),
|
||||
],
|
||||
dir=dir,
|
||||
)
|
||||
@@ -71,7 +79,7 @@ def cls_to_dbscheme(cls: schema.Class):
|
||||
name=inflection.tableize(f"{cls.name}_{f.name}"),
|
||||
columns=[
|
||||
Column("id", type=dbtype(cls.name)),
|
||||
Column(f.name, dbtype(f.type)),
|
||||
Column(f.name, dbtype(f.type, add_or_none_except)),
|
||||
],
|
||||
dir=dir,
|
||||
)
|
||||
@@ -87,7 +95,17 @@ def cls_to_dbscheme(cls: schema.Class):
|
||||
|
||||
|
||||
def get_declarations(data: schema.Schema):
|
||||
return [d for cls in data.classes.values() for d in cls_to_dbscheme(cls)]
|
||||
add_or_none_except = data.root_class.name if data.null else None
|
||||
declarations = [d for cls in data.classes.values() for d in cls_to_dbscheme(cls, add_or_none_except)]
|
||||
if data.null:
|
||||
property_classes = {
|
||||
prop.type for cls in data.classes.values() for prop in cls.properties
|
||||
if cls.name != data.null and prop.type and prop.type[0].isupper()
|
||||
}
|
||||
declarations += [
|
||||
Union(dbtype(t, data.null), [dbtype(t), dbtype(data.null)]) for t in sorted(property_classes)
|
||||
]
|
||||
return declarations
|
||||
|
||||
|
||||
def get_includes(data: schema.Schema, include_dir: pathlib.Path, swift_dir: pathlib.Path):
|
||||
|
||||
@@ -147,7 +147,7 @@ def get_ql_property(cls: schema.Class, prop: schema.Property, prev_child: str =
|
||||
return ql.Property(**args)
|
||||
|
||||
|
||||
def get_ql_class(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
|
||||
def get_ql_class(cls: schema.Class):
|
||||
pragmas = {k: True for k in cls.pragmas if k.startswith("ql")}
|
||||
prev_child = ""
|
||||
properties = []
|
||||
@@ -314,7 +314,7 @@ def generate(opts, renderer):
|
||||
|
||||
data = schema.load_file(input)
|
||||
|
||||
classes = {name: get_ql_class(cls, data.classes) for name, cls in data.classes.items()}
|
||||
classes = {name: get_ql_class(cls) for name, cls in data.classes.items()}
|
||||
if not classes:
|
||||
raise NoClasses
|
||||
root = next(iter(classes.values()))
|
||||
|
||||
@@ -41,10 +41,10 @@ def get_cpp_type(schema_type: str):
|
||||
def get_field(c: dbscheme.Column):
|
||||
args = {
|
||||
"field_name": c.schema_name,
|
||||
"type": c.type,
|
||||
"base_type": c.type,
|
||||
}
|
||||
args.update(cpp.get_field_override(c.schema_name))
|
||||
args["type"] = get_cpp_type(args["type"])
|
||||
args["base_type"] = get_cpp_type(args["base_type"])
|
||||
return cpp.Field(**args)
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ cpp_keywords = {"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", "
|
||||
"xor", "xor_eq"}
|
||||
|
||||
_field_overrides = [
|
||||
(re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"), {"type": "unsigned"}),
|
||||
(re.compile(r"(start|end)_(line|column)|(.*_)?index|width|num_.*"), {"base_type": "unsigned"}),
|
||||
(re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}),
|
||||
]
|
||||
|
||||
@@ -32,7 +32,7 @@ def get_field_override(field: str):
|
||||
@dataclass
|
||||
class Field:
|
||||
field_name: str
|
||||
type: str
|
||||
base_type: str
|
||||
is_optional: bool = False
|
||||
is_repeated: bool = False
|
||||
is_predicate: bool = False
|
||||
@@ -40,13 +40,18 @@ class Field:
|
||||
first: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.is_optional:
|
||||
self.type = f"std::optional<{self.type}>"
|
||||
if self.is_repeated:
|
||||
self.type = f"std::vector<{self.type}>"
|
||||
if self.field_name in cpp_keywords:
|
||||
self.field_name += "_"
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
type = self.base_type
|
||||
if self.is_optional:
|
||||
type = f"std::optional<{type}>"
|
||||
if self.is_repeated:
|
||||
type = f"std::vector<{type}>"
|
||||
return type
|
||||
|
||||
# using @property breaks pystache internals here
|
||||
def get_streamer(self):
|
||||
if self.type == "std::string":
|
||||
@@ -60,6 +65,10 @@ class Field:
|
||||
def is_single(self):
|
||||
return not (self.is_optional or self.is_repeated or self.is_predicate)
|
||||
|
||||
@property
|
||||
def is_label(self):
|
||||
return self.base_type.startswith("TrapLabel<")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Trap:
|
||||
|
||||
@@ -115,6 +115,8 @@ child = _ChildModifier()
|
||||
doc = _DocModifier
|
||||
desc = _DescModifier
|
||||
|
||||
use_for_null = _annotate(null=True)
|
||||
|
||||
qltest = _Namespace(
|
||||
skip=_Pragma("qltest_skip"),
|
||||
collapse_hierarchy=_Pragma("qltest_collapse_hierarchy"),
|
||||
|
||||
@@ -55,6 +55,14 @@ class Property:
|
||||
def is_predicate(self) -> bool:
|
||||
return self.kind == self.Kind.PREDICATE
|
||||
|
||||
@property
|
||||
def has_class_type(self) -> bool:
|
||||
return bool(self.type) and self.type[0].isupper()
|
||||
|
||||
@property
|
||||
def has_builtin_type(self) -> bool:
|
||||
return bool(self.type) and self.type[0].islower()
|
||||
|
||||
|
||||
SingleProperty = functools.partial(Property, Property.Kind.SINGLE)
|
||||
OptionalProperty = functools.partial(Property, Property.Kind.OPTIONAL)
|
||||
@@ -104,6 +112,16 @@ class Class:
|
||||
class Schema:
|
||||
classes: Dict[str, Class] = field(default_factory=dict)
|
||||
includes: Set[str] = field(default_factory=set)
|
||||
null: Optional[str] = None
|
||||
|
||||
@property
|
||||
def root_class(self):
|
||||
# always the first in the dictionary
|
||||
return next(iter(self.classes.values()))
|
||||
|
||||
@property
|
||||
def null_class(self):
|
||||
return self.classes[self.null] if self.null else None
|
||||
|
||||
|
||||
predicate_marker = object()
|
||||
@@ -195,6 +213,8 @@ def _get_class(cls: type) -> Class:
|
||||
raise Error(f"Class name must be capitalized, found {cls.__name__}")
|
||||
if len({b._group for b in cls.__bases__ if hasattr(b, "_group")}) > 1:
|
||||
raise Error(f"Bases with mixed groups for {cls.__name__}")
|
||||
if any(getattr(b, "_null", False) for b in cls.__bases__):
|
||||
raise Error(f"Null class cannot be derived")
|
||||
return Class(name=cls.__name__,
|
||||
bases=[b.__name__ for b in cls.__bases__ if b is not object],
|
||||
derived={d.__name__ for d in cls.__subclasses__()},
|
||||
@@ -233,6 +253,7 @@ def load(m: types.ModuleType) -> Schema:
|
||||
known = {"int", "string", "boolean"}
|
||||
known.update(n for n in m.__dict__ if not n.startswith("__"))
|
||||
import swift.codegen.lib.schema.defs as defs
|
||||
null = None
|
||||
for name, data in m.__dict__.items():
|
||||
if hasattr(defs, name):
|
||||
continue
|
||||
@@ -247,8 +268,13 @@ def load(m: types.ModuleType) -> Schema:
|
||||
f"Only one root class allowed, found second root {name}")
|
||||
cls.check_types(known)
|
||||
classes[name] = cls
|
||||
if getattr(data, "_null", False):
|
||||
if null is not None:
|
||||
raise Error(f"Null class {null} already defined, second null class {name} not allowed")
|
||||
null = name
|
||||
cls.is_null_class = True
|
||||
|
||||
return Schema(includes=includes, classes=_toposort_classes_by_group(classes))
|
||||
return Schema(includes=includes, classes=_toposort_classes_by_group(classes), null=null)
|
||||
|
||||
|
||||
def load_file(path: pathlib.Path) -> Schema:
|
||||
|
||||
@@ -80,6 +80,25 @@ def test_class_with_field(generate, type, expected, property_cls, optional, repe
|
||||
]
|
||||
|
||||
|
||||
def test_class_field_with_null(generate, input):
|
||||
input.null = "Null"
|
||||
a = cpp.Class(name="A")
|
||||
assert generate([
|
||||
schema.Class(name="A", derived={"B"}),
|
||||
schema.Class(name="B", bases=["A"], properties=[
|
||||
schema.SingleProperty("x", "A"),
|
||||
schema.SingleProperty("y", "B"),
|
||||
])
|
||||
]) == [
|
||||
a,
|
||||
cpp.Class(name="B", bases=[a], final=True, trap_name="Bs",
|
||||
fields=[
|
||||
cpp.Field("x", "TrapLabel<ATag>"),
|
||||
cpp.Field("y", "TrapLabel<BOrNoneTag>"),
|
||||
]),
|
||||
]
|
||||
|
||||
|
||||
def test_class_with_predicate(generate):
|
||||
assert generate([
|
||||
schema.Class(name="MyClass", properties=[
|
||||
|
||||
@@ -18,8 +18,9 @@ def dir_param(request):
|
||||
|
||||
@pytest.fixture
|
||||
def generate(opts, input, renderer):
|
||||
def func(classes):
|
||||
def func(classes, null=None):
|
||||
input.classes = {cls.name: cls for cls in classes}
|
||||
input.null = null
|
||||
(out, data), = run_generation(dbschemegen.generate, opts, renderer).items()
|
||||
assert out is opts.dbscheme
|
||||
return data
|
||||
@@ -359,5 +360,114 @@ def test_class_with_derived_and_repeated_property(generate, dir_param):
|
||||
)
|
||||
|
||||
|
||||
def test_null_class(generate):
|
||||
assert generate([
|
||||
schema.Class(
|
||||
name="Base",
|
||||
derived={"W", "X", "Y", "Z", "Null"},
|
||||
),
|
||||
schema.Class(
|
||||
name="W",
|
||||
bases=["Base"],
|
||||
properties=[
|
||||
schema.SingleProperty("w", "W"),
|
||||
schema.SingleProperty("x", "X"),
|
||||
schema.OptionalProperty("y", "Y"),
|
||||
schema.RepeatedProperty("z", "Z"),
|
||||
]
|
||||
),
|
||||
schema.Class(
|
||||
name="X",
|
||||
bases=["Base"],
|
||||
),
|
||||
schema.Class(
|
||||
name="Y",
|
||||
bases=["Base"],
|
||||
),
|
||||
schema.Class(
|
||||
name="Z",
|
||||
bases=["Base"],
|
||||
),
|
||||
schema.Class(
|
||||
name="Null",
|
||||
bases=["Base"],
|
||||
),
|
||||
], null="Null") == dbscheme.Scheme(
|
||||
src=schema_file,
|
||||
includes=[],
|
||||
declarations=[
|
||||
dbscheme.Union(
|
||||
lhs="@base",
|
||||
rhs=["@null", "@w", "@x", "@y", "@z"],
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="ws",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@w', binding=True),
|
||||
dbscheme.Column('w', '@w_or_none'),
|
||||
dbscheme.Column('x', '@x_or_none'),
|
||||
],
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="w_ies",
|
||||
keyset=dbscheme.KeySet(["id"]),
|
||||
columns=[
|
||||
dbscheme.Column('id', '@w'),
|
||||
dbscheme.Column('y', '@y_or_none'),
|
||||
],
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="w_zs",
|
||||
keyset=dbscheme.KeySet(["id", "index"]),
|
||||
columns=[
|
||||
dbscheme.Column('id', '@w'),
|
||||
dbscheme.Column('index', 'int'),
|
||||
dbscheme.Column('z', '@z_or_none'),
|
||||
],
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="xes",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@x', binding=True),
|
||||
],
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="ys",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@y', binding=True),
|
||||
],
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="zs",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@z', binding=True),
|
||||
],
|
||||
),
|
||||
dbscheme.Table(
|
||||
name="nulls",
|
||||
columns=[
|
||||
dbscheme.Column('id', '@null', binding=True),
|
||||
],
|
||||
),
|
||||
dbscheme.Union(
|
||||
lhs="@w_or_none",
|
||||
rhs=["@w", "@null"],
|
||||
),
|
||||
dbscheme.Union(
|
||||
lhs="@x_or_none",
|
||||
rhs=["@x", "@null"],
|
||||
),
|
||||
dbscheme.Union(
|
||||
lhs="@y_or_none",
|
||||
rhs=["@y", "@null"],
|
||||
),
|
||||
dbscheme.Union(
|
||||
lhs="@z_or_none",
|
||||
rhs=["@z", "@null"],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(pytest.main([__file__] + sys.argv[1:]))
|
||||
|
||||
@@ -13,6 +13,8 @@ def test_empty_schema():
|
||||
|
||||
assert data.classes == {}
|
||||
assert data.includes == set()
|
||||
assert data.null is None
|
||||
assert data.null_class is None
|
||||
|
||||
|
||||
def test_one_empty_class():
|
||||
@@ -24,6 +26,7 @@ def test_one_empty_class():
|
||||
assert data.classes == {
|
||||
'MyClass': schema.Class('MyClass'),
|
||||
}
|
||||
assert data.root_class is data.classes['MyClass']
|
||||
|
||||
|
||||
def test_two_empty_classes():
|
||||
@@ -39,6 +42,7 @@ def test_two_empty_classes():
|
||||
'MyClass1': schema.Class('MyClass1', derived={'MyClass2'}),
|
||||
'MyClass2': schema.Class('MyClass2', bases=['MyClass1']),
|
||||
}
|
||||
assert data.root_class is data.classes['MyClass1']
|
||||
|
||||
|
||||
def test_no_external_bases():
|
||||
@@ -452,7 +456,8 @@ def test_property_docstring_newline():
|
||||
property.""")
|
||||
|
||||
assert data.classes == {
|
||||
'A': schema.Class('A', properties=[schema.SingleProperty('x', 'int', description=["very important", "property."])])
|
||||
'A': schema.Class('A',
|
||||
properties=[schema.SingleProperty('x', 'int', description=["very important", "property."])])
|
||||
}
|
||||
|
||||
|
||||
@@ -566,5 +571,54 @@ def test_class_default_doc_name():
|
||||
}
|
||||
|
||||
|
||||
def test_null_class():
|
||||
@schema.load
|
||||
class data:
|
||||
class Root:
|
||||
pass
|
||||
|
||||
@defs.use_for_null
|
||||
class Null(Root):
|
||||
pass
|
||||
|
||||
assert data.classes == {
|
||||
'Root': schema.Class('Root', derived={'Null'}),
|
||||
'Null': schema.Class('Null', bases=['Root']),
|
||||
}
|
||||
assert data.null == 'Null'
|
||||
assert data.null_class is data.classes[data.null]
|
||||
|
||||
|
||||
def test_null_class_cannot_be_derived():
|
||||
with pytest.raises(schema.Error):
|
||||
@schema.load
|
||||
class data:
|
||||
class Root:
|
||||
pass
|
||||
|
||||
@defs.use_for_null
|
||||
class Null(Root):
|
||||
pass
|
||||
|
||||
class Impossible(Null):
|
||||
pass
|
||||
|
||||
|
||||
def test_null_class_cannot_be_defined_multiple_times():
|
||||
with pytest.raises(schema.Error):
|
||||
@schema.load
|
||||
class data:
|
||||
class Root:
|
||||
pass
|
||||
|
||||
@defs.use_for_null
|
||||
class Null1(Root):
|
||||
pass
|
||||
|
||||
@defs.use_for_null
|
||||
class Null2(Root):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(pytest.main([__file__] + sys.argv[1:]))
|
||||
|
||||
Reference in New Issue
Block a user