diff --git a/shared/yeast-macros/src/parse.rs b/shared/yeast-macros/src/parse.rs index fda419aefc7..c0f86887ba6 100644 --- a/shared/yeast-macros/src/parse.rs +++ b/shared/yeast-macros/src/parse.rs @@ -888,9 +888,15 @@ pub fn parse_rule_top(input: TokenStream) -> Result { Ok(quote! { { let __query = #query_code; - yeast::Rule::new(__query, Box::new(|__ast: &mut yeast::Ast, __captures: yeast::captures::Captures, __fresh: &yeast::tree_builder::FreshScope, __source_range: Option, __user_ctx: &mut _| { + yeast::Rule::new(__query, Box::new(|__ast: &mut yeast::Ast, mut __captures: yeast::captures::Captures, __fresh: &yeast::tree_builder::FreshScope, __source_range: Option, __user_ctx: &mut _, __translator: yeast::TranslatorHandle<'_, _>| { + // Auto-translation prefix: recursively translate every + // captured node before invoking the user's transform body. + // For OneShot rules this preserves the legacy behaviour + // (input-schema captures translated to output-schema + // nodes); for Repeating rules it is a no-op. + __translator.auto_translate_captures(&mut __captures, __ast, __user_ctx)?; #(#bindings)* - let mut #ctx_ident = yeast::build::BuildCtx::with_source_range(__ast, &__captures, __fresh, __source_range, __user_ctx); + let mut #ctx_ident = yeast::build::BuildCtx::with_translator(__ast, &__captures, __fresh, __source_range, __user_ctx, __translator); let __result: Vec = { #transform_body }; Ok(__result) })) diff --git a/shared/yeast/src/build.rs b/shared/yeast/src/build.rs index 6c8b392fb8a..9fec7861a55 100644 --- a/shared/yeast/src/build.rs +++ b/shared/yeast/src/build.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::captures::Captures; use crate::tree_builder::FreshScope; -use crate::{Ast, FieldId, Id, NodeContent}; +use crate::{Ast, FieldId, Id, NodeContent, TranslatorHandle}; /// Context for building new AST nodes during a transformation. /// @@ -24,6 +24,11 @@ use crate::{Ast, FieldId, Id, NodeContent}; /// /// The default `C = ()` means rules that don't need any user context don't /// pay any cost. +/// +/// When constructed by the framework (via the rule! macro), `BuildCtx` also +/// carries a [`TranslatorHandle`] that the [`translate`] method delegates +/// to. When constructed by hand (e.g. in tests), the translator is `None` +/// and [`translate`] returns an error. pub struct BuildCtx<'a, C: 'a = ()> { pub ast: &'a mut Ast, pub captures: &'a Captures, @@ -32,6 +37,9 @@ pub struct BuildCtx<'a, C: 'a = ()> { pub source_range: Option, /// User-supplied context, accessible directly via `ctx.field` (via Deref). pub user_ctx: &'a mut C, + /// Optional translator handle, populated when the context is built by + /// the framework's rule driver. None when the context is built by hand. + pub(crate) translator: Option>, } impl<'a, C> BuildCtx<'a, C> { @@ -47,6 +55,7 @@ impl<'a, C> BuildCtx<'a, C> { fresh, source_range: None, user_ctx, + translator: None, } } @@ -63,6 +72,27 @@ impl<'a, C> BuildCtx<'a, C> { fresh, source_range, user_ctx, + translator: None, + } + } + + /// Construct a `BuildCtx` carrying a translator handle. Used by the + /// `rule!` macro to enable [`translate`] inside rule transforms. + pub fn with_translator( + ast: &'a mut Ast, + captures: &'a Captures, + fresh: &'a FreshScope, + source_range: Option, + user_ctx: &'a mut C, + translator: TranslatorHandle<'a, C>, + ) -> Self { + Self { + ast, + captures, + fresh, + source_range, + user_ctx, + translator: Some(translator), } } @@ -139,6 +169,24 @@ impl<'a, C> BuildCtx<'a, C> { } } +impl BuildCtx<'_, C> { + /// Recursively translate a node via the framework's rule machinery. + /// In a OneShot phase, applies OneShot rules to the given node and + /// returns the resulting node ids. In a Repeating phase, errors + /// (translation is not meaningful when input and output share a + /// schema). + /// + /// Errors if this `BuildCtx` was constructed by hand (without a + /// translator handle) — for example, in unit tests that don't go + /// through the rule driver. + pub fn translate(&mut self, id: Id) -> Result, String> { + match &self.translator { + Some(t) => t.translate(self.ast, self.user_ctx, id), + None => Err("translate() called on a BuildCtx without a translator handle".into()), + } + } +} + impl std::ops::Deref for BuildCtx<'_, C> { type Target = C; fn deref(&self) -> &C { diff --git a/shared/yeast/src/lib.rs b/shared/yeast/src/lib.rs index 0b0c00ec910..ac93ae1ab8c 100644 --- a/shared/yeast/src/lib.rs +++ b/shared/yeast/src/lib.rs @@ -700,12 +700,107 @@ impl From for NodeContent { } } -/// The transform function for a rule: takes the AST, captured variables, a -/// fresh-name scope, the source range of the matched node, and a mutable -/// reference to the user context of type `C`. Returns the IDs of the -/// replacement nodes, or an error message if the transform could not be -/// completed (for example, a required capture was missing, or a recursive -/// translation invoked by the transform failed). +/// A handle that lets a rule transform recursively translate AST nodes via +/// the framework's rule machinery. Constructed by the driver and passed as +/// the last argument of every [`Transform`] invocation. +/// +/// The `rule!` macro uses [`TranslatorHandle::auto_translate_captures`] in +/// its generated prefix to translate captures before running the user's +/// transform body. Manually-written transforms (using [`Rule::new`] +/// directly) can call [`TranslatorHandle::translate`] selectively on +/// specific node ids to control when translation happens. +pub struct TranslatorHandle<'a, C> { + inner: TranslatorImpl<'a, C>, +} + +/// Internal phase-specific translation state. Kept private — callers +/// interact with [`TranslatorHandle`] only. +enum TranslatorImpl<'a, C> { + /// OneShot phase translator: recursively applies OneShot rules. + OneShot { + index: &'a RuleIndex<'a, C>, + fresh: &'a tree_builder::FreshScope, + rewrite_depth: usize, + /// The id of the node the current rule is matching. Used by + /// [`auto_translate_captures`] to avoid infinite recursion when a + /// rule captures its own match root (e.g. via `(_) @_`). + matched_root: Id, + }, + /// Repeating phase translator: translation is not meaningful here + /// (input and output schemas are the same). [`translate`] errors; + /// [`auto_translate_captures`] is a no-op so the macro's auto-prefix + /// works unchanged for Repeating rules. + Repeating, +} + +impl<'a, C: Clone> TranslatorHandle<'a, C> { + /// Recursively apply OneShot rules to `id` and return the resulting + /// node ids. Errors in a Repeating phase (where translation is not + /// meaningful). + pub fn translate( + &self, + ast: &mut Ast, + user_ctx: &mut C, + id: Id, + ) -> Result, String> { + match &self.inner { + TranslatorImpl::OneShot { + index, + fresh, + rewrite_depth, + .. + } => apply_one_shot_rules_inner(index, ast, user_ctx, id, fresh, rewrite_depth + 1), + TranslatorImpl::Repeating => { + Err("translate() is not available in a Repeating phase".into()) + } + } + } + + /// Translate every captured node in `captures` in place (OneShot phase + /// only). In a Repeating phase this is a no-op — Repeating rules + /// receive raw captures. + /// + /// Used by the `rule!` macro's generated prefix to preserve the + /// pre-existing "auto-translate captures before running the transform + /// body" behavior. Manually-written transforms typically translate + /// captures selectively via [`translate`] instead. + /// + /// To avoid infinite recursion, a capture whose id matches the rule's + /// matched root (e.g. from a `(_) @_` pattern) is left unchanged. + pub fn auto_translate_captures( + &self, + captures: &mut Captures, + ast: &mut Ast, + user_ctx: &mut C, + ) -> Result<(), String> { + match &self.inner { + TranslatorImpl::OneShot { matched_root, .. } => { + let root = *matched_root; + captures.try_map_all_captures(|cid| { + if cid == root { + Ok(vec![cid]) + } else { + self.translate(ast, user_ctx, cid) + } + }) + } + TranslatorImpl::Repeating => Ok(()), + } + } +} + +/// The transform function for a rule. +/// +/// Takes the AST, the (raw, untranslated) captured variables, a fresh-name +/// scope, the source range of the matched node, a mutable reference to the +/// user context of type `C`, and a [`TranslatorHandle`] for recursively +/// translating nodes. Returns the IDs of the replacement nodes, or an +/// error message if the transform could not be completed. +/// +/// Transforms produced by [`Rule::new`] receive **raw** captures and must +/// translate them themselves (via the handle). Transforms produced by the +/// `rule!` macro have an auto-translation prefix injected for backward +/// compatibility. pub type Transform = Box< dyn Fn( &mut Ast, @@ -713,6 +808,7 @@ pub type Transform = Box< &tree_builder::FreshScope, Option, &mut C, + TranslatorHandle<'_, C>, ) -> Result, String> + Send + Sync, @@ -752,9 +848,12 @@ impl Rule { node: Id, fresh: &tree_builder::FreshScope, user_ctx: &mut C, + translator: TranslatorHandle<'_, C>, ) -> Result>, String> { match self.try_match(ast, node)? { - Some(captures) => Ok(Some(self.run_transform(ast, captures, node, fresh, user_ctx)?)), + Some(captures) => Ok(Some(self.run_transform( + ast, captures, node, fresh, user_ctx, translator, + )?)), None => Ok(None), } } @@ -779,13 +878,14 @@ impl Rule { node: Id, fresh: &tree_builder::FreshScope, user_ctx: &mut C, + translator: TranslatorHandle<'_, C>, ) -> Result, String> { fresh.next_scope(); let source_range = ast.get_node(node).and_then(|n| match n.content { NodeContent::Range(r) => Some(r), _ => n.source_range, }); - (self.transform)(ast, captures, fresh, source_range, user_ctx) + (self.transform)(ast, captures, fresh, source_range, user_ctx, translator) } } @@ -858,7 +958,14 @@ fn apply_repeating_rules_inner( // mutations the rule makes are visible during recursive translation // of its result, but not leaked to the parent's siblings. let snapshot = user_ctx.clone(); - let try_result = rule.try_rule(ast, id, fresh, user_ctx)?; + // Repeating rules don't need a real translator: their captures + // aren't auto-translated (Repeating preserves the input schema), + // and `ctx.translate(id)` errors if invoked from a Repeating + // transform. + let translator = TranslatorHandle { + inner: TranslatorImpl::Repeating, + }; + let try_result = rule.try_rule(ast, id, fresh, user_ctx, translator)?; if let Some(result_node) = try_result { // For non-repeated rules, suppress further application of *this* // rule on the result root, so a rule whose output matches its own @@ -956,27 +1063,25 @@ fn apply_one_shot_rules_inner( let node_kind = ast.get_node(id).map(|n| n.kind()).unwrap_or(""); for rule in index.rules_for_kind(node_kind) { - if let Some(mut captures) = rule.try_match(ast, id)? { + if let Some(captures) = rule.try_match(ast, id)? { // Snapshot the user context before invoking the rule so that any // mutations the rule (or its transitively-translated captures) // make are visible during this rule's transform, but not leaked // to the parent's siblings. let snapshot = user_ctx.clone(); - // Recursively translate every captured node before invoking the - // transform. The transform's output uses output-schema kinds, so - // we must translate captured input-schema nodes to their - // output-schema equivalents first. - captures.try_map_all_captures(|captured_id| { - // Avoid infinite recursion when a capture refers to the root - // node of the matched tree (e.g. an `@_` capture on the - // pattern root): re-analyzing it would match the same rule - // again indefinitely. - if captured_id == id { - return Ok(vec![captured_id]); - } - apply_one_shot_rules_inner(index, ast, user_ctx, captured_id, fresh, rewrite_depth + 1) - })?; - let result = rule.run_transform(ast, captures, id, fresh, user_ctx)?; + // Build the translator handle the transform will use to + // recursively translate captures (or, for macro-generated + // rules, the auto-translate prefix uses it to translate every + // capture up front, preserving the legacy behavior). + let translator = TranslatorHandle { + inner: TranslatorImpl::OneShot { + index, + fresh, + rewrite_depth, + matched_root: id, + }, + }; + let result = rule.run_transform(ast, captures, id, fresh, user_ctx, translator)?; *user_ctx = snapshot; return Ok(result); }