QL: Use shared extractor library

This commit is contained in:
Harry Maclean
2023-03-16 15:05:13 +13:00
parent c90299baee
commit 6171eae7a2
14 changed files with 48 additions and 2430 deletions

View File

@@ -7,8 +7,6 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
flate2 = "1.0"
node-types = { path = "../node-types" }
tree-sitter = ">= 0.20, < 0.21"
tree-sitter-ql = { git = "https://github.com/tree-sitter/tree-sitter-ql.git", rev = "d08db734f8dc52f6bc04db53a966603122bc6985"}
tree-sitter-ql-dbscheme = { git = "https://github.com/erik-krogh/tree-sitter-ql-dbscheme.git", rev = "63e1344353f63931e88bfbc2faa2e78e1421b213"}
@@ -19,5 +17,5 @@ clap = "2.33"
tracing = "0.1"
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
rayon = "1.7.0"
num_cpus = "1.14.0"
regex = "1.7.2"
codeql-extractor = { path = "../../shared/extractor" }

View File

@@ -1,650 +0,0 @@
use crate::trap;
use node_types::{EntryKind, Field, NodeTypeMap, Storage, TypeName};
use std::collections::BTreeMap as Map;
use std::collections::BTreeSet as Set;
use std::fmt;
use std::path::Path;
use tracing::{error, info, span, Level};
use tree_sitter::{Language, Node, Parser, Range, Tree};
pub fn populate_file(writer: &mut trap::Writer, absolute_path: &Path) -> trap::Label {
let (file_label, fresh) =
writer.global_id(&trap::full_id_for_file(&normalize_path(absolute_path)));
if fresh {
writer.add_tuple(
"files",
vec![
trap::Arg::Label(file_label),
trap::Arg::String(normalize_path(absolute_path)),
],
);
populate_parent_folders(writer, file_label, absolute_path.parent());
}
file_label
}
fn populate_empty_file(writer: &mut trap::Writer) -> trap::Label {
let (file_label, fresh) = writer.global_id("empty;sourcefile");
if fresh {
writer.add_tuple(
"files",
vec![
trap::Arg::Label(file_label),
trap::Arg::String("".to_string()),
],
);
}
file_label
}
pub fn populate_empty_location(writer: &mut trap::Writer) {
let file_label = populate_empty_file(writer);
location(writer, file_label, 0, 0, 0, 0);
}
pub fn populate_parent_folders(
writer: &mut trap::Writer,
child_label: trap::Label,
path: Option<&Path>,
) {
let mut path = path;
let mut child_label = child_label;
loop {
match path {
None => break,
Some(folder) => {
let (folder_label, fresh) =
writer.global_id(&trap::full_id_for_folder(&normalize_path(folder)));
writer.add_tuple(
"containerparent",
vec![
trap::Arg::Label(folder_label),
trap::Arg::Label(child_label),
],
);
if fresh {
writer.add_tuple(
"folders",
vec![
trap::Arg::Label(folder_label),
trap::Arg::String(normalize_path(folder)),
],
);
path = folder.parent();
child_label = folder_label;
} else {
break;
}
}
}
}
}
fn location(
writer: &mut trap::Writer,
file_label: trap::Label,
start_line: usize,
start_column: usize,
end_line: usize,
end_column: usize,
) -> trap::Label {
let (loc_label, fresh) = writer.global_id(&format!(
"loc,{{{}}},{},{},{},{}",
file_label, start_line, start_column, end_line, end_column
));
if fresh {
writer.add_tuple(
"locations_default",
vec![
trap::Arg::Label(loc_label),
trap::Arg::Label(file_label),
trap::Arg::Int(start_line),
trap::Arg::Int(start_column),
trap::Arg::Int(end_line),
trap::Arg::Int(end_column),
],
);
}
loc_label
}
/// Extracts the source file at `path`, which is assumed to be canonicalized.
pub fn extract(
language: Language,
language_prefix: &str,
schema: &NodeTypeMap,
trap_writer: &mut trap::Writer,
path: &Path,
source: &[u8],
ranges: &[Range],
) -> std::io::Result<()> {
let path_str = format!("{}", path.display());
let span = span!(
Level::TRACE,
"extract",
file = %path_str
);
let _enter = span.enter();
info!("extracting: {}", path_str);
let mut parser = Parser::new();
parser.set_language(language).unwrap();
parser.set_included_ranges(ranges).unwrap();
let tree = parser.parse(&source, None).expect("Failed to parse file");
trap_writer.comment(format!("Auto-generated TRAP file for {}", path_str));
let file_label = populate_file(trap_writer, path);
let mut visitor = Visitor::new(
source,
trap_writer,
// TODO: should we handle path strings that are not valid UTF8 better?
&path_str,
file_label,
language_prefix,
schema,
);
traverse(&tree, &mut visitor);
parser.reset();
Ok(())
}
/// Normalizes the path according the common CodeQL specification. Assumes that
/// `path` has already been canonicalized using `std::fs::canonicalize`.
fn normalize_path(path: &Path) -> String {
if cfg!(windows) {
// The way Rust canonicalizes paths doesn't match the CodeQL spec, so we
// have to do a bit of work removing certain prefixes and replacing
// backslashes.
let mut components: Vec<String> = Vec::new();
for component in path.components() {
match component {
std::path::Component::Prefix(prefix) => match prefix.kind() {
std::path::Prefix::Disk(letter) | std::path::Prefix::VerbatimDisk(letter) => {
components.push(format!("{}:", letter as char));
}
std::path::Prefix::Verbatim(x) | std::path::Prefix::DeviceNS(x) => {
components.push(x.to_string_lossy().to_string());
}
std::path::Prefix::UNC(server, share)
| std::path::Prefix::VerbatimUNC(server, share) => {
components.push(server.to_string_lossy().to_string());
components.push(share.to_string_lossy().to_string());
}
},
std::path::Component::Normal(n) => {
components.push(n.to_string_lossy().to_string());
}
std::path::Component::RootDir => {}
std::path::Component::CurDir => {}
std::path::Component::ParentDir => {}
}
}
components.join("/")
} else {
// For other operating systems, we can use the canonicalized path
// without modifications.
format!("{}", path.display())
}
}
struct ChildNode {
field_name: Option<&'static str>,
label: trap::Label,
type_name: TypeName,
}
struct Visitor<'a> {
/// The file path of the source code (as string)
path: &'a str,
/// The label to use whenever we need to refer to the `@file` entity of this
/// source file.
file_label: trap::Label,
/// The source code as a UTF-8 byte array
source: &'a [u8],
/// A trap::Writer to accumulate trap entries
trap_writer: &'a mut trap::Writer,
/// A counter for top-level child nodes
toplevel_child_counter: usize,
/// Language-specific name of the AST info table
ast_node_info_table_name: String,
/// Language-specific name of the tokeninfo table
tokeninfo_table_name: String,
/// A lookup table from type name to node types
schema: &'a NodeTypeMap,
/// A stack for gathering information from child nodes. Whenever a node is
/// entered the parent's [Label], child counter, and an empty list is pushed.
/// All children append their data to the list. When the visitor leaves a
/// node the list containing the child data is popped from the stack and
/// matched against the dbscheme for the node. If the expectations are met
/// the corresponding row definitions are added to the trap_output.
stack: Vec<(trap::Label, usize, Vec<ChildNode>)>,
}
impl<'a> Visitor<'a> {
fn new(
source: &'a [u8],
trap_writer: &'a mut trap::Writer,
path: &'a str,
file_label: trap::Label,
language_prefix: &str,
schema: &'a NodeTypeMap,
) -> Visitor<'a> {
Visitor {
path,
file_label,
source,
trap_writer,
toplevel_child_counter: 0,
ast_node_info_table_name: format!("{}_ast_node_info", language_prefix),
tokeninfo_table_name: format!("{}_tokeninfo", language_prefix),
schema,
stack: Vec::new(),
}
}
fn record_parse_error(
&mut self,
error_message: String,
full_error_message: String,
loc: trap::Label,
) {
error!("{}", full_error_message);
let id = self.trap_writer.fresh_id();
self.trap_writer.add_tuple(
"diagnostics",
vec![
trap::Arg::Label(id),
trap::Arg::Int(40), // severity 40 = error
trap::Arg::String("parse_error".to_string()),
trap::Arg::String(error_message),
trap::Arg::String(full_error_message),
trap::Arg::Label(loc),
],
);
}
fn record_parse_error_for_node(
&mut self,
error_message: String,
full_error_message: String,
node: Node,
) {
let (start_line, start_column, end_line, end_column) = location_for(self.source, node);
let loc = location(
self.trap_writer,
self.file_label,
start_line,
start_column,
end_line,
end_column,
);
self.record_parse_error(error_message, full_error_message, loc);
}
fn enter_node(&mut self, node: Node) -> bool {
if node.is_error() || node.is_missing() {
let error_message = if node.is_missing() {
format!("parse error: expecting '{}'", node.kind())
} else {
"parse error".to_string()
};
let full_error_message = format!(
"{}:{}: {}",
&self.path,
node.start_position().row + 1,
error_message
);
self.record_parse_error_for_node(error_message, full_error_message, node);
return false;
}
let id = self.trap_writer.fresh_id();
self.stack.push((id, 0, Vec::new()));
true
}
fn leave_node(&mut self, field_name: Option<&'static str>, node: Node) {
if node.is_error() || node.is_missing() {
return;
}
let (id, _, child_nodes) = self.stack.pop().expect("Vistor: empty stack");
let (start_line, start_column, end_line, end_column) = location_for(self.source, node);
let loc = location(
self.trap_writer,
self.file_label,
start_line,
start_column,
end_line,
end_column,
);
let table = self
.schema
.get(&TypeName {
kind: node.kind().to_owned(),
named: node.is_named(),
})
.unwrap();
let mut valid = true;
let (parent_id, parent_index) = match self.stack.last_mut() {
Some(p) if !node.is_extra() => {
p.1 += 1;
(p.0, p.1 - 1)
}
_ => {
self.toplevel_child_counter += 1;
(self.file_label, self.toplevel_child_counter - 1)
}
};
match &table.kind {
EntryKind::Token { kind_id, .. } => {
self.trap_writer.add_tuple(
&self.ast_node_info_table_name,
vec![
trap::Arg::Label(id),
trap::Arg::Label(parent_id),
trap::Arg::Int(parent_index),
trap::Arg::Label(loc),
],
);
self.trap_writer.add_tuple(
&self.tokeninfo_table_name,
vec![
trap::Arg::Label(id),
trap::Arg::Int(*kind_id),
sliced_source_arg(self.source, node),
],
);
}
EntryKind::Table {
fields,
name: table_name,
} => {
if let Some(args) = self.complex_node(&node, fields, &child_nodes, id) {
self.trap_writer.add_tuple(
&self.ast_node_info_table_name,
vec![
trap::Arg::Label(id),
trap::Arg::Label(parent_id),
trap::Arg::Int(parent_index),
trap::Arg::Label(loc),
],
);
let mut all_args = vec![trap::Arg::Label(id)];
all_args.extend(args);
self.trap_writer.add_tuple(table_name, all_args);
}
}
_ => {
let error_message = format!("unknown table type: '{}'", node.kind());
let full_error_message = format!(
"{}:{}: {}",
&self.path,
node.start_position().row + 1,
error_message
);
self.record_parse_error(error_message, full_error_message, loc);
valid = false;
}
}
if valid && !node.is_extra() {
// Extra nodes are independent root nodes and do not belong to the parent node
// Therefore we should not register them in the parent vector
if let Some(parent) = self.stack.last_mut() {
parent.2.push(ChildNode {
field_name,
label: id,
type_name: TypeName {
kind: node.kind().to_owned(),
named: node.is_named(),
},
});
};
}
}
fn complex_node(
&mut self,
node: &Node,
fields: &[Field],
child_nodes: &[ChildNode],
parent_id: trap::Label,
) -> Option<Vec<trap::Arg>> {
let mut map: Map<&Option<String>, (&Field, Vec<trap::Arg>)> = Map::new();
for field in fields {
map.insert(&field.name, (field, Vec::new()));
}
for child_node in child_nodes {
if let Some((field, values)) = map.get_mut(&child_node.field_name.map(|x| x.to_owned()))
{
//TODO: handle error and missing nodes
if self.type_matches(&child_node.type_name, &field.type_info) {
if let node_types::FieldTypeInfo::ReservedWordInt(int_mapping) =
&field.type_info
{
// We can safely unwrap because type_matches checks the key is in the map.
let (int_value, _) = int_mapping.get(&child_node.type_name.kind).unwrap();
values.push(trap::Arg::Int(*int_value));
} else {
values.push(trap::Arg::Label(child_node.label));
}
} else if field.name.is_some() {
let error_message = format!(
"type mismatch for field {}::{} with type {:?} != {:?}",
node.kind(),
child_node.field_name.unwrap_or("child"),
child_node.type_name,
field.type_info
);
let full_error_message = format!(
"{}:{}: {}",
&self.path,
node.start_position().row + 1,
error_message
);
self.record_parse_error_for_node(error_message, full_error_message, *node);
}
} else if child_node.field_name.is_some() || child_node.type_name.named {
let error_message = format!(
"value for unknown field: {}::{} and type {:?}",
node.kind(),
&child_node.field_name.unwrap_or("child"),
&child_node.type_name
);
let full_error_message = format!(
"{}:{}: {}",
&self.path,
node.start_position().row + 1,
error_message
);
self.record_parse_error_for_node(error_message, full_error_message, *node);
}
}
let mut args = Vec::new();
let mut is_valid = true;
for field in fields {
let child_values = &map.get(&field.name).unwrap().1;
match &field.storage {
Storage::Column { name: column_name } => {
if child_values.len() == 1 {
args.push(child_values.first().unwrap().clone());
} else {
is_valid = false;
let error_message = format!(
"{} for field: {}::{}",
if child_values.is_empty() {
"missing value"
} else {
"too many values"
},
node.kind(),
column_name
);
let full_error_message = format!(
"{}:{}: {}",
&self.path,
node.start_position().row + 1,
error_message
);
self.record_parse_error_for_node(error_message, full_error_message, *node);
}
}
Storage::Table {
name: table_name,
has_index,
column_name: _,
} => {
for (index, child_value) in child_values.iter().enumerate() {
if !*has_index && index > 0 {
error!(
"{}:{}: too many values for field: {}::{}",
&self.path,
node.start_position().row + 1,
node.kind(),
table_name,
);
break;
}
let mut args = vec![trap::Arg::Label(parent_id)];
if *has_index {
args.push(trap::Arg::Int(index))
}
args.push(child_value.clone());
self.trap_writer.add_tuple(table_name, args);
}
}
}
}
if is_valid {
Some(args)
} else {
None
}
}
fn type_matches(&self, tp: &TypeName, type_info: &node_types::FieldTypeInfo) -> bool {
match type_info {
node_types::FieldTypeInfo::Single(single_type) => {
if tp == single_type {
return true;
}
if let EntryKind::Union { members } = &self.schema.get(single_type).unwrap().kind {
if self.type_matches_set(tp, members) {
return true;
}
}
}
node_types::FieldTypeInfo::Multiple { types, .. } => {
return self.type_matches_set(tp, types);
}
node_types::FieldTypeInfo::ReservedWordInt(int_mapping) => {
return !tp.named && int_mapping.contains_key(&tp.kind)
}
}
false
}
fn type_matches_set(&self, tp: &TypeName, types: &Set<TypeName>) -> bool {
if types.contains(tp) {
return true;
}
for other in types.iter() {
if let EntryKind::Union { members } = &self.schema.get(other).unwrap().kind {
if self.type_matches_set(tp, members) {
return true;
}
}
}
false
}
}
// Emit a slice of a source file as an Arg.
fn sliced_source_arg(source: &[u8], n: Node) -> trap::Arg {
let range = n.byte_range();
trap::Arg::String(String::from_utf8_lossy(&source[range.start..range.end]).into_owned())
}
// Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated.
// The first is the location and label definition, and the second is the
// 'Located' entry.
fn location_for(source: &[u8], n: Node) -> (usize, usize, usize, usize) {
// Tree-sitter row, column values are 0-based while CodeQL starts
// counting at 1. In addition Tree-sitter's row and column for the
// end position are exclusive while CodeQL's end positions are inclusive.
// This means that all values should be incremented by 1 and in addition the
// end position needs to be shift 1 to the left. In most cases this means
// simply incrementing all values except the end column except in cases where
// the end column is 0 (start of a line). In such cases the end position must be
// set to the end of the previous line.
let start_line = n.start_position().row + 1;
let start_col = n.start_position().column + 1;
let mut end_line = n.end_position().row + 1;
let mut end_col = n.end_position().column;
if start_line > end_line || start_line == end_line && start_col > end_col {
// the range is empty, clip it to sensible values
end_line = start_line;
end_col = start_col - 1;
} else if end_col == 0 {
// end_col = 0 means that we are at the start of a line
// unfortunately 0 is invalid as column number, therefore
// we should update the end location to be the end of the
// previous line
let mut index = n.end_byte();
if index > 0 && index <= source.len() {
index -= 1;
if source[index] != b'\n' {
error!("expecting a line break symbol, but none found while correcting end column value");
}
end_line -= 1;
end_col = 1;
while index > 0 && source[index - 1] != b'\n' {
index -= 1;
end_col += 1;
}
} else {
error!(
"cannot correct end column value: end_byte index {} is not in range [1,{}]",
index,
source.len()
);
}
}
(start_line, start_col, end_line, end_col)
}
fn traverse(tree: &Tree, visitor: &mut Visitor) {
let cursor = &mut tree.walk();
visitor.enter_node(cursor.node());
let mut recurse = true;
loop {
if recurse && cursor.goto_first_child() {
recurse = visitor.enter_node(cursor.node());
} else {
visitor.leave_node(cursor.field_name(), cursor.node());
if cursor.goto_next_sibling() {
recurse = visitor.enter_node(cursor.node());
} else if cursor.goto_parent() {
recurse = false;
} else {
break;
}
}
}
}
// Numeric indices.
#[derive(Debug, Copy, Clone)]
struct Index(usize);
impl fmt::Display for Index {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}

