diff --git a/misc/codegen/generators/rusttestgen.py b/misc/codegen/generators/rusttestgen.py index a44c5f2b241..b47d6b8725a 100644 --- a/misc/codegen/generators/rusttestgen.py +++ b/misc/codegen/generators/rusttestgen.py @@ -1,5 +1,7 @@ import dataclasses import typing +from collections.abc import Iterable + import inflection from misc.codegen.loaders import schemaloader @@ -27,6 +29,25 @@ class TestCode: function: Function | None = None +def _get_code(doc: list[str]) -> list[str]: + adding_code = False + has_code = False + code = [] + for line in doc: + match line, adding_code: + case ("```", _) | ("```rust", _): + adding_code = not adding_code + has_code = True + case _, False: + code.append(f"// {line}") + case _, True: + code.append(line) + assert not adding_code, "Unterminated code block in docstring:\n " + "\n ".join(doc) + if has_code: + return code + return [] + + def generate(opts, renderer): assert opts.ql_test_output schema = schemaloader.load_file(opts.schema) @@ -36,24 +57,18 @@ def generate(opts, renderer): force=opts.force) as renderer: for cls in schema.classes.values(): if (qlgen.should_skip_qltest(cls, schema.classes) or - "rust_skip_test_from_doc" in cls.pragmas or - not cls.doc): + "rust_skip_doc_test" in cls.pragmas): continue - code = [] - adding_code = False - has_code = False - for line in cls.doc: - match line, adding_code: - case ("```", _) | ("```rust", _): - adding_code = not adding_code - has_code = True - case _, False: - code.append(f"// {line}") - case _, True: - code.append(line) - if not has_code: + code = _get_code(cls.doc) + for p in schema.iter_properties(cls.name): + if "rust_skip_doc_test" in p.pragmas: + continue + property_code = _get_code(p.description) + if property_code: + code.append(f"// # {p.name}") + code += property_code + if not code: continue - assert not adding_code, "Unterminated code block in docstring: " + "\n".join(cls.doc) test_name = inflection.underscore(cls.name) signature = cls.pragmas.get("rust_doc_test_signature", "() -> ()") fn = signature and Function(f"test_{test_name}", signature) diff --git a/misc/codegen/lib/schema.py b/misc/codegen/lib/schema.py index bdaaba32c20..5ee769ca3f2 100644 --- a/misc/codegen/lib/schema.py +++ b/misc/codegen/lib/schema.py @@ -1,6 +1,7 @@ """ schema format representation """ import abc import typing +from collections.abc import Iterable from dataclasses import dataclass, field from typing import List, Set, Union, Dict, Optional from enum import Enum, auto @@ -143,6 +144,12 @@ class Schema: def null_class(self): return self.classes[self.null] if self.null else None + def iter_properties(self, cls: str) -> Iterable[Property]: + cls = self.classes[cls] + for b in cls.bases: + yield from self.iter_properties(b) + yield from cls.properties + predicate_marker = object()