diff --git a/extractor/src/extractor.rs b/extractor/src/extractor.rs index dec97617287..cc91e8d7a2d 100644 --- a/extractor/src/extractor.rs +++ b/extractor/src/extractor.rs @@ -6,6 +6,99 @@ use std::path::Path; use tracing::{error, info, span, Level}; use tree_sitter::{Language, Node, Parser, Tree}; +pub struct TrapWriter { + /// The accumulated trap entries + trap_output: Vec, + /// A counter for generating fresh labels + counter: i32, +} + +impl TrapWriter { + /// Gets a label that will hold the unique ID of the passed string at import time. + /// This can be used for incrementally importable TRAP files -- use globally unique + /// strings to compute a unique ID for table tuples. + /// + /// Note: You probably want to make sure that the key strings that you use are disjoint + /// for disjoint column types; the standard way of doing this is to prefix (or append) + /// the column type name to the ID. Thus, you might identify methods in Java by the + /// full ID "methods_com.method.package.DeclaringClass.method(argumentList)". + + fn fresh_id(&mut self) -> Label { + self.counter += 1; + let label = Label(self.counter); + self.trap_output.push(TrapEntry::FreshId(label)); + label + } + + fn global_id(&mut self, key: &str) -> Label { + self.counter += 1; + let label = Label(self.counter); + self.trap_output + .push(TrapEntry::MapLabelToKey(label, key.to_owned())); + label + } + + fn add_tuple(&mut self, table_name: &str, args: Vec) { + self.trap_output + .push(TrapEntry::GenericTuple(table_name.to_owned(), args)) + } + + fn populate_file(&mut self, absolute_path: &Path) -> Label { + let file_label = self.global_id(&full_id_for_file(absolute_path)); + self.trap_output.push(TrapEntry::GenericTuple( + "files".to_owned(), + vec![ + Arg::Label(file_label), + Arg::String(normalize_path(absolute_path)), + Arg::String(match absolute_path.file_name() { + None => "".to_owned(), + Some(file_name) => format!("{}", file_name.to_string_lossy()), + }), + Arg::String(match absolute_path.extension() { + None => "".to_owned(), + Some(ext) => format!("{}", ext.to_string_lossy()), + }), + Arg::Int(1), // 1 = from source + ], + )); + file_label + } + + // fn populate_unknown_file(&mut self) -> Label { + // self.global_id("unknown;sourcefile") + // } + // fn populate_folder(&mut self, absolute_path: &str) -> Label { + // self.global_id(&format!("{};folder", absolute_path)) + // } + + fn location( + &mut self, + file_label: Label, + start_line: usize, + start_column: usize, + end_line: usize, + end_column: usize, + ) -> Label { + let key = format!( + "loc,{{{}}},{},{},{},{}", + file_label, start_line, start_column, end_line, end_column + ); + let loc_label = self.global_id(&key); + self.add_tuple( + "locations_default", + vec![ + Arg::Label(loc_label), + Arg::Label(file_label), + Arg::Int(start_line), + Arg::Int(start_column), + Arg::Int(end_line), + Arg::Int(end_column), + ], + ); + loc_label + } +} + pub struct Extractor { pub parser: Parser, pub schema: Vec, @@ -36,37 +129,20 @@ impl Extractor { .parser .parse(&source, None) .expect("Failed to parse file"); - let mut counter = -1; - // Create a label for the current file and increment the counter so that - // label doesn't get redefined. - counter += 1; - let file_label = Label::Normal(counter); + let mut trap_writer = TrapWriter { + counter: -1, + trap_output: vec![TrapEntry::Comment(format!( + "Auto-generated TRAP file for {}", + path.display() + ))], + }; + let file_label = &trap_writer.populate_file(path); let mut visitor = Visitor { source: &source, - trap_output: vec![ - TrapEntry::Comment(format!("Auto-generated TRAP file for {}", path.display())), - TrapEntry::MapLabelToKey(file_label, full_id_for_file(path)), - TrapEntry::GenericTuple( - "files".to_owned(), - vec![ - Arg::Label(file_label), - Arg::String(normalize_path(path)), - Arg::String(match path.file_name() { - None => "".to_owned(), - Some(file_name) => format!("{}", file_name.to_string_lossy()), - }), - Arg::String(match path.extension() { - None => "".to_owned(), - Some(ext) => format!("{}", ext.to_string_lossy()), - }), - Arg::Int(1), // 1 = from source - ], - ), - ], - counter, + trap_writer: trap_writer, // TODO: should we handle path strings that are not valid UTF8 better? path: format!("{}", path.display()), - file_label, + file_label: *file_label, stack: Vec::new(), tables: build_schema_lookup(&self.schema), union_types: build_union_type_lookup(&self.schema), @@ -74,7 +150,7 @@ impl Extractor { traverse(&tree, &mut visitor); &self.parser.reset(); - Ok(Program(visitor.trap_output)) + Ok(Program(visitor.trap_writer.trap_output)) } } @@ -149,10 +225,8 @@ struct Visitor<'a> { file_label: Label, /// The source code as a UTF-8 byte array source: &'a Vec, - /// The accumulated trap entries - trap_output: Vec, - /// A counter for generating fresh labels - counter: i32, + /// A TrapWriter to accumulate trap entries + trap_writer: TrapWriter, /// A lookup table from type name to dbscheme table entries tables: Map<&'a TypeName, &'a Entry>, /// A lookup table for union types mapping a type name to its direct members @@ -199,15 +273,19 @@ impl Visitor<'_> { named: node.is_named(), }); if let Some(Entry::Table { fields, .. }) = table { - self.counter += 1; - let id = Label::Normal(self.counter); - let loc = Label::Location(self.counter); - self.trap_output.push(TrapEntry::FreshId(id)); - let (loc_label_def, loc_tuple) = - location_for(&self.source, &self.file_label, loc, node); - self.trap_output.push(loc_label_def); - self.trap_output.push(loc_tuple); - let table_name = node_type_name(node.kind(), node.is_named()); + let id = self.trap_writer.fresh_id(); + let (start_line, start_column, end_line, end_column) = location_for(&self.source, node); + let loc = self.trap_writer.location( + self.file_label.clone(), + start_line, + start_column, + end_line, + end_column, + ); + let table_name = escape_name(&format!( + "{}_def", + node_type_name(node.kind(), node.is_named()) + )); let args: Option>; if fields.is_empty() { args = Some(vec![sliced_source_arg(self.source, node)]); @@ -215,8 +293,11 @@ impl Visitor<'_> { args = self.complex_node(&node, fields, child_nodes, id); } if let Some(args) = args { - self.trap_output - .push(TrapEntry::Definition(table_name, id, args, loc)); + let mut all_args = Vec::new(); + all_args.push(Arg::Label(id)); + all_args.extend(args); + all_args.push(Arg::Label(loc)); + self.trap_writer.add_tuple(&table_name, all_args); } if let Some(parent) = self.stack.last_mut() { parent.push(( @@ -313,13 +394,18 @@ impl Visitor<'_> { ); break; } - self.trap_output.push(TrapEntry::ChildOf( + let table_name = escape_name(&format!( + "{}_{}", node_type_name(&field.parent.kind, field.parent.named), - parent_id, - field.get_name(), - if *has_index { Some(Index(index)) } else { None }, - *child_id, + field.get_name() )); + let mut args = Vec::new(); + args.push(Arg::Label(parent_id)); + if *has_index { + args.push(Arg::Int(index)) + } + args.push(Arg::Label(*child_id)); + self.trap_writer.add_tuple(&table_name, args); } } } @@ -356,12 +442,7 @@ fn sliced_source_arg(source: &Vec, n: Node) -> Arg { // Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated. // The first is the location and label definition, and the second is the // 'Located' entry. -fn location_for<'a>( - source: &Vec, - file_label: &Label, - label: Label, - n: Node, -) -> (TrapEntry, TrapEntry) { +fn location_for<'a>(source: &Vec, n: Node) -> (usize, usize, usize, usize) { // Tree-sitter row, column values are 0-based while CodeQL starts // counting at 1. In addition Tree-sitter's row and column for the // end position are exclusive while CodeQL's end positions are inclusive. @@ -403,23 +484,7 @@ fn location_for<'a>( ); } } - ( - TrapEntry::MapLabelToKey( - label, - format!( - "loc,{{{}}},{},{},{},{}", - file_label, start_line, start_col, end_line, end_col - ), - ), - TrapEntry::Located(vec![ - Arg::Label(label), - Arg::Label(file_label.clone()), - Arg::Int(start_line), - Arg::Int(start_col), - Arg::Int(end_line), - Arg::Int(end_col), - ]), - ) + (start_line, start_col, end_line, end_col) } fn traverse(tree: &Tree, visitor: &mut Visitor) { @@ -460,12 +525,6 @@ enum TrapEntry { FreshId(Label), /// Maps the label to a key, e.g. `#7 = @"foo"`. MapLabelToKey(Label, String), - // @node_def(self, arg?, location)@ - Definition(String, Label, Vec, Label), - // @node_child(self, index, parent)@ - ChildOf(String, Label, String, Option, Label), - // @location(loc, path, r1, c1, r2, c2) - Located(Vec), /// foo_bar(arg*) GenericTuple(String, Vec), Comment(String), @@ -475,47 +534,6 @@ impl fmt::Display for TrapEntry { match self { TrapEntry::FreshId(label) => write!(f, "{} = *", label), TrapEntry::MapLabelToKey(label, key) => write!(f, "{} = @\"{}\"", label, key), - TrapEntry::Definition(n, id, args, loc) => { - let mut args_str = String::new(); - for arg in args { - args_str.push_str(&format!("{}, ", arg)); - } - write!( - f, - "{}({}, {}{})", - escape_name(&format!("{}_def", &n)), - id, - args_str, - loc - ) - } - TrapEntry::ChildOf(pname, id, fname, idx, p) => match idx { - Some(idx) => write!( - f, - "{}({}, {}, {})", - escape_name(&format!("{}_{}", &pname, &fname)), - id, - idx, - p - ), - None => write!( - f, - "{}({}, {})", - escape_name(&format!("{}_{}", &pname, &fname)), - id, - p - ), - }, - TrapEntry::Located(args) => write!( - f, - "locations_default({}, {}, {}, {}, {}, {})", - args.get(0).unwrap(), - args.get(1).unwrap(), - args.get(2).unwrap(), - args.get(3).unwrap(), - args.get(4).unwrap(), - args.get(5).unwrap(), - ), TrapEntry::GenericTuple(name, args) => { write!(f, "{}(", name)?; for (index, arg) in args.iter().enumerate() { @@ -532,19 +550,12 @@ impl fmt::Display for TrapEntry { } #[derive(Debug, Copy, Clone)] -enum Label { - // Identifiers of the form #0, #1... - Normal(i32), // #0, #1, etc. - // Location identifiers of the form #0_loc, #1_loc... - Location(i32), // #0_loc, #1_loc, etc. -} +// Identifiers of the form #0, #1... +struct Label(i32); impl fmt::Display for Label { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Label::Normal(x) => write!(f, "#{}", x), - Label::Location(x) => write!(f, "#{}_loc", x), - } + write!(f, "#{}", self.0) } }