Files
codeql/misc/codegen/generators/rusttestgen.py
Paolo Tranquilli 1dcd60527c Codegen: improve implementation of generated parent/child relationship
This improves the implementation of the generated parent/child
relationship by adding a new `all_children` field to `ql.Class` which
lists all children (both direct and inherited) of a class, carefully
avoiding duplicating children in case of diamond inheritance. This:
* simplifies the generated code,
* avoid children ambiguities in case of diamond inheritance.

This only comes with some changes in the order of children in the
generated tests (we were previously sorting bases alphabetically there).
For the rest this should be a non-functional change.
2025-06-24 17:26:24 +02:00

94 lines
2.7 KiB
Python

import dataclasses
import typing
from collections.abc import Iterable
import inflection
from misc.codegen.loaders import schemaloader
from . import qlgen
@dataclasses.dataclass
class Param:
name: str
type: str
first: bool = False
@dataclasses.dataclass
class Function:
name: str
signature: str
@dataclasses.dataclass
class TestCode:
template: typing.ClassVar[str] = "rust_test_code"
code: str
function: Function | None = None
def _get_code(doc: list[str]) -> list[str]:
adding_code = False
has_code = False
code = []
for line in doc:
match line, adding_code:
case ("```", _) | ("```rust", _):
adding_code = not adding_code
has_code = True
case _, False:
code.append(f"// {line}")
case _, True:
code.append(line)
assert not adding_code, "Unterminated code block in docstring:\n " + "\n ".join(
doc
)
if has_code:
return code
return []
def generate(opts, renderer):
assert opts.ql_test_output
schema = schemaloader.load_file(opts.schema)
with renderer.manage(
generated=opts.ql_test_output.rglob("gen_*.rs"),
stubs=(),
registry=opts.ql_test_output / ".generated_tests.list",
force=opts.force,
) as renderer:
resolver = qlgen.Resolver(schema.classes)
for cls in schema.classes.values():
if cls.imported:
continue
if resolver.should_skip_qltest(cls) or "rust_skip_doc_test" in cls.pragmas:
continue
code = _get_code(cls.doc)
for p in schema.iter_properties(cls.name):
if "rust_skip_doc_test" in p.pragmas:
continue
property_code = _get_code(p.description)
if property_code:
code.append(f"// # {p.name}")
code += property_code
if not code:
continue
test_name = inflection.underscore(cls.name)
signature = cls.pragmas.get("rust_doc_test_signature", "() -> ()")
fn = signature and Function(f"test_{test_name}", signature)
if fn:
indent = 4 * " "
code = [indent + l for l in code]
test_with_name = typing.cast(str, cls.pragmas.get("qltest_test_with"))
test_with = schema.classes[test_with_name] if test_with_name else cls
test = (
opts.ql_test_output
/ test_with.group
/ test_with.name
/ f"gen_{test_name}.rs"
)
renderer.render(TestCode(code="\n".join(code), function=fn), test)