diff --git a/shared/yeast-macros/src/parse.rs b/shared/yeast-macros/src/parse.rs index 4b27b980439..594a59e1b5d 100644 --- a/shared/yeast-macros/src/parse.rs +++ b/shared/yeast-macros/src/parse.rs @@ -296,10 +296,10 @@ fn parse_query_list(tokens: &mut Tokens) -> Result> { // tree! / trees! parsing — direct code generation against BuildCtx // --------------------------------------------------------------------------- -const IMPLICIT_CTX: &str = "__yeast_ctx"; +const IMPLICIT_CTX: &str = "ctx"; /// Determine the context identifier: either explicit `ctx,` or the implicit -/// `__yeast_ctx` from an enclosing `rule!`. +/// `ctx` from an enclosing `rule!`. fn parse_ctx_or_implicit(tokens: &mut Tokens) -> Ident { // Check if first token is an ident followed by a comma let mut lookahead = tokens.clone(); @@ -888,9 +888,9 @@ pub fn parse_rule_top(input: TokenStream) -> Result { Ok(quote! { { let __query = #query_code; - yeast::Rule::new(__query, Box::new(|__ast: &mut yeast::Ast, __captures: yeast::captures::Captures, __fresh: &yeast::tree_builder::FreshScope, __source_range: Option| { + yeast::Rule::new(__query, Box::new(|__ast: &mut yeast::Ast, __captures: yeast::captures::Captures, __fresh: &yeast::tree_builder::FreshScope, __source_range: Option, __user_ctx: &mut _| { #(#bindings)* - let mut #ctx_ident = yeast::build::BuildCtx::with_source_range(__ast, &__captures, __fresh, __source_range); + let mut #ctx_ident = yeast::build::BuildCtx::with_source_range(__ast, &__captures, __fresh, __source_range, __user_ctx); #transform_body })) } diff --git a/shared/yeast/src/bin/main.rs b/shared/yeast/src/bin/main.rs index 975c8e8b25f..978be21cc00 100644 --- a/shared/yeast/src/bin/main.rs +++ b/shared/yeast/src/bin/main.rs @@ -20,7 +20,7 @@ fn main() { let args = Cli::parse(); let language = get_language(&args.language); let source = std::fs::read_to_string(&args.file).unwrap(); - let runner = yeast::Runner::new(language, &[]); + let runner: yeast::Runner = yeast::Runner::new(language, &[]); let ast = runner.run(&source).unwrap(); println!("{}", ast.print(&source, ast.get_root())); } diff --git a/shared/yeast/src/build.rs b/shared/yeast/src/build.rs index d0f1394ca6d..6c8b392fb8a 100644 --- a/shared/yeast/src/build.rs +++ b/shared/yeast/src/build.rs @@ -7,23 +7,46 @@ use crate::{Ast, FieldId, Id, NodeContent}; /// Context for building new AST nodes during a transformation. /// /// Used by the `tree!` and `trees!` macros. Holds a mutable reference to the -/// AST, a reference to the captures from a query match, and a `FreshScope` for -/// generating unique identifiers. -pub struct BuildCtx<'a> { +/// AST, a reference to the captures from a query match, a `FreshScope` for +/// generating unique identifiers, and a mutable reference to a user-defined +/// context of type `C`. +/// +/// The user context `C` is shared across rules via the framework's driver: +/// outer rules can write to it before recursive translation, and inner rules +/// can read (or further mutate) it during their transforms. The framework +/// snapshots and restores the user context around each rule application, so +/// mutations made by a rule are visible to its descendants (via recursive +/// translation) but not to its parent's siblings. +/// +/// `BuildCtx` implements [`Deref`] and [`DerefMut`] targeting `C`, so user +/// context fields are accessible as `ctx.my_field` directly (provided they +/// don't collide with `BuildCtx`'s own fields like `ast`, `captures`, etc.). +/// +/// The default `C = ()` means rules that don't need any user context don't +/// pay any cost. +pub struct BuildCtx<'a, C: 'a = ()> { pub ast: &'a mut Ast, pub captures: &'a Captures, pub fresh: &'a FreshScope, /// Source range of the matched node, inherited by synthetic nodes. pub source_range: Option, + /// User-supplied context, accessible directly via `ctx.field` (via Deref). + pub user_ctx: &'a mut C, } -impl<'a> BuildCtx<'a> { - pub fn new(ast: &'a mut Ast, captures: &'a Captures, fresh: &'a FreshScope) -> Self { +impl<'a, C> BuildCtx<'a, C> { + pub fn new( + ast: &'a mut Ast, + captures: &'a Captures, + fresh: &'a FreshScope, + user_ctx: &'a mut C, + ) -> Self { Self { ast, captures, fresh, source_range: None, + user_ctx, } } @@ -32,12 +55,14 @@ impl<'a> BuildCtx<'a> { captures: &'a Captures, fresh: &'a FreshScope, source_range: Option, + user_ctx: &'a mut C, ) -> Self { Self { ast, captures, fresh, source_range, + user_ctx, } } @@ -113,3 +138,16 @@ impl<'a> BuildCtx<'a> { self.ast.prepend_field_child(node_id, field_id, value_id); } } + +impl std::ops::Deref for BuildCtx<'_, C> { + type Target = C; + fn deref(&self) -> &C { + &*self.user_ctx + } +} + +impl std::ops::DerefMut for BuildCtx<'_, C> { + fn deref_mut(&mut self) -> &mut C { + &mut *self.user_ctx + } +} diff --git a/shared/yeast/src/lib.rs b/shared/yeast/src/lib.rs index 9c3a4ad4114..d93a72221a9 100644 --- a/shared/yeast/src/lib.rs +++ b/shared/yeast/src/lib.rs @@ -701,17 +701,24 @@ impl From for NodeContent { } /// The transform function for a rule: takes the AST, captured variables, a -/// fresh-name scope, and the source range of the matched node, and returns -/// the IDs of the replacement nodes. -pub type Transform = Box< - dyn Fn(&mut Ast, Captures, &tree_builder::FreshScope, Option) -> Vec +/// fresh-name scope, the source range of the matched node, and a mutable +/// reference to the user context of type `C`. Returns the IDs of the +/// replacement nodes. +pub type Transform = Box< + dyn Fn( + &mut Ast, + Captures, + &tree_builder::FreshScope, + Option, + &mut C, + ) -> Vec + Send + Sync, >; -pub struct Rule { +pub struct Rule { query: QueryNode, - transform: Transform, + transform: Transform, /// If true, after this rule fires on a node the engine will try to /// re-apply this same rule on the result root. Defaults to false: /// each rule fires at most once on a given node, which prevents @@ -719,8 +726,8 @@ pub struct Rule { repeated: bool, } -impl Rule { - pub fn new(query: QueryNode, transform: Transform) -> Self { +impl Rule { + pub fn new(query: QueryNode, transform: Transform) -> Self { Self { query, transform, @@ -742,9 +749,10 @@ impl Rule { ast: &mut Ast, node: Id, fresh: &tree_builder::FreshScope, + user_ctx: &mut C, ) -> Result>, String> { match self.try_match(ast, node)? { - Some(captures) => Ok(Some(self.run_transform(ast, captures, node, fresh))), + Some(captures) => Ok(Some(self.run_transform(ast, captures, node, fresh, user_ctx))), None => Ok(None), } } @@ -768,29 +776,30 @@ impl Rule { captures: Captures, node: Id, fresh: &tree_builder::FreshScope, + user_ctx: &mut C, ) -> Vec { fresh.next_scope(); let source_range = ast.get_node(node).and_then(|n| match n.content { NodeContent::Range(r) => Some(r), _ => n.source_range, }); - (self.transform)(ast, captures, fresh, source_range) + (self.transform)(ast, captures, fresh, source_range, user_ctx) } } const MAX_REWRITE_DEPTH: usize = 100; /// Index of rules by their root query kind for fast lookup. -struct RuleIndex<'a> { +struct RuleIndex<'a, C> { /// Rules indexed by root node kind name. - by_kind: BTreeMap<&'static str, Vec<&'a Rule>>, + by_kind: BTreeMap<&'static str, Vec<&'a Rule>>, /// Rules with wildcard queries (Any) that apply to all nodes. - wildcard: Vec<&'a Rule>, + wildcard: Vec<&'a Rule>, } -impl<'a> RuleIndex<'a> { - fn new(rules: &'a [Rule]) -> Self { - let mut by_kind: BTreeMap<&'static str, Vec<&'a Rule>> = BTreeMap::new(); +impl<'a, C> RuleIndex<'a, C> { + fn new(rules: &'a [Rule]) -> Self { + let mut by_kind: BTreeMap<&'static str, Vec<&'a Rule>> = BTreeMap::new(); let mut wildcard = Vec::new(); for rule in rules { match rule.query.root_kind() { @@ -801,7 +810,7 @@ impl<'a> RuleIndex<'a> { Self { by_kind, wildcard } } - fn rules_for_kind(&self, kind: &str) -> impl Iterator { + fn rules_for_kind(&self, kind: &str) -> impl Iterator> { self.by_kind .get(kind) .into_iter() @@ -810,23 +819,25 @@ impl<'a> RuleIndex<'a> { } } -fn apply_repeating_rules( - rules: &[Rule], +fn apply_repeating_rules( + rules: &[Rule], ast: &mut Ast, + user_ctx: &mut C, id: Id, fresh: &tree_builder::FreshScope, ) -> Result, String> { let index = RuleIndex::new(rules); - apply_repeating_rules_inner(&index, ast, id, fresh, 0, None) + apply_repeating_rules_inner(&index, ast, user_ctx, id, fresh, 0, None) } -fn apply_repeating_rules_inner( - index: &RuleIndex, +fn apply_repeating_rules_inner( + index: &RuleIndex, ast: &mut Ast, + user_ctx: &mut C, id: Id, fresh: &tree_builder::FreshScope, rewrite_depth: usize, - skip_rule: Option<*const Rule>, + skip_rule: Option<*const Rule>, ) -> Result, String> { if rewrite_depth > MAX_REWRITE_DEPTH { return Err(format!( @@ -837,11 +848,16 @@ fn apply_repeating_rules_inner( let node_kind = ast.get_node(id).map(|n| n.kind()).unwrap_or(""); for rule in index.rules_for_kind(node_kind) { - let rule_ptr = *rule as *const Rule; + let rule_ptr = *rule as *const Rule; if Some(rule_ptr) == skip_rule { continue; } - if let Some(result_node) = rule.try_rule(ast, id, fresh)? { + // Snapshot the user context before invoking the rule so that any + // mutations the rule makes are visible during recursive translation + // of its result, but not leaked to the parent's siblings. + let snapshot = user_ctx.clone(); + let try_result = rule.try_rule(ast, id, fresh, user_ctx)?; + if let Some(result_node) = try_result { // For non-repeated rules, suppress further application of *this* // rule on the result root, so a rule whose output matches its own // query doesn't loop. Other rules and child traversal are @@ -852,14 +868,19 @@ fn apply_repeating_rules_inner( results.extend(apply_repeating_rules_inner( index, ast, + user_ctx, node, fresh, rewrite_depth + 1, next_skip, )?); } + *user_ctx = snapshot; return Ok(results); } + // Rule didn't match; restore any speculative changes (none expected + // since try_rule only mutates on match, but be defensive). + *user_ctx = snapshot; } // Take the parent's fields by ownership: the recursion will rewrite @@ -874,7 +895,7 @@ fn apply_repeating_rules_inner( for children in fields.values_mut() { let mut new_children: Option> = None; for (i, &child_id) in children.iter().enumerate() { - let result = apply_repeating_rules_inner(index, ast, child_id, fresh, rewrite_depth, None)?; + let result = apply_repeating_rules_inner(index, ast, user_ctx, child_id, fresh, rewrite_depth, None)?; let unchanged = result.len() == 1 && result[0] == child_id; match (&mut new_children, unchanged) { (None, true) => {} // unchanged so far, no allocation needed @@ -903,19 +924,21 @@ fn apply_repeating_rules_inner( /// each visited node, recursion proceeds only through captured nodes (not /// through the input node's children directly), and an error is returned if /// no rule matches a visited node. -fn apply_one_shot_rules( - rules: &[Rule], +fn apply_one_shot_rules( + rules: &[Rule], ast: &mut Ast, + user_ctx: &mut C, id: Id, fresh: &tree_builder::FreshScope, ) -> Result, String> { let index = RuleIndex::new(rules); - apply_one_shot_rules_inner(&index, ast, id, fresh, 0) + apply_one_shot_rules_inner(&index, ast, user_ctx, id, fresh, 0) } -fn apply_one_shot_rules_inner( - index: &RuleIndex, +fn apply_one_shot_rules_inner( + index: &RuleIndex, ast: &mut Ast, + user_ctx: &mut C, id: Id, fresh: &tree_builder::FreshScope, rewrite_depth: usize, @@ -932,6 +955,11 @@ fn apply_one_shot_rules_inner( for rule in index.rules_for_kind(node_kind) { if let Some(mut captures) = rule.try_match(ast, id)? { + // Snapshot the user context before invoking the rule so that any + // mutations the rule (or its transitively-translated captures) + // make are visible during this rule's transform, but not leaked + // to the parent's siblings. + let snapshot = user_ctx.clone(); // Recursively translate every captured node before invoking the // transform. The transform's output uses output-schema kinds, so // we must translate captured input-schema nodes to their @@ -944,9 +972,11 @@ fn apply_one_shot_rules_inner( if captured_id == id { return Ok(vec![captured_id]); } - apply_one_shot_rules_inner(index, ast, captured_id, fresh, rewrite_depth + 1) + apply_one_shot_rules_inner(index, ast, user_ctx, captured_id, fresh, rewrite_depth + 1) })?; - return Ok(rule.run_transform(ast, captures, id, fresh)); + let result = rule.run_transform(ast, captures, id, fresh, user_ctx); + *user_ctx = snapshot; + return Ok(result); } } @@ -974,15 +1004,15 @@ pub enum PhaseKind { /// starts. Rules within a phase compete for matches as usual; rules in /// different phases never compete because each traversal only considers the /// current phase's rules. -pub struct Phase { +pub struct Phase { /// Name used in error messages. pub name: String, - pub rules: Vec, + pub rules: Vec>, pub kind: PhaseKind, } -impl Phase { - pub fn new(name: impl Into, kind: PhaseKind, rules: Vec) -> Self { +impl Phase { + pub fn new(name: impl Into, kind: PhaseKind, rules: Vec>) -> Self { Self { name: name.into(), rules, @@ -1008,17 +1038,30 @@ impl Phase { /// .add_phase("desugar", PhaseKind::Repeating, desugar_rules) /// .with_output_node_types_yaml(yaml); /// ``` -#[derive(Default)] -pub struct DesugaringConfig { +/// +/// The optional type parameter `C` is the user context type threaded through +/// rule transforms. Defaults to `()` (no user context). +pub struct DesugaringConfig { /// Phases of rule application, applied in order. - pub phases: Vec, + pub phases: Vec>, /// Output node-types in YAML format. If `None`, the input grammar's /// node types are used (i.e. the desugared AST has the same node types /// as the tree-sitter grammar). pub output_node_types_yaml: Option<&'static str>, } -impl DesugaringConfig { +// Manual `Default` impl so users with a custom `C` that doesn't implement +// `Default` can still construct an empty config. +impl Default for DesugaringConfig { + fn default() -> Self { + Self { + phases: Vec::new(), + output_node_types_yaml: None, + } + } +} + +impl DesugaringConfig { /// Create an empty configuration. Add phases via [`add_phase`] and an /// optional output schema via [`with_output_node_types_yaml`]. pub fn new() -> Self { @@ -1030,7 +1073,7 @@ impl DesugaringConfig { mut self, name: impl Into, kind: PhaseKind, - rules: Vec, + rules: Vec>, ) -> Self { self.phases.push(Phase::new(name, kind, rules)); self @@ -1052,15 +1095,15 @@ impl DesugaringConfig { } } -pub struct Runner<'a> { +pub struct Runner<'a, C = ()> { language: tree_sitter::Language, schema: schema::Schema, - phases: &'a [Phase], + phases: &'a [Phase], } -impl<'a> Runner<'a> { +impl<'a, C> Runner<'a, C> { /// Create a runner using the input grammar's schema for output. - pub fn new(language: tree_sitter::Language, phases: &'a [Phase]) -> Self { + pub fn new(language: tree_sitter::Language, phases: &'a [Phase]) -> Self { let schema = schema::Schema::from_language(&language); Self { language, @@ -1073,7 +1116,7 @@ impl<'a> Runner<'a> { pub fn with_schema( language: tree_sitter::Language, schema: &schema::Schema, - phases: &'a [Phase], + phases: &'a [Phase], ) -> Self { Self { language, @@ -1085,7 +1128,7 @@ impl<'a> Runner<'a> { /// Create a runner from a [`DesugaringConfig`]. pub fn from_config( language: tree_sitter::Language, - config: &'a DesugaringConfig, + config: &'a DesugaringConfig, ) -> Result { let schema = config.build_schema(&language)?; Ok(Self { @@ -1094,11 +1137,17 @@ impl<'a> Runner<'a> { phases: &config.phases, }) } +} - pub fn run_from_tree( +impl<'a, C: Clone> Runner<'a, C> { + /// Parse `tree` against `source` and run all phases, threading + /// `user_ctx` through every rule transform. The caller owns the + /// initial context state. + pub fn run_from_tree_with_ctx( &self, tree: &tree_sitter::Tree, source: &[u8], + user_ctx: &mut C, ) -> Result { let mut ast = Ast::from_tree_with_schema_and_source( self.schema.clone(), @@ -1106,11 +1155,13 @@ impl<'a> Runner<'a> { &self.language, source.to_vec(), ); - self.run_phases(&mut ast)?; + self.run_phases(&mut ast, user_ctx)?; Ok(ast) } - pub fn run(&self, input: &str) -> Result { + /// Parse `input` and run all phases, threading `user_ctx` through + /// every rule transform. The caller owns the initial context state. + pub fn run_with_ctx(&self, input: &str, user_ctx: &mut C) -> Result { let mut parser = tree_sitter::Parser::new(); parser .set_language(&self.language) @@ -1124,20 +1175,20 @@ impl<'a> Runner<'a> { &self.language, input.as_bytes().to_vec(), ); - self.run_phases(&mut ast)?; + self.run_phases(&mut ast, user_ctx)?; Ok(ast) } /// Apply each phase in turn to the AST, threading the root through. /// A single `FreshScope` is shared across phases so that fresh /// identifiers generated in different phases don't collide. - fn run_phases(&self, ast: &mut Ast) -> Result<(), String> { + fn run_phases(&self, ast: &mut Ast, user_ctx: &mut C) -> Result<(), String> { let fresh = tree_builder::FreshScope::new(); let mut root = ast.get_root(); for phase in self.phases { let res = match phase.kind { - PhaseKind::Repeating => apply_repeating_rules(&phase.rules, ast, root, &fresh), - PhaseKind::OneShot => apply_one_shot_rules(&phase.rules, ast, root, &fresh), + PhaseKind::Repeating => apply_repeating_rules(&phase.rules, ast, user_ctx, root, &fresh), + PhaseKind::OneShot => apply_one_shot_rules(&phase.rules, ast, user_ctx, root, &fresh), } .map_err(|e| format!("Phase `{}`: {e}", phase.name))?; if res.len() != 1 { @@ -1153,3 +1204,23 @@ impl<'a> Runner<'a> { Ok(()) } } + +impl<'a, C: Clone + Default> Runner<'a, C> { + /// Parse `tree` against `source` and run all phases, using the + /// default context (`C::default()`) as the initial context state. + pub fn run_from_tree( + &self, + tree: &tree_sitter::Tree, + source: &[u8], + ) -> Result { + let mut user_ctx = C::default(); + self.run_from_tree_with_ctx(tree, source, &mut user_ctx) + } + + /// Parse `input` and run all phases, using the default context + /// (`C::default()`) as the initial context state. + pub fn run(&self, input: &str) -> Result { + let mut user_ctx = C::default(); + self.run_with_ctx(input, &mut user_ctx) + } +} diff --git a/shared/yeast/tests/test.rs b/shared/yeast/tests/test.rs index 069132d0923..308c72b725f 100644 --- a/shared/yeast/tests/test.rs +++ b/shared/yeast/tests/test.rs @@ -7,7 +7,7 @@ const OUTPUT_SCHEMA_YAML: &str = include_str!("node-types.yml"); /// Helper: parse Ruby source with no rules, return dump. fn parse_and_dump(input: &str) -> String { - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let ast = runner.run(input).unwrap(); dump_ast(&ast, ast.get_root(), input) } @@ -24,7 +24,7 @@ fn run_and_ast(input: &str, rules: Vec) -> Ast { let schema = yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang).unwrap(); let phases = vec![Phase::new("test", PhaseKind::Repeating, rules)]; - let runner = Runner::with_schema(lang, &schema, &phases); + let runner: Runner = Runner::with_schema(lang, &schema, &phases); runner.run(input).unwrap() } @@ -34,7 +34,7 @@ fn run_phased_and_dump(input: &str, phases: Vec) -> String { let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into(); let schema = yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang).unwrap(); - let runner = Runner::with_schema(lang, &schema, &phases); + let runner: Runner = Runner::with_schema(lang, &schema, &phases); let ast = runner.run(input).unwrap(); dump_ast(&ast, ast.get_root(), input) } @@ -46,7 +46,7 @@ fn run_and_get_error(input: &str, rules: Vec) -> String { let schema = yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang).unwrap(); let phases = vec![Phase::new("test", PhaseKind::Repeating, rules)]; - let runner = Runner::with_schema(lang, &schema, &phases); + let runner: Runner = Runner::with_schema(lang, &schema, &phases); runner .run(input) .expect_err("expected runner to return an error") @@ -54,7 +54,7 @@ fn run_and_get_error(input: &str, rules: Vec) -> String { /// Helper: parse Ruby source with no rules and dump with schema type errors. fn parse_and_dump_typed(input: &str, schema_yaml: &str) -> String { - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let ast = runner.run(input).unwrap(); let schema = yeast::node_types_yaml::schema_from_yaml(schema_yaml).unwrap(); dump_ast_with_type_errors(&ast, ast.get_root(), input, &schema) @@ -64,7 +64,7 @@ fn parse_and_dump_typed(input: &str, schema_yaml: &str) -> String { /// building schema with language IDs so field checks align with parser fields. fn parse_and_dump_typed_with_language(input: &str, schema_yaml: &str) -> String { let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into(); - let runner = Runner::new(lang.clone(), &[]); + let runner: Runner = Runner::new(lang.clone(), &[]); let ast = runner.run(input).unwrap(); let schema = yeast::node_types_yaml::schema_from_yaml_with_language(schema_yaml, &lang) .unwrap(); @@ -76,7 +76,7 @@ fn run_and_dump_typed(input: &str, rules: Vec, schema_yaml: &str) -> Strin let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into(); let schema = yeast::node_types_yaml::schema_from_yaml(schema_yaml).unwrap(); let phases = vec![Phase::new("test", PhaseKind::Repeating, rules)]; - let runner = Runner::with_schema(lang, &schema, &phases); + let runner: Runner = Runner::with_schema(lang, &schema, &phases); let ast = runner.run(input).unwrap(); dump_ast_with_type_errors(&ast, ast.get_root(), input, &schema) } @@ -194,7 +194,7 @@ named: // This rewrite runs and preserves the RHS node kind via capture. // With schema above, preserving `integer` should be reported inline. - let rules = vec![yeast::rule!( + let rules: Vec = vec![yeast::rule!( (assignment left: (_) @left right: (_) @right) => (assignment @@ -247,7 +247,7 @@ named: #[test] fn test_query_match() { - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let ast = runner.run("x = 1").unwrap(); let query = yeast::query!( @@ -268,7 +268,7 @@ fn test_query_match() { #[test] fn test_query_no_match() { - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let ast = runner.run("x = 1").unwrap(); let query = yeast::query!( @@ -293,7 +293,7 @@ fn test_query_skips_extras_in_positional_match() { // captured comment to nothing (a common idiom, e.g. // `(comment) => ()` in Swift) leaves the capture's match-list empty // and causes the transform to fail with "Variable X has 0 matches". - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let ast = runner.run("[1, # comment\n2]").unwrap(); // Navigate to the `array` node: program -> array. @@ -327,12 +327,12 @@ fn test_reachable_nodes_excludes_orphaned_rewrite_nodes() { let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into(); let schema = yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang) .unwrap(); - let phases = vec![Phase::new( + let phases: Vec = vec![Phase::new( "test", PhaseKind::Repeating, vec![yeast::rule!((integer) => (identifier "replaced"))], )]; - let runner = Runner::with_schema(lang, &schema, &phases); + let runner: Runner = Runner::with_schema(lang, &schema, &phases); let input = "x = 1"; let ast = runner.run(input).unwrap(); @@ -350,7 +350,7 @@ fn test_reachable_nodes_excludes_orphaned_rewrite_nodes() { #[test] fn test_query_repeated_capture() { - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let ast = runner.run("x, y, z = 1").unwrap(); let query = yeast::query!( @@ -375,7 +375,7 @@ fn test_query_repeated_capture() { #[test] fn test_capture_unnamed_node_parenthesized() { // `("=") @op` captures the unnamed `=` token between left and right. - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let ast = runner.run("x = 1").unwrap(); let query = yeast::query!( @@ -403,7 +403,7 @@ fn test_capture_unnamed_node_parenthesized() { fn test_capture_bare_underscore_repeated() { // `_` matches named and unnamed nodes in bare-child position. On this // assignment shape, bare children correspond to unnamed tokens (the `=`). - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let ast = runner.run("x = 1").unwrap(); let query = yeast::query!((assignment _* @all)); @@ -425,7 +425,7 @@ fn test_capture_bare_underscore_repeated() { #[test] fn test_capture_unnamed_node_bare_literal() { // `"=" @op` (without surrounding parens) is the same as `("=") @op`. - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let ast = runner.run("x = 1").unwrap(); let query = yeast::query!( @@ -454,7 +454,7 @@ fn test_bare_underscore_matches_unnamed() { // Bare `_` matches any node, including unnamed tokens, while `(_)` // matches only named nodes. Demonstrate by matching the unnamed `=` // token in the implicit `child` field of an `assignment`. - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let ast = runner.run("x = 1").unwrap(); let mut cursor = AstCursor::new(&ast); @@ -493,7 +493,7 @@ fn test_bare_forms_in_field_position() { // field's value, not just in the bare-children position. This is // syntactic sugar for `(_)` / `("…")` and goes through the same // code paths. - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let ast = runner.run("x = 1").unwrap(); let mut cursor = AstCursor::new(&ast); @@ -532,7 +532,7 @@ fn test_forward_scan_finds_unnamed_token_late() { // query for `("end")` skip past the first two and match the third. // Without forward-scan, the matcher took the first child unconditionally // and failed. - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let ast = runner.run("for x in list do\n y\nend").unwrap(); // Navigate: program > for > do (the body wrapper). @@ -559,7 +559,7 @@ fn test_forward_scan_preserves_order() { // order. A query for ("end") then ("do") should fail because `do` // appears before `end` in the source order; once forward-scan has // consumed `end`, the iterator is exhausted. - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let ast = runner.run("for x in list do\n y\nend").unwrap(); let mut cursor = AstCursor::new(&ast); @@ -580,7 +580,7 @@ fn test_forward_scan_preserves_order() { #[test] fn test_tree_builder() { - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let mut ast = runner.run("x = 1").unwrap(); let input = "x = 1"; @@ -598,7 +598,8 @@ fn test_tree_builder() { // Swap left and right let fresh = yeast::tree_builder::FreshScope::new(); - let mut ctx = yeast::build::BuildCtx::new(&mut ast, &captures, &fresh); + let mut user_ctx = (); + let mut ctx = yeast::build::BuildCtx::new(&mut ast, &captures, &fresh, &mut user_ctx); let new_id = yeast::tree!(ctx, (program child: (assignment @@ -626,7 +627,7 @@ fn test_tree_builder() { // tree-sitter-ruby grammar with named fields for nodes that only have // unnamed children in tree-sitter (e.g. block_body.stmt, block_parameters.parameter). fn ruby_rules() -> Vec { - let assign_rule = yeast::rule!( + let assign_rule: Rule = yeast::rule!( (assignment left: (left_assignment_list (identifier)* @left @@ -651,7 +652,7 @@ fn ruby_rules() -> Vec { )} ); - let for_rule = yeast::rule!( + let for_rule: Rule = yeast::rule!( (for pattern: (_) @pat value: (in (_) @val) @@ -733,7 +734,7 @@ fn test_desugar_for_loop() { #[test] fn test_shorthand_rule() { - let rule = yeast::rule!( + let rule: Rule = yeast::rule!( (assignment left: (_) @method right: (_) @receiver @@ -885,7 +886,7 @@ fn test_phase_error_includes_phase_name() { PhaseKind::Repeating, vec![swap_assignment_rule().repeated()], )]; - let runner = Runner::with_schema(lang, &schema, &phases); + let runner: Runner = Runner::with_schema(lang, &schema, &phases); let err = runner .run("x = 1") .expect_err("expected runner to return an error"); @@ -928,7 +929,7 @@ fn test_one_shot_phase() { PhaseKind::OneShot, one_shot_xeq1_rules(), )]; - let runner = Runner::with_schema(lang, &schema, &phases); + let runner: Runner = Runner::with_schema(lang, &schema, &phases); let input = "x = 1"; let ast = runner.run(input).unwrap(); @@ -954,7 +955,7 @@ fn test_one_shot_phase_errors_when_no_rule_matches() { let mut rules = one_shot_xeq1_rules(); rules.pop(); let phases = vec![Phase::new("translate", PhaseKind::OneShot, rules)]; - let runner = Runner::with_schema(lang, &schema, &phases); + let runner: Runner = Runner::with_schema(lang, &schema, &phases); let err = runner .run("x = 1") @@ -978,7 +979,7 @@ fn test_one_shot_recurses_into_returned_capture() { let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into(); let schema = yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang).unwrap(); - let rules = vec![ + let rules: Vec = vec![ yeast::rule!( (program (_)* @stmts) => @@ -994,7 +995,7 @@ fn test_one_shot_recurses_into_returned_capture() { yeast::rule!((integer) => (integer "INT")), ]; let phases = vec![Phase::new("translate", PhaseKind::OneShot, rules)]; - let runner = Runner::with_schema(lang, &schema, &phases); + let runner: Runner = Runner::with_schema(lang, &schema, &phases); let input = "x = 1"; let ast = runner.run(input).unwrap(); @@ -1020,7 +1021,7 @@ fn test_one_shot_does_not_recurse_into_wrapper_output() { let lang: tree_sitter::Language = tree_sitter_ruby::LANGUAGE.into(); let schema = yeast::node_types_yaml::schema_from_yaml_with_language(OUTPUT_SCHEMA_YAML, &lang).unwrap(); - let rules = vec![ + let rules: Vec = vec![ yeast::rule!( (program (_)* @stmts) => @@ -1041,7 +1042,7 @@ fn test_one_shot_does_not_recurse_into_wrapper_output() { yeast::rule!((integer) => (integer "INT")), ]; let phases = vec![Phase::new("translate", PhaseKind::OneShot, rules)]; - let runner = Runner::with_schema(lang, &schema, &phases); + let runner: Runner = Runner::with_schema(lang, &schema, &phases); let input = "x = 1"; let ast = runner.run(input).unwrap(); @@ -1065,7 +1066,7 @@ fn test_one_shot_does_not_recurse_into_wrapper_output() { #[test] fn test_cursor_navigation() { - let runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); + let runner: Runner = Runner::new(tree_sitter_ruby::LANGUAGE.into(), &[]); let ast = runner.run("x = 1").unwrap(); let mut cursor = AstCursor::new(&ast); @@ -1139,7 +1140,7 @@ fn test_desugar_for_with_multiple_assignment() { /// resolves to the captured node's source text via `YeastDisplay`. #[test] fn test_hash_brace_renders_capture_source_text() { - let rule = rule!( + let rule: Rule = rule!( (call method: (identifier) @name receiver: (identifier) @recv @@ -1168,7 +1169,7 @@ fn test_hash_brace_renders_capture_source_text() { /// `Display` impl (covered by `YeastDisplay`'s blanket impls for primitives). #[test] fn test_hash_brace_renders_integer_expression() { - let rule = rule!( + let rule: Rule = rule!( (identifier) @_ => (identifier #{1 + 2}) @@ -1187,7 +1188,7 @@ fn test_hash_brace_renders_integer_expression() { /// source location, not the full source range of the matched rule root. #[test] fn test_hash_brace_uses_capture_location_for_leaf() { - let rule = rule!( + let rule: Rule = rule!( (call method: (identifier) @name receiver: (identifier) @recv diff --git a/unified/extractor/src/languages/swift/swift.rs b/unified/extractor/src/languages/swift/swift.rs index 79f0e65b02f..2c786810e49 100644 --- a/unified/extractor/src/languages/swift/swift.rs +++ b/unified/extractor/src/languages/swift/swift.rs @@ -1,5 +1,5 @@ use codeql_extractor::extractor::simple; -use yeast::{rule, DesugaringConfig, PhaseKind}; +use yeast::{rule, tree, DesugaringConfig, PhaseKind}; fn translation_rules() -> Vec { vec![ @@ -99,17 +99,15 @@ fn translation_rules() -> Vec { computed_value: (computed_property accessor: _+ @accessors)) => {..{ - let name_text = __yeast_ctx.ast.source_text(pattern.into()); - let ty_ids: Vec = ty.iter().map(|&t| t.into()).collect(); - let acc_ids: Vec = accessors.iter().map(|&a| a.into()).collect(); - for &acc_id in &acc_ids { - let ident = __yeast_ctx.literal("identifier", &name_text); - __yeast_ctx.prepend_field(acc_id, "name", ident); - for &ty_id in ty_ids.iter().rev() { - __yeast_ctx.prepend_field(acc_id, "type", ty_id); + for &acc in &accessors { + let acc_id: usize = acc.into(); + for &t in ty.iter().rev() { + ctx.prepend_field(acc_id, "type", t.into()); } + let name_id = tree!((identifier #{pattern})); + ctx.prepend_field(acc_id, "name", name_id); } - acc_ids + accessors }} ), // Computed property: shorthand getter (no explicit get/set, just statements) → @@ -137,30 +135,19 @@ fn translation_rules() -> Vec { value: _? @val observers: (willset_didset_block willset: _? @ws didset: _? @ds)) => + (variable_declaration + pattern: (name_pattern identifier: (identifier #{name})) + type: {..ty} + value: {..val}) {..{ - let name_text = __yeast_ctx.ast.source_text(name.into()); - let val_ids: Vec = val.iter().map(|&v| v.into()).collect(); - let ty_ids: Vec = ty.iter().map(|&t| t.into()).collect(); - let mut obs_ids: Vec = Vec::new(); - obs_ids.extend(ws.iter().map(|&o| { let id: usize = o.into(); id })); - obs_ids.extend(ds.iter().map(|&o| { let id: usize = o.into(); id })); - let ident_for_var = __yeast_ctx.literal("identifier", &name_text); - let pat = __yeast_ctx.node("name_pattern", vec![("identifier", vec![ident_for_var])]); - let mut var_fields: Vec<(&str, Vec)> = vec![("pattern", vec![pat])]; - if !ty_ids.is_empty() { - var_fields.push(("type", ty_ids)); + let mut obs_ids = Vec::new(); + for &obs in ws.iter().chain(ds.iter()) { + let obs_id: usize = obs.into(); + let ident = tree!((identifier #{name})); + ctx.prepend_field(obs_id, "name", ident); + obs_ids.push(obs_id); } - if !val_ids.is_empty() { - var_fields.push(("value", val_ids)); - } - let var_id = __yeast_ctx.node("variable_declaration", var_fields); - let mut result = vec![var_id]; - for obs_id in obs_ids { - let ident = __yeast_ctx.literal("identifier", &name_text); - __yeast_ctx.prepend_field(obs_id, "name", ident); - result.push(obs_id); - } - result + obs_ids }} ), // property_binding with any pattern name (identifier or destructuring) @@ -186,19 +173,19 @@ fn translation_rules() -> Vec { (modifiers)* @mods) => {..{ - let binding_text = __yeast_ctx.ast.source_text(binding_kind.into()); + let binding_text = ctx.ast.source_text(binding_kind.into()); let mod_ids: Vec = mods.iter().map(|&m| m.into()).collect(); let decl_ids: Vec = decls.iter().map(|&d| d.into()).collect(); for (i, &decl_id) in decl_ids.iter().enumerate() { if i > 0 { - let chained = __yeast_ctx.literal("modifier", "chained_declaration"); - __yeast_ctx.prepend_field(decl_id, "modifier", chained); + let chained = ctx.literal("modifier", "chained_declaration"); + ctx.prepend_field(decl_id, "modifier", chained); } for &mod_id in mod_ids.iter().rev() { - __yeast_ctx.prepend_field(decl_id, "modifier", mod_id); + ctx.prepend_field(decl_id, "modifier", mod_id); } - let binding_mod = __yeast_ctx.literal("modifier", &binding_text); - __yeast_ctx.prepend_field(decl_id, "modifier", binding_mod); + let binding_mod = ctx.literal("modifier", &binding_text); + ctx.prepend_field(decl_id, "modifier", binding_mod); } decl_ids }} @@ -256,11 +243,11 @@ fn translation_rules() -> Vec { let case_ids: Vec = cases.iter().map(|&c| c.into()).collect(); for (i, &case_id) in case_ids.iter().enumerate() { if i > 0 { - let chained = __yeast_ctx.literal("modifier", "chained_declaration"); - __yeast_ctx.prepend_field(case_id, "modifier", chained); + let chained = ctx.literal("modifier", "chained_declaration"); + ctx.prepend_field(case_id, "modifier", chained); } for &mod_id in mod_ids.iter().rev() { - __yeast_ctx.prepend_field(case_id, "modifier", mod_id); + ctx.prepend_field(case_id, "modifier", mod_id); } } case_ids @@ -343,7 +330,7 @@ fn translation_rules() -> Vec { {..{ let p_id: usize = p.into(); for &d in def.iter().rev() { - __yeast_ctx.prepend_field(p_id, "default", d.into()); + ctx.prepend_field(p_id, "default", d.into()); } vec![p_id] }} @@ -585,9 +572,9 @@ fn translation_rules() -> Vec { ), // Labeled statement (e.g. `outer: for ...`). Strip the trailing ':' from the label token. rule!((labeled_statement label: (statement_label) @lbl statement: @stmt) => {..{ - let text = __yeast_ctx.ast.source_text(lbl.into()); - let name = __yeast_ctx.literal("identifier", &text[..text.len() - 1]); - vec![__yeast_ctx.node("labeled_stmt", vec![("label", vec![name]), ("stmt", vec![stmt.into()])])] + let text = ctx.ast.source_text(lbl.into()); + let name = ctx.literal("identifier", &text[..text.len() - 1]); + vec![ctx.node("labeled_stmt", vec![("label", vec![name]), ("stmt", vec![stmt.into()])])] }}), // ---- Collections ---- // Array literal @@ -602,7 +589,7 @@ fn translation_rules() -> Vec { keys.iter().zip(vals.iter()).map(|(&k, &v)| { let k_id: usize = k.into(); let v_id: usize = v.into(); - __yeast_ctx.node("key_value_pair", vec![ + ctx.node("key_value_pair", vec![ ("key", vec![k_id]), ("value", vec![v_id]), ]) @@ -885,23 +872,23 @@ fn translation_rules() -> Vec { (modifiers)* @mods) => {..{ - let name_text = __yeast_ctx.ast.source_text(pattern.into()); + let name_text = ctx.ast.source_text(pattern.into()); let mod_ids: Vec = mods.iter().map(|&m| m.into()).collect(); let ty_ids: Vec = ty.iter().map(|&t| t.into()).collect(); let acc_ids: Vec = accessors.iter().map(|&a| a.into()).collect(); for (i, &acc_id) in acc_ids.iter().enumerate() { if i > 0 { - let chained = __yeast_ctx.literal("modifier", "chained_declaration"); - __yeast_ctx.prepend_field(acc_id, "modifier", chained); + let chained = ctx.literal("modifier", "chained_declaration"); + ctx.prepend_field(acc_id, "modifier", chained); } for &mod_id in mod_ids.iter().rev() { - __yeast_ctx.prepend_field(acc_id, "modifier", mod_id); + ctx.prepend_field(acc_id, "modifier", mod_id); } for &ty_id in ty_ids.iter().rev() { - __yeast_ctx.prepend_field(acc_id, "type", ty_id); + ctx.prepend_field(acc_id, "type", ty_id); } - let ident = __yeast_ctx.literal("identifier", &name_text); - __yeast_ctx.prepend_field(acc_id, "name", ident); + let ident = ctx.literal("identifier", &name_text); + ctx.prepend_field(acc_id, "name", ident); } acc_ids }} diff --git a/unified/extractor/tests/corpus/swift/variables.txt b/unified/extractor/tests/corpus/swift/variables.txt index f1da058eef2..78b80d9a509 100644 --- a/unified/extractor/tests/corpus/swift/variables.txt +++ b/unified/extractor/tests/corpus/swift/variables.txt @@ -319,3 +319,130 @@ top_level name_expr identifier: identifier "x" value: int_literal "1" + +=== +Property with willSet and didSet observers +=== + +class C { + var x: Int = 0 { + willSet { print(newValue) } + didSet { print(oldValue) } + } +} + +--- + +source_file + statement: + class_declaration + body: + class_body + member: + property_declaration + binding: + value_binding_pattern + mutability: var + declarator: + property_binding + name: + pattern + bound_identifier: simple_identifier "x" + observers: + willset_didset_block + didset: + didset_clause + body: + block + statement: + call_expression + function: simple_identifier "print" + suffix: + call_suffix + arguments: + value_arguments + argument: + value_argument + value: simple_identifier "oldValue" + willset: + willset_clause + body: + block + statement: + call_expression + function: simple_identifier "print" + suffix: + call_suffix + arguments: + value_arguments + argument: + value_argument + value: simple_identifier "newValue" + type: + type_annotation + type: + type + name: + user_type + part: + simple_user_type + name: type_identifier "Int" + value: integer_literal "0" + declaration_kind: class + name: type_identifier "C" + +--- + +top_level + body: + block + stmt: + class_like_declaration + member: + variable_declaration + modifier: modifier "var" + pattern: + name_pattern + identifier: identifier "x" + type: + named_type_expr + name: identifier "Int" + value: int_literal "0" + accessor_declaration + body: + block + stmt: + call_expr + argument: + argument + value: + name_expr + identifier: identifier "newValue" + callee: + name_expr + identifier: identifier "print" + modifier: + modifier "var" + modifier "chained_declaration" + name: identifier "x" + accessor_kind: accessor_kind "willSet" + accessor_declaration + body: + block + stmt: + call_expr + argument: + argument + value: + name_expr + identifier: identifier "oldValue" + callee: + name_expr + identifier: identifier "print" + modifier: + modifier "var" + modifier "chained_declaration" + name: identifier "x" + accessor_kind: accessor_kind "didSet" + modifier: modifier "class" + name: identifier "C" diff --git a/unified/extractor/tests/corpus_tests.rs b/unified/extractor/tests/corpus_tests.rs index 0f1057a8e5b..85a62726d87 100644 --- a/unified/extractor/tests/corpus_tests.rs +++ b/unified/extractor/tests/corpus_tests.rs @@ -168,7 +168,7 @@ fn dump_raw_parse( lang: &simple::LanguageSpec, input: &str, ) -> Result { - let runner = Runner::new(lang.ts_language.clone(), &[]); + let runner: Runner = Runner::new(lang.ts_language.clone(), &[]); let ast = runner .run(input) .map_err(|e| format!("Failed to parse input: {e}"))?;