Rust: take test code also from property descriptions

This commit is contained in:
Paolo Tranquilli
2024-09-20 15:12:13 +02:00
parent d7aa5f1022
commit 2a95068a0a
2 changed files with 38 additions and 16 deletions

View File

@@ -1,5 +1,7 @@
import dataclasses import dataclasses
import typing import typing
from collections.abc import Iterable
import inflection import inflection
from misc.codegen.loaders import schemaloader from misc.codegen.loaders import schemaloader
@@ -27,6 +29,25 @@ class TestCode:
function: Function | None = None function: Function | None = None
def _get_code(doc: list[str]) -> list[str]:
adding_code = False
has_code = False
code = []
for line in doc:
match line, adding_code:
case ("```", _) | ("```rust", _):
adding_code = not adding_code
has_code = True
case _, False:
code.append(f"// {line}")
case _, True:
code.append(line)
assert not adding_code, "Unterminated code block in docstring:\n " + "\n ".join(doc)
if has_code:
return code
return []
def generate(opts, renderer): def generate(opts, renderer):
assert opts.ql_test_output assert opts.ql_test_output
schema = schemaloader.load_file(opts.schema) schema = schemaloader.load_file(opts.schema)
@@ -36,24 +57,18 @@ def generate(opts, renderer):
force=opts.force) as renderer: force=opts.force) as renderer:
for cls in schema.classes.values(): for cls in schema.classes.values():
if (qlgen.should_skip_qltest(cls, schema.classes) or if (qlgen.should_skip_qltest(cls, schema.classes) or
"rust_skip_test_from_doc" in cls.pragmas or "rust_skip_doc_test" in cls.pragmas):
not cls.doc):
continue continue
code = [] code = _get_code(cls.doc)
adding_code = False for p in schema.iter_properties(cls.name):
has_code = False if "rust_skip_doc_test" in p.pragmas:
for line in cls.doc: continue
match line, adding_code: property_code = _get_code(p.description)
case ("```", _) | ("```rust", _): if property_code:
adding_code = not adding_code code.append(f"// # {p.name}")
has_code = True code += property_code
case _, False: if not code:
code.append(f"// {line}")
case _, True:
code.append(line)
if not has_code:
continue continue
assert not adding_code, "Unterminated code block in docstring: " + "\n".join(cls.doc)
test_name = inflection.underscore(cls.name) test_name = inflection.underscore(cls.name)
signature = cls.pragmas.get("rust_doc_test_signature", "() -> ()") signature = cls.pragmas.get("rust_doc_test_signature", "() -> ()")
fn = signature and Function(f"test_{test_name}", signature) fn = signature and Function(f"test_{test_name}", signature)

View File

@@ -1,6 +1,7 @@
""" schema format representation """ """ schema format representation """
import abc import abc
import typing import typing
from collections.abc import Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Set, Union, Dict, Optional from typing import List, Set, Union, Dict, Optional
from enum import Enum, auto from enum import Enum, auto
@@ -143,6 +144,12 @@ class Schema:
def null_class(self): def null_class(self):
return self.classes[self.null] if self.null else None return self.classes[self.null] if self.null else None
def iter_properties(self, cls: str) -> Iterable[Property]:
cls = self.classes[cls]
for b in cls.bases:
yield from self.iter_properties(b)
yield from cls.properties
predicate_marker = object() predicate_marker = object()