Swift: add trapgen unit tests

Closes: https://github.com/github/codeql-c-team/issues/981
This commit is contained in:
Paolo Tranquilli
2022-05-03 17:33:25 +02:00
parent 8e33653d25
commit 10c5c8e71f
11 changed files with 372 additions and 52 deletions

View File

@@ -1,6 +1,11 @@
load("@swift_codegen_deps//:requirements.bzl", "requirement")
py_binary(
name = "codegen",
srcs = glob(["*.py"]),
srcs = glob(
["*.py"],
exclude = ["trapgen.py"],
),
visibility = ["//swift/codegen/test:__pkg__"],
deps = ["//swift/codegen/lib"],
)
@@ -12,5 +17,8 @@ py_binary(
srcs = ["trapgen.py"],
data = ["//swift/codegen/templates:cpp"],
visibility = ["//swift:__subpackages__"],
deps = ["//swift/codegen/lib"],
deps = [
"//swift/codegen/lib",
requirement("toposort"),
],
)

View File

@@ -21,12 +21,14 @@ class Field:
type: str
first: bool = False
@property
def cpp_name(self):
if self.name in cpp_keywords:
return self.name + "_"
return self.name
def stream(self):
# using @property breaks pystache internals here
def get_streamer(self):
if self.type == "std::string":
return lambda x: f"trapQuoted({x})"
elif self.type == "bool":
@@ -65,6 +67,7 @@ class Tag:
self.bases = [TagBase(b) for b in self.bases]
self.bases[0].first = True
@property
def has_bases(self):
return bool(self.bases)

View File

@@ -144,13 +144,10 @@ def get_union(match):
def iterload(file):
data = Re.comment.sub("", file.read())
with open(file) as file:
data = Re.comment.sub("", file.read())
for e in Re.entity.finditer(data):
if e["table"]:
yield get_table(e)
elif e["union"]:
yield get_union(e)
def load(file):
return list(iterload(file))

View File

@@ -1,4 +1,5 @@
pystache
pyyaml
inflection
pystache
pytest
pyyaml
toposort

View File

