Parallelize extraction

Use the Rayon library to do parallel iteration over the file list. The
number of threads used respects the CODEQL_THREADS environment variable.
This commit is contained in:
Nick Rolfe
2020-12-22 19:04:24 +00:00
parent c35283cefb
commit bf4eac5113
4 changed files with 206 additions and 52 deletions

View File

@@ -4,7 +4,7 @@ use std::collections::BTreeSet as Set;
use std::fmt;
use std::path::Path;
use tracing::{error, info, span, Level};
use tree_sitter::{Language, Node, Parser, Tree};
use tree_sitter::{Node, Parser, Tree};
struct TrapWriter {
/// The accumulated trap entries
@@ -148,55 +148,38 @@ impl TrapWriter {
}
}
pub struct Extractor {
pub parser: Parser,
pub schema: NodeTypeMap,
}
/// Extracts the source file at `path`, which is assumed to be canonicalized.
pub fn extract(parser: &mut Parser, schema: &NodeTypeMap, path: &Path) -> std::io::Result<Program> {
let span = span!(
Level::TRACE,
"extract",
file = %path.display()
);
pub fn create(language: Language, schema: NodeTypeMap) -> Extractor {
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let _enter = span.enter();
Extractor { parser, schema }
}
info!("extracting: {}", path.display());
impl Extractor {
/// Extracts the source file at `path`, which is assumed to be canonicalized.
pub fn extract<'a>(&'a mut self, path: &Path) -> std::io::Result<Program> {
let span = span!(
Level::TRACE,
"extract",
file = %path.display()
);
let source = std::fs::read(&path)?;
let tree = parser.parse(&source, None).expect("Failed to parse file");
let mut trap_writer = new_trap_writer();
trap_writer.comment(format!("Auto-generated TRAP file for {}", path.display()));
let file_label = &trap_writer.populate_file(path);
let mut visitor = Visitor {
source: &source,
trap_writer: trap_writer,
// TODO: should we handle path strings that are not valid UTF8 better?
path: format!("{}", path.display()),
file_label: *file_label,
token_counter: 0,
toplevel_child_counter: 0,
stack: Vec::new(),
schema,
};
traverse(&tree, &mut visitor);
let _enter = span.enter();
info!("extracting: {}", path.display());
let source = std::fs::read(&path)?;
let tree = &self
.parser
.parse(&source, None)
.expect("Failed to parse file");
let mut trap_writer = new_trap_writer();
trap_writer.comment(format!("Auto-generated TRAP file for {}", path.display()));
let file_label = &trap_writer.populate_file(path);
let mut visitor = Visitor {
source: &source,
trap_writer: trap_writer,
// TODO: should we handle path strings that are not valid UTF8 better?
path: format!("{}", path.display()),
file_label: *file_label,
token_counter: 0,
toplevel_child_counter: 0,
stack: Vec::new(),
schema: &self.schema,
};
traverse(&tree, &mut visitor);
&self.parser.reset();
Ok(Program(visitor.trap_writer.trap_output))
}
parser.reset();
Ok(Program(visitor.trap_writer.trap_output))
}
/// Normalizes the path according the common CodeQL specification. Assumes that

View File

@@ -1,10 +1,14 @@
mod extractor;
extern crate num_cpus;
use clap;
use flate2::write::GzEncoder;
use rayon::prelude::*;
use std::fs;
use std::io::{BufRead, BufWriter, Write};
use std::path::{Path, PathBuf};
use tree_sitter::Parser;
enum TrapCompression {
None,
@@ -42,6 +46,41 @@ impl TrapCompression {
}
}
/**
* Gets the number of threads the extractor should use, by reading the
* CODEQL_THREADS environment variable and using it as follows:
*
* If the number is positive, it indicates the number of threads that should be
* used. If the number is negative or zero, it should be added to the number of
* cores available on the machine to determine how many threads to use (minimum
* of 1). If unspecified, should be considered as set to 1.
*/
fn num_codeql_threads() -> usize {
match std::env::var("CODEQL_THREADS") {
Ok(num) => match num.parse::<i32>() {
Ok(num) => {
if num <= 0 {
let reduction = -num as usize;
num_cpus::get() - reduction
} else {
num as usize
}
}
Err(_) => {
tracing::error!(
"Unable to parse CODEQL_THREADS value '{}'; defaulting to 1 thread.",
&num
);
1
}
},
// Use 1 thread if the environment variable isn't set.
Err(_) => 1,
}
}
fn main() -> std::io::Result<()> {
tracing_subscriber::fmt()
.with_target(false)
@@ -50,6 +89,21 @@ fn main() -> std::io::Result<()> {
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.init();
let num_threads = num_codeql_threads();
tracing::info!(
"Using {} {}",
num_threads,
if num_threads == 1 {
"thread"
} else {
"threads"
}
);
rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build_global()
.unwrap();
let matches = clap::App::new("Ruby extractor")
.version("1.0")
.author("GitHub")
@@ -76,12 +130,15 @@ fn main() -> std::io::Result<()> {
let language = tree_sitter_ruby::language();
let schema = node_types::read_node_types_str(tree_sitter_ruby::NODE_TYPES)?;
let mut extractor = extractor::create(language, schema);
for line in std::io::BufReader::new(file_list).lines() {
let path = PathBuf::from(line?).canonicalize()?;
let lines: std::io::Result<Vec<String>> = std::io::BufReader::new(file_list).lines().collect();
let lines = lines?;
lines.par_iter().try_for_each(|line| {
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let path = PathBuf::from(line).canonicalize()?;
let trap_file = path_for(&trap_dir, &path, trap_compression.extension());
let src_archive_file = path_for(&src_archive_dir, &path, "");
let trap = extractor.extract(&path)?;
let trap = extractor::extract(&mut parser, &schema, &path)?;
std::fs::create_dir_all(&src_archive_file.parent().unwrap())?;
std::fs::copy(&path, &src_archive_file)?;
std::fs::create_dir_all(&trap_file.parent().unwrap())?;
@@ -96,8 +153,10 @@ fn main() -> std::io::Result<()> {
write!(compressed_writer, "{}", trap)?;
}
}
}
return Ok(());
std::io::Result::Ok(())
})?;
Ok(())
}
fn path_for(dir: &Path, path: &Path, ext: &str) -> PathBuf {