Merge pull request #9034 from redsun82/swift-cpp-gen

Swift: add structured C++ generated classes
This commit is contained in:
Paolo Tranquilli
2022-05-09 17:49:23 +02:00
committed by GitHub
33 changed files with 659 additions and 187 deletions

View File

@@ -1,5 +1,17 @@
load("@swift_codegen_deps//:requirements.bzl", "requirement")
filegroup(
name = "schema",
srcs = ["schema.yml"],
visibility = ["//swift:__subpackages__"],
)
filegroup(
name = "schema_includes",
srcs = glob(["*.dbscheme"]),
visibility = ["//swift:__subpackages__"],
)
py_binary(
name = "codegen",
srcs = glob(
@@ -15,6 +27,17 @@ py_binary(
py_binary(
name = "trapgen",
srcs = ["trapgen.py"],
data = ["//swift/codegen/templates:trap"],
visibility = ["//swift:__subpackages__"],
deps = [
"//swift/codegen/lib",
requirement("toposort"),
],
)
py_binary(
name = "cppgen",
srcs = ["cppgen.py"],
data = ["//swift/codegen/templates:cpp"],
visibility = ["//swift:__subpackages__"],
deps = [

69
swift/codegen/cppgen.py Normal file
View File

@@ -0,0 +1,69 @@
import functools
from typing import Dict
import inflection
from toposort import toposort_flatten
from swift.codegen.lib import cpp, generator, schema
def _get_type(t: str, trap_affix: str) -> str:
if t == "string":
return "std::string"
if t == "boolean":
return "bool"
if t[0].isupper():
return f"{trap_affix}Label<{t}Tag>"
return t
def _get_field(cls: schema.Class, p: schema.Property, trap_affix: str) -> cpp.Field:
trap_name = None
if not p.is_single:
trap_name = inflection.pluralize(inflection.camelize(f"{cls.name}_{p.name}"))
args = dict(
name=p.name + ("_" if p.name in cpp.cpp_keywords else ""),
type=_get_type(p.type, trap_affix),
is_optional=p.is_optional,
is_repeated=p.is_repeated,
trap_name=trap_name,
)
args.update(cpp.get_field_override(p.name))
return cpp.Field(**args)
class Processor:
def __init__(self, data: Dict[str, schema.Class], trap_affix: str):
self._classmap = data
self._trap_affix = trap_affix
@functools.lru_cache(maxsize=None)
def _get_class(self, name: str) -> cpp.Class:
cls = self._classmap[name]
trap_name = None
if not cls.derived or any(p.is_single for p in cls.properties):
trap_name = inflection.pluralize(cls.name)
return cpp.Class(
name=name,
bases=[self._get_class(b) for b in cls.bases],
fields=[_get_field(cls, p, self._trap_affix) for p in cls.properties],
final=not cls.derived,
trap_name=trap_name,
)
def get_classes(self):
inheritance_graph = {k: cls.bases for k, cls in self._classmap.items()}
return [self._get_class(cls) for cls in toposort_flatten(inheritance_graph)]
def generate(opts, renderer):
processor = Processor({cls.name: cls for cls in schema.load(opts.schema).classes}, opts.trap_affix)
out = opts.cpp_output
renderer.render(cpp.ClassList(processor.get_classes(), opts.cpp_namespace, opts.trap_affix,
opts.cpp_include_dir), out / f"{opts.trap_affix}Classes.h")
tags = ("cpp", "schema")
if __name__ == "__main__":
generator.run()

View File

@@ -38,16 +38,7 @@ def cls_to_dbscheme(cls: schema.Class):
)
# use property-specific tables for 1-to-many and 1-to-at-most-1 properties
for f in cls.properties:
if f.is_optional:
yield Table(
keyset=KeySet(["id"]),
name=inflection.tableize(f"{cls.name}_{f.name}"),
columns=[
Column("id", type=dbtype(cls.name)),
Column(f.name, dbtype(f.type)),
],
)
elif f.is_repeated:
if f.is_repeated:
yield Table(
keyset=KeySet(["id", "index"]),
name=inflection.tableize(f"{cls.name}_{f.name}"),
@@ -57,18 +48,27 @@ def cls_to_dbscheme(cls: schema.Class):
Column(inflection.singularize(f.name), dbtype(f.type)),
]
)
elif f.is_optional:
yield Table(
keyset=KeySet(["id"]),
name=inflection.tableize(f"{cls.name}_{f.name}"),
columns=[
Column("id", type=dbtype(cls.name)),
Column(f.name, dbtype(f.type)),
],
)
def get_declarations(data: schema.Schema):
return [d for cls in data.classes for d in cls_to_dbscheme(cls)]
def get_includes(data: schema.Schema, include_dir: pathlib.Path):
def get_includes(data: schema.Schema, include_dir: pathlib.Path, swift_dir: pathlib.Path):
includes = []
for inc in data.includes:
inc = include_dir / inc
with open(inc) as inclusion:
includes.append(SchemeInclude(src=inc.relative_to(paths.swift_dir), data=inclusion.read()))
includes.append(SchemeInclude(src=inc.relative_to(swift_dir), data=inclusion.read()))
return includes
@@ -78,8 +78,8 @@ def generate(opts, renderer):
data = schema.load(input)
dbscheme = Scheme(src=input.relative_to(paths.swift_dir),
includes=get_includes(data, include_dir=input.parent),
dbscheme = Scheme(src=input.relative_to(opts.swift_dir),
includes=get_includes(data, include_dir=input.parent, swift_dir=opts.swift_dir),
declarations=get_declarations(data))
renderer.render(dbscheme, out)

View File

@@ -1,3 +1,4 @@
import re
from dataclasses import dataclass, field
from typing import List, ClassVar
@@ -14,13 +15,35 @@ cpp_keywords = {"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", "
"typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile", "wchar_t", "while",
"xor", "xor_eq"}
_field_overrides = [
(re.compile(r"(start|end)_(line|column)|index|width|num_.*"), {"type": "unsigned"}),
(re.compile(r"(.*)_"), lambda m: {"name": m[1]}),
]
def get_field_override(field: str):
for r, o in _field_overrides:
m = r.fullmatch(field)
if m:
return o(m) if callable(o) else o
return {}
@dataclass
class Field:
name: str
type: str
is_optional: bool = False
is_repeated: bool = False
trap_name: str = None
first: bool = False
def __post_init__(self):
if self.is_optional:
self.type = f"std::optional<{self.type}>"
if self.is_repeated:
self.type = f"std::vector<{self.type}>"
@property
def cpp_name(self):
if self.name in cpp_keywords:
@@ -36,6 +59,12 @@ class Field:
else:
return lambda x: x
@property
def is_single(self):
return not (self.is_optional or self.is_repeated)
@dataclass
class Trap:
@@ -74,13 +103,55 @@ class Tag:
@dataclass
class TrapList:
template: ClassVar = 'cpp_traps'
template: ClassVar = 'trap_traps'
traps: List[Trap] = field(default_factory=list)
traps: List[Trap]
namespace: str
trap_affix: str
include_dir: str
@dataclass
class TagList:
template: ClassVar = 'cpp_tags'
template: ClassVar = 'trap_tags'
tags: List[Tag] = field(default_factory=list)
tags: List[Tag]
namespace: str
@dataclass
class ClassBase:
ref: 'Class'
first: bool = False
@dataclass
class Class:
name: str
bases: List[ClassBase] = field(default_factory=list)
final: bool = False
fields: List[Field] = field(default_factory=list)
trap_name: str = None
def __post_init__(self):
self.bases = [ClassBase(c) for c in sorted(self.bases, key=lambda cls: cls.name)]
if self.bases:
self.bases[0].first = True
@property
def has_bases(self):
return bool(self.bases)
@property
def single_fields(self):
return [f for f in self.fields if f.is_single]
@dataclass
class ClassList:
template: ClassVar = "cpp_classes"
classes: List[Class]
namespace: str
trap_affix: str
include_dir: str

View File

@@ -3,26 +3,30 @@
import argparse
import logging
import sys
from typing import Set
from . import options, render
from . import options, render, paths
def _parse(tags):
def _parse(tags: Set[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser()
for opt in options.get(tags):
opt.add_to(parser)
ret = parser.parse_args()
log_level = logging.DEBUG if ret.verbose else logging.INFO
logging.basicConfig(format="{levelname} {message}", style='{', level=log_level)
return ret
return parser.parse_args()
def run(*modules):
def run(*modules, **kwargs):
""" run generation functions in specified in `modules`, or in current module by default
"""
if modules:
opts = _parse({t for m in modules for t in m.tags})
if kwargs:
opts = argparse.Namespace(**kwargs)
else:
opts = _parse({t for m in modules for t in m.tags})
log_level = logging.DEBUG if opts.verbose else logging.INFO
logging.basicConfig(format="{levelname} {message}", style='{', level=log_level)
exe_path = paths.exe_file.relative_to(opts.swift_dir)
for m in modules:
m.generate(opts, render.Renderer())
m.generate(opts, render.Renderer(exe_path))
else:
run(sys.modules["__main__"])
run(sys.modules["__main__"], **kwargs)

View File

@@ -10,12 +10,18 @@ from . import paths
def _init_options():
Option("--verbose", "-v", action="store_true")
Option("--swift-dir", type=_abspath, default=paths.swift_dir)
Option("--schema", tags=["schema"], type=_abspath, default=paths.swift_dir / "codegen/schema.yml")
Option("--dbscheme", tags=["dbscheme"], type=_abspath, default=paths.swift_dir / "ql/lib/swift.dbscheme")
Option("--ql-output", tags=["ql"], type=_abspath, default=paths.swift_dir / "ql/lib/codeql/swift/generated")
Option("--ql-stub-output", tags=["ql"], type=_abspath, default=paths.swift_dir / "ql/lib/codeql/swift/elements")
Option("--ql-format", tags=["ql"], action="store_true", default=True)
Option("--no-ql-format", tags=["ql"], action="store_false", dest="ql_format")
Option("--codeql-binary", tags=["ql"], default="codeql")
Option("--trap-output", tags=["trap"], type=_abspath, required=True)
Option("--cpp-output", tags=["cpp"], type=_abspath, required=True)
Option("--cpp-namespace", tags=["cpp"], default="codeql")
Option("--trap-affix", tags=["cpp"], default="Trap")
Option("--cpp-include-dir", tags=["cpp"], required=True)
def _abspath(x):

View File

@@ -15,7 +15,4 @@ except KeyError:
lib_dir = swift_dir / 'codegen' / 'lib'
templates_dir = swift_dir / 'codegen' / 'templates'
try:
exe_file = pathlib.Path(sys.argv[0]).resolve().relative_to(swift_dir)
except ValueError:
exe_file = pathlib.Path(sys.argv[0]).name
exe_file = pathlib.Path(sys.argv[0]).resolve()

View File

@@ -8,7 +8,6 @@ import inflection
@dataclass
class Param:
param: str
type: str = None
first: bool = False
@@ -19,15 +18,11 @@ class Property:
tablename: str
tableparams: List[Param]
plural: str = None
params: List[Param] = field(default_factory=list)
first: bool = False
local_var: str = "x"
is_optional: bool = False
def __post_init__(self):
if self.params:
self.params[0].first = True
while self.local_var in (p.param for p in self.params):
self.local_var += "_"
assert self.tableparams
if self.type_is_class:
self.tableparams = [x if x != "result" else self.local_var for x in self.tableparams]
@@ -43,6 +38,10 @@ class Property:
def type_is_class(self):
return self.type[0].isupper()
@property
def is_repeated(self):
return bool(self.plural)
@dataclass
class Class:

View File

@@ -18,9 +18,10 @@ log = logging.getLogger(__name__)
class Renderer:
""" Template renderer using mustache templates in the `templates` directory """
def __init__(self):
def __init__(self, generator):
self._r = pystache.Renderer(search_dirs=str(paths.templates_dir), escape=lambda u: u)
self.written = set()
self._generator = generator
def render(self, data, output: pathlib.Path):
""" Render `data` to `output`.
@@ -34,7 +35,7 @@ class Renderer:
"""
mnemonic = type(data).__name__
output.parent.mkdir(parents=True, exist_ok=True)
data = self._r.render_name(data.template, data, generator=paths.exe_file)
data = self._r.render_name(data.template, data, generator=self._generator)
with open(output, "w") as out:
out.write(data)
log.debug(f"generated {mnemonic} {output.name}")

View File

@@ -35,6 +35,12 @@ class RepeatedProperty(Property):
is_repeated: ClassVar = True
@dataclass
class RepeatedOptionalProperty(Property):
is_optional: ClassVar = True
is_repeated: ClassVar = True
@dataclass
class Class:
name: str
@@ -51,7 +57,10 @@ class Schema:
def _parse_property(name, type):
if type.endswith("*"):
if type.endswith("?*"):
cls = RepeatedOptionalProperty
type = type[:-2]
elif type.endswith("*"):
cls = RepeatedProperty
type = type[:-1]
elif type.endswith("?"):

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3
import logging
import pathlib
import subprocess
import inflection
@@ -18,13 +19,6 @@ def get_ql_property(cls: schema.Class, prop: schema.Property):
tablename=inflection.tableize(cls.name),
tableparams=["this"] + ["result" if p is prop else "_" for p in cls.properties if p.is_single],
)
elif prop.is_optional:
return ql.Property(
singular=inflection.camelize(prop.name),
type=prop.type,
tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
tableparams=["this", "result"],
)
elif prop.is_repeated:
return ql.Property(
singular=inflection.singularize(inflection.camelize(prop.name)),
@@ -32,7 +26,15 @@ def get_ql_property(cls: schema.Class, prop: schema.Property):
type=prop.type,
tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
tableparams=["this", "index", "result"],
params=[ql.Param("index", type="int")],
is_optional=prop.is_optional,
)
elif prop.is_optional:
return ql.Property(
singular=inflection.camelize(prop.name),
type=prop.type,
tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
tableparams=["this", "result"],
is_optional=True,
)
@@ -46,8 +48,8 @@ def get_ql_class(cls: schema.Class):
)
def get_import(file):
stem = file.relative_to(paths.swift_dir / "ql/lib").with_suffix("")
def get_import(file: pathlib.Path, swift_dir: pathlib.Path):
stem = file.relative_to(swift_dir / "ql/lib").with_suffix("")
return str(stem).replace("/", ".")
@@ -56,8 +58,6 @@ def get_types_used_by(cls: ql.Class):
yield b
for p in cls.properties:
yield p.type
for param in p.params:
yield param.type
def get_classes_used_by(cls: ql.Class):
@@ -90,7 +90,7 @@ def generate(opts, renderer):
imports = {}
for c in classes:
imports[c.name] = get_import(stub_out / c.path)
imports[c.name] = get_import(stub_out / c.path, opts.swift_dir)
for c in classes:
qll = (out / c.path).with_suffix(".qll")
@@ -98,7 +98,7 @@ def generate(opts, renderer):
renderer.render(c, qll)
stub_file = (stub_out / c.path).with_suffix(".qll")
if not stub_file.is_file() or is_generated(stub_file):
stub = ql.Stub(name=c.name, base_import=get_import(qll))
stub = ql.Stub(name=c.name, base_import=get_import(qll, opts.swift_dir))
renderer.render(stub, stub_file)
# for example path/to/syntax/generated -> path/to/syntax.qll
@@ -107,7 +107,8 @@ def generate(opts, renderer):
renderer.render(all_imports, include_file)
renderer.cleanup(existing)
format(opts.codeql_binary, renderer.written)
if opts.ql_format:
format(opts.codeql_binary, renderer.written)
tags = ("schema", "ql")

View File

@@ -256,7 +256,7 @@ OperatorDecl:
PatternBindingDecl:
_extends: Decl
inits: Expr*
inits: Expr?*
patterns: Pattern*
PoundDiagnosticDecl:

View File

@@ -1,5 +1,11 @@
package(default_visibility = ["//swift:__subpackages__"])
filegroup(
name = "trap",
srcs = glob(["trap_*.mustache"]),
)
filegroup(
name = "cpp",
srcs = glob(["cpp_*.mustache"]),
visibility = ["//swift:__subpackages__"],
)

View File

@@ -0,0 +1,55 @@
// generated by {{generator}}
// clang-format off
#pragma once
#include <iostream>
#include <optional>
#include <vector>
#include "{{include_dir}}/{{trap_affix}}Label.h"
#include "./{{trap_affix}}Entries.h"
namespace {{namespace}} {
{{#classes}}
struct {{name}}{{#final}} : Binding<{{name}}Tag>{{#bases}}, {{ref.name}}{{/bases}}{{/final}}{{^final}}{{#has_bases}}: {{#bases}}{{^first}}, {{/first}}{{ref.name}}{{/bases}}{{/has_bases}}{{/final}} {
{{#fields}}
{{type}} {{name}}{};
{{/fields}}
{{#final}}
friend std::ostream& operator<<(std::ostream& out, const {{name}}& x) {
x.emit(out);
return out;
}
{{/final}}
protected:
void emit({{^final}}{{trap_affix}}Label<{{name}}Tag> id, {{/final}}std::ostream& out) const {
{{#trap_name}}
out << {{.}}{{trap_affix}}{id{{#single_fields}}, {{name}}{{/single_fields}}} << '\n';
{{/trap_name}}
{{#bases}}
{{ref.name}}::emit(id, out);
{{/bases}}
{{#fields}}
{{#is_optional}}
{{^is_repeated}}
if ({{name}}) out << {{trap_name}}{{trap_affix}}{id, *{{name}}} << '\n';
{{/is_repeated}}
{{/is_optional}}
{{#is_repeated}}
for (auto i = 0u; i < {{name}}.size(); ++i) {
{{^is_optional}}
out << {{trap_name}}{{trap_affix}}{id, i, {{name}}[i]};
{{/is_optional}}
{{#is_optional}}
if ({{name}}[i]) out << {{trap_name}}{{trap_affix}}{id, i, *{{name}}[i]};
{{/is_optional}}
}
{{/is_repeated}}
{{/fields}}
}
};
{{/classes}}
}

View File

@@ -1,4 +1,5 @@
// generated by {{generator}}
{{#imports}}
import {{.}}
{{/imports}}
@@ -20,28 +21,28 @@ class {{name}}Base extends {{db_id}}{{#bases}}, {{.}}{{/bases}} {
{{/final}}
{{#properties}}
{{#type_is_class}}
{{type}} get{{singular}}({{#params}}{{^first}}, {{/first}}{{type}} {{param}}{{/params}}) {
{{type}} get{{singular}}({{#is_repeated}}int index{{/is_repeated}}) {
{{#type_is_class}}
exists({{type}} {{local_var}} |
{{tablename}}({{#tableparams}}{{^first}}, {{/first}}{{param}}{{/tableparams}})
and
result = {{local_var}}.resolve())
}
{{/type_is_class}}
{{^type_is_class}}
{{type}} get{{singular}}({{#params}}{{^first}}, {{/first}}{{type}} {{param}}{{/params}}) {
{{/type_is_class}}
{{^type_is_class}}
{{tablename}}({{#tableparams}}{{^first}}, {{/first}}{{param}}{{/tableparams}})
{{/type_is_class}}
}
{{/type_is_class}}
{{#indefinite_article}}
{{#is_repeated}}
{{type}} get{{.}}{{singular}}() {
result = get{{singular}}({{#params}}{{^first}}, {{/first}}_{{/params}})
{{type}} get{{indefinite_article}}{{singular}}() {
result = get{{singular}}(_)
}
{{^is_optional}}
int getNumberOf{{plural}}() {
result = count(get{{.}}{{singular}}())
result = count(get{{indefinite_article}}{{singular}}())
}
{{/indefinite_article}}
{{/is_optional}}
{{/is_repeated}}
{{/properties}}
}

View File

@@ -1,4 +1,5 @@
// generated by {{generator}}, remove this comment if you wish to edit this file
private import {{base_import}}
class {{name}} extends {{name}}Base { }
class {{name}} extends {{name}}Base {}

View File

@@ -2,7 +2,7 @@
// clang-format off
#pragma once
namespace codeql {
namespace {{namespace}} {
{{#tags}}
// {{id}}

View File

@@ -5,14 +5,14 @@
#include <iostream>
#include <string>
#include "swift/extractor/trap/TrapLabel.h"
#include "swift/extractor/trap/TrapTags.h"
#include "{{include_dir}}/{{trap_affix}}Label.h"
#include "./{{trap_affix}}Tags.h"
namespace codeql {
namespace {{namespace}} {
{{#traps}}
// {{table_name}}
struct {{name}}Trap {
struct {{name}}{{trap_affix}} {
static constexpr bool is_binding = {{#id}}true{{/id}}{{^id}}false{{/id}};
{{#id}}
{{type}} getBoundLabel() const { return {{cpp_name}}; }
@@ -23,7 +23,7 @@ struct {{name}}Trap {
{{/fields}}
};
inline std::ostream &operator<<(std::ostream &out, const {{name}}Trap &e) {
inline std::ostream &operator<<(std::ostream &out, const {{name}}{{trap_affix}} &e) {
out << "{{table_name}}("{{#fields}}{{^first}} << ", "{{/first}}
<< {{#get_streamer}}e.{{cpp_name}}{{/get_streamer}}{{/fields}} << ")";
return out;

View File

@@ -27,6 +27,28 @@ def test_field_get_streamer(type, expected):
assert f.get_streamer()("value") == expected
@pytest.mark.parametrize("is_optional,is_repeated,expected", [
(False, False, True),
(True, False, False),
(False, True, False),
(True, True, False),
])
def test_field_is_single(is_optional, is_repeated, expected):
f = cpp.Field("name", "type", is_optional=is_optional, is_repeated=is_repeated)
assert f.is_single is expected
@pytest.mark.parametrize("is_optional,is_repeated,expected", [
(False, False, "bar"),
(True, False, "std::optional<bar>"),
(False, True, "std::vector<bar>"),
(True, True, "std::vector<std::optional<bar>>"),
])
def test_field_modal_types(is_optional, is_repeated, expected):
f = cpp.Field("name", "bar", is_optional=is_optional, is_repeated=is_repeated)
assert f.type == expected
def test_trap_has_first_field_marked():
fields = [
cpp.Field("a", "x"),
@@ -56,5 +78,39 @@ def test_tag_has_bases(bases, expected):
assert t.has_bases is expected
def test_class_has_first_base_marked():
bases = [
cpp.Class("a"),
cpp.Class("b"),
cpp.Class("c"),
]
expected = [cpp.ClassBase(c) for c in bases]
expected[0].first = True
c = cpp.Class("foo", bases=bases)
assert c.bases == expected
@pytest.mark.parametrize("bases,expected", [
([], False),
(["a"], True),
(["a", "b"], True)
])
def test_class_has_bases(bases, expected):
t = cpp.Class("name", [cpp.Class(b) for b in bases])
assert t.has_bases is expected
def test_class_single_fields():
fields = [
cpp.Field("a", "A"),
cpp.Field("b", "B", is_optional=True),
cpp.Field("c", "C"),
cpp.Field("d", "D", is_repeated=True),
cpp.Field("e", "E"),
]
c = cpp.Class("foo", fields=fields)
assert c.single_fields == fields[::2]
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -0,0 +1,135 @@
import sys
from swift.codegen import cppgen
from swift.codegen.lib import cpp
from swift.codegen.test.utils import *
output_dir = pathlib.Path("path", "to", "output")
@pytest.fixture
def generate(opts, renderer, input):
opts.cpp_output = output_dir
opts.cpp_namespace = "test_namespace"
opts.trap_affix = "TestTrapAffix"
opts.cpp_include_dir = "my/include/dir"
def ret(classes):
input.classes = classes
generated = run_generation(cppgen.generate, opts, renderer)
assert set(generated) == {output_dir / "TestTrapAffixClasses.h"}
generated = generated[output_dir / "TestTrapAffixClasses.h"]
assert isinstance(generated, cpp.ClassList)
assert generated.namespace == opts.cpp_namespace
assert generated.trap_affix == opts.trap_affix
assert generated.include_dir == opts.cpp_include_dir
return generated.classes
return ret
def test_empty(generate):
assert generate([]) == []
def test_empty_class(generate):
assert generate([
schema.Class(name="MyClass"),
]) == [
cpp.Class(name="MyClass", final=True, trap_name="MyClasses")
]
def test_two_class_hierarchy(generate):
base = cpp.Class(name="A")
assert generate([
schema.Class(name="A", derived={"B"}),
schema.Class(name="B", bases={"A"}),
]) == [
base,
cpp.Class(name="B", bases=[base], final=True, trap_name="Bs"),
]
def test_complex_hierarchy_topologically_ordered(generate):
a = cpp.Class(name="A")
b = cpp.Class(name="B")
c = cpp.Class(name="C", bases=[a])
d = cpp.Class(name="D", bases=[a])
e = cpp.Class(name="E", bases=[b, c, d], final=True, trap_name="Es")
f = cpp.Class(name="F", bases=[c], final=True, trap_name="Fs")
assert generate([
schema.Class(name="F", bases={"C"}),
schema.Class(name="B", derived={"E"}),
schema.Class(name="D", bases={"A"}, derived={"E"}),
schema.Class(name="C", bases={"A"}, derived={"E", "F"}),
schema.Class(name="E", bases={"B", "C", "D"}),
schema.Class(name="A", derived={"C", "D"}),
]) == [a, b, c, d, e, f]
@pytest.mark.parametrize("type,expected", [
("a", "a"),
("string", "std::string"),
("boolean", "bool"),
("MyClass", "TestTrapAffixLabel<MyClassTag>"),
])
@pytest.mark.parametrize("property_cls,optional,repeated,trap_name", [
(schema.SingleProperty, False, False, None),
(schema.OptionalProperty, True, False, "MyClassProps"),
(schema.RepeatedProperty, False, True, "MyClassProps"),
(schema.RepeatedOptionalProperty, True, True, "MyClassProps"),
])
def test_class_with_field(generate, type, expected, property_cls, optional, repeated, trap_name):
assert generate([
schema.Class(name="MyClass", properties=[property_cls("prop", type)]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field("prop", expected, is_optional=optional,
is_repeated=repeated, trap_name=trap_name)],
trap_name="MyClasses",
final=True)
]
@pytest.mark.parametrize("name",
["start_line", "start_column", "end_line", "end_column", "index", "num_whatever", "width"])
def test_class_with_overridden_unsigned_field(generate, name):
assert generate([
schema.Class(name="MyClass", properties=[
schema.SingleProperty(name, "bar")]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field(name, "unsigned")],
trap_name="MyClasses",
final=True)
]
def test_class_with_overridden_underscore_field(generate):
assert generate([
schema.Class(name="MyClass", properties=[
schema.SingleProperty("something_", "bar")]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field("something", "bar")],
trap_name="MyClasses",
final=True)
]
@pytest.mark.parametrize("name", cpp.cpp_keywords)
def test_class_with_keyword_field(generate, name):
assert generate([
schema.Class(name="MyClass", properties=[
schema.SingleProperty(name, "bar")]),
]) == [
cpp.Class(name="MyClass",
fields=[cpp.Field(name + "_", "bar")],
trap_name="MyClasses",
final=True)
]
if __name__ == '__main__':
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -151,4 +151,4 @@ int ignored: int ref*/);
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -126,10 +126,11 @@ def test_final_class_with_optional_field(opts, input, renderer):
)
def test_final_class_with_repeated_field(opts, input, renderer):
@pytest.mark.parametrize("property_cls", [schema.RepeatedProperty, schema.RepeatedOptionalProperty])
def test_final_class_with_repeated_field(opts, input, renderer, property_cls):
input.classes = [
schema.Class("Object", properties=[
schema.RepeatedProperty("foo", "bar"),
property_cls("foo", "bar"),
]),
]
assert generate(opts, renderer) == dbscheme.Scheme(
@@ -161,7 +162,8 @@ def test_final_class_with_more_fields(opts, input, renderer):
schema.SingleProperty("one", "x"),
schema.SingleProperty("two", "y"),
schema.OptionalProperty("three", "z"),
schema.RepeatedProperty("four", "w"),
schema.RepeatedProperty("four", "u"),
schema.RepeatedOptionalProperty("five", "v"),
]),
]
assert generate(opts, renderer) == dbscheme.Scheme(
@@ -190,7 +192,16 @@ def test_final_class_with_more_fields(opts, input, renderer):
columns=[
dbscheme.Column('id', '@object'),
dbscheme.Column('index', 'int'),
dbscheme.Column('four', 'w'),
dbscheme.Column('four', 'u'),
]
),
dbscheme.Table(
name="object_fives",
keyset=dbscheme.KeySet(["id", "index"]),
columns=[
dbscheme.Column('id', '@object'),
dbscheme.Column('index', 'int'),
dbscheme.Column('five', 'v'),
]
),
],
@@ -304,4 +315,4 @@ def test_class_with_derived_and_repeated_property(opts, input, renderer):
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -5,31 +5,11 @@ from swift.codegen.lib import ql
from swift.codegen.test.utils import *
def test_property_has_first_param_marked():
params = [ql.Param("a", "x"), ql.Param("b", "y"), ql.Param("c", "z")]
expected = deepcopy(params)
expected[0].first = True
prop = ql.Property("Prop", "foo", "props", ["this"], params=params)
assert prop.params == expected
def test_property_has_first_table_param_marked():
tableparams = ["a", "b", "c"]
prop = ql.Property("Prop", "foo", "props", tableparams)
assert prop.tableparams[0].first
assert [p.param for p in prop.tableparams] == tableparams
assert all(p.type is None for p in prop.tableparams)
@pytest.mark.parametrize("params,expected_local_var", [
(["a", "b", "c"], "x"),
(["a", "x", "c"], "x_"),
(["a", "x", "x_", "c"], "x__"),
(["a", "x", "x_", "x__"], "x___"),
])
def test_property_local_var_avoids_params_collision(params, expected_local_var):
prop = ql.Property("Prop", "foo", "props", ["this"], params=[ql.Param(p) for p in params])
assert prop.local_var == expected_local_var
def test_property_not_a_class():
@@ -59,6 +39,16 @@ def test_property_indefinite_article(name, expected_article):
assert prop.indefinite_article == expected_article
@pytest.mark.parametrize("plural,expected", [
(None, False),
("", False),
("X", True),
])
def test_property_is_plural(plural, expected):
prop = ql.Property("foo", "Foo", "props", ["x"], plural=plural)
assert prop.is_repeated is expected
def test_property_no_plural_no_indefinite_article():
prop = ql.Property("Prop", "Foo", "props", ["x"])
assert prop.indefinite_article is None
@@ -97,4 +87,4 @@ def test_non_root_class():
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -18,12 +18,12 @@ ql_output_path = lambda: paths.swift_dir / "ql/lib/other/path"
import_file = lambda: stub_path().with_suffix(".qll")
stub_import_prefix = "stub.path."
gen_import_prefix = "other.path."
index_param = ql.Param("index", "int")
def generate(opts, renderer, written=None):
opts.ql_stub_output = stub_path()
opts.ql_output = ql_output_path()
opts.ql_format = True
renderer.written = written or []
return run_generation(qlgen.generate, opts, renderer)
@@ -107,7 +107,8 @@ def test_optional_property(opts, input, renderer):
import_file(): ql.ImportList([stub_import_prefix + "MyObject"]),
stub_path() / "MyObject.qll": ql.Stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
ql_output_path() / "MyObject.qll": ql.Class(name="MyObject", final=True, properties=[
ql.Property(singular="Foo", type="bar", tablename="my_object_foos", tableparams=["this", "result"]),
ql.Property(singular="Foo", type="bar", tablename="my_object_foos", tableparams=["this", "result"],
is_optional=True),
])
}
@@ -120,12 +121,26 @@ def test_repeated_property(opts, input, renderer):
import_file(): ql.ImportList([stub_import_prefix + "MyObject"]),
stub_path() / "MyObject.qll": ql.Stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
ql_output_path() / "MyObject.qll": ql.Class(name="MyObject", final=True, properties=[
ql.Property(singular="Foo", plural="Foos", type="bar", tablename="my_object_foos", params=[index_param],
ql.Property(singular="Foo", plural="Foos", type="bar", tablename="my_object_foos",
tableparams=["this", "index", "result"]),
])
}
def test_repeated_optional_property(opts, input, renderer):
input.classes = [
schema.Class("MyObject", properties=[schema.RepeatedOptionalProperty("foo", "bar")]),
]
assert generate(opts, renderer) == {
import_file(): ql.ImportList([stub_import_prefix + "MyObject"]),
stub_path() / "MyObject.qll": ql.Stub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
ql_output_path() / "MyObject.qll": ql.Class(name="MyObject", final=True, properties=[
ql.Property(singular="Foo", plural="Foos", type="bar", tablename="my_object_foos",
tableparams=["this", "index", "result"], is_optional=True),
])
}
def test_single_class_property(opts, input, renderer):
input.classes = [
schema.Class("MyObject", properties=[schema.SingleProperty("foo", "Bar")]),
@@ -195,4 +210,4 @@ def test_empty_cleanup(opts, input, renderer, tmp_path):
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -7,6 +7,9 @@ from swift.codegen.lib import paths
from swift.codegen.lib import render
generator = "test/foogen"
@pytest.fixture
def pystache_renderer_cls():
with mock.patch("pystache.Renderer") as ret:
@@ -22,7 +25,7 @@ def pystache_renderer(pystache_renderer_cls):
@pytest.fixture
def sut(pystache_renderer):
return render.Renderer()
return render.Renderer(generator)
def test_constructor(pystache_renderer_cls, sut):
@@ -40,7 +43,7 @@ def test_render(pystache_renderer, sut):
with mock.patch("builtins.open", mock.mock_open()) as output_stream:
sut.render(data, output)
assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=paths.exe_file),
mock.call.render_name(data.template, data, generator=generator),
]
assert output_stream.mock_calls == [
mock.call(output, 'w'),
@@ -76,4 +79,4 @@ def test_cleanup(sut):
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -140,6 +140,7 @@ A:
one: string
two: int?
three: bool*
four: x?*
""")
assert ret.classes == [
schema.Class(root_name, derived={'A'}),
@@ -147,9 +148,10 @@ A:
schema.SingleProperty('one', 'string'),
schema.OptionalProperty('two', 'int'),
schema.RepeatedProperty('three', 'bool'),
schema.RepeatedOptionalProperty('four', 'x'),
]),
]
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -9,40 +9,51 @@ output_dir = pathlib.Path("path", "to", "output")
@pytest.fixture
def generate(opts, renderer, dbscheme_input):
opts.trap_output = output_dir
opts.cpp_output = output_dir
opts.cpp_namespace = "test_namespace"
opts.trap_affix = "TrapAffix"
opts.cpp_include_dir = "my/include/dir"
def ret(entities):
dbscheme_input.entities = entities
generated = run_generation(trapgen.generate, opts, renderer)
assert set(generated) == {output_dir /
"TrapEntries.h", output_dir / "TrapTags.h"}
return generated[output_dir / "TrapEntries.h"], generated[output_dir / "TrapTags.h"]
"TrapAffixEntries.h", output_dir / "TrapAffixTags.h"}
return generated[output_dir / "TrapAffixEntries.h"], generated[output_dir / "TrapAffixTags.h"]
return ret
@pytest.fixture
def generate_traps(generate):
def generate_traps(opts, generate):
def ret(entities):
traps, _ = generate(entities)
assert isinstance(traps, cpp.TrapList)
assert traps.namespace == opts.cpp_namespace
assert traps.trap_affix == opts.trap_affix
assert traps.include_dir == opts.cpp_include_dir
return traps.traps
return ret
@pytest.fixture
def generate_tags(generate):
def generate_tags(opts, generate):
def ret(entities):
_, tags = generate(entities)
assert isinstance(tags, cpp.TagList)
assert tags.namespace == opts.cpp_namespace
return tags.tags
return ret
def test_empty(generate):
assert generate([]) == (cpp.TrapList([]), cpp.TagList([]))
def test_empty_traps(generate_traps):
assert generate_traps([]) == []
def test_empty_tags(generate_tags):
assert generate_tags([]) == []
def test_one_empty_table_rejected(generate_traps):
@@ -95,7 +106,7 @@ def test_one_table_with_two_binding_first_is_id(generate_traps):
@pytest.mark.parametrize("column,field", [
(dbscheme.Column("x", "string"), cpp.Field("x", "std::string")),
(dbscheme.Column("y", "boolean"), cpp.Field("y", "bool")),
(dbscheme.Column("z", "@db_type"), cpp.Field("z", "TrapLabel<DbTypeTag>")),
(dbscheme.Column("z", "@db_type"), cpp.Field("z", "TrapAffixLabel<DbTypeTag>")),
])
def test_one_table_special_types(generate_traps, column, field):
assert generate_traps([
@@ -105,28 +116,22 @@ def test_one_table_special_types(generate_traps, column, field):
]
@pytest.mark.parametrize("table,name,column,field", [
("locations", "Locations", dbscheme.Column(
"startWhatever", "bar"), cpp.Field("startWhatever", "unsigned")),
("locations", "Locations", dbscheme.Column(
"endWhatever", "bar"), cpp.Field("endWhatever", "unsigned")),
("foos", "Foos", dbscheme.Column("startWhatever", "bar"),
cpp.Field("startWhatever", "bar")),
("foos", "Foos", dbscheme.Column("endWhatever", "bar"),
cpp.Field("endWhatever", "bar")),
("foos", "Foos", dbscheme.Column("index", "bar"), cpp.Field("index", "unsigned")),
("foos", "Foos", dbscheme.Column("num_whatever", "bar"),
cpp.Field("num_whatever", "unsigned")),
("foos", "Foos", dbscheme.Column("whatever_", "bar"), cpp.Field("whatever", "bar")),
])
def test_one_table_overridden_fields(generate_traps, table, name, column, field):
@pytest.mark.parametrize("name", ["start_line", "start_column", "end_line", "end_column", "index", "num_whatever"])
def test_one_table_overridden_unsigned_field(generate_traps, name):
assert generate_traps([
dbscheme.Table(name=table, columns=[column]),
dbscheme.Table(name="foos", columns=[dbscheme.Column(name, "bar")]),
]) == [
cpp.Trap(table, name=name, fields=[field]),
cpp.Trap("foos", name="Foos", fields=[cpp.Field(name, "unsigned")]),
]
def test_one_table_overridden_underscore_named_field(generate_traps):
assert generate_traps([
dbscheme.Table(name="foos", columns=[dbscheme.Column("whatever_", "bar")]),
]) == [
cpp.Trap("foos", name="Foos", fields=[cpp.Field("whatever", "bar")]),
]
def test_one_table_no_tags(generate_tags):
assert generate_tags([
dbscheme.Table(name="foos", columns=[dbscheme.Column("bla", "int")]),
@@ -160,4 +165,4 @@ def test_multiple_union_tags(generate_tags):
if __name__ == '__main__':
sys.exit(pytest.main())
sys.exit(pytest.main([__file__] + sys.argv[1:]))

View File

@@ -3,7 +3,7 @@ from unittest import mock
import pytest
from swift.codegen.lib import render, schema
from swift.codegen.lib import render, schema, paths
schema_dir = pathlib.Path("a", "dir")
schema_file = schema_dir / "schema.yml"
@@ -18,12 +18,14 @@ def write(out, contents=""):
@pytest.fixture
def renderer():
return mock.Mock(spec=render.Renderer())
return mock.Mock(spec=render.Renderer(""))
@pytest.fixture
def opts():
return mock.MagicMock()
ret = mock.MagicMock()
ret.swift_dir = paths.swift_dir
return ret
@pytest.fixture(autouse=True)

View File

@@ -1,45 +1,24 @@
#!/usr/bin/env python3
import logging
import os
import re
import sys
import inflection
from toposort import toposort_flatten
sys.path.append(os.path.dirname(__file__))
from swift.codegen.lib import paths, dbscheme, generator, cpp
field_overrides = [
(re.compile(r"locations.*::(start|end).*|.*::(index|num_.*)"), {"type": "unsigned"}),
(re.compile(r".*::(.*)_"), lambda m: {"name": m[1]}),
]
from swift.codegen.lib import dbscheme, generator, cpp
log = logging.getLogger(__name__)
def get_field_override(table, field):
spec = f"{table}::{field}"
for r, o in field_overrides:
m = r.fullmatch(spec)
if m and callable(o):
return o(m)
elif m:
return o
return {}
def get_tag_name(s):
assert s.startswith("@")
return inflection.camelize(s[1:])
def get_cpp_type(schema_type):
def get_cpp_type(schema_type: str, trap_affix: str):
if schema_type.startswith("@"):
tag = get_tag_name(schema_type)
return f"TrapLabel<{tag}Tag>"
return f"{trap_affix}Label<{tag}Tag>"
if schema_type == "string":
return "std::string"
if schema_type == "boolean":
@@ -47,13 +26,13 @@ def get_cpp_type(schema_type):
return schema_type
def get_field(c: dbscheme.Column, table: str):
def get_field(c: dbscheme.Column, trap_affix: str):
args = {
"name": c.schema_name,
"type": c.type,
}
args.update(get_field_override(table, c.schema_name))
args["type"] = get_cpp_type(args["type"])
args.update(cpp.get_field_override(c.schema_name))
args["type"] = get_cpp_type(args["type"], trap_affix)
return cpp.Field(**args)
@@ -64,32 +43,33 @@ def get_binding_column(t: dbscheme.Table):
return None
def get_trap(t: dbscheme.Table):
def get_trap(t: dbscheme.Table, trap_affix: str):
id = get_binding_column(t)
if id:
id = get_field(id, t.name)
id = get_field(id, trap_affix)
return cpp.Trap(
table_name=t.name,
name=inflection.camelize(t.name),
fields=[get_field(c, t.name) for c in t.columns],
fields=[get_field(c, trap_affix) for c in t.columns],
id=id,
)
def generate(opts, renderer):
tag_graph = {}
out = opts.trap_output
out = opts.cpp_output
traps = []
for e in dbscheme.iterload(opts.dbscheme):
if e.is_table:
traps.append(get_trap(e))
traps.append(get_trap(e, opts.trap_affix))
elif e.is_union:
tag_graph.setdefault(e.lhs, set())
for d in e.rhs:
tag_graph.setdefault(d.type, set()).add(e.lhs)
renderer.render(cpp.TrapList(traps), out / "TrapEntries.h")
renderer.render(cpp.TrapList(traps, opts.cpp_namespace, opts.trap_affix, opts.cpp_include_dir),
out / f"{opts.trap_affix}Entries.h")
tags = []
for index, tag in enumerate(toposort_flatten(tag_graph)):
@@ -99,10 +79,10 @@ def generate(opts, renderer):
index=index,
id=tag,
))
renderer.render(cpp.TagList(tags), out / "TrapTags.h")
renderer.render(cpp.TagList(tags, opts.cpp_namespace), out / f"{opts.trap_affix}Tags.h")
tags = ("trap", "dbscheme")
tags = ("cpp", "dbscheme")
if __name__ == "__main__":
generator.run()

View File

@@ -12,7 +12,7 @@
#include <llvm/Support/FileSystem.h>
#include <llvm/Support/Path.h>
#include "swift/extractor/trap/TrapEntries.h"
#include "swift/extractor/trap/TrapClasses.h"
using namespace codeql;
@@ -75,9 +75,10 @@ static void extractFile(const SwiftExtractorConfiguration& config, swift::Source
}
trap << "\n\n";
TrapLabel<FileTag> label{};
trap << label << "=*\n";
trap << FilesTrap{label, srcFilePath.str().str()} << "\n";
File f;
f.id = TrapLabel<FileTag>{};
f.name = srcFilePath.str().str();
trap << f.id << "=*\n" << f;
// TODO: Pick a better name to avoid collisions
std::string trapName = file.getFilename().str() + ".trap";

View File

@@ -1,16 +1,42 @@
genrule(
name = "gen",
name = "trapgen",
srcs = ["//swift:dbscheme"],
outs = [
"TrapEntries.h",
"TrapTags.h",
],
cmd = "$(location //swift/codegen:trapgen) --dbscheme $< --trap-output $(RULEDIR)",
cmd = " ".join([
"$(location //swift/codegen:trapgen)",
"--dbscheme $<",
"--cpp-include-dir " + package_name(),
"--cpp-output $(RULEDIR)",
]),
exec_tools = ["//swift/codegen:trapgen"],
)
genrule(
name = "cppgen",
srcs = [
"//swift/codegen:schema",
"//swift/codegen:schema_includes",
],
outs = [
"TrapClasses.h",
],
cmd = " ".join([
"$(location //swift/codegen:cppgen)",
"--schema $(location //swift/codegen:schema)",
"--cpp-include-dir " + package_name(),
"--cpp-output $(RULEDIR)",
]),
exec_tools = ["//swift/codegen:cppgen"],
)
cc_library(
name = "trap",
hdrs = glob(["*.h"]) + [":gen"],
hdrs = glob(["*.h"]) + [
":trapgen",
":cppgen",
],
visibility = ["//visibility:public"],
)

View File

@@ -54,6 +54,11 @@ inline auto trapQuoted(const std::string& s) {
return std::quoted(s, '"', '"');
}
template <typename Tag>
struct Binding {
TrapLabel<Tag> id;
};
} // namespace codeql
namespace std {

View File

@@ -15,8 +15,6 @@ class PatternBindingDeclBase extends @pattern_binding_decl, Decl {
Expr getAnInit() { result = getInit(_) }
int getNumberOfInits() { result = count(getAnInit()) }
Pattern getPattern(int index) {
exists(Pattern x |
pattern_binding_decl_patterns(this, index, x) and