Files
codeql/swift/codegen/trapgen.py
2022-04-28 16:39:27 +02:00

137 lines
3.5 KiB
Python
Executable File

#!/usr/bin/env python3
import collections
import logging
import os
import re
import sys
import inflection
sys.path.append(os.path.dirname(__file__))
from lib import paths, dbscheme, generator, cpp
field_overrides = [
(re.compile(r"locations.*::(start|end).*|.*::(index|num_.*)"), {"type": "unsigned"}),
(re.compile(r".*::(.*)_"), lambda m: {"name": m[1]}),
]
log = logging.getLogger(__name__)
def get_field_override(table, field):
spec = f"{table}::{field}"
for r, o in field_overrides:
m = r.fullmatch(spec)
if m and callable(o):
return o(m)
elif m:
return o
return {}
def get_tag_name(s):
assert s.startswith("@")
return inflection.camelize(s[1:])
def get_cpp_type(schema_type):
if schema_type.startswith("@"):
tag = get_tag_name(schema_type)
return f"TrapLabel<{tag}Tag>"
if schema_type == "string":
return "std::string"
if schema_type == "boolean":
return "bool"
return schema_type
def get_field(c: dbscheme.Column, table: str):
args = {
"name": c.schema_name,
"type": c.type,
}
args.update(get_field_override(table, c.schema_name))
args["type"] = get_cpp_type(args["type"])
return cpp.Field(**args)
def get_binding_column(t: dbscheme.Table):
try:
return next(c for c in t.columns if c.binding)
except StopIteration:
return None
def get_trap(t: dbscheme.Table):
id = get_binding_column(t)
if id:
id = get_field(id, t.name)
return cpp.Trap(
table_name=t.name,
name=inflection.camelize(t.name),
fields=[get_field(c, t.name) for c in t.columns],
id=id,
)
def get_guard(path):
path = path.relative_to(paths.swift_dir)
return str(path.with_suffix("")).replace("/", "_").upper()
def get_topologically_ordered_tags(tags):
degree_to_nodes = collections.defaultdict(set)
nodes_to_degree = {}
lookup = {}
for name, t in tags.items():
degree = len(t["bases"])
degree_to_nodes[degree].add(name)
nodes_to_degree[name] = degree
while degree_to_nodes[0]:
sinks = degree_to_nodes.pop(0)
for sink in sorted(sinks):
yield sink
for d in tags[sink]["derived"]:
degree = nodes_to_degree[d]
degree_to_nodes[degree].remove(d)
degree -= 1
nodes_to_degree[d] = degree
degree_to_nodes[degree].add(d)
if any(degree_to_nodes.values()):
raise ValueError("not a dag!")
def generate(opts, renderer):
tag_graph = collections.defaultdict(lambda: {"bases": [], "derived": []})
out = opts.trap_output
traps = []
with open(opts.dbscheme) as input:
for e in dbscheme.iterload(input):
if e.is_table:
traps.append(get_trap(e))
elif e.is_union:
for d in e.rhs:
tag_graph[e.lhs]["derived"].append(d.type)
tag_graph[d.type]["bases"].append(e.lhs)
renderer.render(cpp.TrapList(traps), out / "TrapEntries.h")
tags = []
for index, tag in enumerate(get_topologically_ordered_tags(tag_graph)):
tags.append(cpp.Tag(
name=get_tag_name(tag),
bases=[get_tag_name(b) for b in sorted(tag_graph[tag]["bases"])],
index=index,
id=tag,
))
renderer.render(cpp.TagList(tags), out / "TrapTags.h")
tags = ("trap", "dbscheme")
if __name__ == "__main__":
generator.run()