mirror of
https://github.com/github/codeql.git
synced 2025-12-16 16:53:25 +01:00
This improves the implementation of the generated parent/child relationship by adding a new `all_children` field to `ql.Class` which lists all children (both direct and inherited) of a class, carefully avoiding duplicating children in case of diamond inheritance. This: * simplifies the generated code, * avoid children ambiguities in case of diamond inheritance. This only comes with some changes in the order of children in the generated tests (we were previously sorting bases alphabetically there). For the rest this should be a non-functional change.
94 lines
2.7 KiB
Python
94 lines
2.7 KiB
Python
import dataclasses
|
|
import typing
|
|
from collections.abc import Iterable
|
|
|
|
import inflection
|
|
|
|
from misc.codegen.loaders import schemaloader
|
|
from . import qlgen
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Param:
|
|
name: str
|
|
type: str
|
|
first: bool = False
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Function:
|
|
name: str
|
|
signature: str
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TestCode:
|
|
template: typing.ClassVar[str] = "rust_test_code"
|
|
|
|
code: str
|
|
function: Function | None = None
|
|
|
|
|
|
def _get_code(doc: list[str]) -> list[str]:
|
|
adding_code = False
|
|
has_code = False
|
|
code = []
|
|
for line in doc:
|
|
match line, adding_code:
|
|
case ("```", _) | ("```rust", _):
|
|
adding_code = not adding_code
|
|
has_code = True
|
|
case _, False:
|
|
code.append(f"// {line}")
|
|
case _, True:
|
|
code.append(line)
|
|
assert not adding_code, "Unterminated code block in docstring:\n " + "\n ".join(
|
|
doc
|
|
)
|
|
if has_code:
|
|
return code
|
|
return []
|
|
|
|
|
|
def generate(opts, renderer):
|
|
assert opts.ql_test_output
|
|
schema = schemaloader.load_file(opts.schema)
|
|
with renderer.manage(
|
|
generated=opts.ql_test_output.rglob("gen_*.rs"),
|
|
stubs=(),
|
|
registry=opts.ql_test_output / ".generated_tests.list",
|
|
force=opts.force,
|
|
) as renderer:
|
|
|
|
resolver = qlgen.Resolver(schema.classes)
|
|
for cls in schema.classes.values():
|
|
if cls.imported:
|
|
continue
|
|
if resolver.should_skip_qltest(cls) or "rust_skip_doc_test" in cls.pragmas:
|
|
continue
|
|
code = _get_code(cls.doc)
|
|
for p in schema.iter_properties(cls.name):
|
|
if "rust_skip_doc_test" in p.pragmas:
|
|
continue
|
|
property_code = _get_code(p.description)
|
|
if property_code:
|
|
code.append(f"// # {p.name}")
|
|
code += property_code
|
|
if not code:
|
|
continue
|
|
test_name = inflection.underscore(cls.name)
|
|
signature = cls.pragmas.get("rust_doc_test_signature", "() -> ()")
|
|
fn = signature and Function(f"test_{test_name}", signature)
|
|
if fn:
|
|
indent = 4 * " "
|
|
code = [indent + l for l in code]
|
|
test_with_name = typing.cast(str, cls.pragmas.get("qltest_test_with"))
|
|
test_with = schema.classes[test_with_name] if test_with_name else cls
|
|
test = (
|
|
opts.ql_test_output
|
|
/ test_with.group
|
|
/ test_with.name
|
|
/ f"gen_{test_name}.rs"
|
|
)
|
|
renderer.render(TestCode(code="\n".join(code), function=fn), test)
|