Yeast: Add one-shot phase kind

This commit is contained in:
Asger F
2026-05-08 12:02:56 +02:00
parent a049850c51
commit c3a9218dcf
4 changed files with 323 additions and 29 deletions

View File

@@ -349,8 +349,8 @@ to enable rewriting:
```rust
let desugar = yeast::DesugaringConfig::new()
.add_phase("cleanup", cleanup_rules())
.add_phase("desugar", desugar_rules())
.add_phase("cleanup", yeast::PhaseKind::Repeating, cleanup_rules())
.add_phase("translate", yeast::PhaseKind::OneShot, translate_rules())
.with_output_node_types_yaml(include_str!("output-node-types.yml"));
let lang = simple::LanguageSpec {
@@ -365,6 +365,15 @@ let lang = simple::LanguageSpec {
A single-phase config is just `.add_phase(...)` called once. Phase names
appear in error messages so you can tell which phase failed.
There are two kinds of phases:
- **Repeating**:
Each node is re-processed until none of the rules in the phase matches.
When a node no longer matches any rules, its children are recursively processed. In practice this is used to desugar or simplify an AST, while staying mostly within the same schema.
- **One-shot**:
Each node is processed by the first matching rule, and the engine panics if no rule matches.
Rules are then recursively applied to every captured node.
In practice this is used when translating from one AST schema to another, where an exhaustive match is required.
The same YAML node-types is used for both the runtime yeast `Schema` (so
rules can refer to output-only kinds and fields) and TRAP validation (it
is converted to JSON internally).

View File

@@ -61,6 +61,21 @@ impl Captures {
}
}
}
/// Apply a fallible function to every captured id (across all keys),
/// replacing each id with the result. Stops and returns the error on
/// the first failure.
pub fn try_map_all_captures<E>(
&mut self,
mut f: impl FnMut(Id) -> Result<Id, E>,
) -> Result<(), E> {
for ids in self.captures.values_mut() {
for id in ids {
*id = f(*id)?;
}
}
Ok(())
}
pub fn map_captures_to(&mut self, from: &str, to: &'static str, f: &mut impl FnMut(Id) -> Id) {
if let Some(from_ids) = self.captures.get(from) {
let new_values = from_ids.iter().copied().map(f).collect();

View File

@@ -526,18 +526,39 @@ impl Rule {
node: Id,
fresh: &tree_builder::FreshScope,
) -> Result<Option<Vec<Id>>, String> {
match self.try_match(ast, node)? {
Some(captures) => Ok(Some(self.run_transform(ast, captures, node, fresh))),
None => Ok(None),
}
}
/// Attempt to match this rule's query against `node`, returning the
/// resulting captures on success. Does not invoke the transform.
fn try_match(&self, ast: &Ast, node: Id) -> Result<Option<Captures>, String> {
let mut captures = Captures::new();
if self.query.do_match(ast, node, &mut captures)? {
fresh.next_scope();
let source_range = ast.get_node(node).and_then(|n| match n.content {
NodeContent::Range(r) => Some(r),
_ => n.source_range,
});
Ok(Some((self.transform)(ast, captures, fresh, source_range)))
Ok(Some(captures))
} else {
Ok(None)
}
}
/// Run this rule's transform with the given captures, using `node`'s
/// source range as the source range of the produced nodes.
fn run_transform(
&self,
ast: &mut Ast,
captures: Captures,
node: Id,
fresh: &tree_builder::FreshScope,
) -> Vec<Id> {
fresh.next_scope();
let source_range = ast.get_node(node).and_then(|n| match n.content {
NodeContent::Range(r) => Some(r),
_ => n.source_range,
});
(self.transform)(ast, captures, fresh, source_range)
}
}
const MAX_REWRITE_DEPTH: usize = 100;
@@ -572,17 +593,17 @@ impl<'a> RuleIndex<'a> {
}
}
fn apply_rules(
fn apply_repeating_rules(
rules: &[Rule],
ast: &mut Ast,
id: Id,
fresh: &tree_builder::FreshScope,
) -> Result<Vec<Id>, String> {
let index = RuleIndex::new(rules);
apply_rules_inner(&index, ast, id, fresh, 0, None)
apply_repeating_rules_inner(&index, ast, id, fresh, 0, None)
}
fn apply_rules_inner(
fn apply_repeating_rules_inner(
index: &RuleIndex,
ast: &mut Ast,
id: Id,
@@ -611,7 +632,7 @@ fn apply_rules_inner(
let next_skip = if rule.repeated { None } else { Some(rule_ptr) };
let mut results = Vec::new();
for node in result_node {
results.extend(apply_rules_inner(
results.extend(apply_repeating_rules_inner(
index,
ast,
node,
@@ -636,7 +657,7 @@ fn apply_rules_inner(
for children in fields.values_mut() {
let mut new_children: Option<Vec<Id>> = None;
for (i, &child_id) in children.iter().enumerate() {
let result = apply_rules_inner(index, ast, child_id, fresh, rewrite_depth, None)?;
let result = apply_repeating_rules_inner(index, ast, child_id, fresh, rewrite_depth, None)?;
let unchanged = result.len() == 1 && result[0] == child_id;
match (&mut new_children, unchanged) {
(None, true) => {} // unchanged so far, no allocation needed
@@ -661,6 +682,75 @@ fn apply_rules_inner(
Ok(vec![id])
}
/// Apply rules using `OneShot` semantics: the first matching rule fires on
/// each visited node, recursion proceeds only through captured nodes (not
/// through the input node's children directly), and an error is returned if
/// no rule matches a visited node.
fn apply_one_shot_rules(
rules: &[Rule],
ast: &mut Ast,
id: Id,
fresh: &tree_builder::FreshScope,
) -> Result<Vec<Id>, String> {
let index = RuleIndex::new(rules);
apply_one_shot_rules_inner(&index, ast, id, fresh, 0)
}
fn apply_one_shot_rules_inner(
index: &RuleIndex,
ast: &mut Ast,
id: Id,
fresh: &tree_builder::FreshScope,
rewrite_depth: usize,
) -> Result<Vec<Id>, String> {
if rewrite_depth > MAX_REWRITE_DEPTH {
return Err(format!(
"Desugaring exceeded maximum rewrite depth ({MAX_REWRITE_DEPTH}). \
This likely indicates a non-terminating rule cycle."
));
}
let node_kind = ast.get_node(id).map(|n| n.kind()).unwrap_or("");
for rule in index.rules_for_kind(node_kind) {
if let Some(mut captures) = rule.try_match(ast, id)? {
// Recursively translate every captured node before invoking the
// transform. The transform's output uses output-schema kinds, so
// we must translate captured input-schema nodes to their
// output-schema equivalents first.
captures.try_map_all_captures(|captured_id| {
let result =
apply_one_shot_rules_inner(index, ast, captured_id, fresh, rewrite_depth + 1)?;
if result.len() != 1 {
return Err(format!(
"OneShot: recursion on captured node produced {} results, expected exactly 1",
result.len()
));
}
Ok(result[0])
})?;
return Ok(rule.run_transform(ast, captures, id, fresh));
}
}
Err(format!(
"OneShot: no rule matched node of kind '{node_kind}'"
))
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PhaseKind {
/// A node is re-processed until none of the rules in the phase matches,
/// albeit a single rule cannot be applied twice in a row unless that rule is also marked as repeating.
/// When a node no longer matches any rules, its children are recursively processed (top down).
Repeating,
/// A node is processed by the first matching rule, and the engine panics if no rule matches.
/// Rules are then recursively applied to every captured node.
/// In practice this is used when translating from one AST schema to another, where every node must be rewritten,
/// and it would be a type error to match the rule patterns (based on the input schema) against the output nodes (which conform to the output schema).
OneShot,
}
/// One phase of a desugaring pass: a named bundle of rules that runs to
/// completion (a full traversal applying its rules) before the next phase
/// starts. Rules within a phase compete for matches as usual; rules in
@@ -670,13 +760,15 @@ pub struct Phase {
/// Name used in error messages.
pub name: String,
pub rules: Vec<Rule>,
pub kind: PhaseKind,
}
impl Phase {
pub fn new(name: impl Into<String>, rules: Vec<Rule>) -> Self {
pub fn new(name: impl Into<String>, kind: PhaseKind, rules: Vec<Rule>) -> Self {
Self {
name: name.into(),
rules,
kind,
}
}
}
@@ -694,8 +786,8 @@ impl Phase {
///
/// ```ignore
/// let config = yeast::DesugaringConfig::new()
/// .add_phase("cleanup", cleanup_rules)
/// .add_phase("desugar", desugar_rules)
/// .add_phase("cleanup", PhaseKind::Repeating, cleanup_rules)
/// .add_phase("desugar", PhaseKind::Repeating, desugar_rules)
/// .with_output_node_types_yaml(yaml);
/// ```
#[derive(Default)]
@@ -715,9 +807,14 @@ impl DesugaringConfig {
Self::default()
}
/// Append a new phase with the given name and rules.
pub fn add_phase(mut self, name: impl Into<String>, rules: Vec<Rule>) -> Self {
self.phases.push(Phase::new(name, rules));
/// Append a new phase with the given name, kind, and rules.
pub fn add_phase(
mut self,
name: impl Into<String>,
kind: PhaseKind,
rules: Vec<Rule>,
) -> Self {
self.phases.push(Phase::new(name, kind, rules));
self
}
@@ -806,8 +903,11 @@ impl<'a> Runner<'a> {
let fresh = tree_builder::FreshScope::new();
let mut root = ast.get_root();
for phase in self.phases {
let res = apply_rules(&phase.rules, ast, root, &fresh)
.map_err(|e| format!("Phase `{}`: {e}", phase.name))?;
let res = match phase.kind {
PhaseKind::Repeating => apply_repeating_rules(&phase.rules, ast, root, &fresh),
PhaseKind::OneShot => apply_one_shot_rules(&phase.rules, ast, root, &fresh),
}
.map_err(|e| format!("Phase `{}`: {e}", phase.name))?;
if res.len() != 1 {
return Err(format!(
"Phase `{}`: expected exactly one result node, got {}",

View File

@@ -15,7 +15,7 @@ fn parse_and_dump(input: &str) -> String {
/// Helper: parse Ruby source with a custom output schema and a single
/// phase of rules, return dump.
fn run_and_dump(input: &str, rules: Vec<Rule>) -> String {
run_phased_and_dump(input, vec![Phase::new("test", rules)])
run_phased_and_dump(input, vec![Phase::new("test", PhaseKind::Repeating, rules)])
}
/// Helper: parse Ruby source with a custom output schema and multiple
@@ -35,7 +35,7 @@ fn run_and_get_error(input: &str, rules: Vec<Rule>) -> String {
let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into();
let schema =
yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang).unwrap();
let phases = vec![Phase::new("test", rules)];
let phases = vec![Phase::new("test", PhaseKind::Repeating, rules)];
let runner = Runner::with_schema(lang, &schema, &phases);
runner
.run(input)
@@ -65,7 +65,7 @@ fn parse_and_dump_typed_with_language(input: &str, schema_yaml: &str) -> String
fn run_and_dump_typed(input: &str, rules: Vec<Rule>, schema_yaml: &str) -> String {
let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into();
let schema = yeast::node_types_yaml::schema_from_yaml(schema_yaml).unwrap();
let phases = vec![Phase::new("test", rules)];
let phases = vec![Phase::new("test", PhaseKind::Repeating, rules)];
let runner = Runner::with_schema(lang, &schema, &phases);
let ast = runner.run(input).unwrap();
dump_ast_with_type_errors(&ast, ast.get_root(), input, &schema)
@@ -279,8 +279,12 @@ fn test_reachable_nodes_excludes_orphaned_rewrite_nodes() {
let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into();
let schema = yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang)
.unwrap();
let rules = vec![yeast::rule!((integer) => (identifier "replaced"))];
let runner = Runner::with_schema(lang, &schema, &rules);
let phases = vec![Phase::new(
"test",
PhaseKind::Repeating,
vec![yeast::rule!((integer) => (identifier "replaced"))],
)];
let runner = Runner::with_schema(lang, &schema, &phases);
let input = "x = 1";
let ast = runner.run(input).unwrap();
@@ -783,8 +787,8 @@ fn test_phased_desugaring() {
let dump = run_phased_and_dump(
"x = 1",
vec![
Phase::new("cleanup", cleanup),
Phase::new("desugar", desugar),
Phase::new("cleanup", PhaseKind::Repeating, cleanup),
Phase::new("desugar", PhaseKind::Repeating, desugar),
],
);
assert_dump_eq(
@@ -805,7 +809,11 @@ fn test_phase_error_includes_phase_name() {
let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into();
let schema =
yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang).unwrap();
let phases = vec![Phase::new("buggy", vec![swap_assignment_rule().repeated()])];
let phases = vec![Phase::new(
"buggy",
PhaseKind::Repeating,
vec![swap_assignment_rule().repeated()],
)];
let runner = Runner::with_schema(lang, &schema, &phases);
let err = runner
.run("x = 1")
@@ -820,6 +828,168 @@ fn test_phase_error_includes_phase_name() {
);
}
/// Helper: an exhaustive set of OneShot rules covering every node reachable
/// (via captures) when translating `"x = 1"`.
fn one_shot_xeq1_rules() -> Vec<Rule> {
vec![
yeast::rule!(
(program (_)* @stmts)
=>
(program stmt: {..stmts})
),
yeast::rule!(
(assignment left: (_) @left right: (_) @right)
=>
(first_node left: {left} right: {right})
),
yeast::rule!((identifier) => (identifier "ID")),
yeast::rule!((integer) => (integer "INT")),
]
}
#[test]
fn test_one_shot_phase() {
let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into();
let schema =
yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang).unwrap();
let phases = vec![Phase::new(
"translate",
PhaseKind::OneShot,
one_shot_xeq1_rules(),
)];
let runner = Runner::with_schema(lang, &schema, &phases);
let input = "x = 1";
let ast = runner.run(input).unwrap();
let dump = dump_ast(&ast, ast.get_root(), input);
assert_dump_eq(
&dump,
r#"
program
stmt:
first_node
left: identifier "ID"
right: integer "INT"
"#,
);
}
#[test]
fn test_one_shot_phase_errors_when_no_rule_matches() {
let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into();
let schema =
yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang).unwrap();
// Drop the `integer` rule so the recursion has no rule for `integer`.
let mut rules = one_shot_xeq1_rules();
rules.pop();
let phases = vec![Phase::new("translate", PhaseKind::OneShot, rules)];
let runner = Runner::with_schema(lang, &schema, &phases);
let err = runner
.run("x = 1")
.expect_err("expected OneShot to error on unmatched node");
assert!(
err.contains("Phase `translate`"),
"error should name the phase, got: {err}"
);
assert!(
err.contains("no rule matched") && err.contains("integer"),
"error should describe the unmatched node kind, got: {err}"
);
}
/// OneShot recursion must apply rules to *captured* nodes, even if the rule
/// returns a captured child verbatim. A buggy implementation that only
/// recurses into the children of the rule's output (rather than into the
/// captures) would leave the returned capture untransformed.
#[test]
fn test_one_shot_recurses_into_returned_capture() {
let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into();
let schema =
yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang).unwrap();
let rules = vec![
yeast::rule!(
(program (_)* @stmts)
=>
(program stmt: {..stmts})
),
// Returns the captured `left` verbatim, discarding `right`.
yeast::rule!(
(assignment left: (_) @left right: (_) @right)
=>
{left}
),
yeast::rule!((identifier) => (identifier "ID")),
yeast::rule!((integer) => (integer "INT")),
];
let phases = vec![Phase::new("translate", PhaseKind::OneShot, rules)];
let runner = Runner::with_schema(lang, &schema, &phases);
let input = "x = 1";
let ast = runner.run(input).unwrap();
let dump = dump_ast(&ast, ast.get_root(), input);
// `left` is an `identifier`; OneShot must apply the identifier rule to
// it before the assignment transform returns it verbatim.
assert_dump_eq(
&dump,
r#"
program
stmt: identifier "ID"
"#,
);
}
/// OneShot recursion must NOT descend into the children of the rule's output.
/// A rule may legitimately wrap a captured node in fresh output-schema nodes
/// that have no matching rule of their own (since rule patterns target the
/// input schema). Recursing into the output would erroneously try to find
/// rules for those wrapper kinds and fail.
#[test]
fn test_one_shot_does_not_recurse_into_wrapper_output() {
let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into();
let schema =
yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang).unwrap();
let rules = vec![
yeast::rule!(
(program (_)* @stmts)
=>
(program stmt: {..stmts})
),
// Wraps `left` in nested `first_node`/`second_node` output kinds.
// Neither wrapper kind has a matching rule, so a buggy implementation
// that recurses into the wrapper's children would error.
yeast::rule!(
(assignment left: (_) @left right: (_) @right)
=>
(first_node
left: (second_node left: {left} right: {right})
right: {left}
)
),
yeast::rule!((identifier) => (identifier "ID")),
yeast::rule!((integer) => (integer "INT")),
];
let phases = vec![Phase::new("translate", PhaseKind::OneShot, rules)];
let runner = Runner::with_schema(lang, &schema, &phases);
let input = "x = 1";
let ast = runner.run(input).unwrap();
let dump = dump_ast(&ast, ast.get_root(), input);
assert_dump_eq(
&dump,
r#"
program
stmt:
first_node
left:
second_node
left: identifier "ID"
right: integer "INT"
right: identifier "ID"
"#,
);
}
// ---- Cursor tests ----
#[test]