diff --git a/ruby/extractor/src/extractor.rs b/ruby/extractor/src/extractor.rs index 8cdaff1b738..6c462e11bb4 100644 --- a/ruby/extractor/src/extractor.rs +++ b/ruby/extractor/src/extractor.rs @@ -1,161 +1,112 @@ +use crate::trap; use node_types::{EntryKind, Field, NodeTypeMap, Storage, TypeName}; -use std::borrow::Cow; use std::collections::BTreeMap as Map; use std::collections::BTreeSet as Set; use std::fmt; -use std::io::Write; use std::path::Path; use tracing::{error, info, span, Level}; use tree_sitter::{Language, Node, Parser, Range, Tree}; -pub struct TrapWriter { - /// The accumulated trap entries - trap_output: Vec, - /// A counter for generating fresh labels - counter: u32, - /// cache of global keys - global_keys: std::collections::HashMap, +pub fn populate_file(writer: &mut trap::Writer, absolute_path: &Path) -> trap::Label { + let (file_label, fresh) = + writer.global_id(&trap::full_id_for_file(&normalize_path(absolute_path))); + if fresh { + writer.add_tuple( + "files", + vec![ + trap::Arg::Label(file_label), + trap::Arg::String(normalize_path(absolute_path)), + ], + ); + populate_parent_folders(writer, file_label, absolute_path.parent()); + } + file_label } -pub fn new_trap_writer() -> TrapWriter { - TrapWriter { - counter: 0, - trap_output: Vec::new(), - global_keys: std::collections::HashMap::new(), +fn populate_empty_file(writer: &mut trap::Writer) -> trap::Label { + let (file_label, fresh) = writer.global_id("empty;sourcefile"); + if fresh { + writer.add_tuple( + "files", + vec![ + trap::Arg::Label(file_label), + trap::Arg::String("".to_string()), + ], + ); } + file_label } -impl TrapWriter { - /// Gets a label that will hold the unique ID of the passed string at import time. - /// This can be used for incrementally importable TRAP files -- use globally unique - /// strings to compute a unique ID for table tuples. - /// - /// Note: You probably want to make sure that the key strings that you use are disjoint - /// for disjoint column types; the standard way of doing this is to prefix (or append) - /// the column type name to the ID. Thus, you might identify methods in Java by the - /// full ID "methods_com.method.package.DeclaringClass.method(argumentList)". +pub fn populate_empty_location(writer: &mut trap::Writer) { + let file_label = populate_empty_file(writer); + location(writer, file_label, 0, 0, 0, 0); +} - fn fresh_id(&mut self) -> Label { - let label = Label(self.counter); - self.counter += 1; - self.trap_output.push(TrapEntry::FreshId(label)); - label - } - - fn global_id(&mut self, key: &str) -> (Label, bool) { - if let Some(label) = self.global_keys.get(key) { - return (*label, false); - } - let label = Label(self.counter); - self.counter += 1; - self.global_keys.insert(key.to_owned(), label); - self.trap_output - .push(TrapEntry::MapLabelToKey(label, key.to_owned())); - (label, true) - } - - fn add_tuple(&mut self, table_name: &str, args: Vec) { - self.trap_output - .push(TrapEntry::GenericTuple(table_name.to_owned(), args)) - } - - fn populate_file(&mut self, absolute_path: &Path) -> Label { - let (file_label, fresh) = self.global_id(&full_id_for_file(absolute_path)); - if fresh { - self.add_tuple( - "files", - vec![ - Arg::Label(file_label), - Arg::String(normalize_path(absolute_path)), - ], - ); - self.populate_parent_folders(file_label, absolute_path.parent()); - } - file_label - } - - fn populate_empty_file(&mut self) -> Label { - let (file_label, fresh) = self.global_id("empty;sourcefile"); - if fresh { - self.add_tuple( - "files", - vec![Arg::Label(file_label), Arg::String("".to_string())], - ); - } - file_label - } - - pub fn populate_empty_location(&mut self) { - let file_label = self.populate_empty_file(); - self.location(file_label, 0, 0, 0, 0); - } - - fn populate_parent_folders(&mut self, child_label: Label, path: Option<&Path>) { - let mut path = path; - let mut child_label = child_label; - loop { - match path { - None => break, - Some(folder) => { - let (folder_label, fresh) = self.global_id(&full_id_for_folder(folder)); - self.add_tuple( - "containerparent", - vec![Arg::Label(folder_label), Arg::Label(child_label)], +pub fn populate_parent_folders( + writer: &mut trap::Writer, + child_label: trap::Label, + path: Option<&Path>, +) { + let mut path = path; + let mut child_label = child_label; + loop { + match path { + None => break, + Some(folder) => { + let (folder_label, fresh) = + writer.global_id(&trap::full_id_for_folder(&normalize_path(folder))); + writer.add_tuple( + "containerparent", + vec![ + trap::Arg::Label(folder_label), + trap::Arg::Label(child_label), + ], + ); + if fresh { + writer.add_tuple( + "folders", + vec![ + trap::Arg::Label(folder_label), + trap::Arg::String(normalize_path(folder)), + ], ); - if fresh { - self.add_tuple( - "folders", - vec![ - Arg::Label(folder_label), - Arg::String(normalize_path(folder)), - ], - ); - path = folder.parent(); - child_label = folder_label; - } else { - break; - } + path = folder.parent(); + child_label = folder_label; + } else { + break; } } } } +} - fn location( - &mut self, - file_label: Label, - start_line: usize, - start_column: usize, - end_line: usize, - end_column: usize, - ) -> Label { - let (loc_label, fresh) = self.global_id(&format!( - "loc,{{{}}},{},{},{},{}", - file_label, start_line, start_column, end_line, end_column - )); - if fresh { - self.add_tuple( - "locations_default", - vec![ - Arg::Label(loc_label), - Arg::Label(file_label), - Arg::Int(start_line), - Arg::Int(start_column), - Arg::Int(end_line), - Arg::Int(end_column), - ], - ); - } - loc_label - } - - fn comment(&mut self, text: String) { - self.trap_output.push(TrapEntry::Comment(text)); - } - - pub fn output(self, writer: &mut dyn Write) -> std::io::Result<()> { - write!(writer, "{}", Program(self.trap_output)) +fn location( + writer: &mut trap::Writer, + file_label: trap::Label, + start_line: usize, + start_column: usize, + end_line: usize, + end_column: usize, +) -> trap::Label { + let (loc_label, fresh) = writer.global_id(&format!( + "loc,{{{}}},{},{},{},{}", + file_label, start_line, start_column, end_line, end_column + )); + if fresh { + writer.add_tuple( + "locations_default", + vec![ + trap::Arg::Label(loc_label), + trap::Arg::Label(file_label), + trap::Arg::Int(start_line), + trap::Arg::Int(start_column), + trap::Arg::Int(end_line), + trap::Arg::Int(end_column), + ], + ); } + loc_label } /// Extracts the source file at `path`, which is assumed to be canonicalized. @@ -163,7 +114,7 @@ pub fn extract( language: Language, language_prefix: &str, schema: &NodeTypeMap, - trap_writer: &mut TrapWriter, + trap_writer: &mut trap::Writer, path: &Path, source: &[u8], ranges: &[Range], @@ -183,13 +134,13 @@ pub fn extract( parser.set_included_ranges(ranges).unwrap(); let tree = parser.parse(&source, None).expect("Failed to parse file"); trap_writer.comment(format!("Auto-generated TRAP file for {}", path.display())); - let file_label = &trap_writer.populate_file(path); + let file_label = populate_file(trap_writer, path); let mut visitor = Visitor { source, trap_writer, // TODO: should we handle path strings that are not valid UTF8 better? path: format!("{}", path.display()), - file_label: *file_label, + file_label, toplevel_child_counter: 0, stack: Vec::new(), language_prefix, @@ -201,33 +152,6 @@ pub fn extract( Ok(()) } -/// Escapes a string for use in a TRAP key, by replacing special characters with -/// HTML entities. -fn escape_key<'a, S: Into>>(key: S) -> Cow<'a, str> { - fn needs_escaping(c: char) -> bool { - matches!(c, '&' | '{' | '}' | '"' | '@' | '#') - } - - let key = key.into(); - if key.contains(needs_escaping) { - let mut escaped = String::with_capacity(2 * key.len()); - for c in key.chars() { - match c { - '&' => escaped.push_str("&"), - '{' => escaped.push_str("{"), - '}' => escaped.push_str("}"), - '"' => escaped.push_str("""), - '@' => escaped.push_str("@"), - '#' => escaped.push_str("#"), - _ => escaped.push(c), - } - } - Cow::Owned(escaped) - } else { - key - } -} - /// Normalizes the path according the common CodeQL specification. Assumes that /// `path` has already been canonicalized using `std::fs::canonicalize`. fn normalize_path(path: &Path) -> String { @@ -267,17 +191,9 @@ fn normalize_path(path: &Path) -> String { } } -fn full_id_for_file(path: &Path) -> String { - format!("{};sourcefile", escape_key(&normalize_path(path))) -} - -fn full_id_for_folder(path: &Path) -> String { - format!("{};folder", escape_key(&normalize_path(path))) -} - struct ChildNode { field_name: Option<&'static str>, - label: Label, + label: trap::Label, type_name: TypeName, } @@ -286,11 +202,11 @@ struct Visitor<'a> { path: String, /// The label to use whenever we need to refer to the `@file` entity of this /// source file. - file_label: Label, + file_label: trap::Label, /// The source code as a UTF-8 byte array source: &'a [u8], - /// A TrapWriter to accumulate trap entries - trap_writer: &'a mut TrapWriter, + /// A trap::Writer to accumulate trap entries + trap_writer: &'a mut trap::Writer, /// A counter for top-level child nodes toplevel_child_counter: usize, /// Language prefix @@ -303,7 +219,7 @@ struct Visitor<'a> { /// node the list containing the child data is popped from the stack and /// matched against the dbscheme for the node. If the expectations are met /// the corresponding row definitions are added to the trap_output. - stack: Vec<(Label, usize, Vec)>, + stack: Vec<(trap::Label, usize, Vec)>, } impl Visitor<'_> { @@ -311,19 +227,19 @@ impl Visitor<'_> { &mut self, error_message: String, full_error_message: String, - loc: Label, + loc: trap::Label, ) { error!("{}", full_error_message); let id = self.trap_writer.fresh_id(); self.trap_writer.add_tuple( "diagnostics", vec![ - Arg::Label(id), - Arg::Int(40), // severity 40 = error - Arg::String("parse_error".to_string()), - Arg::String(error_message), - Arg::String(full_error_message), - Arg::Label(loc), + trap::Arg::Label(id), + trap::Arg::Int(40), // severity 40 = error + trap::Arg::String("parse_error".to_string()), + trap::Arg::String(error_message), + trap::Arg::String(full_error_message), + trap::Arg::Label(loc), ], ); } @@ -335,7 +251,8 @@ impl Visitor<'_> { node: Node, ) { let (start_line, start_column, end_line, end_column) = location_for(self.source, node); - let loc = self.trap_writer.location( + let loc = location( + self.trap_writer, self.file_label, start_line, start_column, @@ -374,7 +291,8 @@ impl Visitor<'_> { } let (id, _, child_nodes) = self.stack.pop().expect("Vistor: empty stack"); let (start_line, start_column, end_line, end_column) = location_for(self.source, node); - let loc = self.trap_writer.location( + let loc = location( + self.trap_writer, self.file_label, start_line, start_column, @@ -404,18 +322,19 @@ impl Visitor<'_> { self.trap_writer.add_tuple( &format!("{}_ast_node_info", self.language_prefix), vec![ - Arg::Label(id), - Arg::Label(parent_id), - Arg::Int(parent_index), - Arg::Label(loc), + trap::Arg::Label(id), + trap::Arg::Label(parent_id), + trap::Arg::Int(parent_index), + trap::Arg::Label(loc), ], ); self.trap_writer.add_tuple( &format!("{}_tokeninfo", self.language_prefix), vec![ - Arg::Label(id), - Arg::Int(*kind_id), + trap::Arg::Label(id), + trap::Arg::Int(*kind_id), sliced_source_arg(self.source, node), + trap::Arg::Label(loc), ], ); } @@ -427,13 +346,13 @@ impl Visitor<'_> { self.trap_writer.add_tuple( &format!("{}_ast_node_info", self.language_prefix), vec![ - Arg::Label(id), - Arg::Label(parent_id), - Arg::Int(parent_index), - Arg::Label(loc), + trap::Arg::Label(id), + trap::Arg::Label(parent_id), + trap::Arg::Int(parent_index), + trap::Arg::Label(loc), ], ); - let mut all_args = vec![Arg::Label(id)]; + let mut all_args = vec![trap::Arg::Label(id)]; all_args.extend(args); self.trap_writer.add_tuple(table_name, all_args); } @@ -472,9 +391,9 @@ impl Visitor<'_> { node: &Node, fields: &[Field], child_nodes: &[ChildNode], - parent_id: Label, - ) -> Option> { - let mut map: Map<&Option, (&Field, Vec)> = Map::new(); + parent_id: trap::Label, + ) -> Option> { + let mut map: Map<&Option, (&Field, Vec)> = Map::new(); for field in fields { map.insert(&field.name, (field, Vec::new())); } @@ -488,9 +407,9 @@ impl Visitor<'_> { { // We can safely unwrap because type_matches checks the key is in the map. let (int_value, _) = int_mapping.get(&child_node.type_name.kind).unwrap(); - values.push(Arg::Int(*int_value)); + values.push(trap::Arg::Int(*int_value)); } else { - values.push(Arg::Label(child_node.label)); + values.push(trap::Arg::Label(child_node.label)); } } else if field.name.is_some() { let error_message = format!( @@ -569,9 +488,9 @@ impl Visitor<'_> { ); break; } - let mut args = vec![Arg::Label(parent_id)]; + let mut args = vec![trap::Arg::Label(parent_id)]; if *has_index { - args.push(Arg::Int(index)) + args.push(trap::Arg::Int(index)) } args.push(child_value.clone()); self.trap_writer.add_tuple(table_name, args); @@ -625,9 +544,9 @@ impl Visitor<'_> { } // Emit a slice of a source file as an Arg. -fn sliced_source_arg(source: &[u8], n: Node) -> Arg { +fn sliced_source_arg(source: &[u8], n: Node) -> trap::Arg { let range = n.byte_range(); - Arg::String(String::from_utf8_lossy(&source[range.start..range.end]).into_owned()) + trap::Arg::String(String::from_utf8_lossy(&source[range.start..range.end]).into_owned()) } // Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated. @@ -699,59 +618,6 @@ fn traverse(tree: &Tree, visitor: &mut Visitor) { } } -pub struct Program(Vec); - -impl fmt::Display for Program { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut text = String::new(); - for trap_entry in &self.0 { - text.push_str(&format!("{}\n", trap_entry)); - } - write!(f, "{}", text) - } -} - -enum TrapEntry { - /// Maps the label to a fresh id, e.g. `#123=*`. - FreshId(Label), - /// Maps the label to a key, e.g. `#7=@"foo"`. - MapLabelToKey(Label, String), - /// foo_bar(arg*) - GenericTuple(String, Vec), - Comment(String), -} -impl fmt::Display for TrapEntry { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - TrapEntry::FreshId(label) => write!(f, "{}=*", label), - TrapEntry::MapLabelToKey(label, key) => { - write!(f, "{}=@\"{}\"", label, key.replace("\"", "\"\"")) - } - TrapEntry::GenericTuple(name, args) => { - write!(f, "{}(", name)?; - for (index, arg) in args.iter().enumerate() { - if index > 0 { - write!(f, ",")?; - } - write!(f, "{}", arg)?; - } - write!(f, ")") - } - TrapEntry::Comment(line) => write!(f, "// {}", line), - } - } -} - -#[derive(Debug, Copy, Clone)] -// Identifiers of the form #0, #1... -struct Label(u32); - -impl fmt::Display for Label { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "#{:x}", self.0) - } -} - // Numeric indices. #[derive(Debug, Copy, Clone)] struct Index(usize); @@ -761,69 +627,3 @@ impl fmt::Display for Index { write!(f, "{}", self.0) } } - -// Some untyped argument to a TrapEntry. -#[derive(Debug, Clone)] -enum Arg { - Label(Label), - Int(usize), - String(String), -} - -const MAX_STRLEN: usize = 1048576; - -impl fmt::Display for Arg { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Arg::Label(x) => write!(f, "{}", x), - Arg::Int(x) => write!(f, "{}", x), - Arg::String(x) => write!( - f, - "\"{}\"", - limit_string(x, MAX_STRLEN).replace("\"", "\"\"") - ), - } - } -} - -/// Limit the length (in bytes) of a string. If the string's length in bytes is -/// less than or equal to the limit then the entire string is returned. Otherwise -/// the string is sliced at the provided limit. If there is a multi-byte character -/// at the limit then the returned slice will be slightly shorter than the limit to -/// avoid splitting that multi-byte character. -fn limit_string(string: &str, max_size: usize) -> &str { - if string.len() <= max_size { - return string; - } - let p = string.as_bytes(); - let mut index = max_size; - // We want to clip the string at [max_size]; however, the character at that position - // may span several bytes. We need to find the first byte of the character. In UTF-8 - // encoded data any byte that matches the bit pattern 10XXXXXX is not a start byte. - // Therefore we decrement the index as long as there are bytes matching this pattern. - // This ensures we cut the string at the border between one character and another. - while index > 0 && (p[index] & 0b11000000) == 0b10000000 { - index -= 1; - } - &string[0..index] -} - -#[test] -fn limit_string_test() { - assert_eq!("hello", limit_string(&"hello world".to_owned(), 5)); - assert_eq!("hi ☹", limit_string(&"hi ☹☹".to_owned(), 6)); - assert_eq!("hi ", limit_string(&"hi ☹☹".to_owned(), 5)); -} - -#[test] -fn escape_key_test() { - assert_eq!("foo!", escape_key("foo!")); - assert_eq!("foo{}", escape_key("foo{}")); - assert_eq!("{}", escape_key("{}")); - assert_eq!("", escape_key("")); - assert_eq!("/path/to/foo.rb", escape_key("/path/to/foo.rb")); - assert_eq!( - "/path/to/foo&{}"@#.rb", - escape_key("/path/to/foo&{}\"@#.rb") - ); -} diff --git a/ruby/extractor/src/main.rs b/ruby/extractor/src/main.rs index 7e4aa973518..4fc886a50ad 100644 --- a/ruby/extractor/src/main.rs +++ b/ruby/extractor/src/main.rs @@ -1,51 +1,15 @@ mod extractor; +mod trap; extern crate num_cpus; use clap::arg; -use flate2::write::GzEncoder; use rayon::prelude::*; use std::fs; -use std::io::{BufRead, BufWriter}; +use std::io::BufRead; use std::path::{Path, PathBuf}; use tree_sitter::{Language, Parser, Range}; -enum TrapCompression { - None, - Gzip, -} - -impl TrapCompression { - fn from_env() -> TrapCompression { - match std::env::var("CODEQL_RUBY_TRAP_COMPRESSION") { - Ok(method) => match TrapCompression::from_string(&method) { - Some(c) => c, - None => { - tracing::error!("Unknown compression method '{}'; using gzip.", &method); - TrapCompression::Gzip - } - }, - // Default compression method if the env var isn't set: - Err(_) => TrapCompression::Gzip, - } - } - - fn from_string(s: &str) -> Option { - match s.to_lowercase().as_ref() { - "none" => Some(TrapCompression::None), - "gzip" => Some(TrapCompression::Gzip), - _ => None, - } - } - - fn extension(&self) -> &str { - match self { - TrapCompression::None => "trap", - TrapCompression::Gzip => "trap.gz", - } - } -} - /** * Gets the number of threads the extractor should use, by reading the * CODEQL_THREADS environment variable and using it as described in the @@ -118,7 +82,7 @@ fn main() -> std::io::Result<()> { .value_of("output-dir") .expect("missing --output-dir"); let trap_dir = PathBuf::from(trap_dir); - let trap_compression = TrapCompression::from_env(); + let trap_compression = trap::Compression::from_env("CODEQL_RUBY_TRAP_COMPRESSION"); let file_list = matches.value_of("file-list").expect("missing --file-list"); let file_list = fs::File::open(file_list)?; @@ -141,7 +105,7 @@ fn main() -> std::io::Result<()> { let src_archive_file = path_for(&src_archive_dir, &path, ""); let mut source = std::fs::read(&path)?; let code_ranges; - let mut trap_writer = extractor::new_trap_writer(); + let mut trap_writer = trap::Writer::new(); if path.extension().map_or(false, |x| x == "erb") { tracing::info!("scanning: {}", path.display()); extractor::extract( @@ -181,33 +145,25 @@ fn main() -> std::io::Result<()> { )?; std::fs::create_dir_all(&src_archive_file.parent().unwrap())?; std::fs::copy(&path, &src_archive_file)?; - write_trap(&trap_dir, path, trap_writer, &trap_compression) + write_trap(&trap_dir, path, &trap_writer, trap_compression) }) .expect("failed to extract files"); let path = PathBuf::from("extras"); - let mut trap_writer = extractor::new_trap_writer(); - trap_writer.populate_empty_location(); - write_trap(&trap_dir, path, trap_writer, &trap_compression) + let mut trap_writer = trap::Writer::new(); + extractor::populate_empty_location(&mut trap_writer); + write_trap(&trap_dir, path, &trap_writer, trap_compression) } fn write_trap( trap_dir: &Path, path: PathBuf, - trap_writer: extractor::TrapWriter, - trap_compression: &TrapCompression, + trap_writer: &trap::Writer, + trap_compression: trap::Compression, ) -> std::io::Result<()> { let trap_file = path_for(trap_dir, &path, trap_compression.extension()); std::fs::create_dir_all(&trap_file.parent().unwrap())?; - let trap_file = std::fs::File::create(&trap_file)?; - let mut trap_file = BufWriter::new(trap_file); - match trap_compression { - TrapCompression::None => trap_writer.output(&mut trap_file), - TrapCompression::Gzip => { - let mut compressed_writer = GzEncoder::new(trap_file, flate2::Compression::fast()); - trap_writer.output(&mut compressed_writer) - } - } + trap_writer.write_to_file(&trap_file, trap_compression) } fn scan_erb( diff --git a/ruby/extractor/src/trap.rs b/ruby/extractor/src/trap.rs new file mode 100644 index 00000000000..d64c520c4cc --- /dev/null +++ b/ruby/extractor/src/trap.rs @@ -0,0 +1,272 @@ +use std::borrow::Cow; +use std::fmt; +use std::io::{BufWriter, Write}; +use std::path::Path; + +use flate2::write::GzEncoder; + +pub struct Writer { + /// The accumulated trap entries + trap_output: Vec, + /// A counter for generating fresh labels + counter: u32, + /// cache of global keys + global_keys: std::collections::HashMap, +} + +impl Writer { + pub fn new() -> Writer { + Writer { + counter: 0, + trap_output: Vec::new(), + global_keys: std::collections::HashMap::new(), + } + } + + pub fn fresh_id(&mut self) -> Label { + let label = Label(self.counter); + self.counter += 1; + self.trap_output.push(Entry::FreshId(label)); + label + } + + /// Gets a label that will hold the unique ID of the passed string at import time. + /// This can be used for incrementally importable TRAP files -- use globally unique + /// strings to compute a unique ID for table tuples. + /// + /// Note: You probably want to make sure that the key strings that you use are disjoint + /// for disjoint column types; the standard way of doing this is to prefix (or append) + /// the column type name to the ID. Thus, you might identify methods in Java by the + /// full ID "methods_com.method.package.DeclaringClass.method(argumentList)". + pub fn global_id(&mut self, key: &str) -> (Label, bool) { + if let Some(label) = self.global_keys.get(key) { + return (*label, false); + } + let label = Label(self.counter); + self.counter += 1; + self.global_keys.insert(key.to_owned(), label); + self.trap_output + .push(Entry::MapLabelToKey(label, key.to_owned())); + (label, true) + } + + pub fn add_tuple(&mut self, table_name: &str, args: Vec) { + self.trap_output + .push(Entry::GenericTuple(table_name.to_owned(), args)) + } + + pub fn comment(&mut self, text: String) { + self.trap_output.push(Entry::Comment(text)); + } + + pub fn write_to_file(&self, path: &Path, compression: Compression) -> std::io::Result<()> { + let trap_file = std::fs::File::create(path)?; + let mut trap_file = BufWriter::new(trap_file); + match compression { + Compression::None => self.write_trap_entries(&mut trap_file), + Compression::Gzip => { + let mut compressed_writer = GzEncoder::new(trap_file, flate2::Compression::fast()); + self.write_trap_entries(&mut compressed_writer) + } + } + } + + fn write_trap_entries(&self, file: &mut W) -> std::io::Result<()> { + for trap_entry in &self.trap_output { + writeln!(file, "{}", trap_entry)?; + } + std::io::Result::Ok(()) + } +} + +pub enum Entry { + /// Maps the label to a fresh id, e.g. `#123=*`. + FreshId(Label), + /// Maps the label to a key, e.g. `#7=@"foo"`. + MapLabelToKey(Label, String), + /// foo_bar(arg*) + GenericTuple(String, Vec), + Comment(String), +} + +impl fmt::Display for Entry { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Entry::FreshId(label) => write!(f, "{}=*", label), + Entry::MapLabelToKey(label, key) => { + write!(f, "{}=@\"{}\"", label, key.replace("\"", "\"\"")) + } + Entry::GenericTuple(name, args) => { + write!(f, "{}(", name)?; + for (index, arg) in args.iter().enumerate() { + if index > 0 { + write!(f, ",")?; + } + write!(f, "{}", arg)?; + } + write!(f, ")") + } + Entry::Comment(line) => write!(f, "// {}", line), + } + } +} + +#[derive(Debug, Copy, Clone)] +// Identifiers of the form #0, #1... +pub struct Label(u32); + +impl fmt::Display for Label { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "#{:x}", self.0) + } +} + +// Some untyped argument to a TrapEntry. +#[derive(Debug, Clone)] +pub enum Arg { + Label(Label), + Int(usize), + String(String), +} + +const MAX_STRLEN: usize = 1048576; + +impl fmt::Display for Arg { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Arg::Label(x) => write!(f, "{}", x), + Arg::Int(x) => write!(f, "{}", x), + Arg::String(x) => write!( + f, + "\"{}\"", + limit_string(x, MAX_STRLEN).replace("\"", "\"\"") + ), + } + } +} + +pub struct Program(Vec); + +impl fmt::Display for Program { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut text = String::new(); + for trap_entry in &self.0 { + text.push_str(&format!("{}\n", trap_entry)); + } + write!(f, "{}", text) + } +} + +pub fn full_id_for_file(normalized_path: &str) -> String { + format!("{};sourcefile", escape_key(normalized_path)) +} + +pub fn full_id_for_folder(normalized_path: &str) -> String { + format!("{};folder", escape_key(normalized_path)) +} + +/// Escapes a string for use in a TRAP key, by replacing special characters with +/// HTML entities. +fn escape_key<'a, S: Into>>(key: S) -> Cow<'a, str> { + fn needs_escaping(c: char) -> bool { + matches!(c, '&' | '{' | '}' | '"' | '@' | '#') + } + + let key = key.into(); + if key.contains(needs_escaping) { + let mut escaped = String::with_capacity(2 * key.len()); + for c in key.chars() { + match c { + '&' => escaped.push_str("&"), + '{' => escaped.push_str("{"), + '}' => escaped.push_str("}"), + '"' => escaped.push_str("""), + '@' => escaped.push_str("@"), + '#' => escaped.push_str("#"), + _ => escaped.push(c), + } + } + Cow::Owned(escaped) + } else { + key + } +} + +/// Limit the length (in bytes) of a string. If the string's length in bytes is +/// less than or equal to the limit then the entire string is returned. Otherwise +/// the string is sliced at the provided limit. If there is a multi-byte character +/// at the limit then the returned slice will be slightly shorter than the limit to +/// avoid splitting that multi-byte character. +fn limit_string(string: &str, max_size: usize) -> &str { + if string.len() <= max_size { + return string; + } + let p = string.as_bytes(); + let mut index = max_size; + // We want to clip the string at [max_size]; however, the character at that position + // may span several bytes. We need to find the first byte of the character. In UTF-8 + // encoded data any byte that matches the bit pattern 10XXXXXX is not a start byte. + // Therefore we decrement the index as long as there are bytes matching this pattern. + // This ensures we cut the string at the border between one character and another. + while index > 0 && (p[index] & 0b11000000) == 0b10000000 { + index -= 1; + } + &string[0..index] +} + +#[derive(Clone, Copy)] +pub enum Compression { + None, + Gzip, +} + +impl Compression { + pub fn from_env(var_name: &str) -> Compression { + match std::env::var(var_name) { + Ok(method) => match Compression::from_string(&method) { + Some(c) => c, + None => { + tracing::error!("Unknown compression method '{}'; using gzip.", &method); + Compression::Gzip + } + }, + // Default compression method if the env var isn't set: + Err(_) => Compression::Gzip, + } + } + + pub fn from_string(s: &str) -> Option { + match s.to_lowercase().as_ref() { + "none" => Some(Compression::None), + "gzip" => Some(Compression::Gzip), + _ => None, + } + } + + pub fn extension(&self) -> &str { + match self { + Compression::None => "trap", + Compression::Gzip => "trap.gz", + } + } +} + +#[test] +fn limit_string_test() { + assert_eq!("hello", limit_string(&"hello world".to_owned(), 5)); + assert_eq!("hi ☹", limit_string(&"hi ☹☹".to_owned(), 6)); + assert_eq!("hi ", limit_string(&"hi ☹☹".to_owned(), 5)); +} + +#[test] +fn escape_key_test() { + assert_eq!("foo!", escape_key("foo!")); + assert_eq!("foo{}", escape_key("foo{}")); + assert_eq!("{}", escape_key("{}")); + assert_eq!("", escape_key("")); + assert_eq!("/path/to/foo.rb", escape_key("/path/to/foo.rb")); + assert_eq!( + "/path/to/foo&{}"@#.rb", + escape_key("/path/to/foo&{}\"@#.rb") + ); +}