diff --git a/extractor/src/extractor.rs b/extractor/src/extractor.rs index dec97617287..597de04b986 100644 --- a/extractor/src/extractor.rs +++ b/extractor/src/extractor.rs @@ -6,6 +6,148 @@ use std::path::Path; use tracing::{error, info, span, Level}; use tree_sitter::{Language, Node, Parser, Tree}; +struct TrapWriter { + /// The accumulated trap entries + trap_output: Vec, + /// A counter for generating fresh labels + counter: u32, + /// cache of global keys + global_keys: std::collections::HashMap, +} + +fn new_trap_writer() -> TrapWriter { + TrapWriter { + counter: 0, + trap_output: Vec::new(), + global_keys: std::collections::HashMap::new(), + } +} + +impl TrapWriter { + /// Gets a label that will hold the unique ID of the passed string at import time. + /// This can be used for incrementally importable TRAP files -- use globally unique + /// strings to compute a unique ID for table tuples. + /// + /// Note: You probably want to make sure that the key strings that you use are disjoint + /// for disjoint column types; the standard way of doing this is to prefix (or append) + /// the column type name to the ID. Thus, you might identify methods in Java by the + /// full ID "methods_com.method.package.DeclaringClass.method(argumentList)". + + fn fresh_id(&mut self) -> Label { + let label = Label(self.counter); + self.counter += 1; + self.trap_output.push(TrapEntry::FreshId(label)); + label + } + + fn global_id(&mut self, key: &str) -> (Label, bool) { + if let Some(label) = self.global_keys.get(key) { + return (*label, false); + } + let label = Label(self.counter); + self.counter += 1; + self.global_keys.insert(key.to_owned(), label); + self.trap_output + .push(TrapEntry::MapLabelToKey(label, key.to_owned())); + (label, true) + } + + fn add_tuple(&mut self, table_name: &str, args: Vec) { + self.trap_output + .push(TrapEntry::GenericTuple(table_name.to_owned(), args)) + } + + fn populate_file(&mut self, absolute_path: &Path) -> Label { + let (file_label, fresh) = self.global_id(&full_id_for_file(absolute_path)); + if fresh { + self.add_tuple( + "files", + vec![ + Arg::Label(file_label), + Arg::String(normalize_path(absolute_path)), + Arg::String(match absolute_path.file_name() { + None => "".to_owned(), + Some(file_name) => format!("{}", file_name.to_string_lossy()), + }), + Arg::String(match absolute_path.extension() { + None => "".to_owned(), + Some(ext) => format!("{}", ext.to_string_lossy()), + }), + Arg::Int(1), // 1 = from source + ], + ); + self.populate_parent_folders(file_label, absolute_path.parent()); + } + file_label + } + + fn populate_parent_folders(&mut self, child_label: Label, path: Option<&Path>) { + let mut path = path; + let mut child_label = child_label; + loop { + match path { + None => break, + Some(folder) => { + let (folder_label, fresh) = self.global_id(&full_id_for_folder(folder)); + self.add_tuple( + "containerparent", + vec![Arg::Label(folder_label), Arg::Label(child_label)], + ); + if fresh { + self.add_tuple( + "folders", + vec![ + Arg::Label(folder_label), + Arg::String(normalize_path(folder)), + Arg::String(match folder.file_name() { + None => "".to_owned(), + Some(file_name) => format!("{}", file_name.to_string_lossy()), + }), + ], + ); + path = folder.parent(); + child_label = folder_label; + } else { + break; + } + } + } + } + } + + fn location( + &mut self, + file_label: Label, + start_line: usize, + start_column: usize, + end_line: usize, + end_column: usize, + ) -> Label { + let (loc_label, fresh) = self.global_id(&format!( + "loc,{{{}}},{},{},{},{}", + file_label, start_line, start_column, end_line, end_column + )); + if fresh { + self.add_tuple( + "locations_default", + vec![ + Arg::Label(loc_label), + Arg::Label(file_label), + Arg::Int(start_line), + Arg::Int(start_column), + Arg::Int(end_line), + Arg::Int(end_column), + ], + ); + } + loc_label + } + + fn comment(&mut self, text: String) { + self.trap_output.push(TrapEntry::Comment(text)); + } +} + pub struct Extractor { pub parser: Parser, pub schema: Vec, @@ -36,37 +178,15 @@ impl Extractor { .parser .parse(&source, None) .expect("Failed to parse file"); - let mut counter = -1; - // Create a label for the current file and increment the counter so that - // label doesn't get redefined. - counter += 1; - let file_label = Label::Normal(counter); + let mut trap_writer = new_trap_writer(); + trap_writer.comment(format!("Auto-generated TRAP file for {}", path.display())); + let file_label = &trap_writer.populate_file(path); let mut visitor = Visitor { source: &source, - trap_output: vec![ - TrapEntry::Comment(format!("Auto-generated TRAP file for {}", path.display())), - TrapEntry::MapLabelToKey(file_label, full_id_for_file(path)), - TrapEntry::GenericTuple( - "files".to_owned(), - vec![ - Arg::Label(file_label), - Arg::String(normalize_path(path)), - Arg::String(match path.file_name() { - None => "".to_owned(), - Some(file_name) => format!("{}", file_name.to_string_lossy()), - }), - Arg::String(match path.extension() { - None => "".to_owned(), - Some(ext) => format!("{}", ext.to_string_lossy()), - }), - Arg::Int(1), // 1 = from source - ], - ), - ], - counter, + trap_writer: trap_writer, // TODO: should we handle path strings that are not valid UTF8 better? path: format!("{}", path.display()), - file_label, + file_label: *file_label, stack: Vec::new(), tables: build_schema_lookup(&self.schema), union_types: build_union_type_lookup(&self.schema), @@ -74,7 +194,7 @@ impl Extractor { traverse(&tree, &mut visitor); &self.parser.reset(); - Ok(Program(visitor.trap_output)) + Ok(Program(visitor.trap_writer.trap_output)) } } @@ -121,6 +241,10 @@ fn full_id_for_file(path: &Path) -> String { format!("{};sourcefile", normalize_path(path)) } +fn full_id_for_folder(path: &Path) -> String { + format!("{};folder", normalize_path(path)) +} + fn build_schema_lookup<'a>(schema: &'a Vec) -> Map<&'a TypeName, &'a Entry> { let mut map = std::collections::BTreeMap::new(); for entry in schema { @@ -149,10 +273,8 @@ struct Visitor<'a> { file_label: Label, /// The source code as a UTF-8 byte array source: &'a Vec, - /// The accumulated trap entries - trap_output: Vec, - /// A counter for generating fresh labels - counter: i32, + /// A TrapWriter to accumulate trap entries + trap_writer: TrapWriter, /// A lookup table from type name to dbscheme table entries tables: Map<&'a TypeName, &'a Entry>, /// A lookup table for union types mapping a type name to its direct members @@ -199,15 +321,19 @@ impl Visitor<'_> { named: node.is_named(), }); if let Some(Entry::Table { fields, .. }) = table { - self.counter += 1; - let id = Label::Normal(self.counter); - let loc = Label::Location(self.counter); - self.trap_output.push(TrapEntry::FreshId(id)); - let (loc_label_def, loc_tuple) = - location_for(&self.source, &self.file_label, loc, node); - self.trap_output.push(loc_label_def); - self.trap_output.push(loc_tuple); - let table_name = node_type_name(node.kind(), node.is_named()); + let id = self.trap_writer.fresh_id(); + let (start_line, start_column, end_line, end_column) = location_for(&self.source, node); + let loc = self.trap_writer.location( + self.file_label.clone(), + start_line, + start_column, + end_line, + end_column, + ); + let table_name = escape_name(&format!( + "{}_def", + node_type_name(node.kind(), node.is_named()) + )); let args: Option>; if fields.is_empty() { args = Some(vec![sliced_source_arg(self.source, node)]); @@ -215,8 +341,11 @@ impl Visitor<'_> { args = self.complex_node(&node, fields, child_nodes, id); } if let Some(args) = args { - self.trap_output - .push(TrapEntry::Definition(table_name, id, args, loc)); + let mut all_args = Vec::new(); + all_args.push(Arg::Label(id)); + all_args.extend(args); + all_args.push(Arg::Label(loc)); + self.trap_writer.add_tuple(&table_name, all_args); } if let Some(parent) = self.stack.last_mut() { parent.push(( @@ -313,13 +442,18 @@ impl Visitor<'_> { ); break; } - self.trap_output.push(TrapEntry::ChildOf( + let table_name = escape_name(&format!( + "{}_{}", node_type_name(&field.parent.kind, field.parent.named), - parent_id, - field.get_name(), - if *has_index { Some(Index(index)) } else { None }, - *child_id, + field.get_name() )); + let mut args = Vec::new(); + args.push(Arg::Label(parent_id)); + if *has_index { + args.push(Arg::Int(index)) + } + args.push(Arg::Label(*child_id)); + self.trap_writer.add_tuple(&table_name, args); } } } @@ -356,12 +490,7 @@ fn sliced_source_arg(source: &Vec, n: Node) -> Arg { // Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated. // The first is the location and label definition, and the second is the // 'Located' entry. -fn location_for<'a>( - source: &Vec, - file_label: &Label, - label: Label, - n: Node, -) -> (TrapEntry, TrapEntry) { +fn location_for<'a>(source: &Vec, n: Node) -> (usize, usize, usize, usize) { // Tree-sitter row, column values are 0-based while CodeQL starts // counting at 1. In addition Tree-sitter's row and column for the // end position are exclusive while CodeQL's end positions are inclusive. @@ -403,23 +532,7 @@ fn location_for<'a>( ); } } - ( - TrapEntry::MapLabelToKey( - label, - format!( - "loc,{{{}}},{},{},{},{}", - file_label, start_line, start_col, end_line, end_col - ), - ), - TrapEntry::Located(vec![ - Arg::Label(label), - Arg::Label(file_label.clone()), - Arg::Int(start_line), - Arg::Int(start_col), - Arg::Int(end_line), - Arg::Int(end_col), - ]), - ) + (start_line, start_col, end_line, end_col) } fn traverse(tree: &Tree, visitor: &mut Visitor) { @@ -460,12 +573,6 @@ enum TrapEntry { FreshId(Label), /// Maps the label to a key, e.g. `#7 = @"foo"`. MapLabelToKey(Label, String), - // @node_def(self, arg?, location)@ - Definition(String, Label, Vec, Label), - // @node_child(self, index, parent)@ - ChildOf(String, Label, String, Option, Label), - // @location(loc, path, r1, c1, r2, c2) - Located(Vec), /// foo_bar(arg*) GenericTuple(String, Vec), Comment(String), @@ -474,48 +581,9 @@ impl fmt::Display for TrapEntry { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { TrapEntry::FreshId(label) => write!(f, "{} = *", label), - TrapEntry::MapLabelToKey(label, key) => write!(f, "{} = @\"{}\"", label, key), - TrapEntry::Definition(n, id, args, loc) => { - let mut args_str = String::new(); - for arg in args { - args_str.push_str(&format!("{}, ", arg)); - } - write!( - f, - "{}({}, {}{})", - escape_name(&format!("{}_def", &n)), - id, - args_str, - loc - ) + TrapEntry::MapLabelToKey(label, key) => { + write!(f, "{} = @\"{}\"", label, key.replace("\"", "\"\"")) } - TrapEntry::ChildOf(pname, id, fname, idx, p) => match idx { - Some(idx) => write!( - f, - "{}({}, {}, {})", - escape_name(&format!("{}_{}", &pname, &fname)), - id, - idx, - p - ), - None => write!( - f, - "{}({}, {})", - escape_name(&format!("{}_{}", &pname, &fname)), - id, - p - ), - }, - TrapEntry::Located(args) => write!( - f, - "locations_default({}, {}, {}, {}, {}, {})", - args.get(0).unwrap(), - args.get(1).unwrap(), - args.get(2).unwrap(), - args.get(3).unwrap(), - args.get(4).unwrap(), - args.get(5).unwrap(), - ), TrapEntry::GenericTuple(name, args) => { write!(f, "{}(", name)?; for (index, arg) in args.iter().enumerate() { @@ -532,19 +600,12 @@ impl fmt::Display for TrapEntry { } #[derive(Debug, Copy, Clone)] -enum Label { - // Identifiers of the form #0, #1... - Normal(i32), // #0, #1, etc. - // Location identifiers of the form #0_loc, #1_loc... - Location(i32), // #0_loc, #1_loc, etc. -} +// Identifiers of the form #0, #1... +struct Label(u32); impl fmt::Display for Label { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Label::Normal(x) => write!(f, "#{}", x), - Label::Location(x) => write!(f, "#{}_loc", x), - } + write!(f, "#{}", self.0) } } @@ -566,12 +627,47 @@ enum Arg { String(String), } +const MAX_STRLEN: usize = 1048576; + impl fmt::Display for Arg { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Arg::Label(x) => write!(f, "{}", x), Arg::Int(x) => write!(f, "{}", x), - Arg::String(x) => write!(f, "\"{}\"", x.replace("\"", "\"\"")), + Arg::String(x) => write!( + f, + "\"{}\"", + limit_string(x, MAX_STRLEN).replace("\"", "\"\"") + ), } } } + +/// Limit the length (in bytes) of a string. If the string's length in bytes is +/// less than or equal to the limit then the entire string is returned. Otherwise +/// the string is sliced at the provided limit. If there is a multi-byte character +/// at the limit then the returned slice will be slightly shorter than the limit to +/// avoid splitting that multi-byte character. +fn limit_string(string: &String, max_size: usize) -> &str { + if string.len() <= max_size { + return string; + } + let p = string.as_ptr(); + let mut index = max_size; + // We want to clip the string at [max_size]; however, the character at that position + // may span several bytes. We need to find the first byte of the character. In UTF-8 + // encoded data any byte that matches the bit pattern 10XXXXXX is not a start byte. + // Therefore we decrement the index as long as there are bytes matching this pattern. + // This ensures we cut the string at the border between one character and another. + while index > 0 && unsafe { (*p.offset(index as isize) & 0b11000000) == 0b10000000 } { + index -= 1; + } + &string[0..index] +} + +#[test] +fn limit_string_test() { + assert_eq!("hello", limit_string(&"hello world".to_owned(), 5)); + assert_eq!("hi ☹", limit_string(&"hi ☹☹".to_owned(), 6)); + assert_eq!("hi ", limit_string(&"hi ☹☹".to_owned(), 5)); +}