mirror of
https://github.com/github/codeql.git
synced 2025-12-17 01:03:14 +01:00
Swift: add trapgen unit tests
Closes: https://github.com/github/codeql-c-team/issues/981
This commit is contained in:
@@ -1,6 +1,11 @@
|
||||
load("@swift_codegen_deps//:requirements.bzl", "requirement")
|
||||
|
||||
py_binary(
|
||||
name = "codegen",
|
||||
srcs = glob(["*.py"]),
|
||||
srcs = glob(
|
||||
["*.py"],
|
||||
exclude = ["trapgen.py"],
|
||||
),
|
||||
visibility = ["//swift/codegen/test:__pkg__"],
|
||||
deps = ["//swift/codegen/lib"],
|
||||
)
|
||||
@@ -12,5 +17,8 @@ py_binary(
|
||||
srcs = ["trapgen.py"],
|
||||
data = ["//swift/codegen/templates:cpp"],
|
||||
visibility = ["//swift:__subpackages__"],
|
||||
deps = ["//swift/codegen/lib"],
|
||||
deps = [
|
||||
"//swift/codegen/lib",
|
||||
requirement("toposort"),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -21,12 +21,14 @@ class Field:
|
||||
type: str
|
||||
first: bool = False
|
||||
|
||||
@property
|
||||
def cpp_name(self):
|
||||
if self.name in cpp_keywords:
|
||||
return self.name + "_"
|
||||
return self.name
|
||||
|
||||
def stream(self):
|
||||
# using @property breaks pystache internals here
|
||||
def get_streamer(self):
|
||||
if self.type == "std::string":
|
||||
return lambda x: f"trapQuoted({x})"
|
||||
elif self.type == "bool":
|
||||
@@ -65,6 +67,7 @@ class Tag:
|
||||
self.bases = [TagBase(b) for b in self.bases]
|
||||
self.bases[0].first = True
|
||||
|
||||
@property
|
||||
def has_bases(self):
|
||||
return bool(self.bases)
|
||||
|
||||
|
||||
@@ -144,13 +144,10 @@ def get_union(match):
|
||||
|
||||
|
||||
def iterload(file):
|
||||
data = Re.comment.sub("", file.read())
|
||||
with open(file) as file:
|
||||
data = Re.comment.sub("", file.read())
|
||||
for e in Re.entity.finditer(data):
|
||||
if e["table"]:
|
||||
yield get_table(e)
|
||||
elif e["union"]:
|
||||
yield get_union(e)
|
||||
|
||||
|
||||
def load(file):
|
||||
return list(iterload(file))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
pystache
|
||||
pyyaml
|
||||
inflection
|
||||
pystache
|
||||
pytest
|
||||
pyyaml
|
||||
toposort
|
||||
|
||||
@@ -25,7 +25,7 @@ struct {{name}}Trap {
|
||||
|
||||
inline std::ostream &operator<<(std::ostream &out, const {{name}}Trap &e) {
|
||||
out << "{{table_name}}("{{#fields}}{{^first}} << ", "{{/first}}
|
||||
<< {{#stream}}e.{{cpp_name}}{{/stream}}{{/fields}} << ")";
|
||||
<< {{#get_streamer}}e.{{cpp_name}}{{/get_streamer}}{{/fields}} << ")";
|
||||
return out;
|
||||
}
|
||||
{{/traps}}
|
||||
|
||||
@@ -18,6 +18,7 @@ py_library(
|
||||
deps = [
|
||||
":utils",
|
||||
"//swift/codegen",
|
||||
"//swift/codegen:trapgen",
|
||||
],
|
||||
)
|
||||
for src in glob(["test_*.py"])
|
||||
|
||||
60
swift/codegen/test/test_cpp.py
Normal file
60
swift/codegen/test/test_cpp.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
|
||||
from swift.codegen.lib import cpp
|
||||
|
||||
|
||||
@pytest.mark.parametrize("keyword", cpp.cpp_keywords)
|
||||
def test_field_keyword_cpp_name(keyword):
|
||||
f = cpp.Field(keyword, "int")
|
||||
assert f.cpp_name == keyword + "_"
|
||||
|
||||
|
||||
def test_field_cpp_name():
|
||||
f = cpp.Field("foo", "int")
|
||||
assert f.cpp_name == "foo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type,expected", [
|
||||
("std::string", "trapQuoted(value)"),
|
||||
("bool", '(value ? "true" : "false")'),
|
||||
("something_else", "value"),
|
||||
])
|
||||
def test_field_get_streamer(type, expected):
|
||||
f = cpp.Field("name", type)
|
||||
assert f.get_streamer()("value") == expected
|
||||
|
||||
|
||||
def test_trap_has_first_field_marked():
|
||||
fields = [
|
||||
cpp.Field("a", "x"),
|
||||
cpp.Field("b", "y"),
|
||||
cpp.Field("c", "z"),
|
||||
]
|
||||
expected = deepcopy(fields)
|
||||
expected[0].first = True
|
||||
t = cpp.Trap("table_name", "name", fields)
|
||||
assert t.fields == expected
|
||||
|
||||
|
||||
def test_tag_has_first_base_marked():
|
||||
bases = ["a", "b", "c"]
|
||||
expected = [cpp.TagBase("a", first=True), cpp.TagBase("b"), cpp.TagBase("c")]
|
||||
t = cpp.Tag("name", bases, 0, "id")
|
||||
assert t.bases == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bases,expected", [
|
||||
([], False),
|
||||
(["a"], True),
|
||||
(["a", "b"], True)
|
||||
])
|
||||
def test_tag_has_bases(bases, expected):
|
||||
t = cpp.Tag("name", bases, 0, "id")
|
||||
assert t.has_bases is expected
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(pytest.main())
|
||||
@@ -48,5 +48,107 @@ def test_union_has_first_case_marked():
|
||||
assert [c.type for c in u.rhs] == rhs
|
||||
|
||||
|
||||
# load tests
|
||||
@pytest.fixture
|
||||
def load(tmp_path):
|
||||
file = tmp_path / "test.dbscheme"
|
||||
|
||||
def ret(yml):
|
||||
write(file, yml)
|
||||
return list(dbscheme.iterload(file))
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def test_load_empty(load):
|
||||
assert load("") == []
|
||||
|
||||
|
||||
def test_load_one_empty_table(load):
|
||||
assert load("""
|
||||
test_foos();
|
||||
""") == [
|
||||
dbscheme.Table(name="test_foos", columns=[])
|
||||
]
|
||||
|
||||
|
||||
def test_load_table_with_keyset(load):
|
||||
assert load("""
|
||||
#keyset[x, y,z]
|
||||
test_foos();
|
||||
""") == [
|
||||
dbscheme.Table(name="test_foos", columns=[], keyset=dbscheme.KeySet(["x", "y", "z"]))
|
||||
]
|
||||
|
||||
|
||||
expected_columns = [
|
||||
("int foo: int ref", dbscheme.Column(schema_name="foo", type="int", binding=False)),
|
||||
(" int bar : int ref", dbscheme.Column(schema_name="bar", type="int", binding=False)),
|
||||
("str baz_: str ref", dbscheme.Column(schema_name="baz", type="str", binding=False)),
|
||||
("int x: @foo ref", dbscheme.Column(schema_name="x", type="@foo", binding=False)),
|
||||
("int y: @foo", dbscheme.Column(schema_name="y", type="@foo", binding=True)),
|
||||
("unique int z: @foo", dbscheme.Column(schema_name="z", type="@foo", binding=True)),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("column,expected", expected_columns)
|
||||
def test_load_table_with_column(load, column, expected):
|
||||
assert load(f"""
|
||||
foos(
|
||||
{column}
|
||||
);
|
||||
""") == [
|
||||
dbscheme.Table(name="foos", columns=[deepcopy(expected)])
|
||||
]
|
||||
|
||||
|
||||
def test_load_table_with_multiple_columns(load):
|
||||
columns = ",\n".join(c for c, _ in expected_columns)
|
||||
expected = [deepcopy(e) for _, e in expected_columns]
|
||||
assert load(f"""
|
||||
foos(
|
||||
{columns}
|
||||
);
|
||||
""") == [
|
||||
dbscheme.Table(name="foos", columns=expected)
|
||||
]
|
||||
|
||||
|
||||
def test_load_multiple_table_with_columns(load):
|
||||
tables = [f"table{i}({col});" for i, (col, _) in enumerate(expected_columns)]
|
||||
expected = [dbscheme.Table(name=f"table{i}", columns=[deepcopy(e)]) for i, (_, e) in enumerate(expected_columns)]
|
||||
assert load("\n".join(tables)) == expected
|
||||
|
||||
|
||||
def test_union(load):
|
||||
assert load("@foo = @bar | @baz | @bla;") == [
|
||||
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
|
||||
]
|
||||
|
||||
|
||||
def test_table_and_union(load):
|
||||
assert load("""
|
||||
foos();
|
||||
|
||||
@foo = @bar | @baz | @bla;""") == [
|
||||
dbscheme.Table(name="foos", columns=[]),
|
||||
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
|
||||
]
|
||||
|
||||
|
||||
def test_comments_ignored(load):
|
||||
assert load("""
|
||||
// fake_table();
|
||||
foos(/* x */unique /*y*/int/*
|
||||
z
|
||||
*/ id/* */: /* * */ @bar/*,
|
||||
int ignored: int ref*/);
|
||||
|
||||
@foo = @bar | @baz | @bla; // | @xxx""") == [
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column(schema_name="id", type="@bar", binding=True)]),
|
||||
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
|
||||
]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(pytest.main())
|
||||
|
||||
163
swift/codegen/test/test_trapgen.py
Normal file
163
swift/codegen/test/test_trapgen.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import sys
|
||||
|
||||
from swift.codegen import trapgen
|
||||
from swift.codegen.lib import cpp, dbscheme
|
||||
from swift.codegen.test.utils import *
|
||||
|
||||
output_dir = pathlib.Path("path", "to", "output")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generate(opts, renderer, dbscheme_input):
|
||||
opts.trap_output = output_dir
|
||||
|
||||
def ret(entities):
|
||||
dbscheme_input.entities = entities
|
||||
generated = run_generation(trapgen.generate, opts, renderer)
|
||||
assert set(generated) == {output_dir /
|
||||
"TrapEntries.h", output_dir / "TrapTags.h"}
|
||||
return generated[output_dir / "TrapEntries.h"], generated[output_dir / "TrapTags.h"]
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generate_traps(generate):
|
||||
def ret(entities):
|
||||
traps, _ = generate(entities)
|
||||
assert isinstance(traps, cpp.TrapList)
|
||||
return traps.traps
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generate_tags(generate):
|
||||
def ret(entities):
|
||||
_, tags = generate(entities)
|
||||
assert isinstance(tags, cpp.TagList)
|
||||
return tags.tags
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def test_empty(generate):
|
||||
assert generate([]) == (cpp.TrapList([]), cpp.TagList([]))
|
||||
|
||||
|
||||
def test_one_empty_table_rejected(generate_traps):
|
||||
with pytest.raises(AssertionError):
|
||||
generate_traps([
|
||||
dbscheme.Table(name="foos", columns=[]),
|
||||
])
|
||||
|
||||
|
||||
def test_one_table(generate_traps):
|
||||
assert generate_traps([
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
|
||||
]) == [
|
||||
cpp.Trap("foos", name="Foos", fields=[cpp.Field("bla", "int")]),
|
||||
]
|
||||
|
||||
|
||||
def test_one_table(generate_traps):
|
||||
assert generate_traps([
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
|
||||
]) == [
|
||||
cpp.Trap("foos", name="Foos", fields=[cpp.Field("bla", "int")]),
|
||||
]
|
||||
|
||||
|
||||
def test_one_table_with_id(generate_traps):
|
||||
assert generate_traps([
|
||||
dbscheme.Table(name="foos", columns=[
|
||||
dbscheme.Column("bla", "int", binding=True)]),
|
||||
]) == [
|
||||
cpp.Trap("foos", name="Foos", fields=[cpp.Field(
|
||||
"bla", "int")], id=cpp.Field("bla", "int")),
|
||||
]
|
||||
|
||||
|
||||
def test_one_table_with_two_binding_first_is_id(generate_traps):
|
||||
assert generate_traps([
|
||||
dbscheme.Table(name="foos", columns=[
|
||||
dbscheme.Column("x", "a", binding=True),
|
||||
dbscheme.Column("y", "b", binding=True),
|
||||
]),
|
||||
]) == [
|
||||
cpp.Trap("foos", name="Foos", fields=[
|
||||
cpp.Field("x", "a"),
|
||||
cpp.Field("y", "b"),
|
||||
], id=cpp.Field("x", "a")),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("column,field", [
|
||||
(dbscheme.Column("x", "string"), cpp.Field("x", "std::string")),
|
||||
(dbscheme.Column("y", "boolean"), cpp.Field("y", "bool")),
|
||||
(dbscheme.Column("z", "@db_type"), cpp.Field("z", "TrapLabel<DbTypeTag>")),
|
||||
])
|
||||
def test_one_table_special_types(generate_traps, column, field):
|
||||
assert generate_traps([
|
||||
dbscheme.Table(name="foos", columns=[column]),
|
||||
]) == [
|
||||
cpp.Trap("foos", name="Foos", fields=[field]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("table,name,column,field", [
|
||||
("locations", "Locations", dbscheme.Column(
|
||||
"startWhatever", "bar"), cpp.Field("startWhatever", "unsigned")),
|
||||
("locations", "Locations", dbscheme.Column(
|
||||
"endWhatever", "bar"), cpp.Field("endWhatever", "unsigned")),
|
||||
("foos", "Foos", dbscheme.Column("startWhatever", "bar"),
|
||||
cpp.Field("startWhatever", "bar")),
|
||||
("foos", "Foos", dbscheme.Column("endWhatever", "bar"),
|
||||
cpp.Field("endWhatever", "bar")),
|
||||
("foos", "Foos", dbscheme.Column("index", "bar"), cpp.Field("index", "unsigned")),
|
||||
("foos", "Foos", dbscheme.Column("num_whatever", "bar"),
|
||||
cpp.Field("num_whatever", "unsigned")),
|
||||
("foos", "Foos", dbscheme.Column("whatever_", "bar"), cpp.Field("whatever", "bar")),
|
||||
])
|
||||
def test_one_table_overridden_fields(generate_traps, table, name, column, field):
|
||||
assert generate_traps([
|
||||
dbscheme.Table(name=table, columns=[column]),
|
||||
]) == [
|
||||
cpp.Trap(table, name=name, fields=[field]),
|
||||
]
|
||||
|
||||
|
||||
def test_one_table_no_tags(generate_tags):
|
||||
assert generate_tags([
|
||||
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
|
||||
]) == []
|
||||
|
||||
|
||||
def test_one_union_tags(generate_tags):
|
||||
assert generate_tags([
|
||||
dbscheme.Union(lhs="@left_hand_side", rhs=["@b", "@a", "@c"]),
|
||||
]) == [
|
||||
cpp.Tag(name="LeftHandSide", bases=[], index=0, id="@left_hand_side"),
|
||||
cpp.Tag(name="A", bases=["LeftHandSide"], index=1, id="@a"),
|
||||
cpp.Tag(name="B", bases=["LeftHandSide"], index=2, id="@b"),
|
||||
cpp.Tag(name="C", bases=["LeftHandSide"], index=3, id="@c"),
|
||||
]
|
||||
|
||||
|
||||
def test_multiple_union_tags(generate_tags):
|
||||
assert generate_tags([
|
||||
dbscheme.Union(lhs="@d", rhs=["@a"]),
|
||||
dbscheme.Union(lhs="@a", rhs=["@b", "@c"]),
|
||||
dbscheme.Union(lhs="@e", rhs=["@c", "@f"]),
|
||||
]) == [
|
||||
cpp.Tag(name="D", bases=[], index=0, id="@d"),
|
||||
cpp.Tag(name="E", bases=[], index=1, id="@e"),
|
||||
cpp.Tag(name="A", bases=["D"], index=2, id="@a"),
|
||||
cpp.Tag(name="F", bases=["E"], index=3, id="@f"),
|
||||
cpp.Tag(name="B", bases=["A"], index=4, id="@b"),
|
||||
cpp.Tag(name="C", bases=["A", "E"], index=5, id="@c"),
|
||||
]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(pytest.main())
|
||||
@@ -7,6 +7,7 @@ from swift.codegen.lib import render, schema
|
||||
|
||||
schema_dir = pathlib.Path("a", "dir")
|
||||
schema_file = schema_dir / "schema.yml"
|
||||
dbscheme_file = pathlib.Path("another", "dir", "test.dbscheme")
|
||||
|
||||
|
||||
def write(out, contents=""):
|
||||
@@ -38,7 +39,19 @@ def input(opts, tmp_path):
|
||||
load_mock.return_value = schema.Schema([])
|
||||
yield load_mock.return_value
|
||||
assert load_mock.mock_calls == [
|
||||
mock.call(opts.schema)
|
||||
mock.call(opts.schema),
|
||||
], load_mock.mock_calls
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dbscheme_input(opts, tmp_path):
|
||||
opts.dbscheme = tmp_path / dbscheme_file
|
||||
with mock.patch("swift.codegen.lib.dbscheme.iterload") as load_mock:
|
||||
load_mock.entities = []
|
||||
load_mock.side_effect = lambda _: load_mock.entities
|
||||
yield load_mock
|
||||
assert load_mock.mock_calls == [
|
||||
mock.call(opts.dbscheme),
|
||||
], load_mock.mock_calls
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import inflection
|
||||
from toposort import toposort_flatten
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
from lib import paths, dbscheme, generator, cpp
|
||||
from swift.codegen.lib import paths, dbscheme, generator, cpp
|
||||
|
||||
field_overrides = [
|
||||
(re.compile(r"locations.*::(start|end).*|.*::(index|num_.*)"), {"type": "unsigned"}),
|
||||
@@ -76,54 +76,26 @@ def get_trap(t: dbscheme.Table):
|
||||
)
|
||||
|
||||
|
||||
def get_guard(path):
|
||||
path = path.relative_to(paths.swift_dir)
|
||||
return str(path.with_suffix("")).replace("/", "_").upper()
|
||||
|
||||
|
||||
def get_topologically_ordered_tags(tags):
|
||||
degree_to_nodes = collections.defaultdict(set)
|
||||
nodes_to_degree = {}
|
||||
lookup = {}
|
||||
for name, t in tags.items():
|
||||
degree = len(t["bases"])
|
||||
degree_to_nodes[degree].add(name)
|
||||
nodes_to_degree[name] = degree
|
||||
while degree_to_nodes[0]:
|
||||
sinks = degree_to_nodes.pop(0)
|
||||
for sink in sorted(sinks):
|
||||
yield sink
|
||||
for d in tags[sink]["derived"]:
|
||||
degree = nodes_to_degree[d]
|
||||
degree_to_nodes[degree].remove(d)
|
||||
degree -= 1
|
||||
nodes_to_degree[d] = degree
|
||||
degree_to_nodes[degree].add(d)
|
||||
if any(degree_to_nodes.values()):
|
||||
raise ValueError("not a dag!")
|
||||
|
||||
|
||||
def generate(opts, renderer):
|
||||
tag_graph = collections.defaultdict(lambda: {"bases": [], "derived": []})
|
||||
tag_graph = {}
|
||||
out = opts.trap_output
|
||||
|
||||
traps = []
|
||||
with open(opts.dbscheme) as input:
|
||||
for e in dbscheme.iterload(input):
|
||||
if e.is_table:
|
||||
traps.append(get_trap(e))
|
||||
elif e.is_union:
|
||||
for d in e.rhs:
|
||||
tag_graph[e.lhs]["derived"].append(d.type)
|
||||
tag_graph[d.type]["bases"].append(e.lhs)
|
||||
for e in dbscheme.iterload(opts.dbscheme):
|
||||
if e.is_table:
|
||||
traps.append(get_trap(e))
|
||||
elif e.is_union:
|
||||
tag_graph.setdefault(e.lhs, set())
|
||||
for d in e.rhs:
|
||||
tag_graph.setdefault(d.type, set()).add(e.lhs)
|
||||
|
||||
renderer.render(cpp.TrapList(traps), out / "TrapEntries.h")
|
||||
|
||||
tags = []
|
||||
for index, tag in enumerate(get_topologically_ordered_tags(tag_graph)):
|
||||
for index, tag in enumerate(toposort_flatten(tag_graph)):
|
||||
tags.append(cpp.Tag(
|
||||
name=get_tag_name(tag),
|
||||
bases=[get_tag_name(b) for b in sorted(tag_graph[tag]["bases"])],
|
||||
bases=[get_tag_name(b) for b in sorted(tag_graph[tag])],
|
||||
index=index,
|
||||
id=tag,
|
||||
))
|
||||
|
||||
Reference in New Issue
Block a user