Rust: simplify rust doc test annotation

This commit is contained in:
Paolo Tranquilli
2024-09-09 08:59:17 +02:00
parent 928f3f11f1
commit 3cd8aaf4b0
6 changed files with 20 additions and 37 deletions

View File

@@ -36,7 +36,7 @@ def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
else:
table_name = inflection.tableize(table_name)
args = dict(
field_name=p.name + ("_" if p.name in rust.keywords else ""),
field_name=rust.avoid_keywords(p.name),
base_type=_get_type(p.type),
is_optional=p.is_optional,
is_repeated=p.is_repeated,

View File

@@ -1,5 +1,6 @@
import dataclasses
import typing
import inflection
from misc.codegen.loaders import schemaloader
from . import qlgen
@@ -15,19 +16,7 @@ class Param:
@dataclasses.dataclass
class Function:
name: str
generic_params: list[Param]
params: list[Param]
return_type: str
def __post_init__(self):
if self.generic_params:
self.generic_params[0].first = True
if self.params:
self.params[0].first = True
@property
def has_generic_params(self) -> bool:
return bool(self.generic_params)
signature: str
@dataclasses.dataclass
@@ -48,27 +37,28 @@ def generate(opts, renderer):
for cls in schema.classes.values():
if (qlgen.should_skip_qltest(cls, schema.classes) or
"rust_skip_test_from_doc" in cls.pragmas or
not cls.doc
):
not cls.doc):
continue
fn = cls.rust_doc_test_function
if fn:
generic_params = [Param(k, v) for k, v in fn.params.items() if k[0].isupper() or k[0] == "'"]
params = [Param(k, v) for k, v in fn.params.items() if k[0].islower()]
fn = Function(fn.name, generic_params, params, fn.return_type)
code = []
adding_code = False
has_code = False
for line in cls.doc:
match line, adding_code:
case "```", _:
adding_code = not adding_code
has_code = True
case _, False:
code.append(f"// {line}")
case _, True:
code.append(line)
if not has_code:
continue
test_name = inflection.underscore(cls.name)
signature = cls.rust_doc_test_function
fn = signature and Function(f"test_{test_name}", signature)
if fn:
indent = 4 * " "
code = [indent + l for l in code]
test_with = schema.classes[cls.test_with] if cls.test_with else cls
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{cls.name.lower()}.rs"
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{test_name}.rs"
renderer.render(TestCode(code="\n".join(code), function=fn), test)