Integrate yeast with shared tree-sitter extractor

extract() gains a rules parameter. When empty, uses tree-sitter native
traversal (no behavior change). When non-empty, runs yeast desugaring
and extracts via traverse_yeast.

Adds AstNode trait abstracting over tree_sitter::Node and yeast::Node,
with minimal changes to existing Visitor methods (Node -> &N in 6
signatures). Falls back to un-desugared AST on error.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Taus
2026-05-04 13:13:58 +00:00
parent d6234f2ccd
commit 34e4144a14
6 changed files with 114 additions and 18 deletions

View File

@@ -123,6 +123,8 @@ pub fn run(options: Options) -> std::io::Result<()> {
&path,
&source,
&[],
vec![],
None,
);
let (ranges, line_breaks) = scan_erb(
@@ -211,6 +213,8 @@ pub fn run(options: Options) -> std::io::Result<()> {
&path,
&source,
&code_ranges,
vec![],
None,
);
std::fs::create_dir_all(src_archive_file.parent().unwrap())?;
if needs_conversion {

View File

@@ -20,6 +20,7 @@ serde_json = "1.0"
chrono = { version = "0.4.42", features = ["serde"] }
num_cpus = "1.17.0"
zstd = "0.13.3"
yeast = { path = "../yeast" }
[dev-dependencies]
tree-sitter-ql = "0.23.1"

View File

@@ -18,6 +18,45 @@ use tree_sitter::{Language, Node, Parser, Range, Tree};
pub mod simple;
/// Trait abstracting over tree-sitter and yeast node types for extraction.
trait AstNode {
fn kind(&self) -> &str;
fn is_named(&self) -> bool;
fn is_missing(&self) -> bool;
fn is_error(&self) -> bool;
fn is_extra(&self) -> bool;
fn start_position(&self) -> tree_sitter::Point;
fn end_position(&self) -> tree_sitter::Point;
fn byte_range(&self) -> std::ops::Range<usize>;
fn start_byte(&self) -> usize { self.byte_range().start }
fn end_byte(&self) -> usize { self.byte_range().end }
/// For yeast nodes with synthetic content, return it. Otherwise None.
fn opt_string_content(&self) -> Option<String> { None }
}
impl<'a> AstNode for Node<'a> {
fn kind(&self) -> &str { Node::kind(self) }
fn is_named(&self) -> bool { Node::is_named(self) }
fn is_missing(&self) -> bool { Node::is_missing(self) }
fn is_error(&self) -> bool { Node::is_error(self) }
fn is_extra(&self) -> bool { Node::is_extra(self) }
fn start_position(&self) -> tree_sitter::Point { Node::start_position(self) }
fn end_position(&self) -> tree_sitter::Point { Node::end_position(self) }
fn byte_range(&self) -> std::ops::Range<usize> { Node::byte_range(self) }
}
impl AstNode for yeast::Node {
fn kind(&self) -> &str { yeast::Node::kind(self) }
fn is_named(&self) -> bool { yeast::Node::is_named(self) }
fn is_missing(&self) -> bool { yeast::Node::is_missing(self) }
fn is_error(&self) -> bool { yeast::Node::is_error(self) }
fn is_extra(&self) -> bool { yeast::Node::is_extra(self) }
fn start_position(&self) -> tree_sitter::Point { yeast::Node::start_position(self) }
fn end_position(&self) -> tree_sitter::Point { yeast::Node::end_position(self) }
fn byte_range(&self) -> std::ops::Range<usize> { yeast::Node::byte_range(self) }
fn opt_string_content(&self) -> Option<String> { yeast::Node::opt_string_content(self) }
}
/// Sets the tracing level based on the environment variables
/// `RUST_LOG` and `CODEQL_VERBOSITY` (prioritized in that order),
/// falling back to `warn` if neither is set.
@@ -204,6 +243,9 @@ pub fn location_label(writer: &mut trap::Writer, location: trap::Location) -> tr
}
/// Extracts the source file at `path`, which is assumed to be canonicalized.
/// When `rules` is non-empty, the parsed tree is first transformed through
/// yeast before TRAP extraction. If `output_schema` is provided, it is used
/// by the yeast runner to resolve output-only node kinds and fields.
pub fn extract(
language: &Language,
language_prefix: &str,
@@ -214,6 +256,8 @@ pub fn extract(
path: &Path,
source: &[u8],
ranges: &[Range],
rules: Vec<yeast::Rule>,
output_schema: Option<yeast::schema::Schema>,
) {
let path_str = file_paths::normalize_and_transform_path(path, transformer);
let span = tracing::span!(
@@ -236,13 +280,26 @@ pub fn extract(
source,
diagnostics_writer,
trap_writer,
// TODO: should we handle path strings that are not valid UTF8 better?
&path_str,
file_label,
language_prefix,
schema,
);
traverse(&tree, &mut visitor);
if rules.is_empty() {
traverse(&tree, &mut visitor);
} else {
let runner = match output_schema {
Some(schema) => yeast::Runner::with_schema(language.clone(), schema, rules),
None => yeast::Runner::new(language.clone(), rules),
};
let ast = runner.run_from_tree(&tree)
.unwrap_or_else(|e| {
tracing::error!("Desugaring failed: {e}");
yeast::Ast::from_tree(language.clone(), &tree)
});
traverse_yeast(&ast, &mut visitor);
}
parser.reset();
}
@@ -329,11 +386,11 @@ impl<'a> Visitor<'a> {
);
}
fn record_parse_error_for_node(
fn record_parse_error_for_node<N: AstNode>(
&mut self,
message: &str,
args: &[diagnostics::MessageArg],
node: Node,
node: &N,
status_page: bool,
) {
let loc = location_for(self, self.file_label, node);
@@ -357,7 +414,7 @@ impl<'a> Visitor<'a> {
self.record_parse_error(loc_label, &mesg);
}
fn enter_node(&mut self, node: Node) -> bool {
fn enter_node<N: AstNode>(&mut self, node: &N) -> bool {
if node.is_missing() {
self.record_parse_error_for_node(
"A parse error occurred (expected {} symbol). Check the syntax of the file. If the file is invalid, correct the error or {} the file from analysis.",
@@ -383,7 +440,7 @@ impl<'a> Visitor<'a> {
true
}
fn leave_node(&mut self, field_name: Option<&'static str>, node: Node) {
fn leave_node<N: AstNode>(&mut self, field_name: Option<&'static str>, node: &N) {
if node.is_error() || node.is_missing() {
return;
}
@@ -434,7 +491,7 @@ impl<'a> Visitor<'a> {
fields,
name: table_name,
} => {
if let Some(args) = self.complex_node(&node, fields, &child_nodes, id) {
if let Some(args) = self.complex_node(node, fields, &child_nodes, id) {
self.trap_writer.add_tuple(
&self.ast_node_location_table_name,
vec![trap::Arg::Label(id), trap::Arg::Label(loc_label)],
@@ -495,9 +552,9 @@ impl<'a> Visitor<'a> {
}
}
fn complex_node(
fn complex_node<N: AstNode>(
&mut self,
node: &Node,
node: &N,
fields: &[Field],
child_nodes: &[ChildNode],
parent_id: trap::Label,
@@ -529,7 +586,7 @@ impl<'a> Visitor<'a> {
diagnostics::MessageArg::Code(&format!("{:?}", child_node.type_name)),
diagnostics::MessageArg::Code(&format!("{:?}", field.type_info)),
],
*node,
node,
false,
);
}
@@ -541,7 +598,7 @@ impl<'a> Visitor<'a> {
diagnostics::MessageArg::Code(child_node.field_name.unwrap_or("child")),
diagnostics::MessageArg::Code(&format!("{:?}", child_node.type_name)),
],
*node,
node,
false,
);
}
@@ -566,7 +623,7 @@ impl<'a> Visitor<'a> {
node.kind(),
column_name
);
self.record_parse_error_for_node(&error_message, &[], *node, false);
self.record_parse_error_for_node(&error_message, &[], node, false);
}
}
Storage::Table {
@@ -582,7 +639,7 @@ impl<'a> Visitor<'a> {
diagnostics::MessageArg::Code(node.kind()),
diagnostics::MessageArg::Code(table_name),
],
*node,
node,
false,
);
break;
@@ -639,15 +696,17 @@ impl<'a> Visitor<'a> {
}
// Emit a slice of a source file as an Arg.
fn sliced_source_arg(source: &[u8], n: Node) -> trap::Arg {
let range = n.byte_range();
trap::Arg::String(String::from_utf8_lossy(&source[range.start..range.end]).into_owned())
fn sliced_source_arg<N: AstNode>(source: &[u8], n: &N) -> trap::Arg {
trap::Arg::String(n.opt_string_content().unwrap_or_else(|| {
let range = n.byte_range();
String::from_utf8_lossy(&source[range.start..range.end]).into_owned()
}))
}
// Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated.
// The first is the location and label definition, and the second is the
// 'Located' entry.
fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap::Location {
fn location_for<N: AstNode>(visitor: &mut Visitor, file_label: trap::Label, n: &N) -> trap::Location {
// Tree-sitter row, column values are 0-based while CodeQL starts
// counting at 1. In addition Tree-sitter's row and column for the
// end position are exclusive while CodeQL's end positions are inclusive.
@@ -715,6 +774,28 @@ fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap
fn traverse(tree: &Tree, visitor: &mut Visitor) {
let cursor = &mut tree.walk();
visitor.enter_node(&cursor.node());
let mut recurse = true;
loop {
if recurse && cursor.goto_first_child() {
recurse = visitor.enter_node(&cursor.node());
} else {
visitor.leave_node(cursor.field_name(), &cursor.node());
if cursor.goto_next_sibling() {
recurse = visitor.enter_node(&cursor.node());
} else if cursor.goto_parent() {
recurse = false;
} else {
break;
}
}
}
}
fn traverse_yeast(tree: &yeast::Ast, visitor: &mut Visitor) {
use yeast::Cursor;
let mut cursor = tree.walk();
visitor.enter_node(cursor.node());
let mut recurse = true;
loop {

View File

@@ -12,6 +12,10 @@ pub struct LanguageSpec {
pub prefix: &'static str,
pub ts_language: tree_sitter::Language,
pub node_types: &'static str,
/// If set, the extractor validates TRAP output against these node types
/// instead of `node_types`. Use when desugaring produces an AST that
/// differs from the tree-sitter grammar.
pub output_node_types: Option<&'static str>,
pub file_globs: Vec<String>,
}
@@ -86,7 +90,8 @@ impl Extractor {
let mut schemas = vec![];
for lang in &self.languages {
let schema = node_types::read_node_types_str(lang.prefix, lang.node_types)?;
let effective_node_types = lang.output_node_types.unwrap_or(lang.node_types);
let schema = node_types::read_node_types_str(lang.prefix, effective_node_types)?;
schemas.push(schema);
}
@@ -162,6 +167,8 @@ impl Extractor {
&path,
&source,
&[],
vec![],
None,
);
std::fs::create_dir_all(src_archive_file.parent().unwrap())?;
std::fs::copy(&path, &src_archive_file)?;

View File

@@ -13,6 +13,7 @@ fn simple_extractor() {
prefix: "ql",
ts_language: tree_sitter_ql::LANGUAGE.into(),
node_types: tree_sitter_ql::NODE_TYPES,
output_node_types: None,
file_globs: vec!["*.qll".into()],
};

View File

@@ -13,12 +13,14 @@ fn multiple_language_extractor() {
prefix: "ql",
ts_language: tree_sitter_ql::LANGUAGE.into(),
node_types: tree_sitter_ql::NODE_TYPES,
output_node_types: None,
file_globs: vec!["*.qll".into()],
};
let lang_json = simple::LanguageSpec {
prefix: "json",
ts_language: tree_sitter_json::LANGUAGE.into(),
node_types: tree_sitter_json::NODE_TYPES,
output_node_types: None,
file_globs: vec!["*.json".into(), "*Jsonfile".into()],
};