Rust: generate test code from schema docstrings

This generates test source files from code blocks in class docstrings.

By default the test code is generated as is, but it can optionally:
* be wrapped in a function providing an adequate context using
  `@rust.doc_test_function(name, *, lifetimes=(), return_type="()", **kwargs)`,
  with `kwargs` providing both generic and normal params depending on
  capitalization
* be skipped altogether using `@rust.skip_doc_test`

So for example an annotation like
```python
@rust.doc_test_function("foo",
                        lifetimes=("a",),
                        T="Eq",
                        x="&'a T",
                        y="&'a T",
                        return_type="&'a T")
```
will result in the following wrapper:
```rust
fn foo<'a, T: Eq>(x: &'a T, y: &'a T) -> &'a T {
    // example code here
}
```
This commit is contained in:
Paolo Tranquilli
2024-09-06 13:58:49 +02:00
parent 122e5a7598
commit 8c5cc2efdc
19 changed files with 199 additions and 35 deletions

View File

@@ -1,4 +1,4 @@
from . import dbschemegen, qlgen, trapgen, cppgen, rustgen
from . import dbschemegen, trapgen, cppgen, rustgen, rusttestgen, qlgen
def generate(target, opts, renderer):

View File

@@ -287,7 +287,7 @@ def _is_under_qltest_collapsed_hierarchy(cls: schema.Class, lookup: typing.Dict[
_is_in_qltest_collapsed_hierarchy(lookup[b], lookup) for b in cls.bases)
def _should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
def should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
return "qltest_skip" in cls.pragmas or not (
cls.final or "qltest_collapse_hierarchy" in cls.pragmas) or _is_under_qltest_collapsed_hierarchy(
cls, lookup)
@@ -413,7 +413,7 @@ def generate(opts, renderer):
if test_out:
for c in data.classes.values():
if _should_skip_qltest(c, data.classes):
if should_skip_qltest(c, data.classes):
continue
test_with = data.classes[c.test_with] if c.test_with else c
test_dir = test_out / test_with.group / test_with.name

View File

@@ -86,20 +86,24 @@ def generate(opts, renderer):
processor = Processor(schemaloader.load_file(opts.schema))
out = opts.rust_output
groups = set()
for group, classes in processor.get_classes().items():
group = group or "top"
groups.add(group)
with renderer.manage(generated=out.rglob("*.rs"),
stubs=(),
registry=opts.generated_registry,
force=opts.force) as renderer:
for group, classes in processor.get_classes().items():
group = group or "top"
groups.add(group)
renderer.render(
rust.ClassList(
classes,
opts.schema,
),
out / f"{group}.rs",
)
renderer.render(
rust.ClassList(
classes,
rust.ModuleList(
groups,
opts.schema,
),
out / f"{group}.rs",
out / f"mod.rs",
)
renderer.render(
rust.ModuleList(
groups,
opts.schema,
),
out / f"mod.rs",
)

View File

@@ -0,0 +1,70 @@
import dataclasses
import typing
from misc.codegen.loaders import schemaloader
from . import qlgen
@dataclasses.dataclass
class Param:
name: str
type: str
first: bool = False
@dataclasses.dataclass
class Function:
name: str
generic_params: list[Param]
params: list[Param]
return_type: str
def __post_init__(self):
if self.generic_params:
self.generic_params[0].first = True
if self.params:
self.params[0].first = True
@dataclasses.dataclass
class TestCode:
template: typing.ClassVar[str] = "rust_test_code"
code: str
function: Function | None = None
def generate(opts, renderer):
assert opts.ql_test_output
schema = schemaloader.load_file(opts.schema)
with renderer.manage(generated=opts.ql_test_output.rglob("gen_*.rs"),
stubs=(),
registry=opts.generated_registry,
force=opts.force) as renderer:
for cls in schema.classes.values():
if (qlgen.should_skip_qltest(cls, schema.classes) or
"rust_skip_test_from_doc" in cls.pragmas or
not cls.doc
):
continue
fn = cls.rust_doc_test_function
if fn:
generic_params = [Param(k, v) for k, v in fn.params.items() if k[0].isupper() or k[0] == "'"]
params = [Param(k, v) for k, v in fn.params.items() if k[0].islower()]
fn = Function(fn.name, generic_params, params, fn.return_type)
code = []
adding_code = False
for line in cls.doc:
match line, adding_code:
case "```", _:
adding_code = not adding_code
case _, False:
code.append(f"// {line}")
case _, True:
code.append(line)
if fn:
indent = 4 * " "
code = [indent + l for l in code]
test_with = schema.classes[cls.test_with] if cls.test_with else cls
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{cls.name.lower()}.rs"
renderer.render(TestCode(code="\n".join(code), function=fn), test)