Rust: take test code also from property descriptions

This commit is contained in:
Paolo Tranquilli
2024-09-20 15:12:13 +02:00
parent d7aa5f1022
commit 2a95068a0a
2 changed files with 38 additions and 16 deletions

View File

@@ -1,5 +1,7 @@
import dataclasses
import typing
from collections.abc import Iterable
import inflection
from misc.codegen.loaders import schemaloader
@@ -27,6 +29,25 @@ class TestCode:
function: Function | None = None
def _get_code(doc: list[str]) -> list[str]:
adding_code = False
has_code = False
code = []
for line in doc:
match line, adding_code:
case ("```", _) | ("```rust", _):
adding_code = not adding_code
has_code = True
case _, False:
code.append(f"// {line}")
case _, True:
code.append(line)
assert not adding_code, "Unterminated code block in docstring:\n " + "\n ".join(doc)
if has_code:
return code
return []
def generate(opts, renderer):
assert opts.ql_test_output
schema = schemaloader.load_file(opts.schema)
@@ -36,24 +57,18 @@ def generate(opts, renderer):
force=opts.force) as renderer:
for cls in schema.classes.values():
if (qlgen.should_skip_qltest(cls, schema.classes) or
"rust_skip_test_from_doc" in cls.pragmas or
not cls.doc):
"rust_skip_doc_test" in cls.pragmas):
continue
code = []
adding_code = False
has_code = False
for line in cls.doc:
match line, adding_code:
case ("```", _) | ("```rust", _):
adding_code = not adding_code
has_code = True
case _, False:
code.append(f"// {line}")
case _, True:
code.append(line)
if not has_code:
code = _get_code(cls.doc)
for p in schema.iter_properties(cls.name):
if "rust_skip_doc_test" in p.pragmas:
continue
property_code = _get_code(p.description)
if property_code:
code.append(f"// # {p.name}")
code += property_code
if not code:
continue
assert not adding_code, "Unterminated code block in docstring: " + "\n".join(cls.doc)
test_name = inflection.underscore(cls.name)
signature = cls.pragmas.get("rust_doc_test_signature", "() -> ()")
fn = signature and Function(f"test_{test_name}", signature)

View File

@@ -1,6 +1,7 @@
""" schema format representation """
import abc
import typing
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import List, Set, Union, Dict, Optional
from enum import Enum, auto
@@ -143,6 +144,12 @@ class Schema:
def null_class(self):
return self.classes[self.null] if self.null else None
def iter_properties(self, cls: str) -> Iterable[Property]:
cls = self.classes[cls]
for b in cls.bases:
yield from self.iter_properties(b)
yield from cls.properties
predicate_marker = object()