mirror of
https://github.com/github/codeql.git
synced 2026-05-01 19:55:15 +02:00
Rust: generate test code from schema docstrings
This generates test source files from code blocks in class docstrings.
By default the test code is generated as is, but it can optionally:
* be wrapped in a function providing an adequate context using
`@rust.doc_test_function(name, *, lifetimes=(), return_type="()", **kwargs)`,
with `kwargs` providing both generic and normal params depending on
capitalization
* be skipped altogether using `@rust.skip_doc_test`
So for example an annotation like
```python
@rust.doc_test_function("foo",
lifetimes=("a",),
T="Eq",
x="&'a T",
y="&'a T",
return_type="&'a T")
```
will result in the following wrapper:
```rust
fn foo<'a, T: Eq>(x: &'a T, y: &'a T) -> &'a T {
// example code here
}
```
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from . import dbschemegen, qlgen, trapgen, cppgen, rustgen
|
||||
from . import dbschemegen, trapgen, cppgen, rustgen, rusttestgen, qlgen
|
||||
|
||||
|
||||
def generate(target, opts, renderer):
|
||||
|
||||
@@ -287,7 +287,7 @@ def _is_under_qltest_collapsed_hierarchy(cls: schema.Class, lookup: typing.Dict[
|
||||
_is_in_qltest_collapsed_hierarchy(lookup[b], lookup) for b in cls.bases)
|
||||
|
||||
|
||||
def _should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
|
||||
def should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
|
||||
return "qltest_skip" in cls.pragmas or not (
|
||||
cls.final or "qltest_collapse_hierarchy" in cls.pragmas) or _is_under_qltest_collapsed_hierarchy(
|
||||
cls, lookup)
|
||||
@@ -413,7 +413,7 @@ def generate(opts, renderer):
|
||||
|
||||
if test_out:
|
||||
for c in data.classes.values():
|
||||
if _should_skip_qltest(c, data.classes):
|
||||
if should_skip_qltest(c, data.classes):
|
||||
continue
|
||||
test_with = data.classes[c.test_with] if c.test_with else c
|
||||
test_dir = test_out / test_with.group / test_with.name
|
||||
|
||||
@@ -86,20 +86,24 @@ def generate(opts, renderer):
|
||||
processor = Processor(schemaloader.load_file(opts.schema))
|
||||
out = opts.rust_output
|
||||
groups = set()
|
||||
for group, classes in processor.get_classes().items():
|
||||
group = group or "top"
|
||||
groups.add(group)
|
||||
with renderer.manage(generated=out.rglob("*.rs"),
|
||||
stubs=(),
|
||||
registry=opts.generated_registry,
|
||||
force=opts.force) as renderer:
|
||||
for group, classes in processor.get_classes().items():
|
||||
group = group or "top"
|
||||
groups.add(group)
|
||||
renderer.render(
|
||||
rust.ClassList(
|
||||
classes,
|
||||
opts.schema,
|
||||
),
|
||||
out / f"{group}.rs",
|
||||
)
|
||||
renderer.render(
|
||||
rust.ClassList(
|
||||
classes,
|
||||
rust.ModuleList(
|
||||
groups,
|
||||
opts.schema,
|
||||
),
|
||||
out / f"{group}.rs",
|
||||
out / f"mod.rs",
|
||||
)
|
||||
renderer.render(
|
||||
rust.ModuleList(
|
||||
groups,
|
||||
opts.schema,
|
||||
),
|
||||
out / f"mod.rs",
|
||||
)
|
||||
|
||||
70
misc/codegen/generators/rusttestgen.py
Normal file
70
misc/codegen/generators/rusttestgen.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import dataclasses
|
||||
import typing
|
||||
|
||||
from misc.codegen.loaders import schemaloader
|
||||
from . import qlgen
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Param:
|
||||
name: str
|
||||
type: str
|
||||
first: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Function:
|
||||
name: str
|
||||
generic_params: list[Param]
|
||||
params: list[Param]
|
||||
return_type: str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.generic_params:
|
||||
self.generic_params[0].first = True
|
||||
if self.params:
|
||||
self.params[0].first = True
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestCode:
|
||||
template: typing.ClassVar[str] = "rust_test_code"
|
||||
|
||||
code: str
|
||||
function: Function | None = None
|
||||
|
||||
|
||||
def generate(opts, renderer):
|
||||
assert opts.ql_test_output
|
||||
schema = schemaloader.load_file(opts.schema)
|
||||
with renderer.manage(generated=opts.ql_test_output.rglob("gen_*.rs"),
|
||||
stubs=(),
|
||||
registry=opts.generated_registry,
|
||||
force=opts.force) as renderer:
|
||||
for cls in schema.classes.values():
|
||||
if (qlgen.should_skip_qltest(cls, schema.classes) or
|
||||
"rust_skip_test_from_doc" in cls.pragmas or
|
||||
not cls.doc
|
||||
):
|
||||
continue
|
||||
fn = cls.rust_doc_test_function
|
||||
if fn:
|
||||
generic_params = [Param(k, v) for k, v in fn.params.items() if k[0].isupper() or k[0] == "'"]
|
||||
params = [Param(k, v) for k, v in fn.params.items() if k[0].islower()]
|
||||
fn = Function(fn.name, generic_params, params, fn.return_type)
|
||||
code = []
|
||||
adding_code = False
|
||||
for line in cls.doc:
|
||||
match line, adding_code:
|
||||
case "```", _:
|
||||
adding_code = not adding_code
|
||||
case _, False:
|
||||
code.append(f"// {line}")
|
||||
case _, True:
|
||||
code.append(line)
|
||||
if fn:
|
||||
indent = 4 * " "
|
||||
code = [indent + l for l in code]
|
||||
test_with = schema.classes[cls.test_with] if cls.test_with else cls
|
||||
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{cls.name.lower()}.rs"
|
||||
renderer.render(TestCode(code="\n".join(code), function=fn), test)
|
||||
Reference in New Issue
Block a user