View File

@@ -1,41 +1,9 @@
mod extractor;
mod trap;
extern crate num_cpus;
use rayon::prelude::*;
use std::fs;
use std::io::BufRead;
use std::path::{Path, PathBuf};
/**
* Gets the number of threads the extractor should use, by reading the
* CODEQL_THREADS environment variable and using it as described in the
* extractor spec:
*
* "If the number is positive, it indicates the number of threads that should
* be used. If the number is negative or zero, it should be added to the number
* of cores available on the machine to determine how many threads to use
* (minimum of 1). If unspecified, should be considered as set to -1."
*/
fn num_codeql_threads() -> usize {
let threads_str = std::env::var("CODEQL_THREADS").unwrap_or_else(|_| "-1".to_owned());
match threads_str.parse::<i32>() {
Ok(num) if num <= 0 => {
let reduction = -num as usize;
std::cmp::max(1, num_cpus::get() - reduction)
}
Ok(num) => num as usize,
Err(_) => {
tracing::error!(
"Unable to parse CODEQL_THREADS value '{}'; defaulting to 1 thread.",
&threads_str
);
1
}
}
}
use codeql_extractor::{diagnostics, extractor, node_types, trap};
fn main() -> std::io::Result<()> {
tracing_subscriber::fmt()
@@ -45,7 +13,23 @@ fn main() -> std::io::Result<()> {
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.init();
let num_threads = num_codeql_threads();
let diagnostics = diagnostics::DiagnosticLoggers::new("ql");
let mut main_thread_logger = diagnostics.logger();
let num_threads = match codeql_extractor::options::num_threads() {
Ok(num) => num,
Err(e) => {
main_thread_logger.write(
main_thread_logger
.new_entry("configuration-error", "Configuration error")
.message(
"{}; defaulting to 1 thread.",
&[diagnostics::MessageArg::Code(&e)],
)
.severity(diagnostics::Severity::Warning),
);
1
}
};
tracing::info!(
"Using {} {}",
num_threads,
@@ -55,6 +39,20 @@ fn main() -> std::io::Result<()> {
"threads"
}
);
let trap_compression = match trap::Compression::from_env("CODEQL_QL_TRAP_COMPRESSION") {
Ok(x) => x,
Err(e) => {
main_thread_logger.write(
main_thread_logger
.new_entry("configuration-error", "Configuration error")
.message("{}; using gzip.", &[diagnostics::MessageArg::Code(&e)])
.severity(diagnostics::Severity::Warning),
);
trap::Compression::Gzip
}
};
drop(main_thread_logger);
rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build_global()
@@ -79,7 +77,6 @@ fn main() -> std::io::Result<()> {
.value_of("output-dir")
.expect("missing --output-dir");
let trap_dir = PathBuf::from(trap_dir);
let trap_compression = trap::Compression::from_env("CODEQL_QL_TRAP_COMPRESSION");
let file_list = matches.value_of("file-list").expect("missing --file-list");
let file_list = fs::File::open(file_list)?;
@@ -119,26 +116,29 @@ fn main() -> std::io::Result<()> {
let source = std::fs::read(&path)?;
let code_ranges = vec![];
let mut trap_writer = trap::Writer::new();
let mut diagnostics_writer = diagnostics.logger();
if line.ends_with(".dbscheme") {
extractor::extract(
dbscheme,
"dbscheme",
&dbscheme_schema,
&mut diagnostics_writer,
&mut trap_writer,
&path,
&source,
&code_ranges,
)?
)
} else if line.ends_with("qlpack.yml") {
extractor::extract(
yaml,
"yaml",
&yaml_schema,
&mut diagnostics_writer,
&mut trap_writer,
&path,
&source,
&code_ranges,
)?
)
} else if line.ends_with(".json")
|| line.ends_with(".jsonl")
|| line.ends_with(".jsonc")
@@ -147,31 +147,34 @@ fn main() -> std::io::Result<()> {
json,
"json",
&json_schema,
&mut diagnostics_writer,
&mut trap_writer,
&path,
&source,
&code_ranges,
)?
)
} else if line.ends_with(".blame") {
extractor::extract(
blame,
"blame",
&blame_schema,
&mut diagnostics_writer,
&mut trap_writer,
&path,
&source,
&code_ranges,
)?
)
} else {
extractor::extract(
language,
"ql",
&schema,
&mut diagnostics_writer,
&mut trap_writer,
&path,
&source,
&code_ranges,
)?
)
}
std::fs::create_dir_all(&src_archive_file.parent().unwrap())?;
std::fs::copy(&path, &src_archive_file)?;