@@ -25,7 +25,7 @@ struct {{name}}Trap {
inline std::ostream &operator<<(std::ostream &out, const {{name}}Trap &e) {
out << "{{table_name}}("{{#fields}}{{^first}} << ", "{{/first}}
<< {{#stream}}e.{{cpp_name}}{{/stream}}{{/fields}} << ")";
<< {{#get_streamer}}e.{{cpp_name}}{{/get_streamer}}{{/fields}} << ")";
return out;
}
{{/traps}}

View File

@@ -18,6 +18,7 @@ py_library(
deps = [
":utils",
"//swift/codegen",
"//swift/codegen:trapgen",
],
)
for src in glob(["test_*.py"])

View File

@@ -0,0 +1,60 @@
import sys
from copy import deepcopy
import pytest
from swift.codegen.lib import cpp
@pytest.mark.parametrize("keyword", cpp.cpp_keywords)
def test_field_keyword_cpp_name(keyword):
f = cpp.Field(keyword, "int")
assert f.cpp_name == keyword + "_"
def test_field_cpp_name():
f = cpp.Field("foo", "int")
assert f.cpp_name == "foo"
@pytest.mark.parametrize("type,expected", [
("std::string", "trapQuoted(value)"),
("bool", '(value ? "true" : "false")'),
("something_else", "value"),
])
def test_field_get_streamer(type, expected):
f = cpp.Field("name", type)
assert f.get_streamer()("value") == expected
def test_trap_has_first_field_marked():
fields = [
cpp.Field("a", "x"),
cpp.Field("b", "y"),
cpp.Field("c", "z"),
]
expected = deepcopy(fields)
expected[0].first = True
t = cpp.Trap("table_name", "name", fields)
assert t.fields == expected
def test_tag_has_first_base_marked():
bases = ["a", "b", "c"]
expected = [cpp.TagBase("a", first=True), cpp.TagBase("b"), cpp.TagBase("c")]
t = cpp.Tag("name", bases, 0, "id")
assert t.bases == expected
@pytest.mark.parametrize("bases,expected", [
([], False),
(["a"], True),
(["a", "b"], True)
])
def test_tag_has_bases(bases, expected):
t = cpp.Tag("name", bases, 0, "id")
assert t.has_bases is expected
if __name__ == '__main__':
sys.exit(pytest.main())

View File

@@ -48,5 +48,107 @@ def test_union_has_first_case_marked():
assert [c.type for c in u.rhs] == rhs
# load tests
@pytest.fixture
def load(tmp_path):
file = tmp_path / "test.dbscheme"
def ret(yml):
write(file, yml)
return list(dbscheme.iterload(file))
return ret
def test_load_empty(load):
assert load("") == []
def test_load_one_empty_table(load):
assert load("""
test_foos();
""") == [
dbscheme.Table(name="test_foos", columns=[])
]
def test_load_table_with_keyset(load):
assert load("""
#keyset[x, y,z]
test_foos();
""") == [
dbscheme.Table(name="test_foos", columns=[], keyset=dbscheme.KeySet(["x", "y", "z"]))
]
expected_columns = [
("int foo: int ref", dbscheme.Column(schema_name="foo", type="int", binding=False)),
(" int bar : int ref", dbscheme.Column(schema_name="bar", type="int", binding=False)),
("str baz_: str ref", dbscheme.Column(schema_name="baz", type="str", binding=False)),
("int x: @foo ref", dbscheme.Column(schema_name="x", type="@foo", binding=False)),
("int y: @foo", dbscheme.Column(schema_name="y", type="@foo", binding=True)),
("unique int z: @foo", dbscheme.Column(schema_name="z", type="@foo", binding=True)),
]
@pytest.mark.parametrize("column,expected", expected_columns)
def test_load_table_with_column(load, column, expected):
assert load(f"""
foos(
{column}
);
""") == [
dbscheme.Table(name="foos", columns=[deepcopy(expected)])
]
def test_load_table_with_multiple_columns(load):
columns = ",\n".join(c for c, _ in expected_columns)
expected = [deepcopy(e) for _, e in expected_columns]
assert load(f"""
foos(
{columns}
);
""") == [
dbscheme.Table(name="foos", columns=expected)
]
def test_load_multiple_table_with_columns(load):
tables = [f"table{i}({col});" for i, (col, _) in enumerate(expected_columns)]
expected = [dbscheme.Table(name=f"table{i}", columns=[deepcopy(e)]) for i, (_, e) in enumerate(expected_columns)]
assert load("\n".join(tables)) == expected
def test_union(load):
assert load("@foo = @bar | @baz | @bla;") == [
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
def test_table_and_union(load):
assert load("""
foos();
@foo = @bar | @baz | @bla;""") == [
dbscheme.Table(name="foos", columns=[]),
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
def test_comments_ignored(load):
assert load("""
// fake_table();
foos(/* x */unique /*y*/int/*
z
*/ id/* */: /* * */ @bar/*,
int ignored: int ref*/);
@foo = @bar | @baz | @bla; // | @xxx""") == [
dbscheme.Table(name="foos", columns=[dbscheme.Column(schema_name="id", type="@bar", binding=True)]),
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]
if __name__ == '__main__':
sys.exit(pytest.main())

View File

@@ -0,0 +1,163 @@
import sys
from swift.codegen import trapgen
from swift.codegen.lib import cpp, dbscheme
from swift.codegen.test.utils import *
output_dir = pathlib.Path("path", "to", "output")
@pytest.fixture
def generate(opts, renderer, dbscheme_input):
opts.trap_output = output_dir
def ret(entities):
dbscheme_input.entities = entities
generated = run_generation(trapgen.generate, opts, renderer)
assert set(generated) == {output_dir /
"TrapEntries.h", output_dir / "TrapTags.h"}
return generated[output_dir / "TrapEntries.h"], generated[output_dir / "TrapTags.h"]
return ret
@pytest.fixture
def generate_traps(generate):
def ret(entities):
traps, _ = generate(entities)
assert isinstance(traps, cpp.TrapList)
return traps.traps
return ret
@pytest.fixture
def generate_tags(generate):
def ret(entities):
_, tags = generate(entities)
assert isinstance(tags, cpp.TagList)
return tags.tags
return ret
def test_empty(generate):
assert generate([]) == (cpp.TrapList([]), cpp.TagList([]))
def test_one_empty_table_rejected(generate_traps):
with pytest.raises(AssertionError):
generate_traps([
dbscheme.Table(name="foos", columns=[]),
])
def test_one_table(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
]) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field("bla", "int")]),
]
def test_one_table(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
]) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field("bla", "int")]),
]
def test_one_table_with_id(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[
dbscheme.Column("bla", "int", binding=True)]),
]) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field(
"bla", "int")], id=cpp.Field("bla", "int")),
]
def test_one_table_with_two_binding_first_is_id(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[
dbscheme.Column("x", "a", binding=True),
dbscheme.Column("y", "b", binding=True),
]),
]) == [
cpp.Trap("foos", name="Foos", fields=[
cpp.Field("x", "a"),
cpp.Field("y", "b"),
], id=cpp.Field("x", "a")),
]
@pytest.mark.parametrize("column,field", [
(dbscheme.Column("x", "string"), cpp.Field("x", "std::string")),
(dbscheme.Column("y", "boolean"), cpp.Field("y", "bool")),
(dbscheme.Column("z", "@db_type"), cpp.Field("z", "TrapLabel<DbTypeTag>")),
])
def test_one_table_special_types(generate_traps, column, field):
assert generate_traps([
dbscheme.Table(name="foos", columns=[column]),
]) == [
cpp.Trap("foos", name="Foos", fields=[field]),
]
@pytest.mark.parametrize("table,name,column,field", [
("locations", "Locations", dbscheme.Column(
"startWhatever", "bar"), cpp.Field("startWhatever", "unsigned")),
("locations", "Locations", dbscheme.Column(
"endWhatever", "bar"), cpp.Field("endWhatever", "unsigned")),
("foos", "Foos", dbscheme.Column("startWhatever", "bar"),
cpp.Field("startWhatever", "bar")),
("foos", "Foos", dbscheme.Column("endWhatever", "bar"),
cpp.Field("endWhatever", "bar")),
("foos", "Foos", dbscheme.Column("index", "bar"), cpp.Field("index", "unsigned")),
("foos", "Foos", dbscheme.Column("num_whatever", "bar"),
cpp.Field("num_whatever", "unsigned")),
("foos", "Foos", dbscheme.Column("whatever_", "bar"), cpp.Field("whatever", "bar")),
])
def test_one_table_overridden_fields(generate_traps, table, name, column, field):
assert generate_traps([
dbscheme.Table(name=table, columns=[column]),
]) == [
cpp.Trap(table, name=name, fields=[field]),
]
def test_one_table_no_tags(generate_tags):
assert generate_tags([
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
]) == []
def test_one_union_tags(generate_tags):
assert generate_tags([
dbscheme.Union(lhs="@left_hand_side", rhs=["@b", "@a", "@c"]),
]) == [
cpp.Tag(name="LeftHandSide", bases=[], index=0, id="@left_hand_side"),
cpp.Tag(name="A", bases=["LeftHandSide"], index=1, id="@a"),
cpp.Tag(name="B", bases=["LeftHandSide"], index=2, id="@b"),
cpp.Tag(name="C", bases=["LeftHandSide"], index=3, id="@c"),
]
def test_multiple_union_tags(generate_tags):
assert generate_tags([
dbscheme.Union(lhs="@d", rhs=["@a"]),
dbscheme.Union(lhs="@a", rhs=["@b", "@c"]),
dbscheme.Union(lhs="@e", rhs=["@c", "@f"]),
]) == [
cpp.Tag(name="D", bases=[], index=0, id="@d"),
cpp.Tag(name="E", bases=[], index=1, id="@e"),
cpp.Tag(name="A", bases=["D"], index=2, id="@a"),
cpp.Tag(name="F", bases=["E"], index=3, id="@f"),
cpp.Tag(name="B", bases=["A"], index=4, id="@b"),
cpp.Tag(name="C", bases=["A", "E"], index=5, id="@c"),
]
if __name__ == '__main__':
sys.exit(pytest.main())

