From 83a0e5fea65f711fc4d3e6311763ebb3cfc84f0d Mon Sep 17 00:00:00 2001 From: Nick Rolfe Date: Tue, 10 Nov 2020 19:08:21 +0000 Subject: [PATCH 1/5] Refactor to move naming decisions to shared library --- extractor/src/extractor.rs | 104 ++++++++++--------- generator/src/main.rs | 152 +++++++++++----------------- generator/src/ql.rs | 43 ++------ generator/src/ql_gen.rs | 152 +++++++++++----------------- node-types/src/lib.rs | 155 ++++++++++++++++++++++------ ql/src/codeql_ruby/ast.qll | 200 ++++++++++++++++++------------------- 6 files changed, 400 insertions(+), 406 deletions(-) diff --git a/extractor/src/extractor.rs b/extractor/src/extractor.rs index 566923c9774..e3fd5108033 100644 --- a/extractor/src/extractor.rs +++ b/extractor/src/extractor.rs @@ -1,4 +1,4 @@ -use node_types::{escape_name, node_type_name, Entry, Field, Storage, TypeName}; +use node_types::{escape_name, EntryKind, Field, NodeTypeMap, Storage, TypeName}; use std::collections::BTreeMap as Map; use std::collections::BTreeSet as Set; use std::fmt; @@ -150,10 +150,10 @@ impl TrapWriter { pub struct Extractor { pub parser: Parser, - pub schema: Vec, + pub schema: NodeTypeMap, } -pub fn create(language: Language, schema: Vec) -> Extractor { +pub fn create(language: Language, schema: NodeTypeMap) -> Extractor { let mut parser = Parser::new(); parser.set_language(language).unwrap(); @@ -189,8 +189,7 @@ impl Extractor { file_label: *file_label, token_counter: 0, stack: Vec::new(), - tables: build_schema_lookup(&self.schema), - union_types: build_union_type_lookup(&self.schema), + schema: &self.schema, }; traverse(&tree, &mut visitor); @@ -246,26 +245,6 @@ fn full_id_for_folder(path: &Path) -> String { format!("{};folder", normalize_path(path)) } -fn build_schema_lookup<'a>(schema: &'a Vec) -> Map<&'a TypeName, &'a Entry> { - let mut map = std::collections::BTreeMap::new(); - for entry in schema { - if let Entry::Token { type_name, .. } | Entry::Table { type_name, .. } = entry { - map.insert(type_name, entry); - } - } - map -} - -fn build_union_type_lookup<'a>(schema: &'a Vec) -> Map<&'a TypeName, &'a Set> { - let mut union_types = std::collections::BTreeMap::new(); - for entry in schema { - if let Entry::Union { type_name, members } = entry { - union_types.insert(type_name, members); - } - } - union_types -} - struct Visitor<'a> { /// The file path of the source code (as string) path: String, @@ -278,10 +257,8 @@ struct Visitor<'a> { trap_writer: TrapWriter, /// A counter for tokens token_counter: usize, - /// A lookup table from type name to dbscheme table entries - tables: Map<&'a TypeName, &'a Entry>, - /// A lookup table for union types mapping a type name to its direct members - union_types: Map<&'a TypeName, &'a Set>, + /// A lookup table from type name to node types + schema: &'a NodeTypeMap, /// A stack for gathering information from hild nodes. Whenever a node is entered /// an empty list is pushed. All children append their data (field name, label, type) to /// the the list. When the visitor leaves a node the list containing the child data is popped @@ -328,13 +305,16 @@ impl Visitor<'_> { end_line, end_column, ); - let table = self.tables.get(&TypeName { - kind: node.kind().to_owned(), - named: node.is_named(), - }); + let table = self + .schema + .get(&TypeName { + kind: node.kind().to_owned(), + named: node.is_named(), + }) + .unwrap(); let mut valid = true; - match table { - Some(Entry::Token { kind_id, .. }) => { + match &table.kind { + EntryKind::Token { kind_id, .. } => { self.trap_writer.add_tuple( "tokeninfo", vec![ @@ -348,11 +328,8 @@ impl Visitor<'_> { ); self.token_counter += 1; } - Some(Entry::Table { fields, .. }) => { - let table_name = escape_name(&format!( - "{}_def", - node_type_name(node.kind(), node.is_named()) - )); + EntryKind::Table { fields, .. } => { + let table_name = escape_name(&format!("{}_def", &table.flattened_name)); if let Some(args) = self.complex_node(&node, fields, child_nodes, id) { let mut all_args = Vec::new(); all_args.push(Arg::Label(id)); @@ -386,6 +363,7 @@ impl Visitor<'_> { }; } } + fn complex_node( &mut self, node: &Node, @@ -400,7 +378,7 @@ impl Visitor<'_> { for (child_field, child_id, child_type) in child_nodes { if let Some((field, values)) = map.get_mut(&child_field.map(|x| x.to_owned())) { //TODO: handle error and missing nodes - if self.type_matches(&child_type, &field.types) { + if self.type_matches(&child_type, &field.type_info) { values.push(child_id); } else if field.name.is_some() { error!( @@ -410,7 +388,7 @@ impl Visitor<'_> { node.kind(), child_field.unwrap_or("child"), child_type, - field.types + field.type_info ) } } else { @@ -464,7 +442,7 @@ impl Visitor<'_> { } let table_name = escape_name(&format!( "{}_{}", - node_type_name(&field.parent.kind, field.parent.named), + self.schema.get(&field.parent).unwrap().flattened_name, field.get_name() )); let mut args = Vec::new(); @@ -484,18 +462,48 @@ impl Visitor<'_> { None } } - fn type_matches(&self, tp: &TypeName, types: &Set) -> bool { + + fn type_matches(&self, tp: &TypeName, type_info: &node_types::FieldTypeInfo) -> bool { + match type_info { + node_types::FieldTypeInfo::Single(single_type) => { + if tp == single_type { + return true; + } + match &self.schema.get(single_type).unwrap().kind { + EntryKind::Union { members } => { + if self.type_matches_set(tp, members) { + return true; + } + } + _ => {} + } + } + node_types::FieldTypeInfo::Multiple { + types, + dbscheme_union: _, + ql_class: _, + } => { + return self.type_matches_set(tp, types); + } + } + false + } + + fn type_matches_set(&self, tp: &TypeName, types: &Set) -> bool { if types.contains(tp) { return true; } for other in types.iter() { - if let Some(x) = self.union_types.get(other) { - if self.type_matches(tp, x) { - return true; + match &self.schema.get(other).unwrap().kind { + EntryKind::Union { members } => { + if self.type_matches_set(tp, members) { + return true; + } } + _ => {} } } - return false; + false } } diff --git a/generator/src/main.rs b/generator/src/main.rs index 68b9cd267ef..73c30a1bbe9 100644 --- a/generator/src/main.rs +++ b/generator/src/main.rs @@ -11,71 +11,51 @@ use std::io::LineWriter; use std::path::PathBuf; use tracing::{error, info}; -fn child_node_type_name(token_types: &Map, t: &node_types::TypeName) -> String { - if !t.named { - // an unnamed token - "reserved_word".to_owned() - } else if token_types.contains_key(&t.kind) { - // a named token - format!("token_{}", t.kind) - } else { - // a normal node - node_types::node_type_name(&t.kind, t.named) - } -} - /// Given the name of the parent node, and its field information, returns the /// name of the field's type. This may be an ad-hoc union of all the possible /// types the field can take, in which case the union is added to `entries`. fn make_field_type( - token_types: &Map, - parent_name: &str, - field_name: &str, - types: &Set, + field: &node_types::Field, entries: &mut Vec, + nodes: &node_types::NodeTypeMap, ) -> String { - if types.len() == 1 { - // This field can only have a single type. - let t = types.iter().next().unwrap(); - node_types::escape_name(&child_node_type_name(token_types, t)) - } else { - // This field can have one of several types. Create an ad-hoc QL union - // type to represent them. - let field_union_name = format!("{}_{}_type", parent_name, field_name); - let field_union_name = node_types::escape_name(&field_union_name); - let members: Set = types - .iter() - .map(|t| node_types::escape_name(&child_node_type_name(token_types, t))) - .collect(); - entries.push(dbscheme::Entry::Union(dbscheme::Union { - name: field_union_name.clone(), - members, - })); - field_union_name + match &field.type_info { + node_types::FieldTypeInfo::Multiple { + types, + dbscheme_union, + ql_class: _, + } => { + // This field can have one of several types. Create an ad-hoc QL union + // type to represent them. + let members: Set = types + .iter() + .map(|t| node_types::escape_name(&nodes.get(t).unwrap().flattened_name)) + .collect(); + entries.push(dbscheme::Entry::Union(dbscheme::Union { + name: node_types::escape_name(&dbscheme_union), + members, + })); + dbscheme_union.clone() + } + node_types::FieldTypeInfo::Single(t) => nodes.get(&t).unwrap().flattened_name.clone(), } } /// Adds the appropriate dbscheme information for the given field, either as a /// column on `main_table`, or as an auxiliary table. fn add_field( - token_types: &Map, main_table: &mut dbscheme::Table, field: &node_types::Field, entries: &mut Vec, + nodes: &node_types::NodeTypeMap, ) { let field_name = field.get_name(); - let parent_name = node_types::node_type_name(&field.parent.kind, field.parent.named); + let parent_name = &nodes.get(&field.parent).unwrap().flattened_name; match &field.storage { node_types::Storage::Table(has_index) => { // This field can appear zero or multiple times, so put // it in an auxiliary table. - let field_type = make_field_type( - token_types, - &parent_name, - &field_name, - &field.types, - entries, - ); + let field_type = node_types::escape_name(&make_field_type(&field, entries, nodes)); let parent_column = dbscheme::Column { unique: !*has_index, db_type: dbscheme::DbColumnType::Int, @@ -93,12 +73,12 @@ fn add_field( let field_column = dbscheme::Column { unique: true, db_type: dbscheme::DbColumnType::Int, - name: node_types::escape_name(&field_type), + name: field_type.clone(), ql_type: ql::Type::AtType(field_type), ql_type_is_ref: true, }; let field_table = dbscheme::Table { - name: format!("{}_{}", parent_name, field_name), + name: node_types::escape_name(&format!("{}_{}", parent_name, field_name)), columns: if *has_index { vec![parent_column, index_column, field_column] } else { @@ -120,18 +100,12 @@ fn add_field( node_types::Storage::Column => { // This field must appear exactly once, so we add it as // a column to the main table for the node type. - let field_type = make_field_type( - token_types, - &parent_name, - &field_name, - &field.types, - entries, - ); + let field_type = make_field_type(&field, entries, nodes); main_table.columns.push(dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, name: node_types::escape_name(&field_name), - ql_type: ql::Type::AtType(field_type), + ql_type: ql::Type::AtType(node_types::escape_name(&field_type)), ql_type_is_ref: true, }); } @@ -139,7 +113,7 @@ fn add_field( } /// Converts the given tree-sitter node types into CodeQL dbscheme entries. -fn convert_nodes(nodes: &Vec) -> Vec { +fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { let mut entries: Vec = vec![ create_location_union(), create_locations_default_table(), @@ -152,59 +126,50 @@ fn convert_nodes(nodes: &Vec) -> Vec { create_source_location_prefix_table(), ]; let mut ast_node_members: Set = Set::new(); - let mut token_kinds: Map = Map::new(); - ast_node_members.insert(node_types::escape_name("token")); - for node in nodes { - if let node_types::Entry::Token { type_name, kind_id } = node { - if type_name.named { - token_kinds.insert(type_name.kind.to_owned(), *kind_id); + let token_kinds: Map = nodes + .iter() + .filter_map(|(_, node)| match &node.kind { + node_types::EntryKind::Token { kind_id } => { + Some((node.flattened_name.clone(), *kind_id)) } - } - } + _ => None, + }) + .collect(); + ast_node_members.insert(node_types::escape_name("token")); - for node in nodes { - match &node { - node_types::Entry::Union { - type_name, - members: n_members, - } => { + for (_, node) in nodes { + match &node.kind { + node_types::EntryKind::Union { members: n_members } => { // It's a tree-sitter supertype node, for which we create a union // type. - let mut members: Set = Set::new(); - for n_member in n_members { - members.insert(node_types::escape_name(&child_node_type_name( - &token_kinds, - n_member, - ))); - } + let members: Set = n_members + .iter() + .map(|n| node_types::escape_name(&nodes.get(n).unwrap().flattened_name)) + .collect(); entries.push(dbscheme::Entry::Union(dbscheme::Union { - name: node_types::escape_name(&node_types::node_type_name( - &type_name.kind, - type_name.named, - )), + name: node_types::escape_name(&node.flattened_name), members, })); } - node_types::Entry::Table { type_name, fields } => { + node_types::EntryKind::Table { fields } => { // It's a product type, defined by a table. - let name = node_types::node_type_name(&type_name.kind, type_name.named); let mut main_table = dbscheme::Table { - name: node_types::escape_name(&(format!("{}_def", name))), + name: node_types::escape_name(&(format!("{}_def", &node.flattened_name))), columns: vec![dbscheme::Column { db_type: dbscheme::DbColumnType::Int, name: "id".to_string(), unique: true, - ql_type: ql::Type::AtType(node_types::escape_name(&name)), + ql_type: ql::Type::AtType(node_types::escape_name(&node.flattened_name)), ql_type_is_ref: false, }], keysets: None, }; - ast_node_members.insert(node_types::escape_name(&name)); + ast_node_members.insert(node_types::escape_name(&node.flattened_name)); // If the type also has fields or children, then we create either // auxiliary tables or columns in the defining table for them. for field in fields { - add_field(&token_kinds, &mut main_table, &field, &mut entries); + add_field(&mut main_table, &field, &mut entries, nodes); } if fields.is_empty() { @@ -230,7 +195,7 @@ fn convert_nodes(nodes: &Vec) -> Vec { entries.push(dbscheme::Entry::Table(main_table)); } - node_types::Entry::Token { .. } => {} + node_types::EntryKind::Token { .. } => {} } } @@ -295,15 +260,10 @@ fn add_tokeninfo_table(entries: &mut Vec, token_kinds: Map = Vec::new(); - branches.push((0, "reserved_word".to_owned())); - for (token_kind, idx) in token_kinds.iter() { - branches.push(( - *idx, - node_types::escape_name(&format!("token_{}", token_kind)), - )); - } - + let branches: Vec<(usize, String)> = token_kinds + .iter() + .map(|(name, kind_id)| (*kind_id, node_types::escape_name(name))) + .collect(); entries.push(dbscheme::Entry::Case(dbscheme::Case { name: "token".to_owned(), column: "kind".to_owned(), diff --git a/generator/src/ql.rs b/generator/src/ql.rs index bb554bca34e..ec935fac8a8 100644 --- a/generator/src/ql.rs +++ b/generator/src/ql.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeSet; use std::fmt; pub enum TopLevel { @@ -14,10 +15,11 @@ impl fmt::Display for TopLevel { } } +#[derive(Clone, Eq, PartialEq, Hash)] pub struct Class { pub name: String, pub is_abstract: bool, - pub supertypes: Vec, + pub supertypes: BTreeSet, pub characteristic_predicate: Option, pub predicates: Vec, } @@ -61,7 +63,7 @@ impl fmt::Display for Class { } // The QL type of a column. -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] pub enum Type { /// Primitive `int` type. Int, @@ -69,11 +71,11 @@ pub enum Type { /// Primitive `string` type. String, - /// A user-defined type. - Normal(String), - /// A database type that will need to be referred to with an `@` prefix. AtType(String), + + /// A user-defined type. + Normal(String), } impl fmt::Display for Type { @@ -87,15 +89,13 @@ impl fmt::Display for Type { } } -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq, Hash)] pub enum Expression { Var(String), String(String), Pred(String, Vec), Or(Vec), - And(Vec), Equals(Box, Box), - Exists(Vec, Box), Dot(Box, String, Vec), } @@ -127,30 +127,7 @@ impl fmt::Display for Expression { Ok(()) } } - Expression::And(conjuncts) => { - if conjuncts.is_empty() { - write!(f, "any()") - } else { - for (index, conjunct) in conjuncts.iter().enumerate() { - if index > 0 { - write!(f, " and ")?; - } - write!(f, "{}", conjunct)?; - } - Ok(()) - } - } Expression::Equals(a, b) => write!(f, "{} = {}", a, b), - Expression::Exists(params, formula) => { - write!(f, "exists(")?; - for (index, param) in params.iter().enumerate() { - if index > 0 { - write!(f, ", ")?; - } - write!(f, "{}", param)?; - } - write!(f, " | {})", formula) - } Expression::Dot(x, member_pred, args) => { write!(f, "{}.{}(", x, member_pred)?; for (index, arg) in args.iter().enumerate() { @@ -165,7 +142,7 @@ impl fmt::Display for Expression { } } -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq, Hash)] pub struct Predicate { pub name: String, pub overridden: bool, @@ -196,7 +173,7 @@ impl fmt::Display for Predicate { } } -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq, Hash)] pub struct FormalParameter { pub name: String, pub param_type: Type, diff --git a/generator/src/ql_gen.rs b/generator/src/ql_gen.rs index 767418ed0c8..996e5e9dde8 100644 --- a/generator/src/ql_gen.rs +++ b/generator/src/ql_gen.rs @@ -67,7 +67,9 @@ fn create_ast_node_class() -> ql::Class { ql::Class { name: "AstNode".to_owned(), is_abstract: false, - supertypes: vec![ql::Type::AtType("ast_node".to_owned())], + supertypes: vec![ql::Type::AtType("ast_node".to_owned())] + .into_iter() + .collect(), characteristic_predicate: None, predicates: vec![ to_string, @@ -77,6 +79,7 @@ fn create_ast_node_class() -> ql::Class { ], } } + fn create_token_class() -> ql::Class { let get_value = ql::Predicate { name: "getValue".to_owned(), @@ -128,7 +131,9 @@ fn create_token_class() -> ql::Class { supertypes: vec![ ql::Type::AtType("token".to_owned()), ql::Type::Normal("AstNode".to_owned()), - ], + ] + .into_iter() + .collect(), characteristic_predicate: None, predicates: vec![ get_value, @@ -141,16 +146,18 @@ fn create_token_class() -> ql::Class { // Creates the `ReservedWord` class. fn create_reserved_word_class() -> ql::Class { - let db_name = "reserved_word".to_owned(); - let class_name = dbscheme_name_to_class_name(&db_name); + let db_name = "reserved_word"; + let class_name = "ReservedWord".to_owned(); let describe_ql_class = create_describe_ql_class(&class_name); ql::Class { name: class_name, is_abstract: false, supertypes: vec![ + ql::Type::AtType(db_name.to_owned()), ql::Type::Normal("Token".to_owned()), - ql::Type::AtType(db_name), - ], + ] + .into_iter() + .collect(), characteristic_predicate: None, predicates: vec![describe_ql_class], } @@ -172,47 +179,6 @@ fn create_none_predicate( } } -/// Given the name of the parent node, and its field information, returns the -/// name of the field's type. This may be an ad-hoc union of all the possible -/// types the field can take, in which case we create a new class and push it to -/// `classes`. -fn create_field_class(token_kinds: &BTreeSet, field: &node_types::Field) -> String { - if field.types.len() == 1 { - // This field can only have a single type. - let t = field.types.iter().next().unwrap(); - if !t.named || token_kinds.contains(&t.kind) { - "Token".to_owned() - } else { - node_types::escape_name(&node_types::node_type_name(&t.kind, t.named)) - } - } else { - "AstNode".to_owned() - } -} - -/// Given a valid dbscheme name (i.e. in snake case), produces the equivalent QL -/// name (i.e. in CamelCase). For example, "foo_bar_baz" becomes "FooBarBaz". -fn dbscheme_name_to_class_name(dbscheme_name: &str) -> String { - fn to_title_case(word: &str) -> String { - let mut first = true; - let mut result = String::new(); - for c in word.chars() { - if first { - first = false; - result.push(c.to_ascii_uppercase()); - } else { - result.push(c); - } - } - result - } - dbscheme_name - .split('_') - .map(|word| to_title_case(word)) - .collect::>() - .join("") -} - /// Creates an overridden `describeQlClass` predicate that returns the given /// name. fn create_describe_ql_class(class_name: &str) -> ql::Predicate { @@ -345,20 +311,24 @@ fn create_field_getters( main_table_column_index: &mut usize, parent_name: &str, field: &node_types::Field, - field_type: &str, + nodes: &node_types::NodeTypeMap, ) -> (ql::Predicate, ql::Expression) { - let predicate_name = format!( - "get{}", - dbscheme_name_to_class_name(&node_types::escape_name(&field.get_name())) - ); - let return_type = Some(ql::Type::Normal(dbscheme_name_to_class_name(field_type))); + let predicate_name = field.get_getter_name(); + let return_type = Some(ql::Type::Normal(match &field.type_info { + node_types::FieldTypeInfo::Single(t) => nodes.get(&t).unwrap().ql_class_name.clone(), + node_types::FieldTypeInfo::Multiple { + types: _, + dbscheme_union: _, + ql_class, + } => ql_class.clone(), + })); match &field.storage { node_types::Storage::Column => { let result = ( ql::Predicate { name: predicate_name, overridden: false, - return_type: return_type, + return_type, formal_parameters: vec![], body: create_get_field_expr_for_column_storage( &main_table_name, @@ -381,7 +351,7 @@ fn create_field_getters( ql::Predicate { name: predicate_name, overridden: false, - return_type: return_type, + return_type, formal_parameters: if *has_index { vec![ql::FormalParameter { name: "i".to_owned(), @@ -405,7 +375,7 @@ fn create_field_getters( } /// Converts the given node types into CodeQL classes wrapping the dbscheme. -pub fn convert_nodes(nodes: &Vec) -> Vec { +pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { let mut classes: Vec = vec![ ql::TopLevel::Import("codeql.files.FileSystem".to_owned()), ql::TopLevel::Import("codeql.Locations".to_owned()), @@ -414,61 +384,48 @@ pub fn convert_nodes(nodes: &Vec) -> Vec { ql::TopLevel::Class(create_reserved_word_class()), ]; let mut token_kinds = BTreeSet::new(); - for node in nodes { - if let node_types::Entry::Token { type_name, .. } = node { + for (type_name, node) in nodes { + if let node_types::EntryKind::Token { .. } = &node.kind { if type_name.named { token_kinds.insert(type_name.kind.to_owned()); } } } - for node in nodes { - match &node { - node_types::Entry::Token { - type_name, - kind_id: _, - } => { + for (type_name, node) in nodes { + match &node.kind { + node_types::EntryKind::Token { kind_id: _ } => { if type_name.named { - let db_name = format!("token_{}", &type_name.kind); - let db_name = node_types::escape_name(&db_name); - let class_name = - dbscheme_name_to_class_name(&node_types::escape_name(&type_name.kind)); - let describe_ql_class = create_describe_ql_class(&class_name); + let describe_ql_class = create_describe_ql_class(&node.ql_class_name); + let mut supertypes: BTreeSet = BTreeSet::new(); + supertypes.insert(ql::Type::AtType(node.flattened_name.to_owned())); + supertypes.insert(ql::Type::Normal("Token".to_owned())); classes.push(ql::TopLevel::Class(ql::Class { - name: class_name, + name: node.ql_class_name.clone(), is_abstract: false, - supertypes: vec![ - ql::Type::Normal("Token".to_owned()), - ql::Type::AtType(db_name), - ], + supertypes, characteristic_predicate: None, predicates: vec![describe_ql_class], })); } } - node_types::Entry::Union { - type_name, - members: _, - } => { + node_types::EntryKind::Union { members: _ } => { // It's a tree-sitter supertype node, so we're wrapping a dbscheme // union type. - let union_name = node_types::escape_name(&node_types::node_type_name( - &type_name.kind, - type_name.named, - )); - let class_name = dbscheme_name_to_class_name(&union_name); classes.push(ql::TopLevel::Class(ql::Class { - name: class_name.clone(), + name: node.ql_class_name.clone(), is_abstract: false, supertypes: vec![ - ql::Type::AtType(union_name), + ql::Type::AtType(node_types::escape_name(&node.flattened_name)), ql::Type::Normal("AstNode".to_owned()), - ], + ] + .into_iter() + .collect(), characteristic_predicate: None, predicates: vec![], })); } - node_types::Entry::Table { type_name, fields } => { + node_types::EntryKind::Table { fields } => { // Count how many columns there will be in the main table. // There will be: // - one for the id @@ -484,15 +441,19 @@ pub fn convert_nodes(nodes: &Vec) -> Vec { .count() }; - let name = node_types::node_type_name(&type_name.kind, type_name.named); - let dbscheme_name = node_types::escape_name(&name); - let ql_type = ql::Type::AtType(dbscheme_name.clone()); - let main_table_name = node_types::escape_name(&(format!("{}_def", name))); - let main_class_name = dbscheme_name_to_class_name(&dbscheme_name); + let escaped_name = node_types::escape_name(&node.flattened_name); + let main_class_name = &node.ql_class_name; + let main_table_name = + node_types::escape_name(&format!("{}_def", &node.flattened_name)); let mut main_class = ql::Class { name: main_class_name.clone(), is_abstract: false, - supertypes: vec![ql_type, ql::Type::Normal("AstNode".to_owned())], + supertypes: vec![ + ql::Type::AtType(escaped_name), + ql::Type::Normal("AstNode".to_owned()), + ] + .into_iter() + .collect(), characteristic_predicate: None, predicates: vec![ create_describe_ql_class(&main_class_name), @@ -513,14 +474,13 @@ pub fn convert_nodes(nodes: &Vec) -> Vec { // - predicates to access the fields, // - the QL expressions to access the fields that will be part of getAFieldOrChild. for field in fields { - let field_type = create_field_class(&token_kinds, field); let (get_pred, get_child_expr) = create_field_getters( &main_table_name, main_table_arity, &mut main_table_column_index, - &name, + &node.flattened_name, field, - &field_type, + nodes, ); main_class.predicates.push(get_pred); get_child_exprs.push(get_child_expr); diff --git a/node-types/src/lib.rs b/node-types/src/lib.rs index e2de9df31f3..1ef0da8533b 100644 --- a/node-types/src/lib.rs +++ b/node-types/src/lib.rs @@ -5,20 +5,21 @@ use std::path::Path; use std::collections::BTreeSet as Set; use std::fs; +/// A lookup table from TypeName to Entry. +pub type NodeTypeMap = BTreeMap; + #[derive(Debug)] -pub enum Entry { - Union { - type_name: TypeName, - members: Set, - }, - Table { - type_name: TypeName, - fields: Vec, - }, - Token { - type_name: TypeName, - kind_id: usize, - }, +pub struct Entry { + pub flattened_name: String, + pub ql_class_name: String, + pub kind: EntryKind, +} + +#[derive(Debug)] +pub enum EntryKind { + Union { members: Set }, + Table { fields: Vec }, + Token { kind_id: usize }, } #[derive(Debug, Ord, PartialOrd, Eq, PartialEq)] @@ -27,22 +28,48 @@ pub struct TypeName { pub named: bool, } +#[derive(Debug)] +pub enum FieldTypeInfo { + /// The field has a single type. + Single(TypeName), + + /// The field can take one of several types, so we also provide the name of + /// the database union type that wraps them, and the corresponding QL class + /// name. + Multiple { + types: Set, + dbscheme_union: String, + ql_class: String, + }, +} + #[derive(Debug)] pub struct Field { pub parent: TypeName, - pub types: Set, + pub type_info: FieldTypeInfo, /// The name of the field or None for the anonymous 'children' /// entry from node_types.json pub name: Option, pub storage: Storage, } +fn name_for_field_or_child(name: &Option) -> String { + match name { + Some(name) => name.clone(), + None => "child".to_owned(), + } +} + impl Field { pub fn get_name(&self) -> String { - match &self.name { - Some(name) => name.clone(), - None => "child".to_owned(), - } + name_for_field_or_child(&self.name) + } + + pub fn get_getter_name(&self) -> String { + format!( + "get{}", + dbscheme_name_to_class_name(&escape_name(&name_for_field_or_child(&self.name))) + ) } } @@ -55,13 +82,13 @@ pub enum Storage { Table(bool), } -pub fn read_node_types(node_types_path: &Path) -> std::io::Result> { +pub fn read_node_types(node_types_path: &Path) -> std::io::Result { let file = fs::File::open(node_types_path)?; let node_types = serde_json::from_reader(file)?; Ok(convert_nodes(node_types)) } -pub fn read_node_types_str(node_types_json: &str) -> std::io::Result> { +pub fn read_node_types_str(node_types_json: &str) -> std::io::Result { let node_types = serde_json::from_str(node_types_json)?; Ok(convert_nodes(node_types)) } @@ -77,20 +104,29 @@ fn convert_types(node_types: &Vec) -> Set { let iter = node_types.iter().map(convert_type).collect(); std::collections::BTreeSet::from(iter) } -pub fn convert_nodes(nodes: Vec) -> Vec { - let mut entries: Vec = Vec::new(); + +pub fn convert_nodes(nodes: Vec) -> NodeTypeMap { + let mut entries = NodeTypeMap::new(); let mut token_kinds = Set::new(); for node in nodes { + let flattened_name = node_type_name(&node.kind, node.named); + let ql_class_name = dbscheme_name_to_class_name(&escape_name(&flattened_name)); if let Some(subtypes) = &node.subtypes { // It's a tree-sitter supertype node, for which we create a union // type. - entries.push(Entry::Union { - type_name: TypeName { + entries.insert( + TypeName { kind: node.kind, named: node.named, }, - members: convert_types(&subtypes), - }); + Entry { + flattened_name, + ql_class_name, + kind: EntryKind::Union { + members: convert_types(&subtypes), + }, + }, + ); } else if node.fields.as_ref().map_or(0, |x| x.len()) == 0 && node.children.is_none() { let type_name = TypeName { kind: node.kind, @@ -121,18 +157,34 @@ pub fn convert_nodes(nodes: Vec) -> Vec { // Treat children as if they were a field called 'child'. add_field(&type_name, None, children, &mut fields); } - entries.push(Entry::Table { type_name, fields }); + entries.insert( + type_name, + Entry { + flattened_name, + ql_class_name, + kind: EntryKind::Table { fields }, + }, + ); } } let mut counter = 0; for type_name in token_kinds { - let kind_id = if type_name.named { + let entry = if type_name.named { counter += 1; - counter + let unprefixed_name = node_type_name(&type_name.kind, true); + Entry { + flattened_name: format!("token_{}", &unprefixed_name), + ql_class_name: dbscheme_name_to_class_name(&escape_name(&unprefixed_name)), + kind: EntryKind::Token { kind_id: counter }, + } } else { - 0 + Entry { + flattened_name: "reserved_word".to_owned(), + ql_class_name: "ReservedWord".to_owned(), + kind: EntryKind::Token { kind_id: 0 }, + } }; - entries.push(Entry::Token { type_name, kind_id }); + entries.insert(type_name, entry); } entries } @@ -156,12 +208,26 @@ fn add_field( // with an associated index. Storage::Table(true) }; + let type_info = if field_info.types.len() == 1 { + FieldTypeInfo::Single(convert_type(field_info.types.iter().next().unwrap())) + } else { + // The dbscheme type for this field will be a union. In QL, it'll just be AstNode. + FieldTypeInfo::Multiple { + types: convert_types(&field_info.types), + dbscheme_union: format!( + "{}_{}_type", + &node_type_name(&parent_type_name.kind, parent_type_name.named), + &name_for_field_or_child(&field_name) + ), + ql_class: "AstNode".to_owned(), + } + }; fields.push(Field { parent: TypeName { kind: parent_type_name.kind.to_string(), named: parent_type_name.named, }, - types: convert_types(&field_info.types), + type_info, name: field_name, storage, }); @@ -196,7 +262,7 @@ pub struct FieldInfo { /// Given a tree-sitter node type's (kind, named) pair, returns a single string /// representing the (unescaped) name we'll use to refer to corresponding QL /// type. -pub fn node_type_name(kind: &str, named: bool) -> String { +fn node_type_name(kind: &str, named: bool) -> String { if named { kind.to_string() } else { @@ -267,3 +333,26 @@ pub fn escape_name(name: &str) -> String { result } + +/// Given a valid dbscheme name (i.e. in snake case), produces the equivalent QL +/// name (i.e. in CamelCase). For example, "foo_bar_baz" becomes "FooBarBaz". +fn dbscheme_name_to_class_name(dbscheme_name: &str) -> String { + fn to_title_case(word: &str) -> String { + let mut first = true; + let mut result = String::new(); + for c in word.chars() { + if first { + first = false; + result.push(c.to_ascii_uppercase()); + } else { + result.push(c); + } + } + result + } + dbscheme_name + .split('_') + .map(|word| to_title_case(word)) + .collect::>() + .join("") +} diff --git a/ql/src/codeql_ruby/ast.qll b/ql/src/codeql_ruby/ast.qll index 618cc4c9d97..b66e15806ae 100644 --- a/ql/src/codeql_ruby/ast.qll +++ b/ql/src/codeql_ruby/ast.qll @@ -26,7 +26,7 @@ class Token extends @token, AstNode { override string describeQlClass() { result = "Token" } } -class ReservedWord extends Token, @reserved_word { +class ReservedWord extends @reserved_word, Token { override string describeQlClass() { result = "ReservedWord" } } @@ -173,7 +173,7 @@ class BlockParameter extends @block_parameter, AstNode { override Location getLocation() { block_parameter_def(this, _, result) } - Token getName() { block_parameter_def(this, result, _) } + Identifier getName() { block_parameter_def(this, result, _) } override AstNode getAFieldOrChild() { block_parameter_def(this, result, _) } } @@ -234,6 +234,10 @@ class ChainedString extends @chained_string, AstNode { override AstNode getAFieldOrChild() { chained_string_child(this, _, result) } } +class Character extends @token_character, Token { + override string describeQlClass() { result = "Character" } +} + class Class extends @class, AstNode { override string describeQlClass() { result = "Class" } @@ -246,6 +250,18 @@ class Class extends @class, AstNode { override AstNode getAFieldOrChild() { class_def(this, result, _) or class_child(this, _, result) } } +class ClassVariable extends @token_class_variable, Token { + override string describeQlClass() { result = "ClassVariable" } +} + +class Comment extends @token_comment, Token { + override string describeQlClass() { result = "Comment" } +} + +class Complex extends @token_complex, Token { + override string describeQlClass() { result = "Complex" } +} + class Conditional extends @conditional, AstNode { override string describeQlClass() { result = "Conditional" } @@ -264,6 +280,10 @@ class Conditional extends @conditional, AstNode { } } +class Constant extends @token_constant, Token { + override string describeQlClass() { result = "Constant" } +} + class DestructuredLeftAssignment extends @destructured_left_assignment, AstNode { override string describeQlClass() { result = "DestructuredLeftAssignment" } @@ -344,6 +364,10 @@ class Elsif extends @elsif, AstNode { } } +class EmptyStatement extends @token_empty_statement, Token { + override string describeQlClass() { result = "EmptyStatement" } +} + class EndBlock extends @end_block, AstNode { override string describeQlClass() { result = "EndBlock" } @@ -364,6 +388,10 @@ class Ensure extends @ensure, AstNode { override AstNode getAFieldOrChild() { ensure_child(this, _, result) } } +class EscapeSequence extends @token_escape_sequence, Token { + override string describeQlClass() { result = "EscapeSequence" } +} + class ExceptionVariable extends @exception_variable, AstNode { override string describeQlClass() { result = "ExceptionVariable" } @@ -384,6 +412,14 @@ class Exceptions extends @exceptions, AstNode { override AstNode getAFieldOrChild() { exceptions_child(this, _, result) } } +class False extends @token_false, Token { + override string describeQlClass() { result = "False" } +} + +class Float extends @token_float, Token { + override string describeQlClass() { result = "Float" } +} + class For extends @for, AstNode { override string describeQlClass() { result = "For" } @@ -400,6 +436,10 @@ class For extends @for, AstNode { } } +class GlobalVariable extends @token_global_variable, Token { + override string describeQlClass() { result = "GlobalVariable" } +} + class Hash extends @hash, AstNode { override string describeQlClass() { result = "Hash" } @@ -425,11 +465,15 @@ class HashSplatParameter extends @hash_splat_parameter, AstNode { override Location getLocation() { hash_splat_parameter_def(this, result) } - Token getName() { hash_splat_parameter_name(this, result) } + Identifier getName() { hash_splat_parameter_name(this, result) } override AstNode getAFieldOrChild() { hash_splat_parameter_name(this, result) } } +class HeredocBeginning extends @token_heredoc_beginning, Token { + override string describeQlClass() { result = "HeredocBeginning" } +} + class HeredocBody extends @heredoc_body, AstNode { override string describeQlClass() { result = "HeredocBody" } @@ -440,6 +484,18 @@ class HeredocBody extends @heredoc_body, AstNode { override AstNode getAFieldOrChild() { heredoc_body_child(this, _, result) } } +class HeredocContent extends @token_heredoc_content, Token { + override string describeQlClass() { result = "HeredocContent" } +} + +class HeredocEnd extends @token_heredoc_end, Token { + override string describeQlClass() { result = "HeredocEnd" } +} + +class Identifier extends @token_identifier, Token { + override string describeQlClass() { result = "Identifier" } +} + class If extends @if, AstNode { override string describeQlClass() { result = "If" } @@ -480,6 +536,14 @@ class In extends @in, AstNode { override AstNode getAFieldOrChild() { in_def(this, result, _) } } +class InstanceVariable extends @token_instance_variable, Token { + override string describeQlClass() { result = "InstanceVariable" } +} + +class Integer extends @token_integer, Token { + override string describeQlClass() { result = "Integer" } +} + class Interpolation extends @interpolation, AstNode { override string describeQlClass() { result = "Interpolation" } @@ -495,7 +559,7 @@ class KeywordParameter extends @keyword_parameter, AstNode { override Location getLocation() { keyword_parameter_def(this, _, result) } - Token getName() { keyword_parameter_def(this, result, _) } + Identifier getName() { keyword_parameter_def(this, result, _) } UnderscoreArg getValue() { keyword_parameter_value(this, result) } @@ -606,6 +670,14 @@ class Next extends @next, AstNode { override AstNode getAFieldOrChild() { next_child(this, result) } } +class Nil extends @token_nil, Token { + override string describeQlClass() { result = "Nil" } +} + +class Operator extends @token_operator, Token { + override string describeQlClass() { result = "Operator" } +} + class OperatorAssignment extends @operator_assignment, AstNode { override string describeQlClass() { result = "OperatorAssignment" } @@ -625,7 +697,7 @@ class OptionalParameter extends @optional_parameter, AstNode { override Location getLocation() { optional_parameter_def(this, _, _, result) } - Token getName() { optional_parameter_def(this, result, _, _) } + Identifier getName() { optional_parameter_def(this, result, _, _) } UnderscoreArg getValue() { optional_parameter_def(this, _, result, _) } @@ -693,7 +765,7 @@ class Rational extends @rational, AstNode { override Location getLocation() { rational_def(this, _, result) } - Token getChild() { rational_def(this, result, _) } + Integer getChild() { rational_def(this, result, _) } override AstNode getAFieldOrChild() { rational_def(this, result, _) } } @@ -802,12 +874,16 @@ class ScopeResolution extends @scope_resolution, AstNode { } } +class Self extends @token_self, Token { + override string describeQlClass() { result = "Self" } +} + class Setter extends @setter, AstNode { override string describeQlClass() { result = "Setter" } override Location getLocation() { setter_def(this, _, result) } - Token getChild() { setter_def(this, result, _) } + Identifier getChild() { setter_def(this, result, _) } override AstNode getAFieldOrChild() { setter_def(this, result, _) } } @@ -862,7 +938,7 @@ class SplatParameter extends @splat_parameter, AstNode { override Location getLocation() { splat_parameter_def(this, result) } - Token getName() { splat_parameter_name(this, result) } + Identifier getName() { splat_parameter_name(this, result) } override AstNode getAFieldOrChild() { splat_parameter_name(this, result) } } @@ -887,6 +963,10 @@ class StringArray extends @string_array, AstNode { override AstNode getAFieldOrChild() { string_array_child(this, _, result) } } +class StringContent extends @token_string_content, Token { + override string describeQlClass() { result = "StringContent" } +} + class Subshell extends @subshell, AstNode { override string describeQlClass() { result = "Subshell" } @@ -897,6 +977,10 @@ class Subshell extends @subshell, AstNode { override AstNode getAFieldOrChild() { subshell_child(this, _, result) } } +class Super extends @token_super, Token { + override string describeQlClass() { result = "Super" } +} + class Superclass extends @superclass, AstNode { override string describeQlClass() { result = "Superclass" } @@ -937,6 +1021,10 @@ class Then extends @then, AstNode { override AstNode getAFieldOrChild() { then_child(this, _, result) } } +class True extends @token_true, Token { + override string describeQlClass() { result = "True" } +} + class Unary extends @unary, AstNode { override string describeQlClass() { result = "Unary" } @@ -961,6 +1049,10 @@ class Undef extends @undef, AstNode { override AstNode getAFieldOrChild() { undef_child(this, _, result) } } +class Uninterpreted extends @token_uninterpreted, Token { + override string describeQlClass() { result = "Uninterpreted" } +} + class Unless extends @unless, AstNode { override string describeQlClass() { result = "Unless" } @@ -1070,95 +1162,3 @@ class Yield extends @yield, AstNode { override AstNode getAFieldOrChild() { yield_child(this, result) } } - -class Character extends Token, @token_character { - override string describeQlClass() { result = "Character" } -} - -class ClassVariable extends Token, @token_class_variable { - override string describeQlClass() { result = "ClassVariable" } -} - -class Comment extends Token, @token_comment { - override string describeQlClass() { result = "Comment" } -} - -class Complex extends Token, @token_complex { - override string describeQlClass() { result = "Complex" } -} - -class Constant extends Token, @token_constant { - override string describeQlClass() { result = "Constant" } -} - -class EmptyStatement extends Token, @token_empty_statement { - override string describeQlClass() { result = "EmptyStatement" } -} - -class EscapeSequence extends Token, @token_escape_sequence { - override string describeQlClass() { result = "EscapeSequence" } -} - -class False extends Token, @token_false { - override string describeQlClass() { result = "False" } -} - -class Float extends Token, @token_float { - override string describeQlClass() { result = "Float" } -} - -class GlobalVariable extends Token, @token_global_variable { - override string describeQlClass() { result = "GlobalVariable" } -} - -class HeredocBeginning extends Token, @token_heredoc_beginning { - override string describeQlClass() { result = "HeredocBeginning" } -} - -class HeredocContent extends Token, @token_heredoc_content { - override string describeQlClass() { result = "HeredocContent" } -} - -class HeredocEnd extends Token, @token_heredoc_end { - override string describeQlClass() { result = "HeredocEnd" } -} - -class Identifier extends Token, @token_identifier { - override string describeQlClass() { result = "Identifier" } -} - -class InstanceVariable extends Token, @token_instance_variable { - override string describeQlClass() { result = "InstanceVariable" } -} - -class Integer extends Token, @token_integer { - override string describeQlClass() { result = "Integer" } -} - -class Nil extends Token, @token_nil { - override string describeQlClass() { result = "Nil" } -} - -class Operator extends Token, @token_operator { - override string describeQlClass() { result = "Operator" } -} - -class Self extends Token, @token_self { - override string describeQlClass() { result = "Self" } -} - -class StringContent extends Token, @token_string_content { - override string describeQlClass() { result = "StringContent" } -} - -class Super extends Token, @token_super { - override string describeQlClass() { result = "Super" } -} - -class True extends Token, @token_true { - override string describeQlClass() { result = "True" } -} - -class Uninterpreted extends Token, @token_uninterpreted { - override string describeQlClass() { result = "Uninterpreted" } -} From bbe7c70d3423aa63ecafe6d9c7eebe88775b7e28 Mon Sep 17 00:00:00 2001 From: Nick Rolfe Date: Mon, 16 Nov 2020 15:55:06 +0000 Subject: [PATCH 2/5] more refactoring of names --- extractor/src/extractor.rs | 24 +++++++------- generator/src/main.rs | 51 +++++++++++++++-------------- generator/src/ql_gen.rs | 66 ++++++++++++++++++-------------------- node-types/src/lib.rs | 58 ++++++++++++++++++--------------- 4 files changed, 102 insertions(+), 97 deletions(-) diff --git a/extractor/src/extractor.rs b/extractor/src/extractor.rs index e3fd5108033..be9f32f8b6c 100644 --- a/extractor/src/extractor.rs +++ b/extractor/src/extractor.rs @@ -1,4 +1,4 @@ -use node_types::{escape_name, EntryKind, Field, NodeTypeMap, Storage, TypeName}; +use node_types::{EntryKind, Field, NodeTypeMap, Storage, TypeName}; use std::collections::BTreeMap as Map; use std::collections::BTreeSet as Set; use std::fmt; @@ -328,8 +328,10 @@ impl Visitor<'_> { ); self.token_counter += 1; } - EntryKind::Table { fields, .. } => { - let table_name = escape_name(&format!("{}_def", &table.flattened_name)); + EntryKind::Table { + fields, + name: table_name, + } => { if let Some(args) = self.complex_node(&node, fields, child_nodes, id) { let mut all_args = Vec::new(); all_args.push(Arg::Label(id)); @@ -409,7 +411,7 @@ impl Visitor<'_> { for field in fields { let child_ids = &map.get(&field.name).unwrap().1; match &field.storage { - Storage::Column => { + Storage::Column { name: column_name } => { if child_ids.len() == 1 { args.push(Arg::Label(*child_ids.first().unwrap())); } else { @@ -424,11 +426,14 @@ impl Visitor<'_> { "too many values" }, node.kind(), - &field.get_name() + column_name ) } } - Storage::Table(has_index) => { + Storage::Table { + name: table_name, + has_index, + } => { for (index, child_id) in child_ids.iter().enumerate() { if !*has_index && index > 0 { error!( @@ -436,15 +441,10 @@ impl Visitor<'_> { &self.path, node.start_position().row + 1, node.kind(), - &field.get_name() + table_name, ); break; } - let table_name = escape_name(&format!( - "{}_{}", - self.schema.get(&field.parent).unwrap().flattened_name, - field.get_name() - )); let mut args = Vec::new(); args.push(Arg::Label(parent_id)); if *has_index { diff --git a/generator/src/main.rs b/generator/src/main.rs index 73c30a1bbe9..dd49c2a405b 100644 --- a/generator/src/main.rs +++ b/generator/src/main.rs @@ -29,15 +29,15 @@ fn make_field_type( // type to represent them. let members: Set = types .iter() - .map(|t| node_types::escape_name(&nodes.get(t).unwrap().flattened_name)) + .map(|t| nodes.get(t).unwrap().dbscheme_name.clone()) .collect(); entries.push(dbscheme::Entry::Union(dbscheme::Union { - name: node_types::escape_name(&dbscheme_union), + name: dbscheme_union.clone(), members, })); dbscheme_union.clone() } - node_types::FieldTypeInfo::Single(t) => nodes.get(&t).unwrap().flattened_name.clone(), + node_types::FieldTypeInfo::Single(t) => nodes.get(&t).unwrap().dbscheme_name.clone(), } } @@ -49,18 +49,20 @@ fn add_field( entries: &mut Vec, nodes: &node_types::NodeTypeMap, ) { - let field_name = field.get_name(); - let parent_name = &nodes.get(&field.parent).unwrap().flattened_name; + let parent_name = &nodes.get(&field.parent).unwrap().dbscheme_name; match &field.storage { - node_types::Storage::Table(has_index) => { + node_types::Storage::Table { + name: table_name, + has_index, + } => { // This field can appear zero or multiple times, so put // it in an auxiliary table. - let field_type = node_types::escape_name(&make_field_type(&field, entries, nodes)); + let field_type = make_field_type(&field, entries, nodes); let parent_column = dbscheme::Column { unique: !*has_index, db_type: dbscheme::DbColumnType::Int, - name: node_types::escape_name(&parent_name), - ql_type: ql::Type::AtType(node_types::escape_name(&parent_name)), + name: parent_name.clone(), + ql_type: ql::Type::AtType(parent_name.clone()), ql_type_is_ref: true, }; let index_column = dbscheme::Column { @@ -78,7 +80,7 @@ fn add_field( ql_type_is_ref: true, }; let field_table = dbscheme::Table { - name: node_types::escape_name(&format!("{}_{}", parent_name, field_name)), + name: table_name.clone(), columns: if *has_index { vec![parent_column, index_column, field_column] } else { @@ -87,25 +89,22 @@ fn add_field( // In addition to the field being unique, the combination of // parent+index is unique, so add a keyset for them. keysets: if *has_index { - Some(vec![ - node_types::escape_name(&parent_name), - "index".to_string(), - ]) + Some(vec![parent_name.clone(), "index".to_string()]) } else { None }, }; entries.push(dbscheme::Entry::Table(field_table)); } - node_types::Storage::Column => { + node_types::Storage::Column { name: column_name } => { // This field must appear exactly once, so we add it as // a column to the main table for the node type. let field_type = make_field_type(&field, entries, nodes); main_table.columns.push(dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: node_types::escape_name(&field_name), - ql_type: ql::Type::AtType(node_types::escape_name(&field_type)), + name: column_name.clone(), + ql_type: ql::Type::AtType(field_type), ql_type_is_ref: true, }); } @@ -130,12 +129,12 @@ fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { .iter() .filter_map(|(_, node)| match &node.kind { node_types::EntryKind::Token { kind_id } => { - Some((node.flattened_name.clone(), *kind_id)) + Some((node.dbscheme_name.clone(), *kind_id)) } _ => None, }) .collect(); - ast_node_members.insert(node_types::escape_name("token")); + ast_node_members.insert("token".to_owned()); for (_, node) in nodes { match &node.kind { @@ -144,27 +143,27 @@ fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { // type. let members: Set = n_members .iter() - .map(|n| node_types::escape_name(&nodes.get(n).unwrap().flattened_name)) + .map(|n| nodes.get(n).unwrap().dbscheme_name.clone()) .collect(); entries.push(dbscheme::Entry::Union(dbscheme::Union { - name: node_types::escape_name(&node.flattened_name), + name: node.dbscheme_name.clone(), members, })); } - node_types::EntryKind::Table { fields } => { + node_types::EntryKind::Table { name, fields } => { // It's a product type, defined by a table. let mut main_table = dbscheme::Table { - name: node_types::escape_name(&(format!("{}_def", &node.flattened_name))), + name: name.clone(), columns: vec![dbscheme::Column { db_type: dbscheme::DbColumnType::Int, name: "id".to_string(), unique: true, - ql_type: ql::Type::AtType(node_types::escape_name(&node.flattened_name)), + ql_type: ql::Type::AtType(node.dbscheme_name.clone()), ql_type_is_ref: false, }], keysets: None, }; - ast_node_members.insert(node_types::escape_name(&node.flattened_name)); + ast_node_members.insert(node.dbscheme_name.clone()); // If the type also has fields or children, then we create either // auxiliary tables or columns in the defining table for them. @@ -262,7 +261,7 @@ fn add_tokeninfo_table(entries: &mut Vec, token_kinds: Map = token_kinds .iter() - .map(|(name, kind_id)| (*kind_id, node_types::escape_name(name))) + .map(|(name, kind_id)| (*kind_id, name.clone())) .collect(); entries.push(dbscheme::Entry::Case(dbscheme::Case { name: "token".to_owned(), diff --git a/generator/src/ql_gen.rs b/generator/src/ql_gen.rs index 996e5e9dde8..d1b85fd77b0 100644 --- a/generator/src/ql_gen.rs +++ b/generator/src/ql_gen.rs @@ -309,7 +309,6 @@ fn create_field_getters( main_table_name: &str, main_table_arity: usize, main_table_column_index: &mut usize, - parent_name: &str, field: &node_types::Field, nodes: &node_types::NodeTypeMap, ) -> (ql::Predicate, ql::Expression) { @@ -323,7 +322,7 @@ fn create_field_getters( } => ql_class.clone(), })); match &field.storage { - node_types::Storage::Column => { + node_types::Storage::Column { name: _ } => { let result = ( ql::Predicate { name: predicate_name, @@ -345,32 +344,32 @@ fn create_field_getters( *main_table_column_index += 1; result } - node_types::Storage::Table(has_index) => { - let field_table_name = format!("{}_{}", parent_name, &field.get_name()); - ( - ql::Predicate { - name: predicate_name, - overridden: false, - return_type, - formal_parameters: if *has_index { - vec![ql::FormalParameter { - name: "i".to_owned(), - param_type: ql::Type::Int, - }] - } else { - vec![] - }, - body: create_get_field_expr_for_table_storage( - &field_table_name, - if *has_index { Some("i") } else { None }, - ), + node_types::Storage::Table { + name: field_table_name, + has_index, + } => ( + ql::Predicate { + name: predicate_name, + overridden: false, + return_type, + formal_parameters: if *has_index { + vec![ql::FormalParameter { + name: "i".to_owned(), + param_type: ql::Type::Int, + }] + } else { + vec![] }, - create_get_field_expr_for_table_storage( + body: create_get_field_expr_for_table_storage( &field_table_name, - if *has_index { Some("_") } else { None }, + if *has_index { Some("i") } else { None }, ), - ) - } + }, + create_get_field_expr_for_table_storage( + &field_table_name, + if *has_index { Some("_") } else { None }, + ), + ), } } @@ -398,7 +397,7 @@ pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { if type_name.named { let describe_ql_class = create_describe_ql_class(&node.ql_class_name); let mut supertypes: BTreeSet = BTreeSet::new(); - supertypes.insert(ql::Type::AtType(node.flattened_name.to_owned())); + supertypes.insert(ql::Type::AtType(node.dbscheme_name.to_owned())); supertypes.insert(ql::Type::Normal("Token".to_owned())); classes.push(ql::TopLevel::Class(ql::Class { name: node.ql_class_name.clone(), @@ -416,7 +415,7 @@ pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { name: node.ql_class_name.clone(), is_abstract: false, supertypes: vec![ - ql::Type::AtType(node_types::escape_name(&node.flattened_name)), + ql::Type::AtType(node.dbscheme_name.clone()), ql::Type::Normal("AstNode".to_owned()), ] .into_iter() @@ -425,7 +424,10 @@ pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { predicates: vec![], })); } - node_types::EntryKind::Table { fields } => { + node_types::EntryKind::Table { + name: main_table_name, + fields, + } => { // Count how many columns there will be in the main table. // There will be: // - one for the id @@ -437,19 +439,16 @@ pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { } else { fields .iter() - .filter(|&f| matches!(f.storage, node_types::Storage::Column)) + .filter(|&f| matches!(f.storage, node_types::Storage::Column{..})) .count() }; - let escaped_name = node_types::escape_name(&node.flattened_name); let main_class_name = &node.ql_class_name; - let main_table_name = - node_types::escape_name(&format!("{}_def", &node.flattened_name)); let mut main_class = ql::Class { name: main_class_name.clone(), is_abstract: false, supertypes: vec![ - ql::Type::AtType(escaped_name), + ql::Type::AtType(node.dbscheme_name.clone()), ql::Type::Normal("AstNode".to_owned()), ] .into_iter() @@ -478,7 +477,6 @@ pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { &main_table_name, main_table_arity, &mut main_table_column_index, - &node.flattened_name, field, nodes, ); diff --git a/node-types/src/lib.rs b/node-types/src/lib.rs index 1ef0da8533b..4b279896730 100644 --- a/node-types/src/lib.rs +++ b/node-types/src/lib.rs @@ -10,7 +10,7 @@ pub type NodeTypeMap = BTreeMap; #[derive(Debug)] pub struct Entry { - pub flattened_name: String, + pub dbscheme_name: String, pub ql_class_name: String, pub kind: EntryKind, } @@ -18,7 +18,7 @@ pub struct Entry { #[derive(Debug)] pub enum EntryKind { Union { members: Set }, - Table { fields: Vec }, + Table { name: String, fields: Vec }, Token { kind_id: usize }, } @@ -61,10 +61,6 @@ fn name_for_field_or_child(name: &Option) -> String { } impl Field { - pub fn get_name(&self) -> String { - name_for_field_or_child(&self.name) - } - pub fn get_getter_name(&self) -> String { format!( "get{}", @@ -76,10 +72,10 @@ impl Field { #[derive(Debug)] pub enum Storage { /// the field is stored as a column in the parent table - Column, + Column { name: String }, /// the field is stored in a link table, and may or may not have an /// associated index column - Table(bool), + Table { name: String, has_index: bool }, } pub fn read_node_types(node_types_path: &Path) -> std::io::Result { @@ -109,8 +105,9 @@ pub fn convert_nodes(nodes: Vec) -> NodeTypeMap { let mut entries = NodeTypeMap::new(); let mut token_kinds = Set::new(); for node in nodes { - let flattened_name = node_type_name(&node.kind, node.named); - let ql_class_name = dbscheme_name_to_class_name(&escape_name(&flattened_name)); + let flattened_name = &node_type_name(&node.kind, node.named); + let dbscheme_name = escape_name(&flattened_name); + let ql_class_name = dbscheme_name_to_class_name(&dbscheme_name); if let Some(subtypes) = &node.subtypes { // It's a tree-sitter supertype node, for which we create a union // type. @@ -120,7 +117,7 @@ pub fn convert_nodes(nodes: Vec) -> NodeTypeMap { named: node.named, }, Entry { - flattened_name, + dbscheme_name, ql_class_name, kind: EntryKind::Union { members: convert_types(&subtypes), @@ -139,6 +136,7 @@ pub fn convert_nodes(nodes: Vec) -> NodeTypeMap { kind: node.kind, named: node.named, }; + let table_name = escape_name(&(format!("{}_def", &flattened_name))); let mut fields = Vec::new(); // If the type also has fields or children, then we create either @@ -160,9 +158,12 @@ pub fn convert_nodes(nodes: Vec) -> NodeTypeMap { entries.insert( type_name, Entry { - flattened_name, + dbscheme_name, ql_class_name, - kind: EntryKind::Table { fields }, + kind: EntryKind::Table { + name: table_name, + fields, + }, }, ); } @@ -173,13 +174,13 @@ pub fn convert_nodes(nodes: Vec) -> NodeTypeMap { counter += 1; let unprefixed_name = node_type_name(&type_name.kind, true); Entry { - flattened_name: format!("token_{}", &unprefixed_name), + dbscheme_name: escape_name(&format!("token_{}", &unprefixed_name)), ql_class_name: dbscheme_name_to_class_name(&escape_name(&unprefixed_name)), kind: EntryKind::Token { kind_id: counter }, } } else { Entry { - flattened_name: "reserved_word".to_owned(), + dbscheme_name: "reserved_word".to_owned(), ql_class_name: "ReservedWord".to_owned(), kind: EntryKind::Token { kind_id: 0 }, } @@ -195,18 +196,25 @@ fn add_field( field_info: &FieldInfo, fields: &mut Vec, ) { + let parent_flattened_name = node_type_name(&parent_type_name.kind, parent_type_name.named); let storage = if !field_info.multiple && field_info.required { // This field must appear exactly once, so we add it as // a column to the main table for the node type. - Storage::Column - } else if !field_info.multiple { - // This field is optional but can occur at most once. Put it in an - // auxiliary table without an index. - Storage::Table(false) + Storage::Column { + name: escape_name(&name_for_field_or_child(&field_name)), + } } else { - // This field can occur multiple times. Put it in an auxiliary table - // with an associated index. - Storage::Table(true) + // Put the field in an auxiliary table. + let has_index = field_info.multiple; + let field_table_name = escape_name(&format!( + "{}_{}", + parent_flattened_name, + &name_for_field_or_child(&field_name) + )); + Storage::Table { + has_index, + name: field_table_name, + } }; let type_info = if field_info.types.len() == 1 { FieldTypeInfo::Single(convert_type(field_info.types.iter().next().unwrap())) @@ -216,7 +224,7 @@ fn add_field( types: convert_types(&field_info.types), dbscheme_union: format!( "{}_{}_type", - &node_type_name(&parent_type_name.kind, parent_type_name.named), + &parent_flattened_name, &name_for_field_or_child(&field_name) ), ql_class: "AstNode".to_owned(), @@ -277,7 +285,7 @@ const RESERVED_KEYWORDS: [&'static str; 14] = [ /// Returns a string that's a copy of `name` but suitably escaped to be a valid /// QL identifier. -pub fn escape_name(name: &str) -> String { +fn escape_name(name: &str) -> String { let mut result = String::new(); // If there's a leading underscore, replace it with 'underscore_'. From ad61f7a0a619f6b2d77954692ae9800cc3cb9240 Mon Sep 17 00:00:00 2001 From: Nick Rolfe Date: Mon, 16 Nov 2020 17:48:47 +0000 Subject: [PATCH 3/5] Use references instead of owned strings in generator --- generator/src/dbscheme.rs | 46 ++--- generator/src/main.rs | 365 ++++++++++++++++++++------------------ generator/src/ql.rs | 68 +++---- generator/src/ql_gen.rs | 241 ++++++++++++------------- node-types/src/lib.rs | 16 +- 5 files changed, 367 insertions(+), 369 deletions(-) diff --git a/generator/src/dbscheme.rs b/generator/src/dbscheme.rs index d9dce0630f3..84cc329d065 100644 --- a/generator/src/dbscheme.rs +++ b/generator/src/dbscheme.rs @@ -2,41 +2,41 @@ use crate::ql; use std::collections::BTreeSet as Set; use std::fmt; /// Represents a distinct entry in the database schema. -pub enum Entry { +pub enum Entry<'a> { /// An entry defining a database table. - Table(Table), + Table(Table<'a>), /// An entry defining a database table. - Case(Case), + Case(Case<'a>), /// An entry defining type that is a union of other types. - Union(Union), + Union(Union<'a>), } /// A table in the database schema. -pub struct Table { - pub name: String, - pub columns: Vec, - pub keysets: Option>, +pub struct Table<'a> { + pub name: &'a str, + pub columns: Vec>, + pub keysets: Option>, } /// A union in the database schema. -pub struct Union { - pub name: String, - pub members: Set, +pub struct Union<'a> { + pub name: &'a str, + pub members: Set<&'a str>, } /// A table in the database schema. -pub struct Case { - pub name: String, - pub column: String, - pub branches: Vec<(usize, String)>, +pub struct Case<'a> { + pub name: &'a str, + pub column: &'a str, + pub branches: Vec<(usize, &'a str)>, } /// A column in a table. -pub struct Column { +pub struct Column<'a> { pub db_type: DbColumnType, - pub name: String, + pub name: &'a str, pub unique: bool, - pub ql_type: ql::Type, + pub ql_type: ql::Type<'a>, pub ql_type_is_ref: bool, } @@ -46,7 +46,7 @@ pub enum DbColumnType { String, } -impl fmt::Display for Case { +impl<'a> fmt::Display for Case<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "case @{}.{} of", &self.name, &self.column)?; let mut sep = " "; @@ -58,7 +58,7 @@ impl fmt::Display for Case { } } -impl fmt::Display for Table { +impl<'a> fmt::Display for Table<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if let Some(keyset) = &self.keysets { write!(f, "#keyset[")?; @@ -100,7 +100,7 @@ impl fmt::Display for Table { } } -impl fmt::Display for Union { +impl<'a> fmt::Display for Union<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "@{} = ", self.name)?; let mut first = true; @@ -117,10 +117,10 @@ impl fmt::Display for Union { } /// Generates the dbscheme by writing the given dbscheme `entries` to the `file`. -pub fn write( +pub fn write<'a>( language_name: &str, file: &mut dyn std::io::Write, - entries: &[Entry], + entries: &'a [Entry], ) -> std::io::Result<()> { write!(file, "// CodeQL database schema for {}\n", language_name)?; write!( diff --git a/generator/src/main.rs b/generator/src/main.rs index dd49c2a405b..652209df6bd 100644 --- a/generator/src/main.rs +++ b/generator/src/main.rs @@ -11,14 +11,14 @@ use std::io::LineWriter; use std::path::PathBuf; use tracing::{error, info}; -/// Given the name of the parent node, and its field information, returns the -/// name of the field's type. This may be an ad-hoc union of all the possible -/// types the field can take, in which case the union is added to `entries`. -fn make_field_type( - field: &node_types::Field, - entries: &mut Vec, - nodes: &node_types::NodeTypeMap, -) -> String { +/// Given the name of the parent node, and its field information, returns a pair, +/// the first of which is the name of the field's type. The second is an optional +/// dbscheme entry that should be added, representing a union of all the possible +/// types the field can take. +fn make_field_type<'a>( + field: &'a node_types::Field, + nodes: &'a node_types::NodeTypeMap, +) -> (&'a str, Option>) { match &field.type_info { node_types::FieldTypeInfo::Multiple { types, @@ -27,92 +27,93 @@ fn make_field_type( } => { // This field can have one of several types. Create an ad-hoc QL union // type to represent them. - let members: Set = types + let members: Set<&str> = types .iter() - .map(|t| nodes.get(t).unwrap().dbscheme_name.clone()) + .map(|t| nodes.get(t).unwrap().dbscheme_name.as_str()) .collect(); - entries.push(dbscheme::Entry::Union(dbscheme::Union { - name: dbscheme_union.clone(), - members, - })); - dbscheme_union.clone() + ( + &dbscheme_union, + Some(dbscheme::Entry::Union(dbscheme::Union { + name: dbscheme_union, + members, + })), + ) } - node_types::FieldTypeInfo::Single(t) => nodes.get(&t).unwrap().dbscheme_name.clone(), + node_types::FieldTypeInfo::Single(t) => (&nodes.get(&t).unwrap().dbscheme_name, None), } } -/// Adds the appropriate dbscheme information for the given field, either as a -/// column on `main_table`, or as an auxiliary table. -fn add_field( - main_table: &mut dbscheme::Table, - field: &node_types::Field, - entries: &mut Vec, - nodes: &node_types::NodeTypeMap, -) { +fn add_field_for_table_storage<'a>( + field: &'a node_types::Field, + table_name: &'a str, + has_index: bool, + nodes: &'a node_types::NodeTypeMap, +) -> (dbscheme::Table<'a>, Option>) { let parent_name = &nodes.get(&field.parent).unwrap().dbscheme_name; - match &field.storage { - node_types::Storage::Table { - name: table_name, - has_index, - } => { - // This field can appear zero or multiple times, so put - // it in an auxiliary table. - let field_type = make_field_type(&field, entries, nodes); - let parent_column = dbscheme::Column { - unique: !*has_index, - db_type: dbscheme::DbColumnType::Int, - name: parent_name.clone(), - ql_type: ql::Type::AtType(parent_name.clone()), - ql_type_is_ref: true, - }; - let index_column = dbscheme::Column { - unique: false, - db_type: dbscheme::DbColumnType::Int, - name: "index".to_string(), - ql_type: ql::Type::Int, - ql_type_is_ref: true, - }; - let field_column = dbscheme::Column { - unique: true, - db_type: dbscheme::DbColumnType::Int, - name: field_type.clone(), - ql_type: ql::Type::AtType(field_type), - ql_type_is_ref: true, - }; - let field_table = dbscheme::Table { - name: table_name.clone(), - columns: if *has_index { - vec![parent_column, index_column, field_column] - } else { - vec![parent_column, field_column] - }, - // In addition to the field being unique, the combination of - // parent+index is unique, so add a keyset for them. - keysets: if *has_index { - Some(vec![parent_name.clone(), "index".to_string()]) - } else { - None - }, - }; - entries.push(dbscheme::Entry::Table(field_table)); - } - node_types::Storage::Column { name: column_name } => { - // This field must appear exactly once, so we add it as - // a column to the main table for the node type. - let field_type = make_field_type(&field, entries, nodes); - main_table.columns.push(dbscheme::Column { - unique: false, - db_type: dbscheme::DbColumnType::Int, - name: column_name.clone(), - ql_type: ql::Type::AtType(field_type), - ql_type_is_ref: true, - }); - } - } + // This field can appear zero or multiple times, so put + // it in an auxiliary table. + let (field_type_name, field_type_entry) = make_field_type(&field, nodes); + let parent_column = dbscheme::Column { + unique: !has_index, + db_type: dbscheme::DbColumnType::Int, + name: &parent_name, + ql_type: ql::Type::AtType(&parent_name), + ql_type_is_ref: true, + }; + let index_column = dbscheme::Column { + unique: false, + db_type: dbscheme::DbColumnType::Int, + name: "index", + ql_type: ql::Type::Int, + ql_type_is_ref: true, + }; + let field_column = dbscheme::Column { + unique: true, + db_type: dbscheme::DbColumnType::Int, + name: field_type_name, + ql_type: ql::Type::AtType(field_type_name), + ql_type_is_ref: true, + }; + let field_table = dbscheme::Table { + name: &table_name, + columns: if has_index { + vec![parent_column, index_column, field_column] + } else { + vec![parent_column, field_column] + }, + // In addition to the field being unique, the combination of + // parent+index is unique, so add a keyset for them. + keysets: if has_index { + Some(vec![&parent_name, "index"]) + } else { + None + }, + }; + (field_table, field_type_entry) +} + +fn add_field_for_column_storage<'a>( + field: &'a node_types::Field, + column_name: &'a str, + nodes: &'a node_types::NodeTypeMap, +) -> (dbscheme::Column<'a>, Option>) { + // This field must appear exactly once, so we add it as + // a column to the main table for the node type. + let (field_type_name, field_type_entry) = make_field_type(&field, nodes); + ( + dbscheme::Column { + unique: false, + db_type: dbscheme::DbColumnType::Int, + name: column_name, + ql_type: ql::Type::AtType(field_type_name), + ql_type_is_ref: true, + }, + field_type_entry, + ) } /// Converts the given tree-sitter node types into CodeQL dbscheme entries. -fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { +fn convert_nodes<'a>(nodes: &'a node_types::NodeTypeMap) -> Vec> { let mut entries: Vec = vec![ create_location_union(), create_locations_default_table(), @@ -124,51 +125,68 @@ fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { create_containerparent_table(), create_source_location_prefix_table(), ]; - let mut ast_node_members: Set = Set::new(); - let token_kinds: Map = nodes + let mut ast_node_members: Set<&str> = Set::new(); + let token_kinds: Map<&str, usize> = nodes .iter() .filter_map(|(_, node)| match &node.kind { node_types::EntryKind::Token { kind_id } => { - Some((node.dbscheme_name.clone(), *kind_id)) + Some((node.dbscheme_name.as_str(), *kind_id)) } _ => None, }) .collect(); - ast_node_members.insert("token".to_owned()); + ast_node_members.insert("token"); for (_, node) in nodes { match &node.kind { node_types::EntryKind::Union { members: n_members } => { // It's a tree-sitter supertype node, for which we create a union // type. - let members: Set = n_members + let members: Set<&str> = n_members .iter() - .map(|n| nodes.get(n).unwrap().dbscheme_name.clone()) + .map(|n| nodes.get(n).unwrap().dbscheme_name.as_str()) .collect(); entries.push(dbscheme::Entry::Union(dbscheme::Union { - name: node.dbscheme_name.clone(), + name: &node.dbscheme_name, members, })); } node_types::EntryKind::Table { name, fields } => { // It's a product type, defined by a table. let mut main_table = dbscheme::Table { - name: name.clone(), + name: &name, columns: vec![dbscheme::Column { db_type: dbscheme::DbColumnType::Int, - name: "id".to_string(), + name: "id", unique: true, - ql_type: ql::Type::AtType(node.dbscheme_name.clone()), + ql_type: ql::Type::AtType(&node.dbscheme_name), ql_type_is_ref: false, }], keysets: None, }; - ast_node_members.insert(node.dbscheme_name.clone()); + ast_node_members.insert(&node.dbscheme_name); // If the type also has fields or children, then we create either // auxiliary tables or columns in the defining table for them. for field in fields { - add_field(&mut main_table, &field, &mut entries, nodes); + match &field.storage { + node_types::Storage::Column { name } => { + let (field_column, field_type_entry) = + add_field_for_column_storage(field, name, nodes); + if let Some(field_type_entry) = field_type_entry { + entries.push(field_type_entry); + } + main_table.columns.push(field_column); + } + node_types::Storage::Table { name, has_index } => { + let (field_table, field_type_entry) = + add_field_for_table_storage(field, name, *has_index, nodes); + if let Some(field_type_entry) = field_type_entry { + entries.push(field_type_entry); + } + entries.push(dbscheme::Entry::Table(field_table)); + } + } } if fields.is_empty() { @@ -177,7 +195,7 @@ fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { main_table.columns.push(dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::String, - name: "text".to_string(), + name: "text", ql_type: ql::Type::String, ql_type_is_ref: true, }); @@ -187,8 +205,8 @@ fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { main_table.columns.push(dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "loc".to_string(), - ql_type: ql::Type::AtType("location".to_string()), + name: "loc", + ql_type: ql::Type::AtType("location"), ql_type_is_ref: true, }); @@ -199,75 +217,80 @@ fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { } // Add the tokeninfo table - add_tokeninfo_table(&mut entries, token_kinds); + let (token_case, token_table) = create_tokeninfo(token_kinds); + entries.push(dbscheme::Entry::Table(token_table)); + entries.push(dbscheme::Entry::Case(token_case)); // Create a union of all database types. entries.push(dbscheme::Entry::Union(dbscheme::Union { - name: "ast_node".to_string(), + name: "ast_node", members: ast_node_members, })); entries } -fn add_tokeninfo_table(entries: &mut Vec, token_kinds: Map) { - entries.push(dbscheme::Entry::Table(dbscheme::Table { - name: "tokeninfo".to_owned(), +fn create_tokeninfo<'a>( + token_kinds: Map<&'a str, usize>, +) -> (dbscheme::Case<'a>, dbscheme::Table<'a>) { + let table = dbscheme::Table { + name: "tokeninfo", keysets: None, columns: vec![ dbscheme::Column { db_type: dbscheme::DbColumnType::Int, - name: "id".to_string(), + name: "id", unique: true, - ql_type: ql::Type::AtType("token".to_owned()), + ql_type: ql::Type::AtType("token"), ql_type_is_ref: false, }, dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "kind".to_string(), + name: "kind", ql_type: ql::Type::Int, ql_type_is_ref: true, }, dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "file".to_string(), - ql_type: ql::Type::AtType("file".to_string()), + name: "file", + ql_type: ql::Type::AtType("file"), ql_type_is_ref: true, }, dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "idx".to_string(), + name: "idx", ql_type: ql::Type::Int, ql_type_is_ref: true, }, dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::String, - name: "value".to_string(), + name: "value", ql_type: ql::Type::String, ql_type_is_ref: true, }, dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "loc".to_string(), - ql_type: ql::Type::AtType("location".to_string()), + name: "loc", + ql_type: ql::Type::AtType("location"), ql_type_is_ref: true, }, ], - })); - let branches: Vec<(usize, String)> = token_kinds + }; + let branches: Vec<(usize, &str)> = token_kinds .iter() - .map(|(name, kind_id)| (*kind_id, name.clone())) + .map(|(&name, kind_id)| (*kind_id, name)) .collect(); - entries.push(dbscheme::Entry::Case(dbscheme::Case { - name: "token".to_owned(), - column: "kind".to_owned(), + let case = dbscheme::Case { + name: "token", + column: "kind", branches: branches, - })); + }; + (case, table) } fn write_dbscheme(language: &Language, entries: &[dbscheme::Entry]) -> std::io::Result<()> { @@ -284,49 +307,49 @@ fn write_dbscheme(language: &Language, entries: &[dbscheme::Entry]) -> std::io:: dbscheme::write(&language.name, &mut file, &entries) } -fn create_location_union() -> dbscheme::Entry { +fn create_location_union<'a>() -> dbscheme::Entry<'a> { dbscheme::Entry::Union(dbscheme::Union { - name: "location".to_owned(), - members: vec!["location_default".to_owned()].into_iter().collect(), + name: "location", + members: vec!["location_default"].into_iter().collect(), }) } -fn create_files_table() -> dbscheme::Entry { +fn create_files_table<'a>() -> dbscheme::Entry<'a> { dbscheme::Entry::Table(dbscheme::Table { - name: "files".to_owned(), + name: "files", keysets: None, columns: vec![ dbscheme::Column { unique: true, db_type: dbscheme::DbColumnType::Int, - name: "id".to_owned(), - ql_type: ql::Type::AtType("file".to_owned()), + name: "id", + ql_type: ql::Type::AtType("file"), ql_type_is_ref: false, }, dbscheme::Column { db_type: dbscheme::DbColumnType::String, - name: "name".to_owned(), + name: "name", unique: false, ql_type: ql::Type::String, ql_type_is_ref: true, }, dbscheme::Column { db_type: dbscheme::DbColumnType::String, - name: "simple".to_owned(), + name: "simple", unique: false, ql_type: ql::Type::String, ql_type_is_ref: true, }, dbscheme::Column { db_type: dbscheme::DbColumnType::String, - name: "ext".to_owned(), + name: "ext", unique: false, ql_type: ql::Type::String, ql_type_is_ref: true, }, dbscheme::Column { db_type: dbscheme::DbColumnType::Int, - name: "fromSource".to_owned(), + name: "fromSource", unique: false, ql_type: ql::Type::Int, ql_type_is_ref: true, @@ -334,28 +357,28 @@ fn create_files_table() -> dbscheme::Entry { ], }) } -fn create_folders_table() -> dbscheme::Entry { +fn create_folders_table<'a>() -> dbscheme::Entry<'a> { dbscheme::Entry::Table(dbscheme::Table { - name: "folders".to_owned(), + name: "folders", keysets: None, columns: vec![ dbscheme::Column { unique: true, db_type: dbscheme::DbColumnType::Int, - name: "id".to_owned(), - ql_type: ql::Type::AtType("folder".to_owned()), + name: "id", + ql_type: ql::Type::AtType("folder"), ql_type_is_ref: false, }, dbscheme::Column { db_type: dbscheme::DbColumnType::String, - name: "name".to_owned(), + name: "name", unique: false, ql_type: ql::Type::String, ql_type_is_ref: true, }, dbscheme::Column { db_type: dbscheme::DbColumnType::String, - name: "simple".to_owned(), + name: "simple", unique: false, ql_type: ql::Type::String, ql_type_is_ref: true, @@ -364,50 +387,50 @@ fn create_folders_table() -> dbscheme::Entry { }) } -fn create_locations_default_table() -> dbscheme::Entry { +fn create_locations_default_table<'a>() -> dbscheme::Entry<'a> { dbscheme::Entry::Table(dbscheme::Table { - name: "locations_default".to_string(), + name: "locations_default", keysets: None, columns: vec![ dbscheme::Column { unique: true, db_type: dbscheme::DbColumnType::Int, - name: "id".to_string(), - ql_type: ql::Type::AtType("location_default".to_string()), + name: "id", + ql_type: ql::Type::AtType("location_default"), ql_type_is_ref: false, }, dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "file".to_string(), - ql_type: ql::Type::AtType("file".to_owned()), + name: "file", + ql_type: ql::Type::AtType("file"), ql_type_is_ref: true, }, dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "start_line".to_string(), + name: "start_line", ql_type: ql::Type::Int, ql_type_is_ref: true, }, dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "start_column".to_string(), + name: "start_column", ql_type: ql::Type::Int, ql_type_is_ref: true, }, dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "end_line".to_string(), + name: "end_line", ql_type: ql::Type::Int, ql_type_is_ref: true, }, dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "end_column".to_string(), + name: "end_column", ql_type: ql::Type::Int, ql_type_is_ref: true, }, @@ -415,42 +438,42 @@ fn create_locations_default_table() -> dbscheme::Entry { }) } -fn create_sourceline_union() -> dbscheme::Entry { +fn create_sourceline_union<'a>() -> dbscheme::Entry<'a> { dbscheme::Entry::Union(dbscheme::Union { - name: "sourceline".to_owned(), - members: vec!["file".to_owned()].into_iter().collect(), + name: "sourceline", + members: vec!["file"].into_iter().collect(), }) } -fn create_numlines_table() -> dbscheme::Entry { +fn create_numlines_table<'a>() -> dbscheme::Entry<'a> { dbscheme::Entry::Table(dbscheme::Table { - name: "numlines".to_owned(), + name: "numlines", columns: vec![ dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "element_id".to_string(), - ql_type: ql::Type::AtType("sourceline".to_owned()), + name: "element_id", + ql_type: ql::Type::AtType("sourceline"), ql_type_is_ref: true, }, dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "num_lines".to_string(), + name: "num_lines", ql_type: ql::Type::Int, ql_type_is_ref: true, }, dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "num_code".to_string(), + name: "num_code", ql_type: ql::Type::Int, ql_type_is_ref: true, }, dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "num_comment".to_string(), + name: "num_comment", ql_type: ql::Type::Int, ql_type_is_ref: true, }, @@ -459,31 +482,29 @@ fn create_numlines_table() -> dbscheme::Entry { }) } -fn create_container_union() -> dbscheme::Entry { +fn create_container_union<'a>() -> dbscheme::Entry<'a> { dbscheme::Entry::Union(dbscheme::Union { - name: "container".to_owned(), - members: vec!["folder".to_owned(), "file".to_owned()] - .into_iter() - .collect(), + name: "container", + members: vec!["folder", "file"].into_iter().collect(), }) } -fn create_containerparent_table() -> dbscheme::Entry { +fn create_containerparent_table<'a>() -> dbscheme::Entry<'a> { dbscheme::Entry::Table(dbscheme::Table { - name: "containerparent".to_owned(), + name: "containerparent", columns: vec![ dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, - name: "parent".to_string(), - ql_type: ql::Type::AtType("container".to_owned()), + name: "parent", + ql_type: ql::Type::AtType("container"), ql_type_is_ref: true, }, dbscheme::Column { unique: true, db_type: dbscheme::DbColumnType::Int, - name: "child".to_string(), - ql_type: ql::Type::AtType("container".to_owned()), + name: "child", + ql_type: ql::Type::AtType("container"), ql_type_is_ref: true, }, ], @@ -491,14 +512,14 @@ fn create_containerparent_table() -> dbscheme::Entry { }) } -fn create_source_location_prefix_table() -> dbscheme::Entry { +fn create_source_location_prefix_table<'a>() -> dbscheme::Entry<'a> { dbscheme::Entry::Table(dbscheme::Table { - name: "sourceLocationPrefix".to_string(), + name: "sourceLocationPrefix", keysets: None, columns: vec![dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::String, - name: "prefix".to_string(), + name: "prefix", ql_type: ql::Type::String, ql_type_is_ref: true, }], @@ -516,7 +537,7 @@ fn main() { // TODO: figure out proper dbscheme output path and/or take it from the // command line. let ruby = Language { - name: "Ruby".to_string(), + name: "Ruby".to_owned(), node_types: tree_sitter_ruby::NODE_TYPES, dbscheme_path: PathBuf::from("ql/src/ruby.dbscheme"), ql_library_path: PathBuf::from("ql/src/codeql_ruby/ast.qll"), diff --git a/generator/src/ql.rs b/generator/src/ql.rs index ec935fac8a8..daa261f099d 100644 --- a/generator/src/ql.rs +++ b/generator/src/ql.rs @@ -1,12 +1,12 @@ use std::collections::BTreeSet; use std::fmt; -pub enum TopLevel { - Class(Class), - Import(String), +pub enum TopLevel<'a> { + Class(Class<'a>), + Import(&'a str), } -impl fmt::Display for TopLevel { +impl<'a> fmt::Display for TopLevel<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { TopLevel::Import(x) => write!(f, "import {}", x), @@ -16,15 +16,15 @@ impl fmt::Display for TopLevel { } #[derive(Clone, Eq, PartialEq, Hash)] -pub struct Class { - pub name: String, +pub struct Class<'a> { + pub name: &'a str, pub is_abstract: bool, - pub supertypes: BTreeSet, - pub characteristic_predicate: Option, - pub predicates: Vec, + pub supertypes: BTreeSet>, + pub characteristic_predicate: Option>, + pub predicates: Vec>, } -impl fmt::Display for Class { +impl<'a> fmt::Display for Class<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if self.is_abstract { write!(f, "abstract ")?; @@ -64,7 +64,7 @@ impl fmt::Display for Class { // The QL type of a column. #[derive(Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] -pub enum Type { +pub enum Type<'a> { /// Primitive `int` type. Int, @@ -72,13 +72,13 @@ pub enum Type { String, /// A database type that will need to be referred to with an `@` prefix. - AtType(String), + AtType(&'a str), /// A user-defined type. - Normal(String), + Normal(&'a str), } -impl fmt::Display for Type { +impl<'a> fmt::Display for Type<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Type::Int => write!(f, "int"), @@ -90,16 +90,16 @@ impl fmt::Display for Type { } #[derive(Clone, Eq, PartialEq, Hash)] -pub enum Expression { - Var(String), - String(String), - Pred(String, Vec), - Or(Vec), - Equals(Box, Box), - Dot(Box, String, Vec), +pub enum Expression<'a> { + Var(&'a str), + String(&'a str), + Pred(&'a str, Vec>), + Or(Vec>), + Equals(Box>, Box>), + Dot(Box>, &'a str, Vec>), } -impl fmt::Display for Expression { +impl<'a> fmt::Display for Expression<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Expression::Var(x) => write!(f, "{}", x), @@ -143,15 +143,15 @@ impl fmt::Display for Expression { } #[derive(Clone, Eq, PartialEq, Hash)] -pub struct Predicate { - pub name: String, +pub struct Predicate<'a> { + pub name: &'a str, pub overridden: bool, - pub return_type: Option, - pub formal_parameters: Vec, - pub body: Expression, + pub return_type: Option>, + pub formal_parameters: Vec>, + pub body: Expression<'a>, } -impl fmt::Display for Predicate { +impl<'a> fmt::Display for Predicate<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if self.overridden { write!(f, "override ")?; @@ -174,22 +174,22 @@ impl fmt::Display for Predicate { } #[derive(Clone, Eq, PartialEq, Hash)] -pub struct FormalParameter { - pub name: String, - pub param_type: Type, +pub struct FormalParameter<'a> { + pub name: &'a str, + pub param_type: Type<'a>, } -impl fmt::Display for FormalParameter { +impl<'a> fmt::Display for FormalParameter<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{} {}", self.param_type, self.name) } } /// Generates a QL library by writing the given `classes` to the `file`. -pub fn write( +pub fn write<'a>( language_name: &str, file: &mut dyn std::io::Write, - elements: &[TopLevel], + elements: &'a [TopLevel], ) -> std::io::Result<()> { write!(file, "/*\n")?; write!(file, " * CodeQL library for {}\n", language_name)?; diff --git a/generator/src/ql_gen.rs b/generator/src/ql_gen.rs index d1b85fd77b0..15681ee7261 100644 --- a/generator/src/ql_gen.rs +++ b/generator/src/ql_gen.rs @@ -26,50 +26,40 @@ pub fn write(language: &Language, classes: &[ql::TopLevel]) -> std::io::Result<( /// Creates the hard-coded `AstNode` class that acts as a supertype of all /// classes we generate. -fn create_ast_node_class() -> ql::Class { +fn create_ast_node_class<'a>() -> ql::Class<'a> { // Default implementation of `toString` calls `this.describeQlClass()` let to_string = ql::Predicate { - name: "toString".to_owned(), + name: "toString", overridden: false, return_type: Some(ql::Type::String), formal_parameters: vec![], body: ql::Expression::Equals( - Box::new(ql::Expression::Var("result".to_owned())), + Box::new(ql::Expression::Var("result")), Box::new(ql::Expression::Dot( - Box::new(ql::Expression::Var("this".to_owned())), - "describeQlClass".to_owned(), + Box::new(ql::Expression::Var("this")), + "describeQlClass", vec![], )), ), }; - let get_location = create_none_predicate( - "getLocation", - false, - Some(ql::Type::Normal("Location".to_owned())), - vec![], - ); - let get_a_field_or_child = create_none_predicate( - "getAFieldOrChild", - false, - Some(ql::Type::Normal("AstNode".to_owned())), - vec![], - ); + let get_location = + create_none_predicate("getLocation", false, Some(ql::Type::Normal("Location"))); + let get_a_field_or_child = + create_none_predicate("getAFieldOrChild", false, Some(ql::Type::Normal("AstNode"))); let describe_ql_class = ql::Predicate { - name: "describeQlClass".to_owned(), + name: "describeQlClass", overridden: false, return_type: Some(ql::Type::String), formal_parameters: vec![], body: ql::Expression::Equals( - Box::new(ql::Expression::Var("result".to_owned())), - Box::new(ql::Expression::String("???".to_owned())), + Box::new(ql::Expression::Var("result")), + Box::new(ql::Expression::String("???")), ), }; ql::Class { - name: "AstNode".to_owned(), + name: "AstNode", is_abstract: false, - supertypes: vec![ql::Type::AtType("ast_node".to_owned())] - .into_iter() - .collect(), + supertypes: vec![ql::Type::AtType("ast_node")].into_iter().collect(), characteristic_predicate: None, predicates: vec![ to_string, @@ -80,60 +70,57 @@ fn create_ast_node_class() -> ql::Class { } } -fn create_token_class() -> ql::Class { +fn create_token_class<'a>() -> ql::Class<'a> { let get_value = ql::Predicate { - name: "getValue".to_owned(), + name: "getValue", overridden: false, return_type: Some(ql::Type::String), formal_parameters: vec![], body: ql::Expression::Pred( - "tokeninfo".to_owned(), + "tokeninfo", vec![ - ql::Expression::Var("this".to_owned()), - ql::Expression::Var("_".to_owned()), - ql::Expression::Var("_".to_owned()), - ql::Expression::Var("_".to_owned()), - ql::Expression::Var("result".to_owned()), - ql::Expression::Var("_".to_owned()), + ql::Expression::Var("this"), + ql::Expression::Var("_"), + ql::Expression::Var("_"), + ql::Expression::Var("_"), + ql::Expression::Var("result"), + ql::Expression::Var("_"), ], ), }; let get_location = ql::Predicate { - name: "getLocation".to_owned(), + name: "getLocation", overridden: true, - return_type: Some(ql::Type::Normal("Location".to_owned())), + return_type: Some(ql::Type::Normal("Location")), formal_parameters: vec![], body: ql::Expression::Pred( - "tokeninfo".to_owned(), + "tokeninfo", vec![ - ql::Expression::Var("this".to_owned()), - ql::Expression::Var("_".to_owned()), - ql::Expression::Var("_".to_owned()), - ql::Expression::Var("_".to_owned()), - ql::Expression::Var("_".to_owned()), - ql::Expression::Var("result".to_owned()), + ql::Expression::Var("this"), + ql::Expression::Var("_"), + ql::Expression::Var("_"), + ql::Expression::Var("_"), + ql::Expression::Var("_"), + ql::Expression::Var("result"), ], ), }; let to_string = ql::Predicate { - name: "toString".to_owned(), + name: "toString", overridden: true, return_type: Some(ql::Type::String), formal_parameters: vec![], body: ql::Expression::Equals( - Box::new(ql::Expression::Var("result".to_owned())), - Box::new(ql::Expression::Pred("getValue".to_owned(), vec![])), + Box::new(ql::Expression::Var("result")), + Box::new(ql::Expression::Pred("getValue", vec![])), ), }; ql::Class { - name: "Token".to_owned(), + name: "Token", is_abstract: false, - supertypes: vec![ - ql::Type::AtType("token".to_owned()), - ql::Type::Normal("AstNode".to_owned()), - ] - .into_iter() - .collect(), + supertypes: vec![ql::Type::AtType("token"), ql::Type::Normal("AstNode")] + .into_iter() + .collect(), characteristic_predicate: None, predicates: vec![ get_value, @@ -145,51 +132,47 @@ fn create_token_class() -> ql::Class { } // Creates the `ReservedWord` class. -fn create_reserved_word_class() -> ql::Class { +fn create_reserved_word_class<'a>() -> ql::Class<'a> { let db_name = "reserved_word"; - let class_name = "ReservedWord".to_owned(); + let class_name = "ReservedWord"; let describe_ql_class = create_describe_ql_class(&class_name); ql::Class { name: class_name, is_abstract: false, - supertypes: vec![ - ql::Type::AtType(db_name.to_owned()), - ql::Type::Normal("Token".to_owned()), - ] - .into_iter() - .collect(), + supertypes: vec![ql::Type::AtType(db_name), ql::Type::Normal("Token")] + .into_iter() + .collect(), characteristic_predicate: None, predicates: vec![describe_ql_class], } } /// Creates a predicate whose body is `none()`. -fn create_none_predicate( - name: &str, +fn create_none_predicate<'a>( + name: &'a str, overridden: bool, - return_type: Option, - formal_parameters: Vec, -) -> ql::Predicate { + return_type: Option>, +) -> ql::Predicate<'a> { ql::Predicate { - name: name.to_owned(), + name: name, overridden, return_type, - formal_parameters, - body: ql::Expression::Pred("none".to_owned(), vec![]), + formal_parameters: Vec::new(), + body: ql::Expression::Pred("none", vec![]), } } /// Creates an overridden `describeQlClass` predicate that returns the given /// name. -fn create_describe_ql_class(class_name: &str) -> ql::Predicate { +fn create_describe_ql_class<'a>(class_name: &'a str) -> ql::Predicate<'a> { ql::Predicate { - name: "describeQlClass".to_owned(), + name: "describeQlClass", overridden: true, return_type: Some(ql::Type::String), formal_parameters: vec![], body: ql::Expression::Equals( - Box::new(ql::Expression::Var("result".to_owned())), - Box::new(ql::Expression::String(class_name.to_owned())), + Box::new(ql::Expression::Var("result")), + Box::new(ql::Expression::String(class_name)), ), } } @@ -200,19 +183,19 @@ fn create_describe_ql_class(class_name: &str) -> ql::Predicate { /// /// `def_table` - the name of the table that defines the entity and its location. /// `arity` - the total number of columns in the table -fn create_get_location_predicate(def_table: &str, arity: usize) -> ql::Predicate { +fn create_get_location_predicate<'a>(def_table: &'a str, arity: usize) -> ql::Predicate<'a> { ql::Predicate { - name: "getLocation".to_owned(), + name: "getLocation", overridden: true, - return_type: Some(ql::Type::Normal("Location".to_owned())), + return_type: Some(ql::Type::Normal("Location")), formal_parameters: vec![], // body of the form: foo_bar_def(_, _, ..., result) body: ql::Expression::Pred( - def_table.to_owned(), + def_table, [ - vec![ql::Expression::Var("this".to_owned())], - vec![ql::Expression::Var("_".to_owned()); arity - 2], - vec![ql::Expression::Var("result".to_owned())], + vec![ql::Expression::Var("this")], + vec![ql::Expression::Var("_"); arity - 2], + vec![ql::Expression::Var("result")], ] .concat(), ), @@ -224,18 +207,18 @@ fn create_get_location_predicate(def_table: &str, arity: usize) -> ql::Predicate /// # Arguments /// /// `def_table` - the name of the table that defines the entity and its text. -fn create_get_text_predicate(def_table: &str) -> ql::Predicate { +fn create_get_text_predicate<'a>(def_table: &'a str) -> ql::Predicate<'a> { ql::Predicate { - name: "getText".to_owned(), + name: "getText", overridden: false, return_type: Some(ql::Type::String), formal_parameters: vec![], body: ql::Expression::Pred( - def_table.to_owned(), + def_table, vec![ - ql::Expression::Var("this".to_owned()), - ql::Expression::Var("result".to_owned()), - ql::Expression::Var("_".to_owned()), + ql::Expression::Var("this"), + ql::Expression::Var("result"), + ql::Expression::Var("_"), ], ), } @@ -248,20 +231,20 @@ fn create_get_text_predicate(def_table: &str) -> ql::Predicate { /// * `table_name` - the name of parent's defining table /// * `column_index` - the index in that table that defines the field /// * `arity` - the total number of columns in the table -fn create_get_field_expr_for_column_storage( - table_name: &str, +fn create_get_field_expr_for_column_storage<'a>( + table_name: &'a str, column_index: usize, arity: usize, -) -> ql::Expression { +) -> ql::Expression<'a> { let num_underscores_before = column_index - 1; let num_underscores_after = arity - 2 - num_underscores_before; ql::Expression::Pred( - table_name.to_owned(), + table_name, [ - vec![ql::Expression::Var("this".to_owned())], - vec![ql::Expression::Var("_".to_owned()); num_underscores_before], - vec![ql::Expression::Var("result".to_owned())], - vec![ql::Expression::Var("_".to_owned()); num_underscores_after], + vec![ql::Expression::Var("this")], + vec![ql::Expression::Var("_"); num_underscores_before], + vec![ql::Expression::Var("result")], + vec![ql::Expression::Var("_"); num_underscores_after], ] .concat(), ) @@ -270,22 +253,19 @@ fn create_get_field_expr_for_column_storage( /// Returns an expression to get the field with the given index from its /// auxiliary table. The index name can be "_" so the expression will hold for /// all indices. -fn create_get_field_expr_for_table_storage( - table_name: &str, - index_var_name: Option<&str>, -) -> ql::Expression { +fn create_get_field_expr_for_table_storage<'a>( + table_name: &'a str, + index_var_name: Option<&'a str>, +) -> ql::Expression<'a> { ql::Expression::Pred( - table_name.to_owned(), + table_name, match index_var_name { Some(index_var_name) => vec![ - ql::Expression::Var("this".to_owned()), - ql::Expression::Var(index_var_name.to_owned()), - ql::Expression::Var("result".to_owned()), - ], - None => vec![ - ql::Expression::Var("this".to_owned()), - ql::Expression::Var("result".to_owned()), + ql::Expression::Var("this"), + ql::Expression::Var(index_var_name), + ql::Expression::Var("result"), ], + None => vec![ql::Expression::Var("this"), ql::Expression::Var("result")], }, ) } @@ -305,27 +285,26 @@ fn create_get_field_expr_for_table_storage( /// `parent_name` - the name of the parent node /// `field` - the field whose getters we are creating /// `field_type` - the db name of the field's type (possibly being a union we created) -fn create_field_getters( - main_table_name: &str, +fn create_field_getters<'a>( + main_table_name: &'a str, main_table_arity: usize, main_table_column_index: &mut usize, - field: &node_types::Field, - nodes: &node_types::NodeTypeMap, -) -> (ql::Predicate, ql::Expression) { - let predicate_name = field.get_getter_name(); + field: &'a node_types::Field, + nodes: &'a node_types::NodeTypeMap, +) -> (ql::Predicate<'a>, ql::Expression<'a>) { let return_type = Some(ql::Type::Normal(match &field.type_info { - node_types::FieldTypeInfo::Single(t) => nodes.get(&t).unwrap().ql_class_name.clone(), + node_types::FieldTypeInfo::Single(t) => &nodes.get(&t).unwrap().ql_class_name, node_types::FieldTypeInfo::Multiple { types: _, dbscheme_union: _, ql_class, - } => ql_class.clone(), + } => &ql_class, })); match &field.storage { node_types::Storage::Column { name: _ } => { let result = ( ql::Predicate { - name: predicate_name, + name: &field.getter_name, overridden: false, return_type, formal_parameters: vec![], @@ -349,12 +328,12 @@ fn create_field_getters( has_index, } => ( ql::Predicate { - name: predicate_name, + name: &field.getter_name, overridden: false, return_type, formal_parameters: if *has_index { vec![ql::FormalParameter { - name: "i".to_owned(), + name: "i", param_type: ql::Type::Int, }] } else { @@ -374,10 +353,10 @@ fn create_field_getters( } /// Converts the given node types into CodeQL classes wrapping the dbscheme. -pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { +pub fn convert_nodes<'a>(nodes: &'a node_types::NodeTypeMap) -> Vec> { let mut classes: Vec = vec![ - ql::TopLevel::Import("codeql.files.FileSystem".to_owned()), - ql::TopLevel::Import("codeql.Locations".to_owned()), + ql::TopLevel::Import("codeql.files.FileSystem"), + ql::TopLevel::Import("codeql.Locations"), ql::TopLevel::Class(create_ast_node_class()), ql::TopLevel::Class(create_token_class()), ql::TopLevel::Class(create_reserved_word_class()), @@ -386,7 +365,7 @@ pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { for (type_name, node) in nodes { if let node_types::EntryKind::Token { .. } = &node.kind { if type_name.named { - token_kinds.insert(type_name.kind.to_owned()); + token_kinds.insert(&type_name.kind); } } } @@ -397,10 +376,10 @@ pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { if type_name.named { let describe_ql_class = create_describe_ql_class(&node.ql_class_name); let mut supertypes: BTreeSet = BTreeSet::new(); - supertypes.insert(ql::Type::AtType(node.dbscheme_name.to_owned())); - supertypes.insert(ql::Type::Normal("Token".to_owned())); + supertypes.insert(ql::Type::AtType(&node.dbscheme_name)); + supertypes.insert(ql::Type::Normal("Token")); classes.push(ql::TopLevel::Class(ql::Class { - name: node.ql_class_name.clone(), + name: &node.ql_class_name, is_abstract: false, supertypes, characteristic_predicate: None, @@ -412,11 +391,11 @@ pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { // It's a tree-sitter supertype node, so we're wrapping a dbscheme // union type. classes.push(ql::TopLevel::Class(ql::Class { - name: node.ql_class_name.clone(), + name: &node.ql_class_name, is_abstract: false, supertypes: vec![ - ql::Type::AtType(node.dbscheme_name.clone()), - ql::Type::Normal("AstNode".to_owned()), + ql::Type::AtType(&node.dbscheme_name), + ql::Type::Normal("AstNode"), ] .into_iter() .collect(), @@ -445,11 +424,11 @@ pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { let main_class_name = &node.ql_class_name; let mut main_class = ql::Class { - name: main_class_name.clone(), + name: &main_class_name, is_abstract: false, supertypes: vec![ - ql::Type::AtType(node.dbscheme_name.clone()), - ql::Type::Normal("AstNode".to_owned()), + ql::Type::AtType(&node.dbscheme_name), + ql::Type::Normal("AstNode"), ] .into_iter() .collect(), @@ -485,9 +464,9 @@ pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { } main_class.predicates.push(ql::Predicate { - name: "getAFieldOrChild".to_owned(), + name: "getAFieldOrChild", overridden: true, - return_type: Some(ql::Type::Normal("AstNode".to_owned())), + return_type: Some(ql::Type::Normal("AstNode")), formal_parameters: vec![], body: ql::Expression::Or(get_child_exprs), }); diff --git a/node-types/src/lib.rs b/node-types/src/lib.rs index 4b279896730..daa1232e8fc 100644 --- a/node-types/src/lib.rs +++ b/node-types/src/lib.rs @@ -50,6 +50,8 @@ pub struct Field { /// The name of the field or None for the anonymous 'children' /// entry from node_types.json pub name: Option, + /// The name of the predicate to get this field. + pub getter_name: String, pub storage: Storage, } @@ -60,15 +62,6 @@ fn name_for_field_or_child(name: &Option) -> String { } } -impl Field { - pub fn get_getter_name(&self) -> String { - format!( - "get{}", - dbscheme_name_to_class_name(&escape_name(&name_for_field_or_child(&self.name))) - ) - } -} - #[derive(Debug)] pub enum Storage { /// the field is stored as a column in the parent table @@ -230,6 +223,10 @@ fn add_field( ql_class: "AstNode".to_owned(), } }; + let getter_name = format!( + "get{}", + dbscheme_name_to_class_name(&escape_name(&name_for_field_or_child(&field_name))) + ); fields.push(Field { parent: TypeName { kind: parent_type_name.kind.to_string(), @@ -237,6 +234,7 @@ fn add_field( }, type_info, name: field_name, + getter_name, storage, }); } From 68c97a2d137834f5741c8ddd0c9d906336b34531 Mon Sep 17 00:00:00 2001 From: Nick Rolfe Date: Mon, 16 Nov 2020 18:41:18 +0000 Subject: [PATCH 4/5] Use `..` to ignore fields Co-authored-by: Arthur Baars --- extractor/src/extractor.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/extractor/src/extractor.rs b/extractor/src/extractor.rs index be9f32f8b6c..399473e60fc 100644 --- a/extractor/src/extractor.rs +++ b/extractor/src/extractor.rs @@ -480,8 +480,7 @@ impl Visitor<'_> { } node_types::FieldTypeInfo::Multiple { types, - dbscheme_union: _, - ql_class: _, + .. } => { return self.type_matches_set(tp, types); } From 1a9663ff7dd138dbcd365b92abbce9fbf6caa925 Mon Sep 17 00:00:00 2001 From: Nick Rolfe Date: Mon, 16 Nov 2020 18:43:54 +0000 Subject: [PATCH 5/5] Replace single-branch `match` with `if let` --- extractor/src/extractor.rs | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/extractor/src/extractor.rs b/extractor/src/extractor.rs index 399473e60fc..58565ae0e1a 100644 --- a/extractor/src/extractor.rs +++ b/extractor/src/extractor.rs @@ -478,10 +478,7 @@ impl Visitor<'_> { _ => {} } } - node_types::FieldTypeInfo::Multiple { - types, - .. - } => { + node_types::FieldTypeInfo::Multiple { types, .. } => { return self.type_matches_set(tp, types); } } @@ -493,13 +490,10 @@ impl Visitor<'_> { return true; } for other in types.iter() { - match &self.schema.get(other).unwrap().kind { - EntryKind::Union { members } => { - if self.type_matches_set(tp, members) { - return true; - } + if let EntryKind::Union { members } = &self.schema.get(other).unwrap().kind { + if self.type_matches_set(tp, members) { + return true; } - _ => {} } } false