Files
codeql/misc/codegen/generators/qlgen.py
2025-06-26 12:14:53 +02:00

697 lines
25 KiB
Python
Executable File

"""
QL code generation
`generate(opts, renderer)` will generate in the library directory:
* `generated/Raw.qll` with thin class wrappers around DB types
* `generated/Synth.qll` with the base algebraic datatypes for AST entities
* `generated/<group>/<Class>.qll` with generated properties for each class
* if not already modified, an `elements/<group>/<Class>Impl.qll` stub to customize the above classes
* `elements/<group>/<Class>.qll` that wraps the internal `<Class>Impl.qll` file in a public `final` class.
* `elements.qll` importing all the above public classes
* if not already modified, an `elements/<group>/<Class>Constructor.qll` stub to customize the algebraic datatype
characteristic predicate
* `generated/SynthConstructors.qll` importing all the above constructor stubs
* `generated/PureSynthConstructors.qll` importing constructor stubs for pure synthesized types (that is, not
corresponding to raw types)
Moreover in the test directory for each <Class> in <group> it will generate beneath the
`extractor-tests/generated/<group>/<Class>` directory either
* a `MISSING_SOURCE.txt` explanation file if no source is present, or
* one `<Class>.ql` test query for all single properties and on `<Class>_<property>.ql` test query for each optional or
repeated property
"""
# TODO this should probably be split in different generators now: ql, qltest, maybe qlsynth
import logging
import pathlib
import re
import shutil
import subprocess
import typing
import itertools
import os
import dataclasses
import inflection
from misc.codegen.lib import schema, ql
from misc.codegen.loaders import schemaloader
log = logging.getLogger(__name__)
class Error(Exception):
def __str__(self):
return self.args[0]
class FormatError(Error):
pass
class RootElementHasChildren(Error):
pass
class NoClasses(Error):
pass
abbreviations = {
"expr": "expression",
"arg": "argument",
"stmt": "statement",
"decl": "declaration",
"repr": "representation",
"param": "parameter",
"int": "integer",
"var": "variable",
"ref": "reference",
"pat": "pattern",
}
abbreviations.update({f"{k}s": f"{v}s" for k, v in abbreviations.items()})
_abbreviations_re = re.compile("|".join(rf"\b{abbr}\b" for abbr in abbreviations))
def _humanize(s: str) -> str:
ret = inflection.humanize(s)
ret = ret[0].lower() + ret[1:]
ret = _abbreviations_re.sub(lambda m: abbreviations[m[0]], ret)
return ret
_format_re = re.compile(r"\{(\w+)\}")
def _get_doc(cls: schema.Class, prop: schema.Property, plural=None):
if prop.doc:
if plural is None:
# for consistency, ignore format in non repeated properties
return _format_re.sub(lambda m: m[1], prop.doc)
format = prop.doc
nouns = [m[1] for m in _format_re.finditer(prop.doc)]
if not nouns:
noun, _, rest = prop.doc.partition(" ")
format = f"{{{noun}}} {rest}"
nouns = [noun]
transform = inflection.pluralize if plural else inflection.singularize
return format.format(**{noun: transform(noun) for noun in nouns})
prop_name = _humanize(prop.name)
class_name = cls.pragmas.get(
"ql_default_doc_name", _humanize(inflection.underscore(cls.name))
)
if prop.is_predicate:
return f"this {class_name} {prop_name}"
if plural is not None:
prop_name = (
inflection.pluralize(prop_name)
if plural
else inflection.singularize(prop_name)
)
return f"{prop_name} of this {class_name}"
@dataclasses.dataclass
class Resolver:
lookup: typing.Dict[str, schema.ClassBase]
_property_cache: typing.Dict[tuple[int, int], ql.Property] = dataclasses.field(
default_factory=dict, init=False
)
_class_cache: typing.Dict[int, ql.Class] = dataclasses.field(
default_factory=dict, init=False
)
def _type_is_hideable(self, t: str) -> bool:
match self.lookup.get(t):
case schema.Class() as cls:
return "ql_hideable" in cls.pragmas
case _:
return False
def get_ql_property(
self,
cls: schema.Class,
prop: schema.Property,
) -> ql.Property:
cache_key = (id(cls), id(prop))
if cache_key not in self._property_cache:
args = dict(
type=prop.type if not prop.is_predicate else "predicate",
qltest_skip="qltest_skip" in prop.pragmas,
is_optional=prop.is_optional,
is_predicate=prop.is_predicate,
is_unordered=prop.is_unordered,
description=prop.description,
synth=bool(cls.synth) or prop.synth,
type_is_hideable=self._type_is_hideable(prop.type),
type_is_codegen_class=prop.type in self.lookup
and not self.lookup[prop.type].imported,
internal="ql_internal" in prop.pragmas,
cfg=prop.is_child
and prop.type in self.lookup
and self.lookup[prop.type].cfg,
is_child=prop.is_child,
)
ql_name = prop.pragmas.get("ql_name", prop.name)
db_table_name = prop.pragmas.get("ql_db_table_name")
if db_table_name and prop.is_single:
raise Error(
f"`db_table_name` pragma is not supported for single properties, but {cls.name}.{prop.name} has it"
)
if prop.is_single:
args.update(
singular=inflection.camelize(ql_name),
tablename=inflection.tableize(cls.name),
tableparams=["this"]
+ [
"result" if p is prop else "_"
for p in cls.properties
if p.is_single
],
doc=_get_doc(cls, prop),
)
elif prop.is_repeated:
args.update(
singular=inflection.singularize(inflection.camelize(ql_name)),
plural=inflection.pluralize(inflection.camelize(ql_name)),
tablename=db_table_name
or inflection.tableize(f"{cls.name}_{prop.name}"),
tableparams=(
["this", "index", "result"]
if not prop.is_unordered
else ["this", "result"]
),
doc=_get_doc(cls, prop, plural=False),
doc_plural=_get_doc(cls, prop, plural=True),
)
elif prop.is_optional:
args.update(
singular=inflection.camelize(ql_name),
tablename=db_table_name
or inflection.tableize(f"{cls.name}_{prop.name}"),
tableparams=["this", "result"],
doc=_get_doc(cls, prop),
)
elif prop.is_predicate:
args.update(
singular=inflection.camelize(ql_name, uppercase_first_letter=False),
tablename=db_table_name
or inflection.underscore(f"{cls.name}_{prop.name}"),
tableparams=["this"],
doc=_get_doc(cls, prop),
)
else:
raise ValueError(
f"unknown property kind for {prop.name} from {cls.name}"
)
self._property_cache[cache_key] = ql.Property(**args)
return self._property_cache[cache_key]
def get_ql_class(self, cls: schema.Class) -> ql.Class:
cache_key = id(cls)
if cache_key not in self._class_cache:
if "ql_name" in cls.pragmas:
raise Error(
"ql_name is not supported yet for classes, only for properties"
)
properties = [self.get_ql_property(cls, p) for p in cls.properties]
all_children = [
ql.Child(self.get_ql_property(c, p))
for c, p in self._get_all_properties(cls)
if p.is_child
]
for prev, child in zip([""] + all_children, all_children):
child.prev = prev and prev.property.singular
self._class_cache[cache_key] = ql.Class(
name=cls.name,
bases=cls.bases,
bases_impl=[base + "Impl::" + base for base in cls.bases],
final=not cls.derived,
properties=properties,
all_children=all_children,
dir=pathlib.Path(cls.group or ""),
doc=cls.doc,
hideable="ql_hideable" in cls.pragmas,
internal="ql_internal" in cls.pragmas,
cfg=cls.cfg,
)
return self._class_cache[cache_key]
def get_ql_cfg_class(
self,
cls: schema.Class,
) -> ql.CfgClass:
return ql.CfgClass(
name=cls.name,
bases=[base for base in cls.bases if self.lookup[base.base].cfg],
properties=cls.properties,
doc=cls.doc,
)
def _get_all_properties(
self,
cls: schema.Class,
already_seen: typing.Optional[typing.Set[int]] = None,
) -> typing.Iterable[typing.Tuple[schema.Class, schema.Property]]:
# deduplicate using ids
if already_seen is None:
already_seen = set()
for b in cls.bases:
base = self.lookup[b]
for item in self._get_all_properties(base, already_seen):
yield item
for p in cls.properties:
if id(p) not in already_seen:
already_seen.add(id(p))
yield cls, p
def get_all_properties_to_be_tested(
self,
cls: schema.Class,
) -> typing.Iterable[ql.PropertyForTest]:
for c, p in self._get_all_properties(cls):
if not ("qltest_skip" in c.pragmas or "qltest_skip" in p.pragmas):
p = self.get_ql_property(c, p)
yield ql.PropertyForTest(
p.getter,
is_total=p.is_single or p.is_predicate,
type=p.type if not p.is_predicate else None,
is_indexed=p.is_indexed,
)
def _is_in_qltest_collapsed_hierarchy(
self,
cls: schema.Class,
) -> bool:
return (
"qltest_collapse_hierarchy" in cls.pragmas
or self._is_under_qltest_collapsed_hierarchy(cls)
)
def _is_under_qltest_collapsed_hierarchy(
self,
cls: schema.Class,
) -> bool:
return "qltest_uncollapse_hierarchy" not in cls.pragmas and any(
self._is_in_qltest_collapsed_hierarchy(self.lookup[b]) for b in cls.bases
)
def should_skip_qltest(self, cls: schema.Class) -> bool:
return (
"qltest_skip" in cls.pragmas
or not (cls.final or "qltest_collapse_hierarchy" in cls.pragmas)
or self._is_under_qltest_collapsed_hierarchy(cls)
)
def _to_db_type(x: str) -> str:
if x[0].isupper():
return "Raw::" + x
return x
_final_db_class_lookup = {}
def get_ql_synth_class_db(name: str) -> ql.Synth.FinalClassDb:
return _final_db_class_lookup.setdefault(
name,
ql.Synth.FinalClassDb(
name=name, params=[ql.Synth.Param("id", _to_db_type(name))]
),
)
def get_ql_synth_class(cls: schema.Class):
if cls.derived:
return ql.Synth.NonFinalClass(
name=cls.name, derived=sorted(cls.derived), root=not cls.bases
)
if cls.synth and cls.synth.from_class is not None:
source = cls.synth.from_class
get_ql_synth_class_db(source).subtract_type(cls.name)
return ql.Synth.FinalClassDerivedSynth(
name=cls.name, params=[ql.Synth.Param("id", _to_db_type(source))]
)
if cls.synth and cls.synth.on_arguments is not None:
return ql.Synth.FinalClassFreshSynth(
name=cls.name,
params=[
ql.Synth.Param(k, _to_db_type(v))
for k, v in cls.synth.on_arguments.items()
],
)
return get_ql_synth_class_db(cls.name)
def get_import(file: pathlib.Path, root_dir: pathlib.Path):
stem = file.relative_to(root_dir / "ql/lib").with_suffix("")
return str(stem).replace("/", ".")
def get_types_used_by(cls: ql.Class, is_impl: bool) -> typing.Iterable[str]:
for b in cls.bases:
yield b.base + "Impl" if is_impl else b.base
for p in cls.properties:
yield p.type
if cls.root:
yield cls.name # used in `getResolveStep` and `resolve`
def get_classes_used_by(cls: ql.Class, is_impl: bool) -> typing.List[str]:
return sorted(
set(
t
for t in get_types_used_by(cls, is_impl)
if t[0].isupper() and (is_impl or t != cls.name)
)
)
def format(codeql, files):
ql_files = [str(f) for f in files if f.suffix in (".qll", ".ql")]
if not ql_files:
return
format_cmd = [codeql, "query", "format", "--in-place", "--"] + ql_files
if "/" in codeql or "\\" in codeql:
if not pathlib.Path(codeql).exists():
raise FormatError(f"Provided CodeQL binary `{codeql}` does not exist")
else:
codeql_path = shutil.which(codeql)
if not codeql_path:
raise FormatError(
f"`{codeql}` not found in PATH. Either install it, or pass `-- --codeql-binary` with a full path"
)
codeql = codeql_path
res = subprocess.run(format_cmd, stderr=subprocess.PIPE, text=True)
if res.returncode:
for line in res.stderr.splitlines():
log.error(line.strip())
raise FormatError("QL format failed")
for line in res.stderr.splitlines():
log.debug(line.strip())
def _get_path(cls: schema.Class) -> pathlib.Path:
return pathlib.Path(cls.group or "", cls.name).with_suffix(".qll")
def _get_path_impl(cls: schema.Class) -> pathlib.Path:
return pathlib.Path(cls.group or "", "internal", cls.name + "Impl").with_suffix(
".qll"
)
def _get_path_public(cls: schema.Class) -> pathlib.Path:
return pathlib.Path(
cls.group or "", "internal" if "ql_internal" in cls.pragmas else "", cls.name
).with_suffix(".qll")
def _get_stub(
cls: schema.Class, base_import: str, generated_import_prefix: str
) -> ql.Stub:
if isinstance(cls.synth, schema.SynthInfo):
if cls.synth.from_class is not None:
accessors = [
ql.SynthUnderlyingAccessor(
argument="Entity",
type=_to_db_type(cls.synth.from_class),
constructorparams=["result"],
)
]
elif cls.synth.on_arguments is not None:
accessors = [
ql.SynthUnderlyingAccessor(
argument=inflection.camelize(arg),
type=_to_db_type(type),
constructorparams=[
"result" if a == arg else "_" for a in cls.synth.on_arguments
],
)
for arg, type in cls.synth.on_arguments.items()
]
else:
accessors = []
return ql.Stub(
name=cls.name,
base_import=base_import,
import_prefix=generated_import_prefix,
doc=cls.doc,
synth_accessors=accessors,
)
def _get_class_public(cls: schema.Class) -> ql.ClassPublic:
return ql.ClassPublic(
name=cls.name, doc=cls.doc, internal="ql_internal" in cls.pragmas
)
_stub_qldoc_header = "// the following QLdoc is generated: if you need to edit it, do it in the schema file\n "
_class_qldoc_re = re.compile(
rf"(?P<qldoc>(?:{re.escape(_stub_qldoc_header)})?/\*\*.*?\*/\s*|^\s*)(?:class\s+(?P<class>\w+))?",
re.MULTILINE | re.DOTALL,
)
def _patch_class_qldoc(cls: str, qldoc: str, stub_file: pathlib.Path):
"""Replace or insert `qldoc` as the QLdoc of class `cls` in `stub_file`"""
if not qldoc or not stub_file.exists():
return
qldoc = "\n ".join(l.rstrip() for l in qldoc.splitlines())
with open(stub_file) as input:
contents = input.read()
for match in _class_qldoc_re.finditer(contents):
if match["class"] == cls:
qldoc_start, qldoc_end = match.span("qldoc")
contents = f"{contents[:qldoc_start]}{_stub_qldoc_header}{qldoc}\n {contents[qldoc_end:]}"
tmp = stub_file.with_suffix(f"{stub_file.suffix}.bkp")
with open(tmp, "w") as out:
out.write(contents)
tmp.rename(stub_file)
return
def generate(opts, renderer):
input = opts.schema
out = opts.ql_output
stub_out = opts.ql_stub_output
cfg_out = opts.ql_cfg_output
test_out = opts.ql_test_output
missing_test_source_filename = "MISSING_SOURCE.txt"
include_file = stub_out.with_suffix(".qll")
generated = {q for q in out.rglob("*.qll")}
generated.add(include_file)
if test_out:
generated.update(q for q in test_out.rglob("*.ql"))
generated.update(q for q in test_out.rglob(missing_test_source_filename))
stubs = {q for q in stub_out.rglob("*.qll")}
data = schemaloader.load_file(input)
resolver = Resolver(data.classes)
classes = {
name: resolver.get_ql_class(cls)
for name, cls in data.classes.items()
if not cls.imported
}
if not classes:
raise NoClasses
root = next(iter(classes.values()))
if root.has_children:
raise RootElementHasChildren(root)
pre_imports = {n: cls.module for n, cls in data.classes.items() if cls.imported}
imports = dict(pre_imports)
imports_impl = {}
classes_used_by = {}
cfg_classes = []
generated_import_prefix = get_import(out, opts.root_dir)
registry = opts.generated_registry or pathlib.Path(
os.path.commonpath((out, stub_out, test_out)), ".generated.list"
)
with renderer.manage(
generated=generated, stubs=stubs, registry=registry, force=opts.force
) as renderer:
db_classes = [
cls for name, cls in classes.items() if not data.classes[name].synth
]
renderer.render(
ql.DbClasses(classes=db_classes, imports=sorted(set(pre_imports.values()))),
out / "Raw.qll",
)
classes_by_dir_and_name = sorted(
classes.values(), key=lambda cls: (cls.dir, cls.name)
)
for c in classes_by_dir_and_name:
path = get_import(
(
stub_out / c.dir / "internal" / c.name
if c.internal
else stub_out / c.path
),
opts.root_dir,
)
imports[c.name] = path
path_impl = get_import(
stub_out / c.dir / "internal" / c.name, opts.root_dir
)
imports_impl[c.name + "Impl"] = path_impl + "Impl"
if c.cfg:
cfg_classes.append(resolver.get_ql_cfg_class(c))
for c in classes.values():
qll = out / c.path.with_suffix(".qll")
c.imports = [
imports[t] if t in imports else imports_impl[t] + "::Impl as " + t
for t in get_classes_used_by(c, is_impl=True)
]
classes_used_by[c.name] = get_classes_used_by(c, is_impl=False)
c.import_prefix = generated_import_prefix
renderer.render(c, qll)
if cfg_out:
cfg_classes_val = ql.CfgClasses(
include_file_import=get_import(include_file, opts.root_dir),
classes=cfg_classes,
)
cfg_qll = cfg_out / "CfgNodes.qll"
renderer.render(cfg_classes_val, cfg_qll)
for c in data.classes.values():
if c.imported:
continue
path = _get_path(c)
path_impl = _get_path_impl(c)
stub_file = stub_out / path_impl
base_import = get_import(out / path, opts.root_dir)
stub = _get_stub(c, base_import, generated_import_prefix)
if not renderer.is_customized_stub(stub_file):
renderer.render(stub, stub_file)
else:
qldoc = renderer.render_str(stub, template="ql_stub_class_qldoc")
_patch_class_qldoc(c.name, qldoc, stub_file)
class_public = _get_class_public(c)
path_public = _get_path_public(c)
class_public_file = stub_out / path_public
class_public.imports = [imports[t] for t in classes_used_by[c.name]]
renderer.render(class_public, class_public_file)
# for example path/to/elements -> path/to/elements.qll
renderer.render(
ql.ImportList(
[
i
for name, i in imports.items()
if name not in classes or not classes[name].internal
]
),
include_file,
)
elements_module = get_import(include_file, opts.root_dir)
renderer.render(
ql.GetParentImplementation(
classes=list(classes.values()),
imports=[elements_module]
+ [
i
for name, i in imports.items()
if name in classes and classes[name].internal
],
),
out / "ParentChild.qll",
)
if test_out:
for c in data.classes.values():
if c.imported:
continue
if resolver.should_skip_qltest(c):
continue
test_with_name = c.pragmas.get("qltest_test_with")
test_with = data.classes[test_with_name] if test_with_name else c
test_dir = test_out / test_with.group / test_with.name
test_dir.mkdir(parents=True, exist_ok=True)
if all(
f.suffix in (".txt", ".ql", ".actual", ".expected")
for f in test_dir.glob("*.*")
):
log.warning(f"no test source in {test_dir.relative_to(test_out)}")
renderer.render(
ql.MissingTestInstructions(),
test_dir / missing_test_source_filename,
)
continue
renderer.render(
ql.ClassTester(
class_name=c.name,
properties=list(resolver.get_all_properties_to_be_tested(c)),
elements_module=elements_module,
# in case of collapsed hierarchies we want to see the actual QL class in results
show_ql_class="qltest_collapse_hierarchy" in c.pragmas,
),
test_dir / f"{c.name}.ql",
)
final_synth_types = []
non_final_synth_types = []
constructor_imports = []
synth_constructor_imports = []
stubs = {}
for cls in sorted(
(cls for cls in data.classes.values() if not cls.imported),
key=lambda cls: (cls.group, cls.name),
):
synth_type = get_ql_synth_class(cls)
if synth_type.is_final:
final_synth_types.append(synth_type)
if synth_type.has_params:
stub_file = (
stub_out / cls.group / "internal" / f"{cls.name}Constructor.qll"
)
if not renderer.is_customized_stub(stub_file):
# stub rendering must be postponed as we might not have yet all subtracted synth types in `synth_type`
stubs[stub_file] = ql.Synth.ConstructorStub(
synth_type, import_prefix=generated_import_prefix
)
constructor_import = get_import(stub_file, opts.root_dir)
constructor_imports.append(constructor_import)
if synth_type.is_synth:
synth_constructor_imports.append(constructor_import)
else:
non_final_synth_types.append(synth_type)
for stub_file, data in stubs.items():
renderer.render(data, stub_file)
renderer.render(
ql.Synth.Types(
root.name,
generated_import_prefix,
final_synth_types,
non_final_synth_types,
),
out / "Synth.qll",
)
renderer.render(
ql.ImportList(constructor_imports), out / "SynthConstructors.qll"
)
renderer.render(
ql.ImportList(synth_constructor_imports), out / "PureSynthConstructors.qll"
)
if opts.ql_format:
format(opts.codeql_binary, renderer.written)