Swift: add structured C++ generated classes

This adds `cppgen`, creating structured C++ classes mirroring QL classes
out of `schema.yml`.

An example of generated code at the time of this commit can be found
[in this gist][1].

[1]: https://gist.github.com/redsun82/57304ddb487a8aa40eaa0caa695048fa

Closes https://github.com/github/codeql-c-team/issues/863
This commit is contained in:
Paolo Tranquilli
2022-05-04 18:13:54 +02:00
parent 10c5c8e71f
commit d5d1eb717d
21 changed files with 445 additions and 60 deletions

View File

@@ -1,5 +1,17 @@
load("@swift_codegen_deps//:requirements.bzl", "requirement")
filegroup(
name = "schema",
srcs = ["schema.yml"],
visibility = ["//swift:__subpackages__"],
)
filegroup(
name = "schema_includes",
srcs = glob(["*.dbscheme"]),
visibility = ["//swift:__subpackages__"],
)
py_binary(
name = "codegen",
srcs = glob(
@@ -15,6 +27,17 @@ py_binary(
py_binary(
name = "trapgen",
srcs = ["trapgen.py"],
data = ["//swift/codegen/templates:trap"],
visibility = ["//swift:__subpackages__"],
deps = [
"//swift/codegen/lib",
requirement("toposort"),
],
)
py_binary(
name = "cppgen",
srcs = ["cppgen.py"],
data = ["//swift/codegen/templates:cpp"],
visibility = ["//swift:__subpackages__"],
deps = [

67
swift/codegen/cppgen.py Normal file
View File

@@ -0,0 +1,67 @@
import functools
import inflection
from typing import Dict
from toposort import toposort_flatten
from swift.codegen.lib import cpp, generator, schema
def _get_type(t: str) -> str:
if t == "string":
return "std::string"
if t == "boolean":
return "bool"
if t[0].isupper():
return f"TrapLabel<{t}Tag>"
return t
def _get_field(cls: schema.Class, p: schema.Property) -> cpp.Field:
trap_name = None
if not p.is_single:
trap_name = inflection.pluralize(inflection.camelize(f"{cls.name}_{p.name}")) + "Trap"
args = dict(
name=p.name + ("_" if p.name in cpp.cpp_keywords else ""),
type=_get_type(p.type),
is_optional=p.is_optional,
is_repeated=p.is_repeated,
trap_name=trap_name,
)
args.update(cpp.get_field_override(p.name))
return cpp.Field(**args)
class Processor:
def __init__(self, data: Dict[str, schema.Class]):
self._classmap = data
@functools.cache
def _get_class(self, name: str) -> cpp.Class:
cls = self._classmap[name]
trap_name = None
if not cls.derived or any(p.is_single for p in cls.properties):
trap_name = inflection.pluralize(cls.name) + "Trap"
return cpp.Class(
name=name,
bases=[self._get_class(b) for b in cls.bases],
fields=[_get_field(cls, p) for p in cls.properties],
final=not cls.derived,
trap_name=trap_name,
)
def get_classes(self):
inheritance_graph = {k: cls.bases for k, cls in self._classmap.items()}
return [self._get_class(cls) for cls in toposort_flatten(inheritance_graph)]
def generate(opts, renderer):
processor = Processor({cls.name: cls for cls in schema.load(opts.schema).classes})
out = opts.cpp_output
renderer.render(cpp.ClassList(processor.get_classes()), out / "TrapClasses.h")
tags = ("cpp", "schema")
if __name__ == "__main__":
generator.run()

View File

@@ -1,3 +1,4 @@
import re
from dataclasses import dataclass, field
from typing import List, ClassVar
@@ -14,13 +15,35 @@ cpp_keywords = {"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", "
"typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile", "wchar_t", "while",
"xor", "xor_eq"}
_field_overrides = [
(re.compile(r"(start|end)_(line|column)|index|num_.*"), {"type": "unsigned"}),
(re.compile(r"(.*)_"), lambda m: {"name": m[1]}),
]
def get_field_override(field: str):
for r, o in _field_overrides:
m = r.fullmatch(field)
if m:
return o(m) if callable(o) else o
return {}
@dataclass
class Field:
name: str
type: str
is_optional: bool = False
is_repeated: bool = False
trap_name: str = None
first: bool = False
def __post_init__(self):
if self.is_optional:
self.type = f"std::optional<{self.type}>"
elif self.is_repeated:
self.type = f"std::vector<{self.type}>"
@property
def cpp_name(self):
if self.name in cpp_keywords:
@@ -36,6 +59,12 @@ class Field:
else:
return lambda x: x
@property
def is_single(self):
return not (self.is_optional or self.is_repeated)
@dataclass
class Trap:
@@ -74,13 +103,48 @@ class Tag:
@dataclass
class TrapList:
template: ClassVar = 'cpp_traps'
template: ClassVar = 'trap_traps'
traps: List[Trap] = field(default_factory=list)
@dataclass
class TagList:
template: ClassVar = 'cpp_tags'
template: ClassVar = 'trap_tags'
tags: List[Tag] = field(default_factory=list)
@dataclass
class ClassBase:
ref: 'Class'
first: bool = False
@dataclass
class Class:
name: str
bases: List[ClassBase] = field(default_factory=list)
final: bool = False
fields: List[Field] = field(default_factory=list)
trap_name: str = None
def __post_init__(self):
self.bases = [ClassBase(c) for c in sorted(self.bases, key=lambda cls: cls.name)]
if self.bases:
self.bases[0].first = True
@property
def has_bases(self):
return bool(self.bases)
@property
def single_fields(self):
return [f for f in self.fields if f.is_single]
@dataclass
class ClassList:
template: ClassVar = "cpp_classes"
classes: List[Class]

View File

@@ -15,7 +15,7 @@ def _init_options():
Option("--ql-output", tags=["ql"], type=_abspath, default=paths.swift_dir / "ql/lib/codeql/swift/generated")
Option("--ql-stub-output", tags=["ql"], type=_abspath, default=paths.swift_dir / "ql/lib/codeql/swift/elements")
Option("--codeql-binary", tags=["ql"], default="codeql")
Option("--trap-output", tags=["trap"], type=_abspath, required=True)
Option("--cpp-output", tags=["cpp"], type=_abspath, required=True)
def _abspath(x):

View File

@@ -1,5 +1,11 @@
package(default_visibility = ["//swift:__subpackages__"])
filegroup(
name = "trap",
srcs = glob(["trap_*.mustache"]),
)
filegroup(
name = "cpp",
srcs = glob(["cpp_*.mustache"]),
visibility = ["//swift:__subpackages__"],
)

View File

@@ -0,0 +1,46 @@
// generated by {{generator}}
// clang-format off
#pragma once
#include <iostream>
#include <optional>
#include <vector>
#include "swift/extractor/trap/TrapLabel.h"
#include "swift/extractor/trap/TrapEntries.h"
namespace codeql {
{{#classes}}
struct {{name}}{{#final}} : Binding<{{name}}Tag>{{#bases}}, {{ref.name}}{{/bases}}{{/final}}{{^final}}{{#has_bases}}: {{#bases}}{{^first}}, {{/first}}{{ref.name}}{{/bases}}{{/has_bases}}{{/final}} {
{{#fields}}
{{type}} {{name}}{};
{{/fields}}
{{#final}}
friend std::ostream& operator<<(std::ostream& out, const {{name}}& x) {
x.emit(out);
return out;
}
{{/final}}
protected:
void emit({{^final}}TrapLabel<{{name}}Tag> id, {{/final}}std::ostream& out) const {
{{#bases}}
{{ref.name}}::emit(id, out);
{{/bases}}
{{#trap_name}}
out << {{.}}{id{{#single_fields}}, {{name}}{{/single_fields}}} << '\n';
{{/trap_name}}
{{#fields}}
{{#is_optional}}
if ({{name}}) out << {{trap_name}}{id, *{{name}}} << '\n';
{{/is_optional}}
{{#is_repeated}}
for (auto i = 0u; i < {{name}}.size(); ++i) out << {{trap_name}}{id, i, {{name}}[i]};
{{/is_repeated}}
{{/fields}}
}
};
{{/classes}}
}

View File

@@ -27,6 +27,27 @@ def test_field_get_streamer(type, expected):
assert f.get_streamer()("value") == expected
@pytest.mark.parametrize("is_optional,is_repeated,expected", [
(False, False, True),
(True, False, False),
(False, True, False),
(True, True, False),
])
def test_field_is_single(is_optional, is_repeated, expected):
f = cpp.Field("name", "type", is_optional=is_optional, is_repeated=is_repeated)
assert f.is_single is expected
@pytest.mark.parametrize("is_optional,is_repeated,expected", [
(False, False, "bar"),
(True, False, "std::optional<bar>"),
(False, True, "std::vector<bar>"),
])
def test_field_modal_types(is_optional, is_repeated, expected):
f = cpp.Field("name", "bar", is_optional=is_optional, is_repeated=is_repeated)
assert f.type == expected
def test_trap_has_first_field_marked():
fields = [
cpp.Field("a", "x"),
@@ -56,5 +77,39 @@ def test_tag_has_bases(bases, expected):
assert t.has_bases is expected
def test_class_has_first_base_marked():
bases = [
cpp.Class("a"),
cpp.Class("b"),
cpp.Class("c"),
]
expected = [cpp.ClassBase(c) for c in bases]
expected[0].first = True
c = cpp.Class("foo", bases=bases)
assert c.bases == expected
@pytest.mark.parametrize("bases,expected", [
([], False),
(["a"], True),
(["a", "b"], True)
])
def test_class_has_bases(bases, expected):
t = cpp.Class("name", [cpp.Class(b) for b in bases])
assert t.has_bases is expected
def test_class_single_fields():
fields = [
cpp.Field("a", "A"),
cpp.Field("b", "B", is_optional=True),
cpp.Field("c", "C"),
cpp.Field("d", "D", is_repeated=True),
cpp.Field("e", "E"),
]
c = cpp.Class("foo", fields=fields)
assert c.single_fields == fields[::2]
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -0,0 +1,127 @@
import sys
from swift.codegen import cppgen
from swift.codegen.lib import cpp
from swift.codegen.test.utils import *
output_dir = pathlib.Path("path", "to", "output")
@pytest.fixture
def generate(opts, renderer, input):
opts.cpp_output = output_dir
def ret(classes):
input.classes = classes
generated = run_generation(cppgen.generate, opts, renderer)
assert set(generated) == {output_dir / "TrapClasses.h"}
generated = generated[output_dir / "TrapClasses.h"]
assert isinstance(generated, cpp.ClassList)
return generated.classes
return ret
def test_empty(generate):
assert generate([]) == []
def test_empty_class(generate):
assert generate([
schema.Class(name="MyClass"),
]) == [
cpp.Class(name="MyClass", final=True, trap_name="MyClassesTrap")
]
def test_two_class_hierarchy(generate):
base = cpp.Class(name="A")
assert generate([
schema.Class(name="A", derived={"B"}),
schema.Class(name="B", bases={"A"}),
]) == [
base,
cpp.Class(name="B", bases=[base], final=True, trap_name="BsTrap"),
]
def test_complex_hierarchy_topologically_ordered(generate):
a = cpp.Class(name="A")
b = cpp.Class(name="B")
c = cpp.Class(name="C", bases=[a])
d = cpp.Class(name="D", bases=[a])
e = cpp.Class(name="E", bases=[b, c, d], final=True, trap_name="EsTrap")
f = cpp.Class(name="F", bases=[c], final=True, trap_name="FsTrap")
assert generate([
schema.Class(name="F", bases={"C"}),
schema.Class(name="B", derived={"E"}),
schema.Class(name="D", bases={"A"}, derived={"E"}),
schema.Class(name="C", bases={"A"}, derived={"E", "F"}),
schema.Class(name="E", bases={"B", "C", "D"}),
schema.Class(name="A", derived={"C", "D"}),
]) == [a, b, c, d, e, f]
@pytest.mark.parametrize("type,expected", [
("a", "a"),
("string", "std::string"),
("boolean", "bool"),
("MyClass", "TrapLabel<MyClassTag>"),
])
@pytest.mark.parametrize("property_cls,optional,repeated,trap_name", [
(schema.SingleProperty, False, False, None),
(schema.OptionalProperty, True, False, "MyClassPropsTrap"),
(schema.RepeatedProperty, False, True, "MyClassPropsTrap"),
])
def test_class_with_field(generate, type, expected, property_cls, optional, repeated, trap_name):
assert generate([
schema.Class(name="MyClass", properties=[property_cls("prop", type)]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field("prop", expected, is_optional=optional,
is_repeated=repeated, trap_name=trap_name)],
trap_name="MyClassesTrap",
final=True)
]
@pytest.mark.parametrize("name", ["start_line", "start_column", "end_line", "end_column", "index", "num_whatever"])
def test_class_with_overridden_unsigned_field(generate, name):
assert generate([
schema.Class(name="MyClass", properties=[
schema.SingleProperty(name, "bar")]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field(name, "unsigned")],
trap_name="MyClassesTrap",
final=True)
]
def test_class_with_overridden_underscore_field(generate):
assert generate([
schema.Class(name="MyClass", properties=[
schema.SingleProperty("something_", "bar")]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field("something", "bar")],
trap_name="MyClassesTrap",
final=True)
]
@pytest.mark.parametrize("name", cpp.cpp_keywords)
def test_class_with_keyword_field(generate, name):
assert generate([
schema.Class(name="MyClass", properties=[
schema.SingleProperty(name, "bar")]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field(name + "_", "bar")],
trap_name="MyClassesTrap",
final=True)
]
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -151,4 +151,4 @@ int ignored: int ref*/);
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -304,4 +304,4 @@ def test_class_with_derived_and_repeated_property(opts, input, renderer):
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -97,4 +97,4 @@ def test_non_root_class():
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -195,4 +195,4 @@ def test_empty_cleanup(opts, input, renderer, tmp_path):
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -76,4 +76,4 @@ def test_cleanup(sut):
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -152,4 +152,4 @@ A:
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -9,7 +9,7 @@ output_dir = pathlib.Path("path", "to", "output")
@pytest.fixture
def generate(opts, renderer, dbscheme_input):
opts.trap_output = output_dir
opts.cpp_output = output_dir
def ret(entities):
dbscheme_input.entities = entities
@@ -105,28 +105,22 @@ def test_one_table_special_types(generate_traps, column, field):
]
@pytest.mark.parametrize("table,name,column,field", [
("locations", "Locations", dbscheme.Column(
"startWhatever", "bar"), cpp.Field("startWhatever", "unsigned")),
("locations", "Locations", dbscheme.Column(
"endWhatever", "bar"), cpp.Field("endWhatever", "unsigned")),
("foos", "Foos", dbscheme.Column("startWhatever", "bar"),
cpp.Field("startWhatever", "bar")),
("foos", "Foos", dbscheme.Column("endWhatever", "bar"),
cpp.Field("endWhatever", "bar")),
("foos", "Foos", dbscheme.Column("index", "bar"), cpp.Field("index", "unsigned")),
("foos", "Foos", dbscheme.Column("num_whatever", "bar"),
cpp.Field("num_whatever", "unsigned")),
("foos", "Foos", dbscheme.Column("whatever_", "bar"), cpp.Field("whatever", "bar")),
])
def test_one_table_overridden_fields(generate_traps, table, name, column, field):
@pytest.mark.parametrize("name", ["start_line", "start_column", "end_line", "end_column", "index", "num_whatever"])
def test_one_table_overridden_unsigned_field(generate_traps, name):
assert generate_traps([
dbscheme.Table(name=table, columns=[column]),
dbscheme.Table(name="foos", columns=[dbscheme.Column(name, "bar")]),
]) == [
cpp.Trap(table, name=name, fields=[field]),
cpp.Trap("foos", name="Foos", fields=[cpp.Field(name, "unsigned")]),
]
def test_one_table_overridden_underscore_named_field(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[dbscheme.Column("whatever_", "bar")]),
]) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field("whatever", "bar")]),
]
def test_one_table_no_tags(generate_tags):
assert generate_tags([
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
@@ -160,4 +154,4 @@ def test_multiple_union_tags(generate_tags):
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -1,36 +1,17 @@
#!/usr/bin/env python3
import logging
import os
import re
import sys
import inflection
from toposort import toposort_flatten
sys.path.append(os.path.dirname(__file__))
from swift.codegen.lib import dbscheme, generator, cpp
from swift.codegen.lib import paths, dbscheme, generator, cpp
field_overrides = [
(re.compile(r"locations.*::(start|end).*|.*::(index|num_.*)"), {"type": "unsigned"}),
(re.compile(r".*::(.*)_"), lambda m: {"name": m[1]}),
]
log = logging.getLogger(__name__)
def get_field_override(table, field):
spec = f"{table}::{field}"
for r, o in field_overrides:
m = r.fullmatch(spec)
if m and callable(o):
return o(m)
elif m:
return o
return {}
def get_tag_name(s):
assert s.startswith("@")
return inflection.camelize(s[1:])
@@ -52,7 +33,7 @@ def get_field(c: dbscheme.Column, table: str):
"name": c.schema_name,
"type": c.type,
}
args.update(get_field_override(table, c.schema_name))
args.update(cpp.get_field_override(c.schema_name))
args["type"] = get_cpp_type(args["type"])
return cpp.Field(**args)
@@ -78,7 +59,7 @@ def get_trap(t: dbscheme.Table):
def generate(opts, renderer):
tag_graph = {}
out = opts.trap_output
out = opts.cpp_output
traps = []
for e in dbscheme.iterload(opts.dbscheme):
@@ -102,7 +83,7 @@ def generate(opts, renderer):
renderer.render(cpp.TagList(tags), out / "TrapTags.h")
tags = ("trap", "dbscheme")
tags = ("cpp", "dbscheme")
if __name__ == "__main__":
generator.run()

View File

@@ -12,7 +12,7 @@
#include <llvm/Support/FileSystem.h>
#include <llvm/Support/Path.h>
#include "swift/extractor/trap/TrapEntries.h"
#include "swift/extractor/trap/TrapClasses.h"
using namespace codeql;
@@ -75,9 +75,10 @@ static void extractFile(const SwiftExtractorConfiguration& config, swift::Source
}
trap << "\n\n";
TrapLabel<FileTag> label{};
trap << label << "=*\n";
trap << FilesTrap{label, srcFilePath.str().str()} << "\n";
File f;
f.id = TrapLabel<FileTag>{};
f.name = srcFilePath.str().str();
trap << f.id << "=*\n" << f;
// TODO: Pick a better name to avoid collisions
std::string trapName = file.getFilename().str() + ".trap";

View File

@@ -1,16 +1,32 @@
genrule(
name = "gen",
name = "trapgen",
srcs = ["//swift:dbscheme"],
outs = [
"TrapEntries.h",
"TrapTags.h",
],
cmd = "$(location //swift/codegen:trapgen) --dbscheme $< --trap-output $(RULEDIR)",
cmd = "$(location //swift/codegen:trapgen) --dbscheme $< --cpp-output $(RULEDIR)",
exec_tools = ["//swift/codegen:trapgen"],
)
genrule(
name = "cppgen",
srcs = [
"//swift/codegen:schema",
"//swift/codegen:schema_includes",
],
outs = [
"TrapClasses.h",
],
cmd = "$(location //swift/codegen:cppgen) --schema $(location //swift/codegen:schema) --cpp-output $(RULEDIR)",
exec_tools = ["//swift/codegen:cppgen"],
)
cc_library(
name = "trap",
hdrs = glob(["*.h"]) + [":gen"],
hdrs = glob(["*.h"]) + [
":trapgen",
":cppgen",
],
visibility = ["//visibility:public"],
)

View File

@@ -54,6 +54,11 @@ inline auto trapQuoted(const std::string& s) {
return std::quoted(s, '"', '"');
}
template <typename Tag>
struct Binding {
TrapLabel<Tag> id;
};
} // namespace codeql
namespace std {