Merge pull request #117 from github/tausbn/update-extractor-generator

Upgrade the extractor generator
This commit is contained in:
Taus
2021-10-15 13:59:32 +02:00
committed by GitHub
22 changed files with 2371 additions and 2394 deletions

View File

@@ -62,7 +62,7 @@ jobs:
- name: Release build
run: cargo build --release
- name: Generate dbscheme
run: target/release/ql-generator
run: target/release/ql-generator --dbscheme ql/src/ql.dbscheme --library ql/src/codeql_ql/ast/internal/TreeSitter.qll
- uses: actions/upload-artifact@v2
with:
name: ql.dbscheme

View File

@@ -43,7 +43,7 @@ jobs:
run: cargo build --release
- name: Generate dbscheme
if: ${{ matrix.os == 'ubuntu-latest' }}
run: target/release/ql-generator
run: target/release/ql-generator --dbscheme ql/src/ql.dbscheme --library ql/src/codeql_ql/ast/internal/TreeSitter.qll
- uses: actions/upload-artifact@v2
if: ${{ matrix.os == 'ubuntu-latest' }}
with:

View File

@@ -73,7 +73,7 @@ jobs:
path: stats
- run: |
python -m pip install --user lxml
find stats -name 'stats.xml' | sort | xargs python scripts/merge_stats.py --output ql/src/ql.dbscheme.stats --normalise tokeninfo
find stats -name 'stats.xml' | sort | xargs python scripts/merge_stats.py --output ql/src/ql.dbscheme.stats --normalise ql_tokeninfo
- uses: actions/upload-artifact@v2
with:
name: ql.dbscheme.stats

1
Cargo.lock generated
View File

