From 34e4144a14931fae302b7443a3763fc6bef83a54 Mon Sep 17 00:00:00 2001 From: Taus Date: Mon, 4 May 2026 13:13:58 +0000 Subject: [PATCH] Integrate yeast with shared tree-sitter extractor extract() gains a rules parameter. When empty, uses tree-sitter native traversal (no behavior change). When non-empty, runs yeast desugaring and extracts via traverse_yeast. Adds AstNode trait abstracting over tree_sitter::Node and yeast::Node, with minimal changes to existing Visitor methods (Node -> &N in 6 signatures). Falls back to un-desugared AST on error. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- ruby/extractor/src/extractor.rs | 4 + shared/tree-sitter-extractor/Cargo.toml | 1 + .../src/extractor/mod.rs | 115 +++++++++++++++--- .../src/extractor/simple.rs | 9 +- .../tests/integration_test.rs | 1 + .../tests/multiple_languages.rs | 2 + 6 files changed, 114 insertions(+), 18 deletions(-) diff --git a/ruby/extractor/src/extractor.rs b/ruby/extractor/src/extractor.rs index 6807d09e9be..9ab217e2aea 100644 --- a/ruby/extractor/src/extractor.rs +++ b/ruby/extractor/src/extractor.rs @@ -123,6 +123,8 @@ pub fn run(options: Options) -> std::io::Result<()> { &path, &source, &[], + vec![], + None, ); let (ranges, line_breaks) = scan_erb( @@ -211,6 +213,8 @@ pub fn run(options: Options) -> std::io::Result<()> { &path, &source, &code_ranges, + vec![], + None, ); std::fs::create_dir_all(src_archive_file.parent().unwrap())?; if needs_conversion { diff --git a/shared/tree-sitter-extractor/Cargo.toml b/shared/tree-sitter-extractor/Cargo.toml index d02f02fd588..1ad18a6df5a 100644 --- a/shared/tree-sitter-extractor/Cargo.toml +++ b/shared/tree-sitter-extractor/Cargo.toml @@ -20,6 +20,7 @@ serde_json = "1.0" chrono = { version = "0.4.42", features = ["serde"] } num_cpus = "1.17.0" zstd = "0.13.3" +yeast = { path = "../yeast" } [dev-dependencies] tree-sitter-ql = "0.23.1" diff --git a/shared/tree-sitter-extractor/src/extractor/mod.rs b/shared/tree-sitter-extractor/src/extractor/mod.rs index 0ace3831881..a362bc37fd3 100644 --- a/shared/tree-sitter-extractor/src/extractor/mod.rs +++ b/shared/tree-sitter-extractor/src/extractor/mod.rs @@ -18,6 +18,45 @@ use tree_sitter::{Language, Node, Parser, Range, Tree}; pub mod simple; +/// Trait abstracting over tree-sitter and yeast node types for extraction. +trait AstNode { + fn kind(&self) -> &str; + fn is_named(&self) -> bool; + fn is_missing(&self) -> bool; + fn is_error(&self) -> bool; + fn is_extra(&self) -> bool; + fn start_position(&self) -> tree_sitter::Point; + fn end_position(&self) -> tree_sitter::Point; + fn byte_range(&self) -> std::ops::Range; + fn start_byte(&self) -> usize { self.byte_range().start } + fn end_byte(&self) -> usize { self.byte_range().end } + /// For yeast nodes with synthetic content, return it. Otherwise None. + fn opt_string_content(&self) -> Option { None } +} + +impl<'a> AstNode for Node<'a> { + fn kind(&self) -> &str { Node::kind(self) } + fn is_named(&self) -> bool { Node::is_named(self) } + fn is_missing(&self) -> bool { Node::is_missing(self) } + fn is_error(&self) -> bool { Node::is_error(self) } + fn is_extra(&self) -> bool { Node::is_extra(self) } + fn start_position(&self) -> tree_sitter::Point { Node::start_position(self) } + fn end_position(&self) -> tree_sitter::Point { Node::end_position(self) } + fn byte_range(&self) -> std::ops::Range { Node::byte_range(self) } +} + +impl AstNode for yeast::Node { + fn kind(&self) -> &str { yeast::Node::kind(self) } + fn is_named(&self) -> bool { yeast::Node::is_named(self) } + fn is_missing(&self) -> bool { yeast::Node::is_missing(self) } + fn is_error(&self) -> bool { yeast::Node::is_error(self) } + fn is_extra(&self) -> bool { yeast::Node::is_extra(self) } + fn start_position(&self) -> tree_sitter::Point { yeast::Node::start_position(self) } + fn end_position(&self) -> tree_sitter::Point { yeast::Node::end_position(self) } + fn byte_range(&self) -> std::ops::Range { yeast::Node::byte_range(self) } + fn opt_string_content(&self) -> Option { yeast::Node::opt_string_content(self) } +} + /// Sets the tracing level based on the environment variables /// `RUST_LOG` and `CODEQL_VERBOSITY` (prioritized in that order), /// falling back to `warn` if neither is set. @@ -204,6 +243,9 @@ pub fn location_label(writer: &mut trap::Writer, location: trap::Location) -> tr } /// Extracts the source file at `path`, which is assumed to be canonicalized. +/// When `rules` is non-empty, the parsed tree is first transformed through +/// yeast before TRAP extraction. If `output_schema` is provided, it is used +/// by the yeast runner to resolve output-only node kinds and fields. pub fn extract( language: &Language, language_prefix: &str, @@ -214,6 +256,8 @@ pub fn extract( path: &Path, source: &[u8], ranges: &[Range], + rules: Vec, + output_schema: Option, ) { let path_str = file_paths::normalize_and_transform_path(path, transformer); let span = tracing::span!( @@ -236,13 +280,26 @@ pub fn extract( source, diagnostics_writer, trap_writer, - // TODO: should we handle path strings that are not valid UTF8 better? &path_str, file_label, language_prefix, schema, ); - traverse(&tree, &mut visitor); + + if rules.is_empty() { + traverse(&tree, &mut visitor); + } else { + let runner = match output_schema { + Some(schema) => yeast::Runner::with_schema(language.clone(), schema, rules), + None => yeast::Runner::new(language.clone(), rules), + }; + let ast = runner.run_from_tree(&tree) + .unwrap_or_else(|e| { + tracing::error!("Desugaring failed: {e}"); + yeast::Ast::from_tree(language.clone(), &tree) + }); + traverse_yeast(&ast, &mut visitor); + } parser.reset(); } @@ -329,11 +386,11 @@ impl<'a> Visitor<'a> { ); } - fn record_parse_error_for_node( + fn record_parse_error_for_node( &mut self, message: &str, args: &[diagnostics::MessageArg], - node: Node, + node: &N, status_page: bool, ) { let loc = location_for(self, self.file_label, node); @@ -357,7 +414,7 @@ impl<'a> Visitor<'a> { self.record_parse_error(loc_label, &mesg); } - fn enter_node(&mut self, node: Node) -> bool { + fn enter_node(&mut self, node: &N) -> bool { if node.is_missing() { self.record_parse_error_for_node( "A parse error occurred (expected {} symbol). Check the syntax of the file. If the file is invalid, correct the error or {} the file from analysis.", @@ -383,7 +440,7 @@ impl<'a> Visitor<'a> { true } - fn leave_node(&mut self, field_name: Option<&'static str>, node: Node) { + fn leave_node(&mut self, field_name: Option<&'static str>, node: &N) { if node.is_error() || node.is_missing() { return; } @@ -434,7 +491,7 @@ impl<'a> Visitor<'a> { fields, name: table_name, } => { - if let Some(args) = self.complex_node(&node, fields, &child_nodes, id) { + if let Some(args) = self.complex_node(node, fields, &child_nodes, id) { self.trap_writer.add_tuple( &self.ast_node_location_table_name, vec![trap::Arg::Label(id), trap::Arg::Label(loc_label)], @@ -495,9 +552,9 @@ impl<'a> Visitor<'a> { } } - fn complex_node( + fn complex_node( &mut self, - node: &Node, + node: &N, fields: &[Field], child_nodes: &[ChildNode], parent_id: trap::Label, @@ -529,7 +586,7 @@ impl<'a> Visitor<'a> { diagnostics::MessageArg::Code(&format!("{:?}", child_node.type_name)), diagnostics::MessageArg::Code(&format!("{:?}", field.type_info)), ], - *node, + node, false, ); } @@ -541,7 +598,7 @@ impl<'a> Visitor<'a> { diagnostics::MessageArg::Code(child_node.field_name.unwrap_or("child")), diagnostics::MessageArg::Code(&format!("{:?}", child_node.type_name)), ], - *node, + node, false, ); } @@ -566,7 +623,7 @@ impl<'a> Visitor<'a> { node.kind(), column_name ); - self.record_parse_error_for_node(&error_message, &[], *node, false); + self.record_parse_error_for_node(&error_message, &[], node, false); } } Storage::Table { @@ -582,7 +639,7 @@ impl<'a> Visitor<'a> { diagnostics::MessageArg::Code(node.kind()), diagnostics::MessageArg::Code(table_name), ], - *node, + node, false, ); break; @@ -639,15 +696,17 @@ impl<'a> Visitor<'a> { } // Emit a slice of a source file as an Arg. -fn sliced_source_arg(source: &[u8], n: Node) -> trap::Arg { - let range = n.byte_range(); - trap::Arg::String(String::from_utf8_lossy(&source[range.start..range.end]).into_owned()) +fn sliced_source_arg(source: &[u8], n: &N) -> trap::Arg { + trap::Arg::String(n.opt_string_content().unwrap_or_else(|| { + let range = n.byte_range(); + String::from_utf8_lossy(&source[range.start..range.end]).into_owned() + })) } // Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated. // The first is the location and label definition, and the second is the // 'Located' entry. -fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap::Location { +fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: &N) -> trap::Location { // Tree-sitter row, column values are 0-based while CodeQL starts // counting at 1. In addition Tree-sitter's row and column for the // end position are exclusive while CodeQL's end positions are inclusive. @@ -715,6 +774,28 @@ fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap fn traverse(tree: &Tree, visitor: &mut Visitor) { let cursor = &mut tree.walk(); + visitor.enter_node(&cursor.node()); + let mut recurse = true; + loop { + if recurse && cursor.goto_first_child() { + recurse = visitor.enter_node(&cursor.node()); + } else { + visitor.leave_node(cursor.field_name(), &cursor.node()); + + if cursor.goto_next_sibling() { + recurse = visitor.enter_node(&cursor.node()); + } else if cursor.goto_parent() { + recurse = false; + } else { + break; + } + } + } +} + +fn traverse_yeast(tree: &yeast::Ast, visitor: &mut Visitor) { + use yeast::Cursor; + let mut cursor = tree.walk(); visitor.enter_node(cursor.node()); let mut recurse = true; loop { diff --git a/shared/tree-sitter-extractor/src/extractor/simple.rs b/shared/tree-sitter-extractor/src/extractor/simple.rs index b8446d02f89..18466af5510 100644 --- a/shared/tree-sitter-extractor/src/extractor/simple.rs +++ b/shared/tree-sitter-extractor/src/extractor/simple.rs @@ -12,6 +12,10 @@ pub struct LanguageSpec { pub prefix: &'static str, pub ts_language: tree_sitter::Language, pub node_types: &'static str, + /// If set, the extractor validates TRAP output against these node types + /// instead of `node_types`. Use when desugaring produces an AST that + /// differs from the tree-sitter grammar. + pub output_node_types: Option<&'static str>, pub file_globs: Vec, } @@ -86,7 +90,8 @@ impl Extractor { let mut schemas = vec![]; for lang in &self.languages { - let schema = node_types::read_node_types_str(lang.prefix, lang.node_types)?; + let effective_node_types = lang.output_node_types.unwrap_or(lang.node_types); + let schema = node_types::read_node_types_str(lang.prefix, effective_node_types)?; schemas.push(schema); } @@ -162,6 +167,8 @@ impl Extractor { &path, &source, &[], + vec![], + None, ); std::fs::create_dir_all(src_archive_file.parent().unwrap())?; std::fs::copy(&path, &src_archive_file)?; diff --git a/shared/tree-sitter-extractor/tests/integration_test.rs b/shared/tree-sitter-extractor/tests/integration_test.rs index 2b243ff7945..eae069dd33e 100644 --- a/shared/tree-sitter-extractor/tests/integration_test.rs +++ b/shared/tree-sitter-extractor/tests/integration_test.rs @@ -13,6 +13,7 @@ fn simple_extractor() { prefix: "ql", ts_language: tree_sitter_ql::LANGUAGE.into(), node_types: tree_sitter_ql::NODE_TYPES, + output_node_types: None, file_globs: vec!["*.qll".into()], }; diff --git a/shared/tree-sitter-extractor/tests/multiple_languages.rs b/shared/tree-sitter-extractor/tests/multiple_languages.rs index 2e45e56754a..3e3936e6abd 100644 --- a/shared/tree-sitter-extractor/tests/multiple_languages.rs +++ b/shared/tree-sitter-extractor/tests/multiple_languages.rs @@ -13,12 +13,14 @@ fn multiple_language_extractor() { prefix: "ql", ts_language: tree_sitter_ql::LANGUAGE.into(), node_types: tree_sitter_ql::NODE_TYPES, + output_node_types: None, file_globs: vec!["*.qll".into()], }; let lang_json = simple::LanguageSpec { prefix: "json", ts_language: tree_sitter_json::LANGUAGE.into(), node_types: tree_sitter_json::NODE_TYPES, + output_node_types: None, file_globs: vec!["*.json".into(), "*Jsonfile".into()], };