Swift: added trapgen

This checks in the trapgen script generating trap entries in C++.

The codegen suite has been slightly reorganized, moving the templates
directory up one level and chopping everything into smaller bazel
packages. Running tests is now done via
```
bazel run //swift/codegen/test
```

With respect to the PoC, the nested `codeql::trap` namespace has been
dropped in favour of a `Trap` prefix (or suffix in case of entries)
within the `codeql` namespace. Also, generated C++ code is not checked
in in git any more, and generated during build. Finally, labels get
printed in hex in the trap file.

`TrapLabel` is for the moment only default-constructible, so only one
single label is possible. `TrapArena`, that is responsible for creating
disjoint labels will come in a later commit.
This commit is contained in:
Paolo Tranquilli
2022-04-28 10:47:48 +02:00
parent 604a5fc71f
commit 773ef62406
31 changed files with 558 additions and 73 deletions

View File

@@ -13,18 +13,18 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: ./.github/actions/fetch-codeql
- uses: bazelbuild/setup-bazelisk@v2
- uses: actions/setup-python@v3 - uses: actions/setup-python@v3
with: with:
python-version: '~3.8' python-version: '~3.8'
cache: 'pip' cache: 'pip'
- uses: ./.github/actions/fetch-codeql - name: Install python dependencies
- uses: bazelbuild/setup-bazelisk@v2
- name: Install dependencies
run: | run: |
pip install -r swift/codegen/requirements.txt pip install -r swift/codegen/requirements.txt
- name: Run unit tests - name: Run unit tests
run: | run: |
bazel test //swift/codegen:tests --test_output=errors bazel test //swift/codegen/test --test_output=errors
- name: Check that code was generated - name: Check that code was generated
run: | run: |
bazel run //swift/codegen bazel run //swift/codegen

View File

@@ -29,6 +29,13 @@ jobs:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: ./.github/actions/fetch-codeql - uses: ./.github/actions/fetch-codeql
- uses: bazelbuild/setup-bazelisk@v2 - uses: bazelbuild/setup-bazelisk@v2
- uses: actions/setup-python@v3
with:
python-version: '~3.8'
cache: 'pip'
- name: Install python dependencies
run: |
pip install -r codegen/requirements.txt
- name: Build Swift extractor - name: Build Swift extractor
run: | run: |
bazel run //swift:create-extractor-pack bazel run //swift:create-extractor-pack

View File

@@ -47,5 +47,5 @@ repos:
name: Run Swift code generation unit tests name: Run Swift code generation unit tests
files: ^swift/codegen/.*\.py$ files: ^swift/codegen/.*\.py$
language: system language: system
entry: bazel test //swift/codegen:tests entry: bazel test //swift/codegen/test
pass_filenames: false pass_filenames: false

6
swift/.gitignore vendored
View File

@@ -1 +1,5 @@
extractor-pack # directory created by bazel run //swift:create-extractor-pack
/extractor-pack
# output files created by running tests
*.o

View File