@@ -329,6 +329,7 @@ dependencies = [
name = "ql-generator"
version = "0.1.0"
dependencies = [
"clap",
"node-types",
"tracing",
"tracing-subscriber",

View File

@@ -1,6 +1,6 @@
cargo build --release
cargo run --release -p ql-generator
cargo run --release -p ql-generator -- --dbscheme ql/src/ql.dbscheme --library ql/src/codeql_ql/ast/internal/
codeql query format -i ql\src\codeql_ql\ast\internal\TreeSitter.qll
if (Test-Path -Path extractor-pack) {

View File

@@ -12,7 +12,7 @@ fi
cargo build --release
cargo run --release -p ql-generator
cargo run --release -p ql-generator -- --dbscheme ql/src/ql.dbscheme --library ql/src/codeql_ql/ast/internal/TreeSitter.qll
codeql query format -i ql/src/codeql_ql/ast/internal/TreeSitter.qll
rm -rf extractor-pack

View File

@@ -3,11 +3,13 @@ use std::borrow::Cow;
use std::collections::BTreeMap as Map;
use std::collections::BTreeSet as Set;
use std::fmt;
use std::io::Write;
use std::path::Path;
use tracing::{error, info, span, Level};
use tree_sitter::{Language, Node, Parser, Range, Tree};
struct TrapWriter {
pub struct TrapWriter {
/// The accumulated trap entries
trap_output: Vec<TrapEntry>,
/// A counter for generating fresh labels
@@ -16,7 +18,7 @@ struct TrapWriter {
global_keys: std::collections::HashMap<String, Label>,
}
fn new_trap_writer() -> TrapWriter {
pub fn new_trap_writer() -> TrapWriter {
TrapWriter {
counter: 0,
trap_output: Vec::new(),
@@ -66,15 +68,6 @@ impl TrapWriter {
vec![
Arg::Label(file_label),
Arg::String(normalize_path(absolute_path)),
Arg::String(match absolute_path.file_name() {
None => "".to_owned(),
Some(file_name) => format!("{}", file_name.to_string_lossy()),
}),
Arg::String(match absolute_path.extension() {
None => "".to_owned(),
Some(ext) => format!("{}", ext.to_string_lossy()),
}),
Arg::Int(1), // 1 = from source
],
);
self.populate_parent_folders(file_label, absolute_path.parent());
@@ -82,6 +75,22 @@ impl TrapWriter {
file_label
}
fn populate_empty_file(&mut self) -> Label {
let (file_label, fresh) = self.global_id("empty;sourcefile");
if fresh {
self.add_tuple(
"files",
vec![Arg::Label(file_label), Arg::String("".to_string())],
);
}
file_label
}
pub fn populate_empty_location(&mut self) {
let file_label = self.populate_empty_file();
self.location(file_label, 0, 0, 0, 0);
}
fn populate_parent_folders(&mut self, child_label: Label, path: Option<&Path>) {
let mut path = path;
let mut child_label = child_label;
@@ -100,10 +109,6 @@ impl TrapWriter {
vec![
Arg::Label(folder_label),
Arg::String(normalize_path(folder)),
Arg::String(match folder.file_name() {
None => "".to_owned(),
Some(file_name) => format!("{}", file_name.to_string_lossy()),
}),
],
);
path = folder.parent();
@@ -147,16 +152,22 @@ impl TrapWriter {
fn comment(&mut self, text: String) {
self.trap_output.push(TrapEntry::Comment(text));
}
pub fn output(self, writer: &mut dyn Write) -> std::io::Result<()> {
write!(writer, "{}", Program(self.trap_output))
}
}
/// Extracts the source file at `path`, which is assumed to be canonicalized.
pub fn extract(
language: Language,
language_prefix: &str,
schema: &NodeTypeMap,
trap_writer: &mut TrapWriter,
path: &Path,
source: &Vec<u8>,
source: &[u8],
ranges: &[Range],
) -> std::io::Result<Program> {
) -> std::io::Result<()> {
let span = span!(
Level::TRACE,
"extract",
@@ -169,41 +180,32 @@ pub fn extract(
let mut parser = Parser::new();
parser.set_language(language).unwrap();
parser.set_included_ranges(&ranges).unwrap();
parser.set_included_ranges(ranges).unwrap();
let tree = parser.parse(&source, None).expect("Failed to parse file");
let mut trap_writer = new_trap_writer();
trap_writer.comment(format!("Auto-generated TRAP file for {}", path.display()));
let file_label = &trap_writer.populate_file(path);
let mut visitor = Visitor {
source: &source,
trap_writer: trap_writer,
source,
trap_writer,
// TODO: should we handle path strings that are not valid UTF8 better?
path: format!("{}", path.display()),
file_label: *file_label,
token_counter: 0,
toplevel_child_counter: 0,
stack: Vec::new(),
language_prefix,
schema,
};
traverse(&tree, &mut visitor);
parser.reset();
Ok(Program(visitor.trap_writer.trap_output))
Ok(())
}
/// Escapes a string for use in a TRAP key, by replacing special characters with
/// HTML entities.
fn escape_key<'a, S: Into<Cow<'a, str>>>(key: S) -> Cow<'a, str> {
fn needs_escaping(c: char) -> bool {
match c {
'&' => true,
'{' => true,
'}' => true,
'"' => true,
'@' => true,
'#' => true,
_ => false,
}
matches!(c, '&' | '{' | '}' | '"' | '@' | '#')
}
let key = key.into();
@@ -286,13 +288,13 @@ struct Visitor<'a> {
/// source file.
file_label: Label,
/// The source code as a UTF-8 byte array
source: &'a Vec<u8>,
source: &'a [u8],
/// A TrapWriter to accumulate trap entries
trap_writer: TrapWriter,
/// A counter for tokens
token_counter: usize,
trap_writer: &'a mut TrapWriter,
/// A counter for top-level child nodes
toplevel_child_counter: usize,
/// Language prefix
language_prefix: &'a str,
/// A lookup table from type name to node types
schema: &'a NodeTypeMap,
/// A stack for gathering information from child nodes. Whenever a node is
@@ -332,7 +334,7 @@ impl Visitor<'_> {
full_error_message: String,
node: Node,
) {
let (start_line, start_column, end_line, end_column) = location_for(&self.source, node);
let (start_line, start_column, end_line, end_column) = location_for(self.source, node);
let loc = self.trap_writer.location(
self.file_label,
start_line,
@@ -363,7 +365,7 @@ impl Visitor<'_> {
let id = self.trap_writer.fresh_id();
self.stack.push((id, 0, Vec::new()));
return true;
true
}
fn leave_node(&mut self, field_name: Option<&'static str>, node: Node) {
@@ -371,7 +373,7 @@ impl Visitor<'_> {
return;
}
let (id, _, child_nodes) = self.stack.pop().expect("Vistor: empty stack");
let (start_line, start_column, end_line, end_column) = location_for(&self.source, node);
let (start_line, start_column, end_line, end_column) = location_for(self.source, node);
let loc = self.trap_writer.location(
self.file_label,
start_line,
@@ -400,7 +402,7 @@ impl Visitor<'_> {
match &table.kind {
EntryKind::Token { kind_id, .. } => {
self.trap_writer.add_tuple(
"ast_node_parent",
&format!("{}_ast_node_parent", self.language_prefix),
vec![
Arg::Label(id),
Arg::Label(parent_id),
@@ -408,17 +410,14 @@ impl Visitor<'_> {
],
);
self.trap_writer.add_tuple(
"tokeninfo",
&format!("{}_tokeninfo", self.language_prefix),
vec![
Arg::Label(id),
Arg::Int(*kind_id),
Arg::Label(self.file_label),
Arg::Int(self.token_counter),
sliced_source_arg(self.source, node),
Arg::Label(loc),
],
);
self.token_counter += 1;
}
EntryKind::Table {
fields,
@@ -426,18 +425,17 @@ impl Visitor<'_> {
} => {
if let Some(args) = self.complex_node(&node, fields, &child_nodes, id) {
self.trap_writer.add_tuple(
"ast_node_parent",
&format!("{}_ast_node_parent", self.language_prefix),
vec![
Arg::Label(id),
Arg::Label(parent_id),
Arg::Int(parent_index),
],
);
let mut all_args = Vec::new();
all_args.push(Arg::Label(id));
let mut all_args = vec![Arg::Label(id)];
all_args.extend(args);
all_args.push(Arg::Label(loc));
self.trap_writer.add_tuple(&table_name, all_args);
self.trap_writer.add_tuple(table_name, all_args);
}
}
_ => {
@@ -472,8 +470,8 @@ impl Visitor<'_> {
fn complex_node(
&mut self,
node: &Node,
fields: &Vec<Field>,
child_nodes: &Vec<ChildNode>,
fields: &[Field],
child_nodes: &[ChildNode],
parent_id: Label,
) -> Option<Vec<Arg>> {
let mut map: Map<&Option<String>, (&Field, Vec<Arg>)> = Map::new();
@@ -510,22 +508,20 @@ impl Visitor<'_> {
);
self.record_parse_error_for_node(error_message, full_error_message, *node);
}
} else {
if child_node.field_name.is_some() || child_node.type_name.named {
let error_message = format!(
"value for unknown field: {}::{} and type {:?}",
node.kind(),
&child_node.field_name.unwrap_or("child"),
&child_node.type_name
);
let full_error_message = format!(
"{}:{}: {}",
&self.path,
node.start_position().row + 1,
error_message
);
self.record_parse_error_for_node(error_message, full_error_message, *node);
}
} else if child_node.field_name.is_some() || child_node.type_name.named {
let error_message = format!(
"value for unknown field: {}::{} and type {:?}",
node.kind(),
&child_node.field_name.unwrap_or("child"),
&child_node.type_name
);
let full_error_message = format!(
"{}:{}: {}",
&self.path,
node.start_position().row + 1,
error_message
);
self.record_parse_error_for_node(error_message, full_error_message, *node);
}
}
let mut args = Vec::new();
@@ -573,13 +569,12 @@ impl Visitor<'_> {
);
break;
}
let mut args = Vec::new();
args.push(Arg::Label(parent_id));
let mut args = vec![Arg::Label(parent_id)];
if *has_index {
args.push(Arg::Int(index))
}
args.push(child_value.clone());
self.trap_writer.add_tuple(&table_name, args);
self.trap_writer.add_tuple(table_name, args);
}
}
}
@@ -597,13 +592,10 @@ impl Visitor<'_> {
if tp == single_type {
return true;
}
match &self.schema.get(single_type).unwrap().kind {
EntryKind::Union { members } => {
if self.type_matches_set(tp, members) {
return true;
}
if let EntryKind::Union { members } = &self.schema.get(single_type).unwrap().kind {
if self.type_matches_set(tp, members) {
return true;
}
_ => {}
}
}
node_types::FieldTypeInfo::Multiple { types, .. } => {
@@ -633,7 +625,7 @@ impl Visitor<'_> {
}
// Emit a slice of a source file as an Arg.
fn sliced_source_arg(source: &Vec<u8>, n: Node) -> Arg {
fn sliced_source_arg(source: &[u8], n: Node) -> Arg {
let range = n.byte_range();
Arg::String(String::from_utf8_lossy(&source[range.start..range.end]).into_owned())
}
@@ -641,7 +633,7 @@ fn sliced_source_arg(source: &Vec<u8>, n: Node) -> Arg {
// Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated.
// The first is the location and label definition, and the second is the
// 'Located' entry.
fn location_for<'a>(source: &Vec<u8>, n: Node) -> (usize, usize, usize, usize) {
fn location_for(source: &[u8], n: Node) -> (usize, usize, usize, usize) {
// Tree-sitter row, column values are 0-based while CodeQL starts
// counting at 1. In addition Tree-sitter's row and column for the
// end position are exclusive while CodeQL's end positions are inclusive.
@@ -720,9 +712,9 @@ impl fmt::Display for Program {
}
enum TrapEntry {
/// Maps the label to a fresh id, e.g. `#123 = *`.
/// Maps the label to a fresh id, e.g. `#123=*`.
FreshId(Label),
/// Maps the label to a key, e.g. `#7 = @"foo"`.
/// Maps the label to a key, e.g. `#7=@"foo"`.
MapLabelToKey(Label, String),
/// foo_bar(arg*)
GenericTuple(String, Vec<Arg>),
@@ -731,15 +723,15 @@ enum TrapEntry {
impl fmt::Display for TrapEntry {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
TrapEntry::FreshId(label) => write!(f, "{} = *", label),
TrapEntry::FreshId(label) => write!(f, "{}=*", label),
TrapEntry::MapLabelToKey(label, key) => {
write!(f, "{} = @\"{}\"", label, key.replace("\"", "\"\""))
write!(f, "{}=@\"{}\"", label, key.replace("\"", "\"\""))
}
TrapEntry::GenericTuple(name, args) => {
write!(f, "{}(", name)?;
for (index, arg) in args.iter().enumerate() {
if index > 0 {
write!(f, ", ")?;
write!(f, ",")?;
}
write!(f, "{}", arg)?;
}
@@ -756,7 +748,7 @@ struct Label(u32);
impl fmt::Display for Label {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "#{}", self.0)
write!(f, "#{:x}", self.0)
}
}
@@ -799,18 +791,18 @@ impl fmt::Display for Arg {
/// the string is sliced at the provided limit. If there is a multi-byte character
/// at the limit then the returned slice will be slightly shorter than the limit to
/// avoid splitting that multi-byte character.
fn limit_string(string: &String, max_size: usize) -> &str {
fn limit_string(string: &str, max_size: usize) -> &str {
if string.len() <= max_size {
return string;
}
let p = string.as_ptr();
let p = string.as_bytes();
let mut index = max_size;
// We want to clip the string at [max_size]; however, the character at that position
// may span several bytes. We need to find the first byte of the character. In UTF-8
// encoded data any byte that matches the bit pattern 10XXXXXX is not a start byte.
// Therefore we decrement the index as long as there are bytes matching this pattern.
// This ensures we cut the string at the border between one character and another.
while index > 0 && unsafe { (*p.offset(index as isize) & 0b11000000) == 0b10000000 } {
while index > 0 && (p[index] & 0b11000000) == 0b10000000 {
index -= 1;
}
&string[0..index]
@@ -829,9 +821,9 @@ fn escape_key_test() {
assert_eq!("foo&lbrace;&rbrace;", escape_key("foo{}"));
assert_eq!("&lbrace;&rbrace;", escape_key("{}"));
assert_eq!("", escape_key(""));
assert_eq!("/path/to/foo.ql", escape_key("/path/to/foo.ql"));
assert_eq!("/path/to/foo.rb", escape_key("/path/to/foo.rb"));
assert_eq!(
"/path/to/foo&amp;&lbrace;&rbrace;&quot;&commat;&num;.ql",
escape_key("/path/to/foo&{}\"@#.ql")
"/path/to/foo&amp;&lbrace;&rbrace;&quot;&commat;&num;.rb",
escape_key("/path/to/foo&{}\"@#.rb")
);
}

View File

@@ -2,13 +2,11 @@ mod extractor;
extern crate num_cpus;
use clap;
use flate2::write::GzEncoder;
use rayon::prelude::*;
use std::fs;
use std::io::{BufRead, BufWriter, Write};
use std::io::{BufRead, BufWriter};
use std::path::{Path, PathBuf};
use tree_sitter::{Language, Parser, Range};
enum TrapCompression {
None,
@@ -40,8 +38,8 @@ impl TrapCompression {
fn extension(&self) -> &str {
match self {
TrapCompression::None => ".trap",
TrapCompression::Gzip => ".trap.gz",
TrapCompression::None => "trap",
TrapCompression::Gzip => "trap.gz",
}
}
}
@@ -54,28 +52,24 @@ impl TrapCompression {
* "If the number is positive, it indicates the number of threads that should
* be used. If the number is negative or zero, it should be added to the number
* of cores available on the machine to determine how many threads to use
* (minimum of 1). If unspecified, should be considered as set to 1."
* (minimum of 1). If unspecified, should be considered as set to -1."
*/
fn num_codeql_threads() -> usize {
match std::env::var("CODEQL_THREADS") {
// Use 1 thread if the environment variable isn't set.
Err(_) => 1,
let threads_str = std::env::var("CODEQL_THREADS").unwrap_or_else(|_| "-1".to_owned());
match threads_str.parse::<i32>() {
Ok(num) if num <= 0 => {
let reduction = -num as usize;
std::cmp::max(1, num_cpus::get() - reduction)
}
Ok(num) => num as usize,
Ok(num) => match num.parse::<i32>() {
Ok(num) if num <= 0 => {
let reduction = -num as usize;
num_cpus::get() - reduction
}
Ok(num) => num as usize,
Err(_) => {
tracing::error!(
"Unable to parse CODEQL_THREADS value '{}'; defaulting to 1 thread.",
&num
);
1
}
},
Err(_) => {
tracing::error!(
"Unable to parse CODEQL_THREADS value '{}'; defaulting to 1 thread.",
&threads_str
);
1
}
}
}
@@ -127,29 +121,55 @@ fn main() -> std::io::Result<()> {
let file_list = fs::File::open(file_list)?;
let language = tree_sitter_ql::language();
let schema = node_types::read_node_types_str(tree_sitter_ql::NODE_TYPES)?;
let schema = node_types::read_node_types_str("ql", tree_sitter_ql::NODE_TYPES)?;
let lines: std::io::Result<Vec<String>> = std::io::BufReader::new(file_list).lines().collect();
let lines = lines?;
lines.par_iter().try_for_each(|line| {
let path = PathBuf::from(line).canonicalize()?;
let trap_file = path_for(&trap_dir, &path, trap_compression.extension());
let src_archive_file = path_for(&src_archive_dir, &path, "");
let mut source = std::fs::read(&path)?;
let code_ranges = vec![];
let trap = extractor::extract(language, &schema, &path, &source, &code_ranges)?;
std::fs::create_dir_all(&src_archive_file.parent().unwrap())?;
std::fs::copy(&path, &src_archive_file)?;
std::fs::create_dir_all(&trap_file.parent().unwrap())?;
let trap_file = std::fs::File::create(&trap_file)?;
let mut trap_file = BufWriter::new(trap_file);
match trap_compression {
TrapCompression::None => write!(trap_file, "{}", trap),
TrapCompression::Gzip => {
let mut compressed_writer = GzEncoder::new(trap_file, flate2::Compression::fast());
write!(compressed_writer, "{}", trap)
}
lines
.par_iter()
.try_for_each(|line| {
let path = PathBuf::from(line).canonicalize()?;
let src_archive_file = path_for(&src_archive_dir, &path, "");
let source = std::fs::read(&path)?;
let code_ranges = vec![];
let mut trap_writer = extractor::new_trap_writer();
extractor::extract(
language,
"ql",
&schema,
&mut trap_writer,
&path,
&source,
&code_ranges,
)?;
std::fs::create_dir_all(&src_archive_file.parent().unwrap())?;
std::fs::copy(&path, &src_archive_file)?;
write_trap(&trap_dir, path, trap_writer, &trap_compression)
})
.expect("failed to extract files");
let path = PathBuf::from("extras");
let mut trap_writer = extractor::new_trap_writer();
trap_writer.populate_empty_location();
write_trap(&trap_dir, path, trap_writer, &trap_compression)
}
fn write_trap(
trap_dir: &Path,
path: PathBuf,
trap_writer: extractor::TrapWriter,
trap_compression: &TrapCompression,
) -> std::io::Result<()> {
let trap_file = path_for(trap_dir, &path, trap_compression.extension());
std::fs::create_dir_all(&trap_file.parent().unwrap())?;
let trap_file = std::fs::File::create(&trap_file)?;
let mut trap_file = BufWriter::new(trap_file);
match trap_compression {
TrapCompression::None => trap_writer.output(&mut trap_file),
TrapCompression::Gzip => {
let mut compressed_writer = GzEncoder::new(trap_file, flate2::Compression::fast());
trap_writer.output(&mut compressed_writer)
}
})
}
}
fn path_for(dir: &Path, path: &Path, ext: &str) -> PathBuf {
@@ -184,12 +204,18 @@ fn path_for(dir: &Path, path: &Path, ext: &str) -> PathBuf {
}
}
}
if let Some(x) = result.extension() {
let mut new_ext = x.to_os_string();
new_ext.push(ext);
result.set_extension(new_ext);
} else {
result.set_extension(ext);
if !ext.is_empty() {
match result.extension() {
Some(x) => {
let mut new_ext = x.to_os_string();
new_ext.push(".");
new_ext.push(ext);
result.set_extension(new_ext);
}
None => {
result.set_extension(ext);
}
}
}
result
}

View File

@@ -7,6 +7,7 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
clap = "2.33"
node-types = { path = "../node-types" }
tracing = "0.1"
tracing-subscriber = { version = "0.2", features = ["env-filter"] }

View File

@@ -68,10 +68,10 @@ impl<'a> fmt::Display for Table<'a> {
}
write!(f, "{}", key)?;
}
write!(f, "]\n")?;
writeln!(f, "]")?;
}
write!(f, "{}(\n", self.name)?;
writeln!(f, "{}(", self.name)?;
for (column_index, column) in self.columns.iter().enumerate() {
write!(f, " ")?;
if column.unique {
@@ -92,7 +92,7 @@ impl<'a> fmt::Display for Table<'a> {
if column_index + 1 != self.columns.len() {
write!(f, ",")?;
}
write!(f, "\n")?;
writeln!(f)?;
}
write!(f, ");")?;
@@ -117,17 +117,7 @@ impl<'a> fmt::Display for Union<'a> {
}
/// Generates the dbscheme by writing the given dbscheme `entries` to the `file`.
pub fn write<'a>(
language_name: &str,
file: &mut dyn std::io::Write,
entries: &'a [Entry],
) -> std::io::Result<()> {
write!(file, "// CodeQL database schema for {}\n", language_name)?;
write!(
file,
"// Automatically generated from the tree-sitter grammar; do not edit\n\n"
)?;
pub fn write<'a>(file: &mut dyn std::io::Write, entries: &'a [Entry]) -> std::io::Result<()> {
for entry in entries {
match entry {
Entry::Case(case) => write!(file, "{}\n\n", case)?,

View File

@@ -1,8 +1,4 @@
use std::path::PathBuf;
pub struct Language {
pub name: String,
pub node_types: &'static str,
pub dbscheme_path: PathBuf,
pub ql_library_path: PathBuf,
}

View File

@@ -8,8 +8,8 @@ use std::collections::BTreeMap as Map;
use std::collections::BTreeSet as Set;
use std::fs::File;
use std::io::LineWriter;
use std::io::Write;
use std::path::PathBuf;
use tracing::{error, info};
/// Given the name of the parent node, and its field information, returns a pair,
/// the first of which is the field's type. The second is an optional dbscheme
@@ -32,7 +32,7 @@ fn make_field_type<'a>(
.map(|t| nodes.get(t).unwrap().dbscheme_name.as_str())
.collect();
(
ql::Type::AtType(&dbscheme_union),
ql::Type::At(dbscheme_union),
Some(dbscheme::Entry::Union(dbscheme::Union {
name: dbscheme_union,
members,
@@ -40,14 +40,14 @@ fn make_field_type<'a>(
)
}
node_types::FieldTypeInfo::Single(t) => {
let dbscheme_name = &nodes.get(&t).unwrap().dbscheme_name;
(ql::Type::AtType(dbscheme_name), None)
let dbscheme_name = &nodes.get(t).unwrap().dbscheme_name;
(ql::Type::At(dbscheme_name), None)
}
node_types::FieldTypeInfo::ReservedWordInt(int_mapping) => {
// The field will be an `int` in the db, and we add a case split to
// create other db types for each integer value.
let mut branches: Vec<(usize, &'a str)> = Vec::new();
for (_, (value, name)) in int_mapping {
for (value, name) in int_mapping.values() {
branches.push((*value, name));
}
let case = dbscheme::Entry::Case(dbscheme::Case {
@@ -73,12 +73,12 @@ fn add_field_for_table_storage<'a>(
let parent_name = &nodes.get(&field.parent).unwrap().dbscheme_name;
// This field can appear zero or multiple times, so put
// it in an auxiliary table.
let (field_ql_type, field_type_entry) = make_field_type(parent_name, &field, nodes);
let (field_ql_type, field_type_entry) = make_field_type(parent_name, field, nodes);
let parent_column = dbscheme::Column {
unique: !has_index,
db_type: dbscheme::DbColumnType::Int,
name: &parent_name,
ql_type: ql::Type::AtType(&parent_name),
name: parent_name,
ql_type: ql::Type::At(parent_name),
ql_type_is_ref: true,
};
let index_column = dbscheme::Column {
@@ -96,7 +96,7 @@ fn add_field_for_table_storage<'a>(
ql_type_is_ref: true,
};
let field_table = dbscheme::Table {
name: &table_name,
name: table_name,
columns: if has_index {
vec![parent_column, index_column, field_column]
} else {
@@ -105,7 +105,7 @@ fn add_field_for_table_storage<'a>(
// In addition to the field being unique, the combination of
// parent+index is unique, so add a keyset for them.
keysets: if has_index {
Some(vec![&parent_name, "index"])
Some(vec![parent_name, "index"])
} else {
None
},
@@ -121,7 +121,7 @@ fn add_field_for_column_storage<'a>(
) -> (dbscheme::Column<'a>, Option<dbscheme::Entry<'a>>) {
// This field must appear exactly once, so we add it as
// a column to the main table for the node type.
let (field_ql_type, field_type_entry) = make_field_type(parent_name, &field, nodes);
let (field_ql_type, field_type_entry) = make_field_type(parent_name, field, nodes);
(
dbscheme::Column {
unique: false,
@@ -135,18 +135,16 @@ fn add_field_for_column_storage<'a>(
}
/// Converts the given tree-sitter node types into CodeQL dbscheme entries.
fn convert_nodes<'a>(nodes: &'a node_types::NodeTypeMap) -> Vec<dbscheme::Entry<'a>> {
let mut entries: Vec<dbscheme::Entry> = vec![
create_location_union(),
create_locations_default_table(),
create_sourceline_union(),
create_numlines_table(),
create_files_table(),
create_folders_table(),
create_container_union(),
create_containerparent_table(),
create_source_location_prefix_table(),
];
/// Returns a tuple containing:
///
/// 1. A vector of dbscheme entries.
/// 2. A set of names of the members of the `<lang>_ast_node` union.
/// 3. A map where the keys are the dbscheme names for token kinds, and the
/// values are their integer representations.
fn convert_nodes(
nodes: &node_types::NodeTypeMap,
) -> (Vec<dbscheme::Entry>, Set<&str>, Map<&str, usize>) {
let mut entries: Vec<dbscheme::Entry> = Vec::new();
let mut ast_node_members: Set<&str> = Set::new();
let token_kinds: Map<&str, usize> = nodes
.iter()
@@ -157,8 +155,7 @@ fn convert_nodes<'a>(nodes: &'a node_types::NodeTypeMap) -> Vec<dbscheme::Entry<
_ => None,
})
.collect();
ast_node_members.insert("token");
for (_, node) in nodes {
for node in nodes.values() {
match &node.kind {
node_types::EntryKind::Union { members: n_members } => {
// It's a tree-sitter supertype node, for which we create a union
@@ -175,12 +172,12 @@ fn convert_nodes<'a>(nodes: &'a node_types::NodeTypeMap) -> Vec<dbscheme::Entry<
node_types::EntryKind::Table { name, fields } => {
// It's a product type, defined by a table.
let mut main_table = dbscheme::Table {
name: &name,
name,
columns: vec![dbscheme::Column {
db_type: dbscheme::DbColumnType::Int,
name: "id",
unique: true,
ql_type: ql::Type::AtType(&node.dbscheme_name),
ql_type: ql::Type::At(&node.dbscheme_name),
ql_type_is_ref: false,
}],
keysets: None,
@@ -240,7 +237,7 @@ fn convert_nodes<'a>(nodes: &'a node_types::NodeTypeMap) -> Vec<dbscheme::Entry<
unique: false,
db_type: dbscheme::DbColumnType::Int,
name: "loc",
ql_type: ql::Type::AtType("location"),
ql_type: ql::Type::At("location"),
ql_type_is_ref: true,
});
@@ -250,48 +247,30 @@ fn convert_nodes<'a>(nodes: &'a node_types::NodeTypeMap) -> Vec<dbscheme::Entry<
}
}
// Add the tokeninfo table
let (token_case, token_table) = create_tokeninfo(token_kinds);
entries.push(dbscheme::Entry::Table(token_table));
entries.push(dbscheme::Entry::Case(token_case));
// Add the diagnostics table
let (diagnostics_case, diagnostics_table) = create_diagnostics();
entries.push(dbscheme::Entry::Table(diagnostics_table));
entries.push(dbscheme::Entry::Case(diagnostics_case));
// Create a union of all database types.
entries.push(dbscheme::Entry::Union(dbscheme::Union {
name: "ast_node",
members: ast_node_members,
}));
// Create the ast_node_parent union.
entries.push(dbscheme::Entry::Union(dbscheme::Union {
name: "ast_node_parent",
members: ["ast_node", "file"].iter().cloned().collect(),
}));
entries.push(dbscheme::Entry::Table(create_ast_node_parent_table()));
entries
(entries, ast_node_members, token_kinds)
}
fn create_ast_node_parent_table<'a>() -> dbscheme::Table<'a> {
/// Creates a dbscheme table entry representing the parent relation for AST nodes.
///
/// # Arguments
/// - `name` - the name of both the table to create and the node parent type.
/// - `ast_node_name` - the name of the node child type.
fn create_ast_node_parent_table<'a>(name: &'a str, ast_node_name: &'a str) -> dbscheme::Table<'a> {
dbscheme::Table {
name: "ast_node_parent",
name,
columns: vec![
dbscheme::Column {
db_type: dbscheme::DbColumnType::Int,
name: "child",
unique: false,
ql_type: ql::Type::AtType("ast_node"),
ql_type: ql::Type::At(ast_node_name),
ql_type_is_ref: true,
},
dbscheme::Column {
db_type: dbscheme::DbColumnType::Int,
name: "parent",
unique: false,
ql_type: ql::Type::AtType("ast_node_parent"),
ql_type: ql::Type::At(name),
ql_type_is_ref: true,
},
dbscheme::Column {
@@ -306,18 +285,16 @@ fn create_ast_node_parent_table<'a>() -> dbscheme::Table<'a> {
}
}
fn create_tokeninfo<'a>(
token_kinds: Map<&'a str, usize>,
) -> (dbscheme::Case<'a>, dbscheme::Table<'a>) {
let table = dbscheme::Table {
name: "tokeninfo",
fn create_tokeninfo<'a>(name: &'a str, type_name: &'a str) -> dbscheme::Table<'a> {
dbscheme::Table {
name,
keysets: None,
columns: vec![
dbscheme::Column {
db_type: dbscheme::DbColumnType::Int,
name: "id",
unique: true,
ql_type: ql::Type::AtType("token"),
ql_type: ql::Type::At(type_name),
ql_type_is_ref: false,
},
dbscheme::Column {
@@ -327,20 +304,6 @@ fn create_tokeninfo<'a>(
ql_type: ql::Type::Int,
ql_type_is_ref: true,
},
dbscheme::Column {
unique: false,
db_type: dbscheme::DbColumnType::Int,
name: "file",
ql_type: ql::Type::AtType("file"),
ql_type_is_ref: true,
},
dbscheme::Column {
unique: false,
db_type: dbscheme::DbColumnType::Int,
name: "idx",
ql_type: ql::Type::Int,
ql_type_is_ref: true,
},
dbscheme::Column {
unique: false,
db_type: dbscheme::DbColumnType::String,
@@ -352,35 +315,23 @@ fn create_tokeninfo<'a>(
unique: false,
db_type: dbscheme::DbColumnType::Int,
name: "loc",
ql_type: ql::Type::AtType("location"),
ql_type: ql::Type::At("location"),
ql_type_is_ref: true,
},
],
};
}
}
fn create_token_case<'a>(name: &'a str, token_kinds: Map<&'a str, usize>) -> dbscheme::Case<'a> {
let branches: Vec<(usize, &str)> = token_kinds
.iter()
.map(|(&name, kind_id)| (*kind_id, name))
.collect();
let case = dbscheme::Case {
name: "token",
dbscheme::Case {
name,
column: "kind",
branches: branches,
};
(case, table)
}
fn write_dbscheme(language: &Language, entries: &[dbscheme::Entry]) -> std::io::Result<()> {
info!(
"Writing database schema for {} to '{}'",
&language.name,
match language.dbscheme_path.to_str() {
None => "<undisplayable>",
Some(p) => p,
}
);
let file = File::create(&language.dbscheme_path)?;
let mut file = LineWriter::new(file);
dbscheme::write(&language.name, &mut file, &entries)
branches,
}
}
fn create_location_union<'a>() -> dbscheme::Entry<'a> {
@@ -399,7 +350,7 @@ fn create_files_table<'a>() -> dbscheme::Entry<'a> {
unique: true,
db_type: dbscheme::DbColumnType::Int,
name: "id",
ql_type: ql::Type::AtType("file"),
ql_type: ql::Type::At("file"),
ql_type_is_ref: false,
},
dbscheme::Column {
@@ -409,27 +360,6 @@ fn create_files_table<'a>() -> dbscheme::Entry<'a> {
ql_type: ql::Type::String,
ql_type_is_ref: true,
},
dbscheme::Column {
db_type: dbscheme::DbColumnType::String,
name: "simple",
unique: false,
ql_type: ql::Type::String,
ql_type_is_ref: true,
},
dbscheme::Column {
db_type: dbscheme::DbColumnType::String,
name: "ext",
unique: false,
ql_type: ql::Type::String,
ql_type_is_ref: true,
},
dbscheme::Column {
db_type: dbscheme::DbColumnType::Int,
name: "fromSource",
unique: false,
ql_type: ql::Type::Int,
ql_type_is_ref: true,
},
],
})
}
@@ -442,7 +372,7 @@ fn create_folders_table<'a>() -> dbscheme::Entry<'a> {
unique: true,
db_type: dbscheme::DbColumnType::Int,
name: "id",
ql_type: ql::Type::AtType("folder"),
ql_type: ql::Type::At("folder"),
ql_type_is_ref: false,
},
dbscheme::Column {
@@ -452,13 +382,6 @@ fn create_folders_table<'a>() -> dbscheme::Entry<'a> {
ql_type: ql::Type::String,
ql_type_is_ref: true,
},
dbscheme::Column {
db_type: dbscheme::DbColumnType::String,
name: "simple",
unique: false,
ql_type: ql::Type::String,
ql_type_is_ref: true,
},
],
})
}
@@ -472,14 +395,14 @@ fn create_locations_default_table<'a>() -> dbscheme::Entry<'a> {
unique: true,
db_type: dbscheme::DbColumnType::Int,
name: "id",
ql_type: ql::Type::AtType("location_default"),
ql_type: ql::Type::At("location_default"),
ql_type_is_ref: false,
},
dbscheme::Column {
unique: false,
db_type: dbscheme::DbColumnType::Int,
name: "file",
ql_type: ql::Type::AtType("file"),
ql_type: ql::Type::At("file"),
ql_type_is_ref: true,
},
dbscheme::Column {
@@ -514,50 +437,6 @@ fn create_locations_default_table<'a>() -> dbscheme::Entry<'a> {
})
}
fn create_sourceline_union<'a>() -> dbscheme::Entry<'a> {
dbscheme::Entry::Union(dbscheme::Union {
name: "sourceline",
members: vec!["file"].into_iter().collect(),
})
}
fn create_numlines_table<'a>() -> dbscheme::Entry<'a> {
dbscheme::Entry::Table(dbscheme::Table {
name: "numlines",
columns: vec![
dbscheme::Column {
unique: false,
db_type: dbscheme::DbColumnType::Int,
name: "element_id",
ql_type: ql::Type::AtType("sourceline"),
ql_type_is_ref: true,
},
dbscheme::Column {
unique: false,
db_type: dbscheme::DbColumnType::Int,
name: "num_lines",
ql_type: ql::Type::Int,
ql_type_is_ref: true,
},
dbscheme::Column {
unique: false,
db_type: dbscheme::DbColumnType::Int,
name: "num_code",
ql_type: ql::Type::Int,
ql_type_is_ref: true,
},
dbscheme::Column {
unique: false,
db_type: dbscheme::DbColumnType::Int,
name: "num_comment",
ql_type: ql::Type::Int,
ql_type_is_ref: true,
},
],
keysets: None,
})
}
fn create_container_union<'a>() -> dbscheme::Entry<'a> {
dbscheme::Entry::Union(dbscheme::Union {
name: "container",
@@ -573,14 +452,14 @@ fn create_containerparent_table<'a>() -> dbscheme::Entry<'a> {
unique: false,
db_type: dbscheme::DbColumnType::Int,
name: "parent",
ql_type: ql::Type::AtType("container"),
ql_type: ql::Type::At("container"),
ql_type_is_ref: true,
},
dbscheme::Column {
unique: true,
db_type: dbscheme::DbColumnType::Int,
name: "child",
ql_type: ql::Type::AtType("container"),
ql_type: ql::Type::At("container"),
ql_type_is_ref: true,
},
],
@@ -611,7 +490,7 @@ fn create_diagnostics<'a>() -> (dbscheme::Case<'a>, dbscheme::Table<'a>) {
unique: true,
db_type: dbscheme::DbColumnType::Int,
name: "id",
ql_type: ql::Type::AtType("diagnostic"),
ql_type: ql::Type::At("diagnostic"),
ql_type_is_ref: false,
},
dbscheme::Column {
@@ -646,7 +525,7 @@ fn create_diagnostics<'a>() -> (dbscheme::Case<'a>, dbscheme::Table<'a>) {
unique: false,
db_type: dbscheme::DbColumnType::Int,
name: "location",
ql_type: ql::Type::AtType("location_default"),
ql_type: ql::Type::At("location_default"),
ql_type_is_ref: true,
},
],
@@ -665,7 +544,7 @@ fn create_diagnostics<'a>() -> (dbscheme::Case<'a>, dbscheme::Table<'a>) {
(case, table)
}
fn main() {
fn main() -> std::io::Result<()> {
tracing_subscriber::fmt()
.with_target(false)
.without_time()
@@ -673,33 +552,114 @@ fn main() {
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.init();
// TODO: figure out proper dbscheme output path and/or take it from the
// command line.
let ql = Language {
let matches = clap::App::new("QL dbscheme generator")
.version("1.0")
.author("GitHub")
.about("CodeQL QL dbscheme generator")
.args_from_usage(
"--dbscheme=<FILE> 'Path of the generated dbscheme file'
--library=<FILE> 'Path of the generated QLL file'",
)
.get_matches();
let dbscheme_path = matches.value_of("dbscheme").expect("missing --dbscheme");
let dbscheme_path = PathBuf::from(dbscheme_path);
let ql_library_path = matches.value_of("library").expect("missing --library");
let ql_library_path = PathBuf::from(ql_library_path);
let languages = vec![Language {
name: "QL".to_owned(),
node_types: tree_sitter_ql::NODE_TYPES,
dbscheme_path: PathBuf::from("ql/src/ql.dbscheme"),
ql_library_path: PathBuf::from("ql/src/codeql_ql/ast/internal/TreeSitter.qll"),
};
match node_types::read_node_types_str(&ql.node_types) {
Err(e) => {
error!("Failed to read node-types JSON for {}: {}", ql.name, e);
std::process::exit(1);
}
Ok(nodes) => {
let dbscheme_entries = convert_nodes(&nodes);
}];
let mut dbscheme_writer = LineWriter::new(File::create(dbscheme_path)?);
write!(
dbscheme_writer,
"// CodeQL database schema for {}\n\
// Automatically generated from the tree-sitter grammar; do not edit\n\n",
languages[0].name
)?;
let (diagnostics_case, diagnostics_table) = create_diagnostics();
dbscheme::write(
&mut dbscheme_writer,
&[
create_location_union(),
create_locations_default_table(),
create_files_table(),
create_folders_table(),
create_container_union(),
create_containerparent_table(),
create_source_location_prefix_table(),
dbscheme::Entry::Table(diagnostics_table),
dbscheme::Entry::Case(diagnostics_case),
],
)?;
if let Err(e) = write_dbscheme(&ql, &dbscheme_entries) {
error!("Failed to write dbscheme: {}", e);
std::process::exit(2);
}
let mut ql_writer = LineWriter::new(File::create(ql_library_path)?);
write!(
ql_writer,
"/*\n\
* CodeQL library for {}
* Automatically generated from the tree-sitter grammar; do not edit\n\
*/\n\n",
languages[0].name
)?;
ql::write(
&mut ql_writer,
&[
ql::TopLevel::Import("codeql.files.FileSystem"),
ql::TopLevel::Import("codeql.Locations"),
],
)?;
let classes = ql_gen::convert_nodes(&nodes);
for language in languages {
let prefix = node_types::to_snake_case(&language.name);
let ast_node_name = format!("{}_ast_node", &prefix);
let ast_node_parent_name = format!("{}_ast_node_parent", &prefix);
let token_name = format!("{}_token", &prefix);
let tokeninfo_name = format!("{}_tokeninfo", &prefix);
let reserved_word_name = format!("{}_reserved_word", &prefix);
let nodes = node_types::read_node_types_str(&prefix, language.node_types)?;
let (dbscheme_entries, mut ast_node_members, token_kinds) = convert_nodes(&nodes);
ast_node_members.insert(&token_name);
dbscheme::write(&mut dbscheme_writer, &dbscheme_entries)?;
let token_case = create_token_case(&token_name, token_kinds);
dbscheme::write(
&mut dbscheme_writer,
&[
dbscheme::Entry::Table(create_tokeninfo(&tokeninfo_name, &token_name)),
dbscheme::Entry::Case(token_case),
dbscheme::Entry::Union(dbscheme::Union {
name: &ast_node_name,
members: ast_node_members,
}),
dbscheme::Entry::Union(dbscheme::Union {
name: &ast_node_parent_name,
members: [&ast_node_name, "file"].iter().cloned().collect(),
}),
dbscheme::Entry::Table(create_ast_node_parent_table(
&ast_node_parent_name,
&ast_node_name,
)),
],
)?;
if let Err(e) = ql_gen::write(&ql, &classes) {
println!("Failed to write QL library: {}", e);
std::process::exit(3);
}
}
let mut body = vec![
ql::TopLevel::Class(ql_gen::create_ast_node_class(
&ast_node_name,
&ast_node_parent_name,
)),
ql::TopLevel::Class(ql_gen::create_token_class(&token_name, &tokeninfo_name)),
ql::TopLevel::Class(ql_gen::create_reserved_word_class(&reserved_word_name)),
];
body.append(&mut ql_gen::convert_nodes(&nodes));
ql::write(
&mut ql_writer,
&[ql::TopLevel::Module(ql::Module {
qldoc: None,
name: &language.name,
body,
})],
)?;
}
Ok(())
}

View File

@@ -1,9 +1,11 @@
use std::collections::BTreeSet;
use std::fmt;
#[derive(Clone, Eq, PartialEq, Hash)]
pub enum TopLevel<'a> {
Class(Class<'a>),
Import(&'a str),
Module(Module<'a>),
}
impl<'a> fmt::Display for TopLevel<'a> {
@@ -11,6 +13,7 @@ impl<'a> fmt::Display for TopLevel<'a> {
match self {
TopLevel::Import(x) => write!(f, "private import {}", x),
TopLevel::Class(cls) => write!(f, "{}", cls),
TopLevel::Module(m) => write!(f, "{}", m),
}
}
}
@@ -40,15 +43,15 @@ impl<'a> fmt::Display for Class<'a> {
}
write!(f, "{}", supertype)?;
}
write!(f, " {{ \n")?;
writeln!(f, " {{ ")?;
if let Some(charpred) = &self.characteristic_predicate {
write!(
writeln!(
f,
" {}\n",
" {}",
Predicate {
qldoc: None,
name: self.name.clone(),
name: self.name,
overridden: false,
return_type: None,
formal_parameters: vec![],
@@ -58,7 +61,7 @@ impl<'a> fmt::Display for Class<'a> {
}
for predicate in &self.predicates {
write!(f, " {}\n", predicate)?;
writeln!(f, " {}", predicate)?;
}
write!(f, "}}")?;
@@ -67,6 +70,26 @@ impl<'a> fmt::Display for Class<'a> {
}
}
#[derive(Clone, Eq, PartialEq, Hash)]
pub struct Module<'a> {
pub qldoc: Option<String>,
pub name: &'a str,
pub body: Vec<TopLevel<'a>>,
}
impl<'a> fmt::Display for Module<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(qldoc) = &self.qldoc {
write!(f, "/** {} */", qldoc)?;
}
writeln!(f, "module {} {{ ", self.name)?;
for decl in &self.body {
writeln!(f, " {}", decl)?;
}
write!(f, "}}")?;
Ok(())
}
}
// The QL type of a column.
#[derive(Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub enum Type<'a> {
@@ -77,7 +100,7 @@ pub enum Type<'a> {
String,
/// A database type that will need to be referred to with an `@` prefix.
AtType(&'a str),
At(&'a str),
/// A user-defined type.
Normal(&'a str),
@@ -89,7 +112,7 @@ impl<'a> fmt::Display for Type<'a> {
Type::Int => write!(f, "int"),
Type::String => write!(f, "string"),
Type::Normal(name) => write!(f, "{}", name),
Type::AtType(name) => write!(f, "@{}", name),
Type::At(name) => write!(f, "@{}", name),
}
}
}
@@ -104,12 +127,13 @@ pub enum Expression<'a> {
Or(Vec<Expression<'a>>),
Equals(Box<Expression<'a>>, Box<Expression<'a>>),
Dot(Box<Expression<'a>>, &'a str, Vec<Expression<'a>>),
Aggregate(
&'a str,
Vec<FormalParameter<'a>>,
Box<Expression<'a>>,
Box<Expression<'a>>,
),
Aggregate {
name: &'a str,
vars: Vec<FormalParameter<'a>>,
range: Option<Box<Expression<'a>>>,
expr: Box<Expression<'a>>,
second_expr: Option<Box<Expression<'a>>>,
},
}
impl<'a> fmt::Display for Expression<'a> {
@@ -165,15 +189,31 @@ impl<'a> fmt::Display for Expression<'a> {
}
write!(f, ")")
}
Expression::Aggregate(n, vars, range, term) => {
write!(f, "{}(", n)?;
for (index, var) in vars.iter().enumerate() {
if index > 0 {
write!(f, ", ")?;
Expression::Aggregate {
name,
vars,
range,
expr,
second_expr,
} => {
write!(f, "{}(", name)?;
if !vars.is_empty() {
for (index, var) in vars.iter().enumerate() {
if index > 0 {
write!(f, ", ")?;
}
write!(f, "{}", var)?;
}
write!(f, "{}", var)?;
write!(f, " | ")?;
}
write!(f, " | {} | {})", range, term)
if let Some(range) = range {
write!(f, "{} | ", range)?;
}
write!(f, "{}", expr)?;
if let Some(second_expr) = second_expr {
write!(f, ", {}", second_expr)?;
}
write!(f, ")")
}
}
}
@@ -226,25 +266,10 @@ impl<'a> fmt::Display for FormalParameter<'a> {
}
}
/// Generates a QL library by writing the given `classes` to the `file`.
pub fn write<'a>(
language_name: &str,
file: &mut dyn std::io::Write,
elements: &'a [TopLevel],
) -> std::io::Result<()> {
write!(file, "/*\n")?;
write!(file, " * CodeQL library for {}\n", language_name)?;
write!(
file,
" * Automatically generated from the tree-sitter grammar; do not edit\n"
)?;
write!(file, " */\n\n")?;
write!(file, "module Generated {{\n")?;
/// Generates a QL library by writing the given `elements` to the `file`.
pub fn write<'a>(file: &mut dyn std::io::Write, elements: &'a [TopLevel]) -> std::io::Result<()> {
for element in elements {
write!(file, "{}\n\n", &element)?;
}
write!(file, "}}")?;
Ok(())
}

View File

@@ -1,32 +1,9 @@
use crate::language::Language;
use crate::ql;
use std::collections::BTreeSet;
use std::fs::File;
use std::io::LineWriter;
/// Writes the QL AST library for the given library.
///
/// # Arguments
///
/// `language` - the language for which we're generating a library
/// `classes` - the list of classes to write.
pub fn write(language: &Language, classes: &[ql::TopLevel]) -> std::io::Result<()> {
println!(
"Writing QL library for {} to '{}'",
&language.name,
match language.ql_library_path.to_str() {
None => "<undisplayable>",
Some(p) => p,
}
);
let file = File::create(&language.ql_library_path)?;
let mut file = LineWriter::new(file);
ql::write(&language.name, &mut file, &classes)
}
/// Creates the hard-coded `AstNode` class that acts as a supertype of all
/// classes we generate.
fn create_ast_node_class<'a>() -> ql::Class<'a> {
pub fn create_ast_node_class<'a>(ast_node: &'a str, ast_node_parent: &'a str) -> ql::Class<'a> {
// Default implementation of `toString` calls `this.getAPrimaryQlClass()`
let to_string = ql::Predicate {
qldoc: Some(String::from(
@@ -64,7 +41,7 @@ fn create_ast_node_class<'a>() -> ql::Class<'a> {
return_type: Some(ql::Type::Normal("AstNode")),
formal_parameters: vec![],
body: ql::Expression::Pred(
"ast_node_parent",
ast_node_parent,
vec![
ql::Expression::Var("this"),
ql::Expression::Var("result"),
@@ -81,7 +58,7 @@ fn create_ast_node_class<'a>() -> ql::Class<'a> {
return_type: Some(ql::Type::Int),
formal_parameters: vec![],
body: ql::Expression::Pred(
"ast_node_parent",
ast_node_parent,
vec![
ql::Expression::Var("this"),
ql::Expression::Var("_"),
@@ -102,11 +79,36 @@ fn create_ast_node_class<'a>() -> ql::Class<'a> {
Box::new(ql::Expression::String("???")),
),
};
let get_primary_ql_classes = ql::Predicate {
qldoc: Some(
"Gets a comma-separated list of the names of the primary CodeQL \
classes to which this element belongs."
.to_owned(),
),
name: "getPrimaryQlClasses",
overridden: false,
return_type: Some(ql::Type::String),
formal_parameters: vec![],
body: ql::Expression::Equals(
Box::new(ql::Expression::Var("result")),
Box::new(ql::Expression::Aggregate {
name: "concat",
vars: vec![],
range: None,
expr: Box::new(ql::Expression::Dot(
Box::new(ql::Expression::Var("this")),
"getAPrimaryQlClass",
vec![],
)),
second_expr: Some(Box::new(ql::Expression::String(","))),
}),
),
};
ql::Class {
qldoc: Some(String::from("The base class for all AST nodes")),
name: "AstNode",
is_abstract: false,
supertypes: vec![ql::Type::AtType("ast_node")].into_iter().collect(),
supertypes: vec![ql::Type::At(ast_node)].into_iter().collect(),
characteristic_predicate: None,
predicates: vec![
to_string,
@@ -115,19 +117,20 @@ fn create_ast_node_class<'a>() -> ql::Class<'a> {
get_parent_index,
get_a_field_or_child,
get_a_primary_ql_class,
get_primary_ql_classes,
],
}
}
fn create_token_class<'a>() -> ql::Class<'a> {
let tokeninfo_arity = 6;
pub fn create_token_class<'a>(token_type: &'a str, tokeninfo: &'a str) -> ql::Class<'a> {
let tokeninfo_arity = 4;
let get_value = ql::Predicate {
qldoc: Some(String::from("Gets the value of this token.")),
name: "getValue",
overridden: false,
return_type: Some(ql::Type::String),
formal_parameters: vec![],
body: create_get_field_expr_for_column_storage("result", "tokeninfo", 3, tokeninfo_arity),
body: create_get_field_expr_for_column_storage("result", tokeninfo, 1, tokeninfo_arity),
};
let get_location = ql::Predicate {
qldoc: Some(String::from("Gets the location of this token.")),
@@ -135,7 +138,7 @@ fn create_token_class<'a>() -> ql::Class<'a> {
overridden: true,
return_type: Some(ql::Type::Normal("Location")),
formal_parameters: vec![],
body: create_get_field_expr_for_column_storage("result", "tokeninfo", 4, tokeninfo_arity),
body: create_get_field_expr_for_column_storage("result", tokeninfo, 2, tokeninfo_arity),
};
let to_string = ql::Predicate {
qldoc: Some(String::from(
@@ -147,14 +150,18 @@ fn create_token_class<'a>() -> ql::Class<'a> {
formal_parameters: vec![],
body: ql::Expression::Equals(
Box::new(ql::Expression::Var("result")),
Box::new(ql::Expression::Pred("getValue", vec![])),
Box::new(ql::Expression::Dot(
Box::new(ql::Expression::Var("this")),
"getValue",
vec![],
)),
),
};
ql::Class {
qldoc: Some(String::from("A token.")),
name: "Token",
is_abstract: false,
supertypes: vec![ql::Type::AtType("token"), ql::Type::Normal("AstNode")]
supertypes: vec![ql::Type::At(token_type), ql::Type::Normal("AstNode")]
.into_iter()
.collect(),
characteristic_predicate: None,
@@ -168,15 +175,14 @@ fn create_token_class<'a>() -> ql::Class<'a> {
}
// Creates the `ReservedWord` class.
fn create_reserved_word_class<'a>() -> ql::Class<'a> {
let db_name = "reserved_word";
pub fn create_reserved_word_class(db_name: &str) -> ql::Class {
let class_name = "ReservedWord";
let get_a_primary_ql_class = create_get_a_primary_ql_class(&class_name);
let get_a_primary_ql_class = create_get_a_primary_ql_class(class_name);
ql::Class {
qldoc: Some(String::from("A reserved word.")),
name: class_name,
is_abstract: false,
supertypes: vec![ql::Type::AtType(db_name), ql::Type::Normal("Token")]
supertypes: vec![ql::Type::At(db_name), ql::Type::Normal("Token")]
.into_iter()
.collect(),
characteristic_predicate: None,
@@ -192,8 +198,8 @@ fn create_none_predicate<'a>(
return_type: Option<ql::Type<'a>>,
) -> ql::Predicate<'a> {
ql::Predicate {
qldoc: qldoc,
name: name,
qldoc,
name,
overridden,
return_type,
formal_parameters: Vec::new(),
@@ -203,7 +209,7 @@ fn create_none_predicate<'a>(
/// Creates an overridden `getAPrimaryQlClass` predicate that returns the given
/// name.
fn create_get_a_primary_ql_class<'a>(class_name: &'a str) -> ql::Predicate<'a> {
fn create_get_a_primary_ql_class(class_name: &str) -> ql::Predicate {
ql::Predicate {
qldoc: Some(String::from(
"Gets the name of the primary QL class for this element.",
@@ -225,7 +231,7 @@ fn create_get_a_primary_ql_class<'a>(class_name: &'a str) -> ql::Predicate<'a> {
///
/// `def_table` - the name of the table that defines the entity and its location.
/// `arity` - the total number of columns in the table
fn create_get_location_predicate<'a>(def_table: &'a str, arity: usize) -> ql::Predicate<'a> {
fn create_get_location_predicate(def_table: &str, arity: usize) -> ql::Predicate {
ql::Predicate {
qldoc: Some(String::from("Gets the location of this element.")),
name: "getLocation",
@@ -250,7 +256,7 @@ fn create_get_location_predicate<'a>(def_table: &'a str, arity: usize) -> ql::Pr
/// # Arguments
///
/// `def_table` - the name of the table that defines the entity and its text.
fn create_get_text_predicate<'a>(def_table: &'a str) -> ql::Predicate<'a> {
fn create_get_text_predicate(def_table: &str) -> ql::Predicate {
ql::Predicate {
qldoc: Some(String::from("Gets the text content of this element.")),
name: "getText",
@@ -341,13 +347,13 @@ fn create_field_getters<'a>(
) -> (ql::Predicate<'a>, Option<ql::Expression<'a>>) {
let return_type = match &field.type_info {
node_types::FieldTypeInfo::Single(t) => {
Some(ql::Type::Normal(&nodes.get(&t).unwrap().ql_class_name))
Some(ql::Type::Normal(&nodes.get(t).unwrap().ql_class_name))
}
node_types::FieldTypeInfo::Multiple {
types: _,
dbscheme_union: _,
ql_class,
} => Some(ql::Type::Normal(&ql_class)),
} => Some(ql::Type::Normal(ql_class)),
node_types::FieldTypeInfo::ReservedWordInt(_) => Some(ql::Type::String),
};
let formal_parameters = match &field.storage {
@@ -383,13 +389,13 @@ fn create_field_getters<'a>(
(
create_get_field_expr_for_column_storage(
get_value_result_var_name,
&main_table_name,
main_table_name,
column_index,
main_table_arity,
),
create_get_field_expr_for_column_storage(
get_value_result_var_name,
&main_table_name,
main_table_name,
column_index,
main_table_arity,
),
@@ -402,12 +408,12 @@ fn create_field_getters<'a>(
} => (
create_get_field_expr_for_table_storage(
get_value_result_var_name,
&field_table_name,
field_table_name,
if *has_index { Some("i") } else { None },
),
create_get_field_expr_for_table_storage(
get_value_result_var_name,
&field_table_name,
field_table_name,
if *has_index { Some("_") } else { None },
),
),
@@ -434,15 +440,16 @@ fn create_field_getters<'a>(
})
.collect();
(
ql::Expression::Aggregate(
"exists",
vec![ql::FormalParameter {
ql::Expression::Aggregate {
name: "exists",
vars: vec![ql::FormalParameter {
name: "value",
param_type: ql::Type::Int,
}],
Box::new(get_value),
Box::new(ql::Expression::Or(disjuncts)),
),
range: Some(Box::new(get_value)),
expr: Box::new(ql::Expression::Or(disjuncts)),
second_expr: None,
},
// Since the getter returns a string and not an AstNode, it won't be part of getAFieldOrChild:
None,
)
@@ -452,11 +459,9 @@ fn create_field_getters<'a>(
}
};
let qldoc = match &field.name {
Some(name) => {
format!("Gets the node corresponding to the field `{}`.", name)
}
Some(name) => format!("Gets the node corresponding to the field `{}`.", name),
None => {
if formal_parameters.len() == 0 {
if formal_parameters.is_empty() {
"Gets the child of this node.".to_owned()
} else {
"Gets the `i`th child of this node.".to_owned()
@@ -477,14 +482,8 @@ fn create_field_getters<'a>(
}
/// Converts the given node types into CodeQL classes wrapping the dbscheme.
pub fn convert_nodes<'a>(nodes: &'a node_types::NodeTypeMap) -> Vec<ql::TopLevel<'a>> {
let mut classes: Vec<ql::TopLevel> = vec![
ql::TopLevel::Import("codeql.files.FileSystem"),
ql::TopLevel::Import("codeql.Locations"),
ql::TopLevel::Class(create_ast_node_class()),
ql::TopLevel::Class(create_token_class()),
ql::TopLevel::Class(create_reserved_word_class()),
];
pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec<ql::TopLevel> {
let mut classes: Vec<ql::TopLevel> = Vec::new();
let mut token_kinds = BTreeSet::new();
for (type_name, node) in nodes {
if let node_types::EntryKind::Token { .. } = &node.kind {
@@ -500,7 +499,7 @@ pub fn convert_nodes<'a>(nodes: &'a node_types::NodeTypeMap) -> Vec<ql::TopLevel
if type_name.named {
let get_a_primary_ql_class = create_get_a_primary_ql_class(&node.ql_class_name);
let mut supertypes: BTreeSet<ql::Type> = BTreeSet::new();
supertypes.insert(ql::Type::AtType(&node.dbscheme_name));
supertypes.insert(ql::Type::At(&node.dbscheme_name));
supertypes.insert(ql::Type::Normal("Token"));
classes.push(ql::TopLevel::Class(ql::Class {
qldoc: Some(format!("A class representing `{}` tokens.", type_name.kind)),
@@ -520,7 +519,7 @@ pub fn convert_nodes<'a>(nodes: &'a node_types::NodeTypeMap) -> Vec<ql::TopLevel
name: &node.ql_class_name,
is_abstract: false,
supertypes: vec![
ql::Type::AtType(&node.dbscheme_name),
ql::Type::At(&node.dbscheme_name),
ql::Type::Normal("AstNode"),
]
.into_iter()
@@ -551,25 +550,25 @@ pub fn convert_nodes<'a>(nodes: &'a node_types::NodeTypeMap) -> Vec<ql::TopLevel
let main_class_name = &node.ql_class_name;
let mut main_class = ql::Class {
qldoc: Some(format!("A class representing `{}` nodes.", type_name.kind)),
name: &main_class_name,
name: main_class_name,
is_abstract: false,
supertypes: vec![
ql::Type::AtType(&node.dbscheme_name),
ql::Type::At(&node.dbscheme_name),
ql::Type::Normal("AstNode"),
]
.into_iter()
.collect(),
characteristic_predicate: None,
predicates: vec![
create_get_a_primary_ql_class(&main_class_name),
create_get_location_predicate(&main_table_name, main_table_arity),
create_get_a_primary_ql_class(main_class_name),
create_get_location_predicate(main_table_name, main_table_arity),
],
};
if fields.is_empty() {
main_class
.predicates
.push(create_get_text_predicate(&main_table_name));
.push(create_get_text_predicate(main_table_name));
} else {
let mut main_table_column_index: usize = 0;
let mut get_child_exprs: Vec<ql::Expression> = Vec::new();
@@ -580,7 +579,7 @@ pub fn convert_nodes<'a>(nodes: &'a node_types::NodeTypeMap) -> Vec<ql::TopLevel
// - the QL expressions to access the fields that will be part of getAFieldOrChild.
for field in fields {
let (get_pred, get_child_expr) = create_field_getters(
&main_table_name,
main_table_name,
main_table_arity,
&mut main_table_column_index,
field,

View File

@@ -81,15 +81,15 @@ pub enum Storage {
},
}
pub fn read_node_types(node_types_path: &Path) -> std::io::Result<NodeTypeMap> {
pub fn read_node_types(prefix: &str, node_types_path: &Path) -> std::io::Result<NodeTypeMap> {
let file = fs::File::open(node_types_path)?;
let node_types = serde_json::from_reader(file)?;
Ok(convert_nodes(&node_types))
let node_types: Vec<NodeInfo> = serde_json::from_reader(file)?;
Ok(convert_nodes(prefix, &node_types))
}
pub fn read_node_types_str(node_types_json: &str) -> std::io::Result<NodeTypeMap> {
let node_types = serde_json::from_str(node_types_json)?;
Ok(convert_nodes(&node_types))
pub fn read_node_types_str(prefix: &str, node_types_json: &str) -> std::io::Result<NodeTypeMap> {
let node_types: Vec<NodeInfo> = serde_json::from_str(node_types_json)?;
Ok(convert_nodes(prefix, &node_types))
}
fn convert_type(node_type: &NodeType) -> TypeName {
@@ -99,32 +99,33 @@ fn convert_type(node_type: &NodeType) -> TypeName {
}
}
fn convert_types(node_types: &Vec<NodeType>) -> Set<TypeName> {
let iter = node_types.iter().map(convert_type).collect();
std::collections::BTreeSet::from(iter)
fn convert_types(node_types: &[NodeType]) -> Set<TypeName> {
node_types.iter().map(convert_type).collect()
}
pub fn convert_nodes(nodes: &Vec<NodeInfo>) -> NodeTypeMap {
pub fn convert_nodes(prefix: &str, nodes: &[NodeInfo]) -> NodeTypeMap {
let mut entries = NodeTypeMap::new();
let mut token_kinds = Set::new();
// First, find all the token kinds
for node in nodes {
if node.subtypes.is_none() {
if node.fields.as_ref().map_or(0, |x| x.len()) == 0 && node.children.is_none() {
let type_name = TypeName {
kind: node.kind.clone(),
named: node.named,
};
token_kinds.insert(type_name);
}
if node.subtypes.is_none()
&& node.fields.as_ref().map_or(0, |x| x.len()) == 0
&& node.children.is_none()
{
let type_name = TypeName {
kind: node.kind.clone(),
named: node.named,
};
token_kinds.insert(type_name);
}
}
for node in nodes {
let flattened_name = &node_type_name(&node.kind, node.named);
let dbscheme_name = escape_name(&flattened_name);
let dbscheme_name = escape_name(flattened_name);
let ql_class_name = dbscheme_name_to_class_name(&dbscheme_name);
let dbscheme_name = format!("{}_{}", prefix, &dbscheme_name);
if let Some(subtypes) = &node.subtypes {
// It's a tree-sitter supertype node, for which we create a union
// type.
@@ -137,7 +138,7 @@ pub fn convert_nodes(nodes: &Vec<NodeInfo>) -> NodeTypeMap {
dbscheme_name,
ql_class_name,
kind: EntryKind::Union {
members: convert_types(&subtypes),
members: convert_types(subtypes),
},
},
);
@@ -150,6 +151,8 @@ pub fn convert_nodes(nodes: &Vec<NodeInfo>) -> NodeTypeMap {
named: node.named,
};
let table_name = escape_name(&(format!("{}_def", &flattened_name)));
let table_name = format!("{}_{}", prefix, &table_name);
let mut fields = Vec::new();
// If the type also has fields or children, then we create either
@@ -157,6 +160,7 @@ pub fn convert_nodes(nodes: &Vec<NodeInfo>) -> NodeTypeMap {
if let Some(node_fields) = &node.fields {
for (field_name, field_info) in node_fields {
add_field(
prefix,
&type_name,
Some(field_name.to_string()),
field_info,
@@ -167,7 +171,14 @@ pub fn convert_nodes(nodes: &Vec<NodeInfo>) -> NodeTypeMap {
}
if let Some(children) = &node.children {
// Treat children as if they were a field called 'child'.
add_field(&type_name, None, children, &mut fields, &token_kinds);
add_field(
prefix,
&type_name,
None,
children,
&mut fields,
&token_kinds,
);
}
entries.insert(
type_name,
@@ -188,13 +199,13 @@ pub fn convert_nodes(nodes: &Vec<NodeInfo>) -> NodeTypeMap {
counter += 1;
let unprefixed_name = node_type_name(&type_name.kind, true);
Entry {
dbscheme_name: escape_name(&format!("token_{}", &unprefixed_name)),
dbscheme_name: escape_name(&format!("{}_token_{}", &prefix, &unprefixed_name)),
ql_class_name: dbscheme_name_to_class_name(&escape_name(&unprefixed_name)),
kind: EntryKind::Token { kind_id: counter },
}
} else {
Entry {
dbscheme_name: "reserved_word".to_owned(),
dbscheme_name: format!("{}_reserved_word", &prefix),
ql_class_name: "ReservedWord".to_owned(),
kind: EntryKind::Token { kind_id: 0 },
}
@@ -205,6 +216,7 @@ pub fn convert_nodes(nodes: &Vec<NodeInfo>) -> NodeTypeMap {
}
fn add_field(
prefix: &str,
parent_type_name: &TypeName,
field_name: Option<String>,
field_info: &FieldInfo,
@@ -221,7 +233,8 @@ fn add_field(
// Put the field in an auxiliary table.
let has_index = field_info.multiple;
let field_table_name = escape_name(&format!(
"{}_{}",
"{}_{}_{}",
&prefix,
parent_flattened_name,
&name_for_field_or_child(&field_name)
));
@@ -240,13 +253,11 @@ fn add_field(
// All possible types for this field are reserved words. The db
// representation will be an `int` with a `case @foo.field = ...` to
// enumerate the possible values.
let mut counter = 0;
let mut field_token_ints: BTreeMap<String, (usize, String)> = BTreeMap::new();
for t in converted_types {
for (counter, t) in converted_types.into_iter().enumerate() {
let dbscheme_variant_name =
escape_name(&format!("{}_{}", parent_flattened_name, t.kind));
escape_name(&format!("{}_{}_{}", &prefix, parent_flattened_name, t.kind));
field_token_ints.insert(t.kind.to_owned(), (counter, dbscheme_variant_name));
counter += 1;
}
FieldTypeInfo::ReservedWordInt(field_token_ints)
} else if field_info.types.len() == 1 {
@@ -256,7 +267,8 @@ fn add_field(
FieldTypeInfo::Multiple {
types: converted_types,
dbscheme_union: format!(
"{}_{}_type",
"{}_{}_{}_type",
&prefix,
&parent_flattened_name,
&name_for_field_or_child(&field_name)
),
@@ -316,7 +328,7 @@ fn node_type_name(kind: &str, named: bool) -> String {
}
}
const RESERVED_KEYWORDS: [&'static str; 14] = [
const RESERVED_KEYWORDS: [&str; 14] = [
"boolean", "case", "date", "float", "int", "key", "of", "order", "ref", "string", "subtype",
"type", "unique", "varchar",
];
@@ -380,6 +392,23 @@ fn escape_name(name: &str) -> String {
result
}
pub fn to_snake_case(word: &str) -> String {
let mut prev_upper = true;
let mut result = String::new();
for c in word.chars() {
if c.is_uppercase() {
if !prev_upper {
result.push('_')
}
prev_upper = true;
result.push(c.to_ascii_lowercase());
} else {
prev_upper = false;
result.push(c);
}
}
result
}
/// Given a valid dbscheme name (i.e. in snake case), produces the equivalent QL
/// name (i.e. in CamelCase). For example, "foo_bar_baz" becomes "FooBarBaz".
fn dbscheme_name_to_class_name(dbscheme_name: &str) -> String {
@@ -402,3 +431,10 @@ fn dbscheme_name_to_class_name(dbscheme_name: &str) -> String {
.collect::<Vec<String>>()
.join("")
}
#[test]
fn to_snake_case_test() {
assert_eq!("python", to_snake_case("Python"));
assert_eq!("yaml", to_snake_case("YAML"));
assert_eq!("set_literal", to_snake_case("SetLiteral"));
}

View File

@@ -160,7 +160,7 @@ abstract class Container extends @container {
/** A folder. */
class Folder extends Container, @folder {
override string getAbsolutePath() { folders(this, result, _) }
override string getAbsolutePath() { folders(this, result) }
/** Gets the URL of this folder. */
override string getURL() { result = "folder://" + this.getAbsolutePath() }
@@ -168,21 +168,21 @@ class Folder extends Container, @folder {
/** A file. */
class File extends Container, @file {
override string getAbsolutePath() { files(this, result, _, _, _) }
override string getAbsolutePath() { files(this, result) }
/** Gets the URL of this file. */
override string getURL() { result = "file://" + this.getAbsolutePath() + ":0:0:0:0" }
/** Gets a token in this file. */
private Generated::Token getAToken() { result.getLocation().getFile() = this }
private QL::Token getAToken() { result.getLocation().getFile() = this }
/** Holds if `line` contains a token. */
private predicate line(int line, boolean comment) {
exists(Generated::Token token, Location l |
exists(QL::Token token, Location l |
token = this.getAToken() and
l = token.getLocation() and
line in [l.getStartLine() .. l.getEndLine()] and
if token instanceof @token_block_comment or token instanceof @token_line_comment
if token instanceof @ql_token_block_comment or token instanceof @ql_token_line_comment
then comment = true
else comment = false
)
@@ -198,5 +198,5 @@ class File extends Container, @file {
int getNumberOfLinesOfComments() { result = count(int line | this.line(line, true)) }
/** Holds if this file was extracted from ordinary source code. */
predicate fromSource() { files(this, _, _, _, 1) }
predicate fromSource() { any() }
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,69 +4,65 @@ private import Builtins
cached
newtype TAstNode =
TTopLevel(Generated::Ql file) or
TQLDoc(Generated::Qldoc qldoc) or
TClasslessPredicate(Generated::ClasslessPredicate pred) or
TVarDecl(Generated::VarDecl decl) or
TClass(Generated::Dataclass dc) or
TCharPred(Generated::Charpred pred) or
TClassPredicate(Generated::MemberPredicate pred) or
TDBRelation(Generated::DbTable table) or
TSelect(Generated::Select sel) or
TModule(Generated::Module mod) or
TNewType(Generated::Datatype dt) or
TNewTypeBranch(Generated::DatatypeBranch branch) or
TImport(Generated::ImportDirective imp) or
TType(Generated::TypeExpr type) or
TDisjunction(Generated::Disjunction disj) or
TConjunction(Generated::Conjunction conj) or
TComparisonFormula(Generated::CompTerm comp) or
TComparisonOp(Generated::Compop op) or
TQuantifier(Generated::Quantified quant) or
TFullAggregate(Generated::Aggregate agg) {
agg.getChild(_) instanceof Generated::FullAggregateBody
TTopLevel(QL::Ql file) or
TQLDoc(QL::Qldoc qldoc) or
TClasslessPredicate(QL::ClasslessPredicate pred) or
TVarDecl(QL::VarDecl decl) or
TClass(QL::Dataclass dc) or
TCharPred(QL::Charpred pred) or
TClassPredicate(QL::MemberPredicate pred) or
TDBRelation(QL::DbTable table) or
TSelect(QL::Select sel) or
TModule(QL::Module mod) or
TNewType(QL::Datatype dt) or
TNewTypeBranch(QL::DatatypeBranch branch) or
TImport(QL::ImportDirective imp) or
TType(QL::TypeExpr type) or
TDisjunction(QL::Disjunction disj) or
TConjunction(QL::Conjunction conj) or
TComparisonFormula(QL::CompTerm comp) or
TComparisonOp(QL::Compop op) or
TQuantifier(QL::Quantified quant) or
TFullAggregate(QL::Aggregate agg) { agg.getChild(_) instanceof QL::FullAggregateBody } or
TExprAggregate(QL::Aggregate agg) { agg.getChild(_) instanceof QL::ExprAggregateBody } or
TSuper(QL::SuperRef sup) or
TIdentifier(QL::Variable var) or
TAsExpr(QL::AsExpr asExpr) { asExpr.getChild(1) instanceof QL::VarName } or
TPredicateCall(QL::CallOrUnqualAggExpr call) or
TMemberCall(QL::QualifiedExpr expr) {
not expr.getChild(_).(QL::QualifiedRhs).getChild(_) instanceof QL::TypeExpr
} or
TExprAggregate(Generated::Aggregate agg) {
agg.getChild(_) instanceof Generated::ExprAggregateBody
TInlineCast(QL::QualifiedExpr expr) {
expr.getChild(_).(QL::QualifiedRhs).getChild(_) instanceof QL::TypeExpr
} or
TSuper(Generated::SuperRef sup) or
TIdentifier(Generated::Variable var) or
TAsExpr(Generated::AsExpr asExpr) { asExpr.getChild(1) instanceof Generated::VarName } or
TPredicateCall(Generated::CallOrUnqualAggExpr call) or
TMemberCall(Generated::QualifiedExpr expr) {
not expr.getChild(_).(Generated::QualifiedRhs).getChild(_) instanceof Generated::TypeExpr
TNoneCall(QL::SpecialCall call) or
TAnyCall(QL::Aggregate agg) {
"any" = agg.getChild(0).(QL::AggId).getValue() and
not agg.getChild(_) instanceof QL::FullAggregateBody
} or
TInlineCast(Generated::QualifiedExpr expr) {
expr.getChild(_).(Generated::QualifiedRhs).getChild(_) instanceof Generated::TypeExpr
} or
TNoneCall(Generated::SpecialCall call) or
TAnyCall(Generated::Aggregate agg) {
"any" = agg.getChild(0).(Generated::AggId).getValue() and
not agg.getChild(_) instanceof Generated::FullAggregateBody
} or
TNegation(Generated::Negation neg) or
TIfFormula(Generated::IfTerm ifterm) or
TImplication(Generated::Implication impl) or
TInstanceOf(Generated::InstanceOf inst) or
TInFormula(Generated::InExpr inexpr) or
THigherOrderFormula(Generated::HigherOrderTerm hop) or
TExprAnnotation(Generated::ExprAnnotation expr_anno) or
TAddSubExpr(Generated::AddExpr addexp) or
TMulDivModExpr(Generated::MulExpr mulexpr) or
TRange(Generated::Range range) or
TSet(Generated::SetLiteral set) or
TLiteral(Generated::Literal lit) or
TUnaryExpr(Generated::UnaryExpr unaryexpr) or
TDontCare(Generated::Underscore dontcare) or
TModuleExpr(Generated::ModuleExpr me) or
TPredicateExpr(Generated::PredicateExpr pe) or
TAnnotation(Generated::Annotation annot) or
TAnnotationArg(Generated::AnnotArg arg) or
TYamlCommemt(Generated::YamlComment yc) or
TYamlEntry(Generated::YamlEntry ye) or
TYamlKey(Generated::YamlKey yk) or
TYamlListitem(Generated::YamlListitem yli) or
TYamlValue(Generated::YamlValue yv) or
TNegation(QL::Negation neg) or
TIfFormula(QL::IfTerm ifterm) or
TImplication(QL::Implication impl) or
TInstanceOf(QL::InstanceOf inst) or
TInFormula(QL::InExpr inexpr) or
THigherOrderFormula(QL::HigherOrderTerm hop) or
TExprAnnotation(QL::ExprAnnotation expr_anno) or
TAddSubExpr(QL::AddExpr addexp) or
TMulDivModExpr(QL::MulExpr mulexpr) or
TRange(QL::Range range) or
TSet(QL::SetLiteral set) or
TLiteral(QL::Literal lit) or
TUnaryExpr(QL::UnaryExpr unaryexpr) or
TDontCare(QL::Underscore dontcare) or
TModuleExpr(QL::ModuleExpr me) or
TPredicateExpr(QL::PredicateExpr pe) or
TAnnotation(QL::Annotation annot) or
TAnnotationArg(QL::AnnotArg arg) or
TYamlCommemt(QL::YamlComment yc) or
TYamlEntry(QL::YamlEntry ye) or
TYamlKey(QL::YamlKey yk) or
TYamlListitem(QL::YamlListitem yli) or
TYamlValue(QL::YamlValue yv) or
TBuiltinClassless(string ret, string name, string args) { isBuiltinClassless(ret, name, args) } or
TBuiltinMember(string qual, string ret, string name, string args) {
isBuiltinMember(qual, ret, name, args)
@@ -90,7 +86,7 @@ class TModuleRef = TImport or TModuleExpr;
class TYAMLNode = TYamlCommemt or TYamlEntry or TYamlKey or TYamlListitem or TYamlValue;
private Generated::AstNode toGeneratedFormula(AST::AstNode n) {
private QL::AstNode toQLFormula(AST::AstNode n) {
n = TConjunction(result) or
n = TDisjunction(result) or
n = TComparisonFormula(result) or
@@ -106,7 +102,7 @@ private Generated::AstNode toGeneratedFormula(AST::AstNode n) {
n = TInFormula(result)
}
private Generated::AstNode toGeneratedExpr(AST::AstNode n) {
private QL::AstNode toQLExpr(AST::AstNode n) {
n = TAddSubExpr(result) or
n = TMulDivModExpr(result) or
n = TRange(result) or
@@ -120,7 +116,7 @@ private Generated::AstNode toGeneratedExpr(AST::AstNode n) {
n = TDontCare(result)
}
private Generated::AstNode toGenerateYAML(AST::AstNode n) {
private QL::AstNode toGenerateYAML(AST::AstNode n) {
n = TYamlCommemt(result) or
n = TYamlEntry(result) or
n = TYamlKey(result) or
@@ -131,19 +127,19 @@ private Generated::AstNode toGenerateYAML(AST::AstNode n) {
/**
* Gets the underlying TreeSitter entity for a given AST node.
*/
Generated::AstNode toGenerated(AST::AstNode n) {
result = toGeneratedExpr(n)
QL::AstNode toQL(AST::AstNode n) {
result = toQLExpr(n)
or
result = toGeneratedFormula(n)
result = toQLFormula(n)
or
result = toGenerateYAML(n)
or
result.(Generated::ParExpr).getChild() = toGenerated(n)
result.(QL::ParExpr).getChild() = toQL(n)
or
result =
any(Generated::AsExpr ae |
not ae.getChild(1) instanceof Generated::VarName and
toGenerated(n) = ae.getChild(0)
any(QL::AsExpr ae |
not ae.getChild(1) instanceof QL::VarName and
toQL(n) = ae.getChild(0)
)
or
n = TTopLevel(result)

View File

@@ -148,14 +148,16 @@ private predicate resolveSelectionName(Import imp, ContainerOrModule m, int i) {
cached
private module Cached {
// TODO: Use `AstNode::getParent` once it is total
private Generated::AstNode parent(Generated::AstNode n) {
result = n.getParent() and
not n instanceof Generated::Module
}
private Module getEnclosingModule0(AstNode n) {
AstNodes::toGenerated(result) = parent*(AstNodes::toGenerated(n).getParent())
not n instanceof Module and
(
n = result.getAChild(_)
or
exists(AstNode prev |
result = getEnclosingModule0(prev) and
n = prev.getAChild(_)
)
)
}
cached

File diff suppressed because it is too large Load Diff

View File

@@ -60,9 +60,9 @@ pragma[nomagic]
VariableScope scopeOf(AstNode n) { result = parent*(n.getParent()) }
private string getName(Identifier i) {
exists(Generated::Variable v |
exists(QL::Variable v |
i = TIdentifier(v) and
result = v.getChild().(Generated::VarName).getChild().getValue()
result = v.getChild().(QL::VarName).getChild().getValue()
)
}

File diff suppressed because it is too large Load Diff