View File

@@ -7,6 +7,7 @@ from swift.codegen.lib import render, schema
schema_dir = pathlib.Path("a", "dir")
schema_file = schema_dir / "schema.yml"
dbscheme_file = pathlib.Path("another", "dir", "test.dbscheme")
def write(out, contents=""):
@@ -38,7 +39,19 @@ def input(opts, tmp_path):
load_mock.return_value = schema.Schema([])
yield load_mock.return_value
assert load_mock.mock_calls == [
mock.call(opts.schema)
mock.call(opts.schema),
], load_mock.mock_calls
@pytest.fixture
def dbscheme_input(opts, tmp_path):
opts.dbscheme = tmp_path / dbscheme_file
with mock.patch("swift.codegen.lib.dbscheme.iterload") as load_mock:
load_mock.entities = []
load_mock.side_effect = lambda _: load_mock.entities
yield load_mock
assert load_mock.mock_calls == [
mock.call(opts.dbscheme),
], load_mock.mock_calls

View File

@@ -1,16 +1,16 @@
#!/usr/bin/env python3
import collections
import logging
import os
import re
import sys
import inflection
from toposort import toposort_flatten
sys.path.append(os.path.dirname(__file__))
from lib import paths, dbscheme, generator, cpp
from swift.codegen.lib import paths, dbscheme, generator, cpp
field_overrides = [
(re.compile(r"locations.*::(start|end).*|.*::(index|num_.*)"), {"type": "unsigned"}),
@@ -76,54 +76,26 @@ def get_trap(t: dbscheme.Table):
)
def get_guard(path):
path = path.relative_to(paths.swift_dir)
return str(path.with_suffix("")).replace("/", "_").upper()
def get_topologically_ordered_tags(tags):
degree_to_nodes = collections.defaultdict(set)
nodes_to_degree = {}
lookup = {}
for name, t in tags.items():
degree = len(t["bases"])
degree_to_nodes[degree].add(name)
nodes_to_degree[name] = degree
while degree_to_nodes[0]:
sinks = degree_to_nodes.pop(0)
for sink in sorted(sinks):
yield sink
for d in tags[sink]["derived"]:
degree = nodes_to_degree[d]
degree_to_nodes[degree].remove(d)
degree -= 1
nodes_to_degree[d] = degree
degree_to_nodes[degree].add(d)
if any(degree_to_nodes.values()):
raise ValueError("not a dag!")
def generate(opts, renderer):
tag_graph = collections.defaultdict(lambda: {"bases": [], "derived": []})
tag_graph = {}
out = opts.trap_output
traps = []
with open(opts.dbscheme) as input:
for e in dbscheme.iterload(input):
if e.is_table:
traps.append(get_trap(e))
elif e.is_union:
for d in e.rhs:
tag_graph[e.lhs]["derived"].append(d.type)
tag_graph[d.type]["bases"].append(e.lhs)
for e in dbscheme.iterload(opts.dbscheme):
if e.is_table:
traps.append(get_trap(e))
elif e.is_union:
tag_graph.setdefault(e.lhs, set())
for d in e.rhs:
tag_graph.setdefault(d.type, set()).add(e.lhs)
renderer.render(cpp.TrapList(traps), out / "TrapEntries.h")
tags = []
for index, tag in enumerate(get_topologically_ordered_tags(tag_graph)):
for index, tag in enumerate(toposort_flatten(tag_graph)):
tags.append(cpp.Tag(
name=get_tag_name(tag),
bases=[get_tag_name(b) for b in sorted(tag_graph[tag]["bases"])],
bases=[get_tag_name(b) for b in sorted(tag_graph[tag])],
index=index,
id=tag,
))