@@ -2,11 +2,17 @@ load("@rules_pkg//:mappings.bzl", "pkg_attributes", "pkg_filegroup", "pkg_files"
load("@rules_pkg//:install.bzl", "pkg_install") load("@rules_pkg//:install.bzl", "pkg_install")
load("//:defs.bzl", "codeql_platform") load("//:defs.bzl", "codeql_platform")
pkg_files( filegroup(
name = "dbscheme", name = "dbscheme",
srcs = ["ql/lib/swift.dbscheme"],
visibility = ["//visibility:public"],
)
pkg_files(
name = "dbscheme_files",
srcs = [ srcs = [
"ql/lib/swift.dbscheme",
"ql/lib/swift.dbscheme.stats", "ql/lib/swift.dbscheme.stats",
":dbscheme",
], ],
) )
@@ -25,7 +31,7 @@ pkg_files(
pkg_filegroup( pkg_filegroup(
name = "extractor-pack-generic", name = "extractor-pack-generic",
srcs = [ srcs = [
":dbscheme", ":dbscheme_files",
":manifest", ":manifest",
":qltest", ":qltest",
], ],
@@ -58,7 +64,7 @@ pkg_filegroup(
name = "extractor-pack-arch", name = "extractor-pack-arch",
srcs = [ srcs = [
":extractor", ":extractor",
":swift-test-sdk-arch" ":swift-test-sdk-arch",
], ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )

View File

@@ -1,31 +1,16 @@
py_binary( py_binary(
name = "codegen", name = "codegen",
srcs = glob([ srcs = glob(["*.py"]),
"lib/*.py", visibility = ["//swift/codegen/test:__pkg__"],
"*.py", deps = ["//swift/codegen/lib"],
]),
) )
py_library( # as opposed to the above, that is meant to only be run with bazel run,
name = "test_utils", # we need to be precise with data dependencies of this which is meant be run during build
testonly = True, py_binary(
srcs = ["test/utils.py"], name = "trapgen",
deps = [":codegen"], srcs = ["trapgen.py"],
) data = ["//swift/codegen/templates:cpp"],
visibility = ["//swift:__subpackages__"],
[ deps = ["//swift/codegen/lib"],
py_test(
name = src[len("test/"):-len(".py")],
size = "small",
srcs = [src],
deps = [
":codegen",
":test_utils",
],
)
for src in glob(["test/test_*.py"])
]
test_suite(
name = "tests",
) )

View File

@@ -6,4 +6,4 @@ import dbschemegen
import qlgen import qlgen
if __name__ == "__main__": if __name__ == "__main__":
generator.run(dbschemegen.generate, qlgen.generate) generator.run(dbschemegen, qlgen)

View File

@@ -79,11 +79,13 @@ def generate(opts, renderer):
data = schema.load(input) data = schema.load(input)
dbscheme = Scheme(src=input.relative_to(paths.swift_dir), dbscheme = Scheme(src=input.relative_to(paths.swift_dir),
includes=get_includes(data, include_dir=input.parent), includes=get_includes(data, include_dir=input.parent),
declarations=get_declarations(data)) declarations=get_declarations(data))
renderer.render(dbscheme, out) renderer.render(dbscheme, out)
tags = ("schema", "dbscheme")
if __name__ == "__main__": if __name__ == "__main__":
generator.run(generate, tags=["schema", "dbscheme"]) generator.run()

View File

@@ -0,0 +1,5 @@
py_library(
name = "lib",
srcs = glob(["*.py"]),
visibility = ["//swift/codegen:__subpackages__"],
)

83
swift/codegen/lib/cpp.py Normal file
View File

@@ -0,0 +1,83 @@
from dataclasses import dataclass, field
from typing import List, ClassVar
# taken from https://en.cppreference.com/w/cpp/keyword
cpp_keywords = {"alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", "atomic_commit", "atomic_noexcept",
"auto", "bitand", "bitor", "bool", "break", "case", "catch", "char", "char8_t", "char16_t", "char32_t",
"class", "compl", "concept", "const", "consteval", "constexpr", "constinit", "const_cast", "continue",
"co_await", "co_return", "co_yield", "decltype", "default", "delete", "do", "double", "dynamic_cast",
"else", "enum", "explicit", "export", "extern", "false", "float", "for", "friend", "goto", "if",
"inline", "int", "long", "mutable", "namespace", "new", "noexcept", "not", "not_eq", "nullptr",
"operator", "or", "or_eq", "private", "protected", "public", "reflexpr", "register", "reinterpret_cast",
"requires", "return", "short", "signed", "sizeof", "static", "static_assert", "static_cast", "struct",
"switch", "synchronized", "template", "this", "thread_local", "throw", "true", "try", "typedef",
"typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile", "wchar_t", "while",
"xor", "xor_eq"}
@dataclass
class Field:
name: str
type: str
first: bool = False
def cpp_name(self):
if self.name in cpp_keywords:
return self.name + "_"
return self.name
def stream(self):
if self.type == "std::string":
return lambda x: f"trapQuoted({x})"
elif self.type == "bool":
return lambda x: f'({x} ? "true" : "false")'
else:
return lambda x: x
@dataclass
class Trap:
table_name: str
name: str
fields: List[Field]
id: Field = None
def __post_init__(self):
assert self.fields
self.fields[0].first = True
@dataclass
class TagBase:
base: str
first: bool = False
@dataclass
class Tag:
name: str
bases: List[TagBase]
index: int
id: str
def __post_init__(self):
if self.bases:
self.bases = [TagBase(b) for b in self.bases]
self.bases[0].first = True
def has_bases(self):
return bool(self.bases)
@dataclass
class TrapList:
template: ClassVar = 'cpp_traps'
traps: List[Trap] = field(default_factory=list)
@dataclass
class TagList:
template: ClassVar = 'cpp_tags'
tags: List[Tag] = field(default_factory=list)

View File

@@ -1,6 +1,7 @@
""" dbscheme format representation """ """ dbscheme format representation """
import logging import logging
import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, List from typing import ClassVar, List
@@ -102,3 +103,54 @@ class Scheme:
src: str src: str
includes: List[SchemeInclude] includes: List[SchemeInclude]
declarations: List[Decl] declarations: List[Decl]
class Re:
entity = re.compile(
"(?m)"
r"(?:^#keyset\[(?P<tablekeys>[\w\s,]+)\][\s\n]*)?^(?P<table>\w+)\((?P<tablebody>[^\)]*)\);?"
"|"
r"^(?P<union>@\w+)\s*=\s*(?P<unionbody>@\w+(?:\s*\|\s*@\w+)*)\s*;?"
)
field = re.compile(r"(?m)[\w\s]*\s(?P<field>\w+)\s*:\s*(?P<type>@?\w+)(?P<ref>\s+ref)?")
key = re.compile(r"@\w+")
comment = re.compile(r"(?m)(?s)/\*.*?\*/|//[^\n]*$")
def get_column(match):
return Column(
schema_name=match["field"].rstrip("_"),
type=match["type"],
binding=not match["ref"],
)
def get_table(match):
keyset = None
if match["tablekeys"]:
keyset = KeySet(k.strip() for k in match["tablekeys"].split(","))
return Table(
name=match["table"],
columns=[get_column(f) for f in Re.field.finditer(match["tablebody"])],
keyset=keyset,
)
def get_union(match):
return Union(
lhs=match["union"],
rhs=(d[0] for d in Re.key.finditer(match["unionbody"])),
)
def iterload(file):
data = Re.comment.sub("", file.read())
for e in Re.entity.finditer(data):
if e["table"]:
yield get_table(e)
elif e["union"]:
yield get_union(e)
def load(file):
return list(iterload(file))

View File

@@ -17,11 +17,12 @@ def _parse(tags):
return ret return ret
def run(*generators, tags=None): def run(*modules):
""" run generation functions in `generators`, parsing options tagged with `tags` (all if unspecified) """ run generation functions in specified in `modules`, or in current module by default
`generators` should be callables taking as input an option namespace and a `render.Renderer` instance
""" """
opts = _parse(tags) if modules:
for g in generators: opts = _parse({t for m in modules for t in m.tags})
g(opts, render.Renderer()) for m in modules:
m.generate(opts, render.Renderer())
else:
run(sys.modules["__main__"])

View File

@@ -3,7 +3,7 @@
import argparse import argparse
import collections import collections
import pathlib import pathlib
from typing import Tuple from typing import Set
from . import paths from . import paths
@@ -15,6 +15,7 @@ def _init_options():
Option("--ql-output", tags=["ql"], type=_abspath, default=paths.swift_dir / "ql/lib/codeql/swift/generated") Option("--ql-output", tags=["ql"], type=_abspath, default=paths.swift_dir / "ql/lib/codeql/swift/generated")
Option("--ql-stub-output", tags=["ql"], type=_abspath, default=paths.swift_dir / "ql/lib/codeql/swift/elements") Option("--ql-stub-output", tags=["ql"], type=_abspath, default=paths.swift_dir / "ql/lib/codeql/swift/elements")
Option("--codeql-binary", tags=["ql"], default="codeql") Option("--codeql-binary", tags=["ql"], default="codeql")
Option("--trap-output", tags=["trap"], type=_abspath, required=True)
def _abspath(x): def _abspath(x):
@@ -42,13 +43,10 @@ class Option:
_init_options() _init_options()
def get(tags: Tuple[str]): def get(tags: Set[str]):
""" get options marked by `tags` """ get options marked by `tags`
Return all options if tags is falsy. Options tagged by wildcard '*' are always returned Options tagged by wildcard '*' are always returned
""" """
if not tags: # use specifically tagged options + those tagged with wildcard *
return (o for tagged_opts in _options.values() for o in tagged_opts) return (o for tag in ('*',) + tuple(tags) for o in _options[tag])
else:
# use specifically tagged options + those tagged with wildcard *
return (o for tag in ('*',) + tags for o in _options[tag])

View File

@@ -5,14 +5,15 @@ import sys
import os import os
try: try:
_workspace_dir = pathlib.Path(os.environ['BUILD_WORKSPACE_DIRECTORY']).resolve() # <- means we are using bazel run workspace_dir = pathlib.Path(os.environ['BUILD_WORKSPACE_DIRECTORY']).resolve() # <- means we are using bazel run
swift_dir = _workspace_dir / 'swift' swift_dir = workspace_dir / 'swift'
except KeyError: except KeyError:
_this_file = pathlib.Path(__file__).resolve() _this_file = pathlib.Path(__file__).resolve()
swift_dir = _this_file.parents[2] swift_dir = _this_file.parents[2]
workspace_dir = swift_dir.parent
lib_dir = swift_dir / 'codegen' / 'lib' lib_dir = swift_dir / 'codegen' / 'lib'
templates_dir = lib_dir / 'templates' templates_dir = swift_dir / 'codegen' / 'templates'
try: try:
exe_file = pathlib.Path(sys.argv[0]).resolve().relative_to(swift_dir) exe_file = pathlib.Path(sys.argv[0]).resolve().relative_to(swift_dir)

View File

@@ -19,19 +19,25 @@ class Renderer:
""" Template renderer using mustache templates in the `templates` directory """ """ Template renderer using mustache templates in the `templates` directory """
def __init__(self): def __init__(self):
self._r = pystache.Renderer(search_dirs=str(paths.lib_dir / "templates"), escape=lambda u: u) self._r = pystache.Renderer(search_dirs=str(paths.templates_dir), escape=lambda u: u)
self.written = set() self.written = set()
def render(self, data, output: pathlib.Path): def render(self, data, output: pathlib.Path, guard_base: pathlib.Path = None):
""" Render `data` to `output`. """ Render `data` to `output`.
`data` must have a `template` attribute denoting which template to use from the template directory. `data` must have a `template` attribute denoting which template to use from the template directory.
If the file is unchanged, then no write is performed (and `done_something` remains unchanged) If the file is unchanged, then no write is performed (and `done_something` remains unchanged)
If `guard_base` is provided, it must be a path at the root of `output` and a header guard will be injected in
the template based off of the relative path of `output` in `guard_base`
""" """
mnemonic = type(data).__name__ mnemonic = type(data).__name__
output.parent.mkdir(parents=True, exist_ok=True) output.parent.mkdir(parents=True, exist_ok=True)
data = self._r.render_name(data.template, data, generator=paths.exe_file) guard = None
if guard_base is not None:
guard = str(output.relative_to(guard_base)).replace("/", "_").replace(".", "_").upper()
data = self._r.render_name(data.template, data, generator=paths.exe_file, guard=guard)
with open(output, "w") as out: with open(output, "w") as out:
out.write(data) out.write(data)
log.debug(f"generated {mnemonic} {output.name}") log.debug(f"generated {mnemonic} {output.name}")

View File

@@ -110,5 +110,7 @@ def generate(opts, renderer):
format(opts.codeql_binary, renderer.written) format(opts.codeql_binary, renderer.written)
tags = ("schema", "ql")
if __name__ == "__main__": if __name__ == "__main__":
generator.run(generate, tags=["schema", "ql"]) generator.run()

View File

@@ -0,0 +1,5 @@
filegroup(
name = "cpp",
srcs = glob(["cpp_*.mustache"]),
visibility = ["//swift:__subpackages__"],
)

View File

@@ -0,0 +1,15 @@
// generated by {{generator}}
// clang-format off
#ifndef SWIFT_EXTRACTOR_TRAP_{{guard}}
#define SWIFT_EXTRACTOR_TRAP_{{guard}}
namespace codeql {
{{#tags}}
// {{id}}
struct {{name}}Tag {{#has_bases}}: {{#bases}}{{^first}}, {{/first}}{{base}}Tag{{/bases}} {{/has_bases}}{
static constexpr const char* prefix = "{{index}}";
};
{{/tags}}
}
#endif

View File

@@ -0,0 +1,35 @@
// generated by {{generator}}
// clang-format off
#ifndef SWIFT_EXTRACTOR_TRAP_{{guard}}
#define SWIFT_EXTRACTOR_TRAP_{{guard}}
#include <iostream>
#include <string>
#include "swift/extractor/trap/TrapLabel.h"
#include "swift/extractor/trap/TrapTags.h"
namespace codeql {
{{#traps}}
// {{table_name}}
struct {{name}}Trap {
static constexpr bool is_binding = {{#id}}true{{/id}}{{^id}}false{{/id}};
{{#id}}
{{type}} getBoundLabel() const { return {{cpp_name}}; }
{{/id}}
{{#fields}}
{{type}} {{cpp_name}}{};
{{/fields}}
};
inline std::ostream &operator<<(std::ostream &out, const {{name}}Trap &e) {
out << "{{table_name}}("{{#fields}}{{^first}} << ", "{{/first}}
<< {{#stream}}e.{{cpp_name}}{{/stream}}{{/fields}} << ")";
return out;
}
{{/traps}}
}
#endif

View File

@@ -0,0 +1,23 @@
py_library(
name = "utils",
testonly = True,
srcs = ["utils.py"],
deps = ["//swift/codegen/lib"],
)
[
py_test(
name = src[:-len(".py")],
size = "small",
srcs = [src],
deps = [
":utils",
"//swift/codegen",
],
)
for src in glob(["test_*.py"])
]
test_suite(
name = "test",
)

View File

@@ -1,4 +1,5 @@
import sys import sys
import pathlib
from unittest import mock from unittest import mock
import pytest import pytest
@@ -40,8 +41,8 @@ def test_render(pystache_renderer, sut):
with mock.patch("builtins.open", mock.mock_open()) as output_stream: with mock.patch("builtins.open", mock.mock_open()) as output_stream:
sut.render(data, output) sut.render(data, output)
assert pystache_renderer.mock_calls == [ assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=paths.exe_file), mock.call.render_name(data.template, data, generator=paths.exe_file, guard=None),
], pystache_renderer.mock_calls ]
assert output_stream.mock_calls == [ assert output_stream.mock_calls == [
mock.call(output, 'w'), mock.call(output, 'w'),
mock.call().__enter__(), mock.call().__enter__(),
@@ -51,6 +52,17 @@ def test_render(pystache_renderer, sut):
assert sut.written == {output} assert sut.written == {output}
def test_render_with_guard(pystache_renderer, sut):
guard_base = pathlib.Path("test", "guard")
data = mock.Mock()
output = guard_base / "this" / "is" / "a" / "header.h"
with mock.patch("builtins.open", mock.mock_open()) as output_stream:
sut.render(data, output, guard_base=guard_base)
assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=paths.exe_file, guard="THIS_IS_A_HEADER_H"),
]
def test_written(sut): def test_written(sut):
data = [mock.Mock() for _ in range(4)] data = [mock.Mock() for _ in range(4)]
output = [mock.Mock() for _ in data] output = [mock.Mock() for _ in data]

136
swift/codegen/trapgen.py Executable file
View File

@@ -0,0 +1,136 @@
#!/usr/bin/env python3
import collections
import logging
import os
import re
import sys
import inflection
sys.path.append(os.path.dirname(__file__))
from lib import paths, dbscheme, generator, cpp
field_overrides = [
(re.compile(r"locations.*::(start|end).*|.*::(index|num_.*)"), {"type": "unsigned"}),
(re.compile(r".*::(.*)_"), lambda m: {"name": m[1]}),
]
log = logging.getLogger(__name__)
def get_field_override(table, field):
spec = f"{table}::{field}"
for r, o in field_overrides:
m = r.fullmatch(spec)
if m and callable(o):
return o(m)
elif m:
return o
return {}
def get_tag_name(s):
assert s.startswith("@")
return inflection.camelize(s[1:])
def get_cpp_type(schema_type):
if schema_type.startswith("@"):
tag = get_tag_name(schema_type)
return f"TrapLabel<{tag}Tag>"
if schema_type == "string":
return "std::string"
if schema_type == "boolean":
return "bool"
return schema_type
def get_field(c: dbscheme.Column, table: str):
args = {
"name": c.schema_name,
"type": c.type,
}
args.update(get_field_override(table, c.schema_name))
args["type"] = get_cpp_type(args["type"])
return cpp.Field(**args)
def get_binding_column(t: dbscheme.Table):
try:
return next(c for c in t.columns if c.binding)
except StopIteration:
return None
def get_trap(t: dbscheme.Table):
id = get_binding_column(t)
if id:
id = get_field(id, t.name)
return cpp.Trap(
table_name=t.name,
name=inflection.camelize(t.name),
fields=[get_field(c, t.name) for c in t.columns],
id=id,
)
def get_guard(path):
path = path.relative_to(paths.swift_dir)
return str(path.with_suffix("")).replace("/", "_").upper()
def get_topologically_ordered_tags(tags):
degree_to_nodes = collections.defaultdict(set)
nodes_to_degree = {}
lookup = {}
for name, t in tags.items():
degree = len(t["bases"])
degree_to_nodes[degree].add(name)
nodes_to_degree[name] = degree
while degree_to_nodes[0]:
sinks = degree_to_nodes.pop(0)
for sink in sorted(sinks):
yield sink
for d in tags[sink]["derived"]:
degree = nodes_to_degree[d]
degree_to_nodes[degree].remove(d)
degree -= 1
nodes_to_degree[d] = degree
degree_to_nodes[degree].add(d)
if any(degree_to_nodes.values()):
raise ValueError("not a dag!")
def generate(opts, renderer):
tag_graph = collections.defaultdict(lambda: {"bases": [], "derived": []})
out = opts.trap_output
traps = []
with open(opts.dbscheme) as input:
for e in dbscheme.iterload(input):
if e.is_table:
traps.append(get_trap(e))
elif e.is_union:
for d in e.rhs:
tag_graph[e.lhs]["derived"].append(d.type)
tag_graph[d.type]["bases"].append(e.lhs)
renderer.render(cpp.TrapList(traps), out / "TrapEntries.h", guard_base=out)
tags = []
for index, tag in enumerate(get_topologically_ordered_tags(tag_graph)):
tags.append(cpp.Tag(
name=get_tag_name(tag),
bases=[get_tag_name(b) for b in sorted(tag_graph[tag]["bases"])],
index=index,
id=tag,
))
renderer.render(cpp.TagList(tags), out / "TrapTags.h", guard_base=out)
tags = ("trap", "dbscheme")
if __name__ == "__main__":
generator.run()

View File

@@ -9,17 +9,20 @@ alias(
cc_binary( cc_binary(
name = "extractor", name = "extractor",
srcs = [ srcs = [
"SwiftExtractor.h",
"SwiftExtractor.cpp", "SwiftExtractor.cpp",
"SwiftExtractor.h",
"SwiftExtractorConfiguration.h", "SwiftExtractorConfiguration.h",
"main.cpp", "main.cpp",
], ],
features = ["-universal_binaries"],
target_compatible_with = select({ target_compatible_with = select({
"@platforms//os:linux": [], "@platforms//os:linux": [],
"@platforms//os:macos": [], "@platforms//os:macos": [],
"//conditions:default": ["@platforms//:incompatible"], "//conditions:default": ["@platforms//:incompatible"],
}), }),
visibility = ["//swift:__pkg__"], visibility = ["//swift:__pkg__"],
deps = [":swift-llvm-support"], deps = [
features = ["-universal_binaries"], ":swift-llvm-support",
"//swift/extractor/trap",
],
) )

View File

@@ -12,6 +12,8 @@
#include <llvm/Support/FileSystem.h> #include <llvm/Support/FileSystem.h>
#include <llvm/Support/Path.h> #include <llvm/Support/Path.h>
#include "swift/extractor/trap/TrapEntries.h"
using namespace codeql; using namespace codeql;
static void extractFile(const SwiftExtractorConfiguration& config, swift::SourceFile& file) { static void extractFile(const SwiftExtractorConfiguration& config, swift::SourceFile& file) {
@@ -60,15 +62,15 @@ static void extractFile(const SwiftExtractorConfiguration& config, swift::Source
<< "': " << ec.message() << "\n"; << "': " << ec.message() << "\n";
return; return;
} }
std::stringstream ss; trap << "// extractor-args: ";
for (auto opt : config.frontendOptions) { for (auto opt : config.frontendOptions) {
ss << std::quoted(opt) << " "; trap << std::quoted(opt) << " ";
} }
ss << "\n"; trap << "\n\n";
trap << "// extractor-args: " << ss.str();
trap << "#0=*\n"; TrapLabel<FileTag> label{};
trap << "files(#0, " << std::quoted(srcFilePath.str().str()) << ")\n"; trap << label << "=*\n";
trap << FilesTrap{label, srcFilePath.str().str()} << "\n";
// TODO: Pick a better name to avoid collisions // TODO: Pick a better name to avoid collisions
std::string trapName = file.getFilename().str() + ".trap"; std::string trapName = file.getFilename().str() + ".trap";

View File

@@ -0,0 +1,16 @@
genrule(
name = "gen",
srcs = ["//swift:dbscheme"],
outs = [
"TrapEntries.h",
"TrapTags.h",
],
cmd = "$(location //swift/codegen:trapgen) --dbscheme $< --trap-output $(RULEDIR)",
exec_tools = ["//swift/codegen:trapgen"],
)
cc_library(
name = "trap",
hdrs = glob(["*.h"]) + [":gen"],
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,68 @@
#ifndef SWIFT_EXTRACTOR_TRAP_LABEL_H
#define SWIFT_EXTRACTOR_TRAP_LABEL_H
#include <iomanip>
#include <iostream>
#include <string>
#include "swift/extractor/trap/TrapTagTraits.h"
#include "swift/extractor/trap/TrapTags.h"
namespace codeql {
class UntypedTrapLabel {
uint64_t id_;
friend class std::hash<UntypedTrapLabel>;
protected:
UntypedTrapLabel() : id_{0xffffffffffffffff} {}
UntypedTrapLabel(uint64_t id) : id_{id} {}
public:
friend std::ostream& operator<<(std::ostream& out, UntypedTrapLabel l) {
out << '#' << std::hex << l.id_ << std::dec;
return out;
}
friend bool operator==(UntypedTrapLabel lhs, UntypedTrapLabel rhs) { return lhs.id_ == rhs.id_; }
};
template <typename Tag>
class TrapLabel : public UntypedTrapLabel {
template <typename OtherTag>
friend class TrapLabel;
using UntypedTrapLabel::UntypedTrapLabel;
public:
TrapLabel() = default;
template <typename OtherTag>
TrapLabel(const TrapLabel<OtherTag>& other) : UntypedTrapLabel(other) {
// we temporarily need to bypass the label type system for unknown AST nodes and types
if constexpr (std::is_same_v<Tag, UnknownAstNodeTag>) {
static_assert(std::is_base_of_v<AstNodeTag, OtherTag>, "wrong label assignment!");
} else if constexpr (std::is_same_v<Tag, UnknownTypeTag>) {
static_assert(std::is_base_of_v<TypeTag, OtherTag>, "wrong label assignment!");
} else {
static_assert(std::is_base_of_v<Tag, OtherTag>, "wrong label assignment!");
}
}
};
inline auto trapQuoted(const std::string& s) {
return std::quoted(s, '"', '"');
}
} // namespace codeql
namespace std {
template <>
struct hash<codeql::UntypedTrapLabel> {
size_t operator()(const codeql::UntypedTrapLabel& l) const noexcept {
return std::hash<uint64_t>{}(l.id_);
}
};
} // namespace std
#endif // SWIFT_EXTRACTOR_LIB_EXTRACTOR_CPP_TRAP_LABEL_H_

View File

@@ -0,0 +1,18 @@
#ifndef SWIFT_EXTRACTOR_INCLUDE_EXTRACTOR_TRAP_TAGTRAITS_H
#define SWIFT_EXTRACTOR_INCLUDE_EXTRACTOR_TRAP_TAGTRAITS_H
#include <type_traits>
namespace codeql::trap {
template <typename T>
struct ToTagFunctor;
template <typename T>
struct ToTagOverride : ToTagFunctor<T> {};
template <typename T>
using ToTag = typename ToTagOverride<std::remove_const_t<T>>::type;
} // namespace codeql::trap
#endif // SWIFT_EXTRACTOR_INCLUDE_EXTRACTOR_TRAP_TAGTRAITS_H