View File

@@ -1,275 +0,0 @@
use std::borrow::Cow;
use std::fmt;
use std::io::{BufWriter, Write};
use std::path::Path;
use flate2::write::GzEncoder;
pub struct Writer {
/// The accumulated trap entries
trap_output: Vec<Entry>,
/// A counter for generating fresh labels
counter: u32,
/// cache of global keys
global_keys: std::collections::HashMap<String, Label>,
}
impl Writer {
pub fn new() -> Writer {
Writer {
counter: 0,
trap_output: Vec::new(),
global_keys: std::collections::HashMap::new(),
}
}
pub fn fresh_id(&mut self) -> Label {
let label = Label(self.counter);
self.counter += 1;
self.trap_output.push(Entry::FreshId(label));
label
}
/// Gets a label that will hold the unique ID of the passed string at import time.
/// This can be used for incrementally importable TRAP files -- use globally unique
/// strings to compute a unique ID for table tuples.
///
/// Note: You probably want to make sure that the key strings that you use are disjoint
/// for disjoint column types; the standard way of doing this is to prefix (or append)
/// the column type name to the ID. Thus, you might identify methods in Java by the
/// full ID "methods_com.method.package.DeclaringClass.method(argumentList)".
pub fn global_id(&mut self, key: &str) -> (Label, bool) {
if let Some(label) = self.global_keys.get(key) {
return (*label, false);
}
let label = Label(self.counter);
self.counter += 1;
self.global_keys.insert(key.to_owned(), label);
self.trap_output
.push(Entry::MapLabelToKey(label, key.to_owned()));
(label, true)
}
pub fn add_tuple(&mut self, table_name: &str, args: Vec<Arg>) {
self.trap_output
.push(Entry::GenericTuple(table_name.to_owned(), args))
}
pub fn comment(&mut self, text: String) {
self.trap_output.push(Entry::Comment(text));
}
pub fn write_to_file(&self, path: &Path, compression: Compression) -> std::io::Result<()> {
let trap_file = std::fs::File::create(path)?;
match compression {
Compression::None => {
let mut trap_file = BufWriter::new(trap_file);
self.write_trap_entries(&mut trap_file)
}
Compression::Gzip => {
let trap_file = GzEncoder::new(trap_file, flate2::Compression::fast());
let mut trap_file = BufWriter::new(trap_file);
self.write_trap_entries(&mut trap_file)
}
}
}
fn write_trap_entries<W: Write>(&self, file: &mut W) -> std::io::Result<()> {
for trap_entry in &self.trap_output {
writeln!(file, "{}", trap_entry)?;
}
std::io::Result::Ok(())
}
}
pub enum Entry {
/// Maps the label to a fresh id, e.g. `#123=*`.
FreshId(Label),
/// Maps the label to a key, e.g. `#7=@"foo"`.
MapLabelToKey(Label, String),
/// foo_bar(arg*)
GenericTuple(String, Vec<Arg>),
Comment(String),
}
impl fmt::Display for Entry {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Entry::FreshId(label) => write!(f, "{}=*", label),
Entry::MapLabelToKey(label, key) => {
write!(f, "{}=@\"{}\"", label, key.replace("\"", "\"\""))
}
Entry::GenericTuple(name, args) => {
write!(f, "{}(", name)?;
for (index, arg) in args.iter().enumerate() {
if index > 0 {
write!(f, ",")?;
}
write!(f, "{}", arg)?;
}
write!(f, ")")
}
Entry::Comment(line) => write!(f, "// {}", line),
}
}
}
#[derive(Debug, Copy, Clone)]
// Identifiers of the form #0, #1...
pub struct Label(u32);
impl fmt::Display for Label {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "#{:x}", self.0)
}
}
// Some untyped argument to a TrapEntry.
#[derive(Debug, Clone)]
pub enum Arg {
Label(Label),
Int(usize),
String(String),
}
const MAX_STRLEN: usize = 1048576;
impl fmt::Display for Arg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Arg::Label(x) => write!(f, "{}", x),
Arg::Int(x) => write!(f, "{}", x),
Arg::String(x) => write!(
f,
"\"{}\"",
limit_string(x, MAX_STRLEN).replace("\"", "\"\"")
),
}
}
}
pub struct Program(Vec<Entry>);
impl fmt::Display for Program {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut text = String::new();
for trap_entry in &self.0 {
text.push_str(&format!("{}\n", trap_entry));
}
write!(f, "{}", text)
}
}
pub fn full_id_for_file(normalized_path: &str) -> String {
format!("{};sourcefile", escape_key(normalized_path))
}
pub fn full_id_for_folder(normalized_path: &str) -> String {
format!("{};folder", escape_key(normalized_path))
}
/// Escapes a string for use in a TRAP key, by replacing special characters with
/// HTML entities.
fn escape_key<'a, S: Into<Cow<'a, str>>>(key: S) -> Cow<'a, str> {
fn needs_escaping(c: char) -> bool {
matches!(c, '&' | '{' | '}' | '"' | '@' | '#')
}
let key = key.into();
if key.contains(needs_escaping) {
let mut escaped = String::with_capacity(2 * key.len());
for c in key.chars() {
match c {
'&' => escaped.push_str("&amp;"),
'{' => escaped.push_str("&lbrace;"),
'}' => escaped.push_str("&rbrace;"),
'"' => escaped.push_str("&quot;"),
'@' => escaped.push_str("&commat;"),
'#' => escaped.push_str("&num;"),
_ => escaped.push(c),
}
}
Cow::Owned(escaped)
} else {
key
}
}
/// Limit the length (in bytes) of a string. If the string's length in bytes is
/// less than or equal to the limit then the entire string is returned. Otherwise
/// the string is sliced at the provided limit. If there is a multi-byte character
/// at the limit then the returned slice will be slightly shorter than the limit to
/// avoid splitting that multi-byte character.
fn limit_string(string: &str, max_size: usize) -> &str {
if string.len() <= max_size {
return string;
}
let p = string.as_bytes();
let mut index = max_size;
// We want to clip the string at [max_size]; however, the character at that position
// may span several bytes. We need to find the first byte of the character. In UTF-8
// encoded data any byte that matches the bit pattern 10XXXXXX is not a start byte.
// Therefore we decrement the index as long as there are bytes matching this pattern.
// This ensures we cut the string at the border between one character and another.
while index > 0 && (p[index] & 0b11000000) == 0b10000000 {
index -= 1;
}
&string[0..index]
}
#[derive(Clone, Copy)]
pub enum Compression {
None,
Gzip,
}
impl Compression {
pub fn from_env(var_name: &str) -> Compression {
match std::env::var(var_name) {
Ok(method) => match Compression::from_string(&method) {
Some(c) => c,
None => {
tracing::error!("Unknown compression method '{}'; using gzip.", &method);
Compression::Gzip
}
},
// Default compression method if the env var isn't set:
Err(_) => Compression::Gzip,
}
}
pub fn from_string(s: &str) -> Option<Compression> {
match s.to_lowercase().as_ref() {
"none" => Some(Compression::None),
"gzip" => Some(Compression::Gzip),
_ => None,
}
}
pub fn extension(&self) -> &str {
match self {
Compression::None => "trap",
Compression::Gzip => "trap.gz",
}
}
}
#[test]
fn limit_string_test() {
assert_eq!("hello", limit_string(&"hello world".to_owned(), 5));
assert_eq!("hi ☹", limit_string(&"hi ☹☹".to_owned(), 6));
assert_eq!("hi ", limit_string(&"hi ☹☹".to_owned(), 5));
}
#[test]
fn escape_key_test() {
assert_eq!("foo!", escape_key("foo!"));
assert_eq!("foo&lbrace;&rbrace;", escape_key("foo{}"));
assert_eq!("&lbrace;&rbrace;", escape_key("{}"));
assert_eq!("", escape_key(""));
assert_eq!("/path/to/foo.rb", escape_key("/path/to/foo.rb"));
assert_eq!(
"/path/to/foo&amp;&lbrace;&rbrace;&quot;&commat;&num;.rb",
escape_key("/path/to/foo&{}\"@#.rb")
);
}