From 5862ab66296cee2c92c290ff297718b47a8dc7eb Mon Sep 17 00:00:00 2001 From: Taus Date: Mon, 4 May 2026 13:13:58 +0000 Subject: [PATCH] yeast: Integrate yeast with shared tree-sitter extractor extract() gains a rules parameter. When empty, uses tree-sitter native traversal (no behavior change). When non-empty, runs yeast desugaring and extracts via traverse_yeast. Adds AstNode trait abstracting over tree_sitter::Node and yeast::Node, with minimal changes to existing Visitor methods (Node -> &N in 6 signatures). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- ql/extractor/src/extractor.rs | 4 + ruby/extractor/src/extractor.rs | 2 + shared/tree-sitter-extractor/Cargo.toml | 1 + .../src/extractor/mod.rs | 152 ++++++++++++++++-- .../src/extractor/simple.rs | 19 ++- .../tests/integration_test.rs | 1 + .../tests/multiple_languages.rs | 2 + 7 files changed, 163 insertions(+), 18 deletions(-) diff --git a/ql/extractor/src/extractor.rs b/ql/extractor/src/extractor.rs index 8383d8424ee..66096442e85 100644 --- a/ql/extractor/src/extractor.rs +++ b/ql/extractor/src/extractor.rs @@ -29,24 +29,28 @@ pub fn run(options: Options) -> std::io::Result<()> { prefix: "ql", ts_language: tree_sitter_ql::LANGUAGE.into(), node_types: tree_sitter_ql::NODE_TYPES, + desugar: None, file_globs: vec!["*.ql".into(), "*.qll".into()], }, simple::LanguageSpec { prefix: "dbscheme", ts_language: tree_sitter_ql_dbscheme::LANGUAGE.into(), node_types: tree_sitter_ql_dbscheme::NODE_TYPES, + desugar: None, file_globs: vec!["*.dbscheme".into()], }, simple::LanguageSpec { prefix: "json", ts_language: tree_sitter_json::LANGUAGE.into(), node_types: tree_sitter_json::NODE_TYPES, + desugar: None, file_globs: vec!["*.json".into(), "*.jsonl".into(), "*.jsonc".into()], }, simple::LanguageSpec { prefix: "blame", ts_language: tree_sitter_blame::LANGUAGE.into(), node_types: tree_sitter_blame::NODE_TYPES, + desugar: None, file_globs: vec!["*.blame".into()], }, ], diff --git a/ruby/extractor/src/extractor.rs b/ruby/extractor/src/extractor.rs index 6807d09e9be..4849f473ccb 100644 --- a/ruby/extractor/src/extractor.rs +++ b/ruby/extractor/src/extractor.rs @@ -123,6 +123,7 @@ pub fn run(options: Options) -> std::io::Result<()> { &path, &source, &[], + None, ); let (ranges, line_breaks) = scan_erb( @@ -211,6 +212,7 @@ pub fn run(options: Options) -> std::io::Result<()> { &path, &source, &code_ranges, + None, ); std::fs::create_dir_all(src_archive_file.parent().unwrap())?; if needs_conversion { diff --git a/shared/tree-sitter-extractor/Cargo.toml b/shared/tree-sitter-extractor/Cargo.toml index d02f02fd588..1ad18a6df5a 100644 --- a/shared/tree-sitter-extractor/Cargo.toml +++ b/shared/tree-sitter-extractor/Cargo.toml @@ -20,6 +20,7 @@ serde_json = "1.0" chrono = { version = "0.4.42", features = ["serde"] } num_cpus = "1.17.0" zstd = "0.13.3" +yeast = { path = "../yeast" } [dev-dependencies] tree-sitter-ql = "0.23.1" diff --git a/shared/tree-sitter-extractor/src/extractor/mod.rs b/shared/tree-sitter-extractor/src/extractor/mod.rs index 0ace3831881..83060fc40d4 100644 --- a/shared/tree-sitter-extractor/src/extractor/mod.rs +++ b/shared/tree-sitter-extractor/src/extractor/mod.rs @@ -18,6 +18,82 @@ use tree_sitter::{Language, Node, Parser, Range, Tree}; pub mod simple; +/// Trait abstracting over tree-sitter and yeast node types for extraction. +trait AstNode { + fn kind(&self) -> &str; + fn is_named(&self) -> bool; + fn is_missing(&self) -> bool; + fn is_error(&self) -> bool; + fn is_extra(&self) -> bool; + fn start_position(&self) -> tree_sitter::Point; + fn end_position(&self) -> tree_sitter::Point; + fn byte_range(&self) -> std::ops::Range; + fn end_byte(&self) -> usize { + self.byte_range().end + } + /// For yeast nodes with synthetic content, return it. Otherwise None. + fn opt_string_content(&self) -> Option { + None + } +} + +impl<'a> AstNode for Node<'a> { + fn kind(&self) -> &str { + Node::kind(self) + } + fn is_named(&self) -> bool { + Node::is_named(self) + } + fn is_missing(&self) -> bool { + Node::is_missing(self) + } + fn is_error(&self) -> bool { + Node::is_error(self) + } + fn is_extra(&self) -> bool { + Node::is_extra(self) + } + fn start_position(&self) -> tree_sitter::Point { + Node::start_position(self) + } + fn end_position(&self) -> tree_sitter::Point { + Node::end_position(self) + } + fn byte_range(&self) -> std::ops::Range { + Node::byte_range(self) + } +} + +impl AstNode for yeast::Node { + fn kind(&self) -> &str { + yeast::Node::kind(self) + } + fn is_named(&self) -> bool { + yeast::Node::is_named(self) + } + fn is_missing(&self) -> bool { + yeast::Node::is_missing(self) + } + fn is_error(&self) -> bool { + yeast::Node::is_error(self) + } + fn is_extra(&self) -> bool { + yeast::Node::is_extra(self) + } + fn start_position(&self) -> tree_sitter::Point { + yeast::Node::start_position(self) + } + fn end_position(&self) -> tree_sitter::Point { + yeast::Node::end_position(self) + } + fn byte_range(&self) -> std::ops::Range { + yeast::Node::byte_range(self) + } + fn opt_string_content(&self) -> Option { + yeast::Node::opt_string_content(self) + } +} + /// Sets the tracing level based on the environment variables /// `RUST_LOG` and `CODEQL_VERBOSITY` (prioritized in that order), /// falling back to `warn` if neither is set. @@ -204,6 +280,10 @@ pub fn location_label(writer: &mut trap::Writer, location: trap::Location) -> tr } /// Extracts the source file at `path`, which is assumed to be canonicalized. +/// When `desugar` is `Some`, the parsed tree is first transformed through +/// yeast before TRAP extraction, using the rules and (optional) output +/// schema from the [`yeast::DesugaringConfig`]. +#[allow(clippy::too_many_arguments)] pub fn extract( language: &Language, language_prefix: &str, @@ -214,6 +294,7 @@ pub fn extract( path: &Path, source: &[u8], ranges: &[Range], + desugar: Option<&yeast::DesugaringConfig>, ) { let path_str = file_paths::normalize_and_transform_path(path, transformer); let span = tracing::span!( @@ -236,13 +317,22 @@ pub fn extract( source, diagnostics_writer, trap_writer, - // TODO: should we handle path strings that are not valid UTF8 better? &path_str, file_label, language_prefix, schema, ); - traverse(&tree, &mut visitor); + + if let Some(config) = desugar { + let runner = yeast::Runner::from_config(language.clone(), config) + .unwrap_or_else(|e| panic!("Failed to build desugaring runner for {path_str}: {e}")); + let ast = runner + .run_from_tree(&tree) + .unwrap_or_else(|e| panic!("Desugaring failed for {path_str}: {e}")); + traverse_yeast(&ast, &mut visitor); + } else { + traverse(&tree, &mut visitor); + } parser.reset(); } @@ -329,11 +419,11 @@ impl<'a> Visitor<'a> { ); } - fn record_parse_error_for_node( + fn record_parse_error_for_node( &mut self, message: &str, args: &[diagnostics::MessageArg], - node: Node, + node: &N, status_page: bool, ) { let loc = location_for(self, self.file_label, node); @@ -357,7 +447,7 @@ impl<'a> Visitor<'a> { self.record_parse_error(loc_label, &mesg); } - fn enter_node(&mut self, node: Node) -> bool { + fn enter_node(&mut self, node: &N) -> bool { if node.is_missing() { self.record_parse_error_for_node( "A parse error occurred (expected {} symbol). Check the syntax of the file. If the file is invalid, correct the error or {} the file from analysis.", @@ -383,7 +473,7 @@ impl<'a> Visitor<'a> { true } - fn leave_node(&mut self, field_name: Option<&'static str>, node: Node) { + fn leave_node(&mut self, field_name: Option<&'static str>, node: &N) { if node.is_error() || node.is_missing() { return; } @@ -434,7 +524,7 @@ impl<'a> Visitor<'a> { fields, name: table_name, } => { - if let Some(args) = self.complex_node(&node, fields, &child_nodes, id) { + if let Some(args) = self.complex_node(node, fields, &child_nodes, id) { self.trap_writer.add_tuple( &self.ast_node_location_table_name, vec![trap::Arg::Label(id), trap::Arg::Label(loc_label)], @@ -495,9 +585,9 @@ impl<'a> Visitor<'a> { } } - fn complex_node( + fn complex_node( &mut self, - node: &Node, + node: &N, fields: &[Field], child_nodes: &[ChildNode], parent_id: trap::Label, @@ -529,7 +619,7 @@ impl<'a> Visitor<'a> { diagnostics::MessageArg::Code(&format!("{:?}", child_node.type_name)), diagnostics::MessageArg::Code(&format!("{:?}", field.type_info)), ], - *node, + node, false, ); } @@ -541,7 +631,7 @@ impl<'a> Visitor<'a> { diagnostics::MessageArg::Code(child_node.field_name.unwrap_or("child")), diagnostics::MessageArg::Code(&format!("{:?}", child_node.type_name)), ], - *node, + node, false, ); } @@ -566,7 +656,7 @@ impl<'a> Visitor<'a> { node.kind(), column_name ); - self.record_parse_error_for_node(&error_message, &[], *node, false); + self.record_parse_error_for_node(&error_message, &[], node, false); } } Storage::Table { @@ -582,7 +672,7 @@ impl<'a> Visitor<'a> { diagnostics::MessageArg::Code(node.kind()), diagnostics::MessageArg::Code(table_name), ], - *node, + node, false, ); break; @@ -639,15 +729,21 @@ impl<'a> Visitor<'a> { } // Emit a slice of a source file as an Arg. -fn sliced_source_arg(source: &[u8], n: Node) -> trap::Arg { - let range = n.byte_range(); - trap::Arg::String(String::from_utf8_lossy(&source[range.start..range.end]).into_owned()) +fn sliced_source_arg(source: &[u8], n: &N) -> trap::Arg { + trap::Arg::String(n.opt_string_content().unwrap_or_else(|| { + let range = n.byte_range(); + String::from_utf8_lossy(&source[range.start..range.end]).into_owned() + })) } // Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated. // The first is the location and label definition, and the second is the // 'Located' entry. -fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap::Location { +fn location_for( + visitor: &mut Visitor, + file_label: trap::Label, + n: &N, +) -> trap::Location { // Tree-sitter row, column values are 0-based while CodeQL starts // counting at 1. In addition Tree-sitter's row and column for the // end position are exclusive while CodeQL's end positions are inclusive. @@ -715,6 +811,28 @@ fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap fn traverse(tree: &Tree, visitor: &mut Visitor) { let cursor = &mut tree.walk(); + visitor.enter_node(&cursor.node()); + let mut recurse = true; + loop { + if recurse && cursor.goto_first_child() { + recurse = visitor.enter_node(&cursor.node()); + } else { + visitor.leave_node(cursor.field_name(), &cursor.node()); + + if cursor.goto_next_sibling() { + recurse = visitor.enter_node(&cursor.node()); + } else if cursor.goto_parent() { + recurse = false; + } else { + break; + } + } + } +} + +fn traverse_yeast(tree: &yeast::Ast, visitor: &mut Visitor) { + use yeast::Cursor; + let mut cursor = tree.walk(); visitor.enter_node(cursor.node()); let mut recurse = true; loop { diff --git a/shared/tree-sitter-extractor/src/extractor/simple.rs b/shared/tree-sitter-extractor/src/extractor/simple.rs index b8446d02f89..120ed9da7e3 100644 --- a/shared/tree-sitter-extractor/src/extractor/simple.rs +++ b/shared/tree-sitter-extractor/src/extractor/simple.rs @@ -7,11 +7,17 @@ use std::path::{Path, PathBuf}; use crate::diagnostics; use crate::node_types; +use yeast; pub struct LanguageSpec { pub prefix: &'static str, pub ts_language: tree_sitter::Language, pub node_types: &'static str, + /// Optional yeast desugaring configuration. When set, the parsed + /// tree is rewritten through yeast before TRAP extraction. The + /// config's `output_node_types_yaml` (if set) provides the schema + /// used both at runtime (for the rewriter) and for TRAP validation. + pub desugar: Option, pub file_globs: Vec, } @@ -86,7 +92,17 @@ impl Extractor { let mut schemas = vec![]; for lang in &self.languages { - let schema = node_types::read_node_types_str(lang.prefix, lang.node_types)?; + let effective_node_types: String = + match lang.desugar.as_ref().and_then(|c| c.output_node_types_yaml) { + Some(yaml) => yeast::node_types_yaml::convert(yaml).map_err(|e| { + std::io::Error::other(format!( + "Failed to convert YAML node-types to JSON for {}: {e}", + lang.prefix + )) + })?, + None => lang.node_types.to_string(), + }; + let schema = node_types::read_node_types_str(lang.prefix, &effective_node_types)?; schemas.push(schema); } @@ -162,6 +178,7 @@ impl Extractor { &path, &source, &[], + lang.desugar.as_ref(), ); std::fs::create_dir_all(src_archive_file.parent().unwrap())?; std::fs::copy(&path, &src_archive_file)?; diff --git a/shared/tree-sitter-extractor/tests/integration_test.rs b/shared/tree-sitter-extractor/tests/integration_test.rs index 2b243ff7945..694eb526f39 100644 --- a/shared/tree-sitter-extractor/tests/integration_test.rs +++ b/shared/tree-sitter-extractor/tests/integration_test.rs @@ -13,6 +13,7 @@ fn simple_extractor() { prefix: "ql", ts_language: tree_sitter_ql::LANGUAGE.into(), node_types: tree_sitter_ql::NODE_TYPES, + desugar: None, file_globs: vec!["*.qll".into()], }; diff --git a/shared/tree-sitter-extractor/tests/multiple_languages.rs b/shared/tree-sitter-extractor/tests/multiple_languages.rs index 2e45e56754a..e345eec5828 100644 --- a/shared/tree-sitter-extractor/tests/multiple_languages.rs +++ b/shared/tree-sitter-extractor/tests/multiple_languages.rs @@ -13,12 +13,14 @@ fn multiple_language_extractor() { prefix: "ql", ts_language: tree_sitter_ql::LANGUAGE.into(), node_types: tree_sitter_ql::NODE_TYPES, + desugar: None, file_globs: vec!["*.qll".into()], }; let lang_json = simple::LanguageSpec { prefix: "json", ts_language: tree_sitter_json::LANGUAGE.into(), node_types: tree_sitter_json::NODE_TYPES, + desugar: None, file_globs: vec!["*.json".into(), "*Jsonfile".into()], };