diff --git a/extractor/src/extractor.rs b/extractor/src/extractor.rs index 566923c9774..e3fd5108033 100644 --- a/extractor/src/extractor.rs +++ b/extractor/src/extractor.rs @@ -1,4 +1,4 @@ -use node_types::{escape_name, node_type_name, Entry, Field, Storage, TypeName}; +use node_types::{escape_name, EntryKind, Field, NodeTypeMap, Storage, TypeName}; use std::collections::BTreeMap as Map; use std::collections::BTreeSet as Set; use std::fmt; @@ -150,10 +150,10 @@ impl TrapWriter { pub struct Extractor { pub parser: Parser, - pub schema: Vec, + pub schema: NodeTypeMap, } -pub fn create(language: Language, schema: Vec) -> Extractor { +pub fn create(language: Language, schema: NodeTypeMap) -> Extractor { let mut parser = Parser::new(); parser.set_language(language).unwrap(); @@ -189,8 +189,7 @@ impl Extractor { file_label: *file_label, token_counter: 0, stack: Vec::new(), - tables: build_schema_lookup(&self.schema), - union_types: build_union_type_lookup(&self.schema), + schema: &self.schema, }; traverse(&tree, &mut visitor); @@ -246,26 +245,6 @@ fn full_id_for_folder(path: &Path) -> String { format!("{};folder", normalize_path(path)) } -fn build_schema_lookup<'a>(schema: &'a Vec) -> Map<&'a TypeName, &'a Entry> { - let mut map = std::collections::BTreeMap::new(); - for entry in schema { - if let Entry::Token { type_name, .. } | Entry::Table { type_name, .. } = entry { - map.insert(type_name, entry); - } - } - map -} - -fn build_union_type_lookup<'a>(schema: &'a Vec) -> Map<&'a TypeName, &'a Set> { - let mut union_types = std::collections::BTreeMap::new(); - for entry in schema { - if let Entry::Union { type_name, members } = entry { - union_types.insert(type_name, members); - } - } - union_types -} - struct Visitor<'a> { /// The file path of the source code (as string) path: String, @@ -278,10 +257,8 @@ struct Visitor<'a> { trap_writer: TrapWriter, /// A counter for tokens token_counter: usize, - /// A lookup table from type name to dbscheme table entries - tables: Map<&'a TypeName, &'a Entry>, - /// A lookup table for union types mapping a type name to its direct members - union_types: Map<&'a TypeName, &'a Set>, + /// A lookup table from type name to node types + schema: &'a NodeTypeMap, /// A stack for gathering information from hild nodes. Whenever a node is entered /// an empty list is pushed. All children append their data (field name, label, type) to /// the the list. When the visitor leaves a node the list containing the child data is popped @@ -328,13 +305,16 @@ impl Visitor<'_> { end_line, end_column, ); - let table = self.tables.get(&TypeName { - kind: node.kind().to_owned(), - named: node.is_named(), - }); + let table = self + .schema + .get(&TypeName { + kind: node.kind().to_owned(), + named: node.is_named(), + }) + .unwrap(); let mut valid = true; - match table { - Some(Entry::Token { kind_id, .. }) => { + match &table.kind { + EntryKind::Token { kind_id, .. } => { self.trap_writer.add_tuple( "tokeninfo", vec![ @@ -348,11 +328,8 @@ impl Visitor<'_> { ); self.token_counter += 1; } - Some(Entry::Table { fields, .. }) => { - let table_name = escape_name(&format!( - "{}_def", - node_type_name(node.kind(), node.is_named()) - )); + EntryKind::Table { fields, .. } => { + let table_name = escape_name(&format!("{}_def", &table.flattened_name)); if let Some(args) = self.complex_node(&node, fields, child_nodes, id) { let mut all_args = Vec::new(); all_args.push(Arg::Label(id)); @@ -386,6 +363,7 @@ impl Visitor<'_> { }; } } + fn complex_node( &mut self, node: &Node, @@ -400,7 +378,7 @@ impl Visitor<'_> { for (child_field, child_id, child_type) in child_nodes { if let Some((field, values)) = map.get_mut(&child_field.map(|x| x.to_owned())) { //TODO: handle error and missing nodes - if self.type_matches(&child_type, &field.types) { + if self.type_matches(&child_type, &field.type_info) { values.push(child_id); } else if field.name.is_some() { error!( @@ -410,7 +388,7 @@ impl Visitor<'_> { node.kind(), child_field.unwrap_or("child"), child_type, - field.types + field.type_info ) } } else { @@ -464,7 +442,7 @@ impl Visitor<'_> { } let table_name = escape_name(&format!( "{}_{}", - node_type_name(&field.parent.kind, field.parent.named), + self.schema.get(&field.parent).unwrap().flattened_name, field.get_name() )); let mut args = Vec::new(); @@ -484,18 +462,48 @@ impl Visitor<'_> { None } } - fn type_matches(&self, tp: &TypeName, types: &Set) -> bool { + + fn type_matches(&self, tp: &TypeName, type_info: &node_types::FieldTypeInfo) -> bool { + match type_info { + node_types::FieldTypeInfo::Single(single_type) => { + if tp == single_type { + return true; + } + match &self.schema.get(single_type).unwrap().kind { + EntryKind::Union { members } => { + if self.type_matches_set(tp, members) { + return true; + } + } + _ => {} + } + } + node_types::FieldTypeInfo::Multiple { + types, + dbscheme_union: _, + ql_class: _, + } => { + return self.type_matches_set(tp, types); + } + } + false + } + + fn type_matches_set(&self, tp: &TypeName, types: &Set) -> bool { if types.contains(tp) { return true; } for other in types.iter() { - if let Some(x) = self.union_types.get(other) { - if self.type_matches(tp, x) { - return true; + match &self.schema.get(other).unwrap().kind { + EntryKind::Union { members } => { + if self.type_matches_set(tp, members) { + return true; + } } + _ => {} } } - return false; + false } } diff --git a/generator/src/main.rs b/generator/src/main.rs index 68b9cd267ef..73c30a1bbe9 100644 --- a/generator/src/main.rs +++ b/generator/src/main.rs @@ -11,71 +11,51 @@ use std::io::LineWriter; use std::path::PathBuf; use tracing::{error, info}; -fn child_node_type_name(token_types: &Map, t: &node_types::TypeName) -> String { - if !t.named { - // an unnamed token - "reserved_word".to_owned() - } else if token_types.contains_key(&t.kind) { - // a named token - format!("token_{}", t.kind) - } else { - // a normal node - node_types::node_type_name(&t.kind, t.named) - } -} - /// Given the name of the parent node, and its field information, returns the /// name of the field's type. This may be an ad-hoc union of all the possible /// types the field can take, in which case the union is added to `entries`. fn make_field_type( - token_types: &Map, - parent_name: &str, - field_name: &str, - types: &Set, + field: &node_types::Field, entries: &mut Vec, + nodes: &node_types::NodeTypeMap, ) -> String { - if types.len() == 1 { - // This field can only have a single type. - let t = types.iter().next().unwrap(); - node_types::escape_name(&child_node_type_name(token_types, t)) - } else { - // This field can have one of several types. Create an ad-hoc QL union - // type to represent them. - let field_union_name = format!("{}_{}_type", parent_name, field_name); - let field_union_name = node_types::escape_name(&field_union_name); - let members: Set = types - .iter() - .map(|t| node_types::escape_name(&child_node_type_name(token_types, t))) - .collect(); - entries.push(dbscheme::Entry::Union(dbscheme::Union { - name: field_union_name.clone(), - members, - })); - field_union_name + match &field.type_info { + node_types::FieldTypeInfo::Multiple { + types, + dbscheme_union, + ql_class: _, + } => { + // This field can have one of several types. Create an ad-hoc QL union + // type to represent them. + let members: Set = types + .iter() + .map(|t| node_types::escape_name(&nodes.get(t).unwrap().flattened_name)) + .collect(); + entries.push(dbscheme::Entry::Union(dbscheme::Union { + name: node_types::escape_name(&dbscheme_union), + members, + })); + dbscheme_union.clone() + } + node_types::FieldTypeInfo::Single(t) => nodes.get(&t).unwrap().flattened_name.clone(), } } /// Adds the appropriate dbscheme information for the given field, either as a /// column on `main_table`, or as an auxiliary table. fn add_field( - token_types: &Map, main_table: &mut dbscheme::Table, field: &node_types::Field, entries: &mut Vec, + nodes: &node_types::NodeTypeMap, ) { let field_name = field.get_name(); - let parent_name = node_types::node_type_name(&field.parent.kind, field.parent.named); + let parent_name = &nodes.get(&field.parent).unwrap().flattened_name; match &field.storage { node_types::Storage::Table(has_index) => { // This field can appear zero or multiple times, so put // it in an auxiliary table. - let field_type = make_field_type( - token_types, - &parent_name, - &field_name, - &field.types, - entries, - ); + let field_type = node_types::escape_name(&make_field_type(&field, entries, nodes)); let parent_column = dbscheme::Column { unique: !*has_index, db_type: dbscheme::DbColumnType::Int, @@ -93,12 +73,12 @@ fn add_field( let field_column = dbscheme::Column { unique: true, db_type: dbscheme::DbColumnType::Int, - name: node_types::escape_name(&field_type), + name: field_type.clone(), ql_type: ql::Type::AtType(field_type), ql_type_is_ref: true, }; let field_table = dbscheme::Table { - name: format!("{}_{}", parent_name, field_name), + name: node_types::escape_name(&format!("{}_{}", parent_name, field_name)), columns: if *has_index { vec![parent_column, index_column, field_column] } else { @@ -120,18 +100,12 @@ fn add_field( node_types::Storage::Column => { // This field must appear exactly once, so we add it as // a column to the main table for the node type. - let field_type = make_field_type( - token_types, - &parent_name, - &field_name, - &field.types, - entries, - ); + let field_type = make_field_type(&field, entries, nodes); main_table.columns.push(dbscheme::Column { unique: false, db_type: dbscheme::DbColumnType::Int, name: node_types::escape_name(&field_name), - ql_type: ql::Type::AtType(field_type), + ql_type: ql::Type::AtType(node_types::escape_name(&field_type)), ql_type_is_ref: true, }); } @@ -139,7 +113,7 @@ fn add_field( } /// Converts the given tree-sitter node types into CodeQL dbscheme entries. -fn convert_nodes(nodes: &Vec) -> Vec { +fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { let mut entries: Vec = vec![ create_location_union(), create_locations_default_table(), @@ -152,59 +126,50 @@ fn convert_nodes(nodes: &Vec) -> Vec { create_source_location_prefix_table(), ]; let mut ast_node_members: Set = Set::new(); - let mut token_kinds: Map = Map::new(); - ast_node_members.insert(node_types::escape_name("token")); - for node in nodes { - if let node_types::Entry::Token { type_name, kind_id } = node { - if type_name.named { - token_kinds.insert(type_name.kind.to_owned(), *kind_id); + let token_kinds: Map = nodes + .iter() + .filter_map(|(_, node)| match &node.kind { + node_types::EntryKind::Token { kind_id } => { + Some((node.flattened_name.clone(), *kind_id)) } - } - } + _ => None, + }) + .collect(); + ast_node_members.insert(node_types::escape_name("token")); - for node in nodes { - match &node { - node_types::Entry::Union { - type_name, - members: n_members, - } => { + for (_, node) in nodes { + match &node.kind { + node_types::EntryKind::Union { members: n_members } => { // It's a tree-sitter supertype node, for which we create a union // type. - let mut members: Set = Set::new(); - for n_member in n_members { - members.insert(node_types::escape_name(&child_node_type_name( - &token_kinds, - n_member, - ))); - } + let members: Set = n_members + .iter() + .map(|n| node_types::escape_name(&nodes.get(n).unwrap().flattened_name)) + .collect(); entries.push(dbscheme::Entry::Union(dbscheme::Union { - name: node_types::escape_name(&node_types::node_type_name( - &type_name.kind, - type_name.named, - )), + name: node_types::escape_name(&node.flattened_name), members, })); } - node_types::Entry::Table { type_name, fields } => { + node_types::EntryKind::Table { fields } => { // It's a product type, defined by a table. - let name = node_types::node_type_name(&type_name.kind, type_name.named); let mut main_table = dbscheme::Table { - name: node_types::escape_name(&(format!("{}_def", name))), + name: node_types::escape_name(&(format!("{}_def", &node.flattened_name))), columns: vec![dbscheme::Column { db_type: dbscheme::DbColumnType::Int, name: "id".to_string(), unique: true, - ql_type: ql::Type::AtType(node_types::escape_name(&name)), + ql_type: ql::Type::AtType(node_types::escape_name(&node.flattened_name)), ql_type_is_ref: false, }], keysets: None, }; - ast_node_members.insert(node_types::escape_name(&name)); + ast_node_members.insert(node_types::escape_name(&node.flattened_name)); // If the type also has fields or children, then we create either // auxiliary tables or columns in the defining table for them. for field in fields { - add_field(&token_kinds, &mut main_table, &field, &mut entries); + add_field(&mut main_table, &field, &mut entries, nodes); } if fields.is_empty() { @@ -230,7 +195,7 @@ fn convert_nodes(nodes: &Vec) -> Vec { entries.push(dbscheme::Entry::Table(main_table)); } - node_types::Entry::Token { .. } => {} + node_types::EntryKind::Token { .. } => {} } } @@ -295,15 +260,10 @@ fn add_tokeninfo_table(entries: &mut Vec, token_kinds: Map = Vec::new(); - branches.push((0, "reserved_word".to_owned())); - for (token_kind, idx) in token_kinds.iter() { - branches.push(( - *idx, - node_types::escape_name(&format!("token_{}", token_kind)), - )); - } - + let branches: Vec<(usize, String)> = token_kinds + .iter() + .map(|(name, kind_id)| (*kind_id, node_types::escape_name(name))) + .collect(); entries.push(dbscheme::Entry::Case(dbscheme::Case { name: "token".to_owned(), column: "kind".to_owned(), diff --git a/generator/src/ql.rs b/generator/src/ql.rs index bb554bca34e..ec935fac8a8 100644 --- a/generator/src/ql.rs +++ b/generator/src/ql.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeSet; use std::fmt; pub enum TopLevel { @@ -14,10 +15,11 @@ impl fmt::Display for TopLevel { } } +#[derive(Clone, Eq, PartialEq, Hash)] pub struct Class { pub name: String, pub is_abstract: bool, - pub supertypes: Vec, + pub supertypes: BTreeSet, pub characteristic_predicate: Option, pub predicates: Vec, } @@ -61,7 +63,7 @@ impl fmt::Display for Class { } // The QL type of a column. -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] pub enum Type { /// Primitive `int` type. Int, @@ -69,11 +71,11 @@ pub enum Type { /// Primitive `string` type. String, - /// A user-defined type. - Normal(String), - /// A database type that will need to be referred to with an `@` prefix. AtType(String), + + /// A user-defined type. + Normal(String), } impl fmt::Display for Type { @@ -87,15 +89,13 @@ impl fmt::Display for Type { } } -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq, Hash)] pub enum Expression { Var(String), String(String), Pred(String, Vec), Or(Vec), - And(Vec), Equals(Box, Box), - Exists(Vec, Box), Dot(Box, String, Vec), } @@ -127,30 +127,7 @@ impl fmt::Display for Expression { Ok(()) } } - Expression::And(conjuncts) => { - if conjuncts.is_empty() { - write!(f, "any()") - } else { - for (index, conjunct) in conjuncts.iter().enumerate() { - if index > 0 { - write!(f, " and ")?; - } - write!(f, "{}", conjunct)?; - } - Ok(()) - } - } Expression::Equals(a, b) => write!(f, "{} = {}", a, b), - Expression::Exists(params, formula) => { - write!(f, "exists(")?; - for (index, param) in params.iter().enumerate() { - if index > 0 { - write!(f, ", ")?; - } - write!(f, "{}", param)?; - } - write!(f, " | {})", formula) - } Expression::Dot(x, member_pred, args) => { write!(f, "{}.{}(", x, member_pred)?; for (index, arg) in args.iter().enumerate() { @@ -165,7 +142,7 @@ impl fmt::Display for Expression { } } -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq, Hash)] pub struct Predicate { pub name: String, pub overridden: bool, @@ -196,7 +173,7 @@ impl fmt::Display for Predicate { } } -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq, Hash)] pub struct FormalParameter { pub name: String, pub param_type: Type, diff --git a/generator/src/ql_gen.rs b/generator/src/ql_gen.rs index 767418ed0c8..996e5e9dde8 100644 --- a/generator/src/ql_gen.rs +++ b/generator/src/ql_gen.rs @@ -67,7 +67,9 @@ fn create_ast_node_class() -> ql::Class { ql::Class { name: "AstNode".to_owned(), is_abstract: false, - supertypes: vec![ql::Type::AtType("ast_node".to_owned())], + supertypes: vec![ql::Type::AtType("ast_node".to_owned())] + .into_iter() + .collect(), characteristic_predicate: None, predicates: vec![ to_string, @@ -77,6 +79,7 @@ fn create_ast_node_class() -> ql::Class { ], } } + fn create_token_class() -> ql::Class { let get_value = ql::Predicate { name: "getValue".to_owned(), @@ -128,7 +131,9 @@ fn create_token_class() -> ql::Class { supertypes: vec![ ql::Type::AtType("token".to_owned()), ql::Type::Normal("AstNode".to_owned()), - ], + ] + .into_iter() + .collect(), characteristic_predicate: None, predicates: vec![ get_value, @@ -141,16 +146,18 @@ fn create_token_class() -> ql::Class { // Creates the `ReservedWord` class. fn create_reserved_word_class() -> ql::Class { - let db_name = "reserved_word".to_owned(); - let class_name = dbscheme_name_to_class_name(&db_name); + let db_name = "reserved_word"; + let class_name = "ReservedWord".to_owned(); let describe_ql_class = create_describe_ql_class(&class_name); ql::Class { name: class_name, is_abstract: false, supertypes: vec![ + ql::Type::AtType(db_name.to_owned()), ql::Type::Normal("Token".to_owned()), - ql::Type::AtType(db_name), - ], + ] + .into_iter() + .collect(), characteristic_predicate: None, predicates: vec![describe_ql_class], } @@ -172,47 +179,6 @@ fn create_none_predicate( } } -/// Given the name of the parent node, and its field information, returns the -/// name of the field's type. This may be an ad-hoc union of all the possible -/// types the field can take, in which case we create a new class and push it to -/// `classes`. -fn create_field_class(token_kinds: &BTreeSet, field: &node_types::Field) -> String { - if field.types.len() == 1 { - // This field can only have a single type. - let t = field.types.iter().next().unwrap(); - if !t.named || token_kinds.contains(&t.kind) { - "Token".to_owned() - } else { - node_types::escape_name(&node_types::node_type_name(&t.kind, t.named)) - } - } else { - "AstNode".to_owned() - } -} - -/// Given a valid dbscheme name (i.e. in snake case), produces the equivalent QL -/// name (i.e. in CamelCase). For example, "foo_bar_baz" becomes "FooBarBaz". -fn dbscheme_name_to_class_name(dbscheme_name: &str) -> String { - fn to_title_case(word: &str) -> String { - let mut first = true; - let mut result = String::new(); - for c in word.chars() { - if first { - first = false; - result.push(c.to_ascii_uppercase()); - } else { - result.push(c); - } - } - result - } - dbscheme_name - .split('_') - .map(|word| to_title_case(word)) - .collect::>() - .join("") -} - /// Creates an overridden `describeQlClass` predicate that returns the given /// name. fn create_describe_ql_class(class_name: &str) -> ql::Predicate { @@ -345,20 +311,24 @@ fn create_field_getters( main_table_column_index: &mut usize, parent_name: &str, field: &node_types::Field, - field_type: &str, + nodes: &node_types::NodeTypeMap, ) -> (ql::Predicate, ql::Expression) { - let predicate_name = format!( - "get{}", - dbscheme_name_to_class_name(&node_types::escape_name(&field.get_name())) - ); - let return_type = Some(ql::Type::Normal(dbscheme_name_to_class_name(field_type))); + let predicate_name = field.get_getter_name(); + let return_type = Some(ql::Type::Normal(match &field.type_info { + node_types::FieldTypeInfo::Single(t) => nodes.get(&t).unwrap().ql_class_name.clone(), + node_types::FieldTypeInfo::Multiple { + types: _, + dbscheme_union: _, + ql_class, + } => ql_class.clone(), + })); match &field.storage { node_types::Storage::Column => { let result = ( ql::Predicate { name: predicate_name, overridden: false, - return_type: return_type, + return_type, formal_parameters: vec![], body: create_get_field_expr_for_column_storage( &main_table_name, @@ -381,7 +351,7 @@ fn create_field_getters( ql::Predicate { name: predicate_name, overridden: false, - return_type: return_type, + return_type, formal_parameters: if *has_index { vec![ql::FormalParameter { name: "i".to_owned(), @@ -405,7 +375,7 @@ fn create_field_getters( } /// Converts the given node types into CodeQL classes wrapping the dbscheme. -pub fn convert_nodes(nodes: &Vec) -> Vec { +pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec { let mut classes: Vec = vec![ ql::TopLevel::Import("codeql.files.FileSystem".to_owned()), ql::TopLevel::Import("codeql.Locations".to_owned()), @@ -414,61 +384,48 @@ pub fn convert_nodes(nodes: &Vec) -> Vec { ql::TopLevel::Class(create_reserved_word_class()), ]; let mut token_kinds = BTreeSet::new(); - for node in nodes { - if let node_types::Entry::Token { type_name, .. } = node { + for (type_name, node) in nodes { + if let node_types::EntryKind::Token { .. } = &node.kind { if type_name.named { token_kinds.insert(type_name.kind.to_owned()); } } } - for node in nodes { - match &node { - node_types::Entry::Token { - type_name, - kind_id: _, - } => { + for (type_name, node) in nodes { + match &node.kind { + node_types::EntryKind::Token { kind_id: _ } => { if type_name.named { - let db_name = format!("token_{}", &type_name.kind); - let db_name = node_types::escape_name(&db_name); - let class_name = - dbscheme_name_to_class_name(&node_types::escape_name(&type_name.kind)); - let describe_ql_class = create_describe_ql_class(&class_name); + let describe_ql_class = create_describe_ql_class(&node.ql_class_name); + let mut supertypes: BTreeSet = BTreeSet::new(); + supertypes.insert(ql::Type::AtType(node.flattened_name.to_owned())); + supertypes.insert(ql::Type::Normal("Token".to_owned())); classes.push(ql::TopLevel::Class(ql::Class { - name: class_name, + name: node.ql_class_name.clone(), is_abstract: false, - supertypes: vec![ - ql::Type::Normal("Token".to_owned()), - ql::Type::AtType(db_name), - ], + supertypes, characteristic_predicate: None, predicates: vec![describe_ql_class], })); } } - node_types::Entry::Union { - type_name, - members: _, - } => { + node_types::EntryKind::Union { members: _ } => { // It's a tree-sitter supertype node, so we're wrapping a dbscheme // union type. - let union_name = node_types::escape_name(&node_types::node_type_name( - &type_name.kind, - type_name.named, - )); - let class_name = dbscheme_name_to_class_name(&union_name); classes.push(ql::TopLevel::Class(ql::Class { - name: class_name.clone(), + name: node.ql_class_name.clone(), is_abstract: false, supertypes: vec![ - ql::Type::AtType(union_name), + ql::Type::AtType(node_types::escape_name(&node.flattened_name)), ql::Type::Normal("AstNode".to_owned()), - ], + ] + .into_iter() + .collect(), characteristic_predicate: None, predicates: vec![], })); } - node_types::Entry::Table { type_name, fields } => { + node_types::EntryKind::Table { fields } => { // Count how many columns there will be in the main table. // There will be: // - one for the id @@ -484,15 +441,19 @@ pub fn convert_nodes(nodes: &Vec) -> Vec { .count() }; - let name = node_types::node_type_name(&type_name.kind, type_name.named); - let dbscheme_name = node_types::escape_name(&name); - let ql_type = ql::Type::AtType(dbscheme_name.clone()); - let main_table_name = node_types::escape_name(&(format!("{}_def", name))); - let main_class_name = dbscheme_name_to_class_name(&dbscheme_name); + let escaped_name = node_types::escape_name(&node.flattened_name); + let main_class_name = &node.ql_class_name; + let main_table_name = + node_types::escape_name(&format!("{}_def", &node.flattened_name)); let mut main_class = ql::Class { name: main_class_name.clone(), is_abstract: false, - supertypes: vec![ql_type, ql::Type::Normal("AstNode".to_owned())], + supertypes: vec![ + ql::Type::AtType(escaped_name), + ql::Type::Normal("AstNode".to_owned()), + ] + .into_iter() + .collect(), characteristic_predicate: None, predicates: vec![ create_describe_ql_class(&main_class_name), @@ -513,14 +474,13 @@ pub fn convert_nodes(nodes: &Vec) -> Vec { // - predicates to access the fields, // - the QL expressions to access the fields that will be part of getAFieldOrChild. for field in fields { - let field_type = create_field_class(&token_kinds, field); let (get_pred, get_child_expr) = create_field_getters( &main_table_name, main_table_arity, &mut main_table_column_index, - &name, + &node.flattened_name, field, - &field_type, + nodes, ); main_class.predicates.push(get_pred); get_child_exprs.push(get_child_expr); diff --git a/node-types/src/lib.rs b/node-types/src/lib.rs index e2de9df31f3..1ef0da8533b 100644 --- a/node-types/src/lib.rs +++ b/node-types/src/lib.rs @@ -5,20 +5,21 @@ use std::path::Path; use std::collections::BTreeSet as Set; use std::fs; +/// A lookup table from TypeName to Entry. +pub type NodeTypeMap = BTreeMap; + #[derive(Debug)] -pub enum Entry { - Union { - type_name: TypeName, - members: Set, - }, - Table { - type_name: TypeName, - fields: Vec, - }, - Token { - type_name: TypeName, - kind_id: usize, - }, +pub struct Entry { + pub flattened_name: String, + pub ql_class_name: String, + pub kind: EntryKind, +} + +#[derive(Debug)] +pub enum EntryKind { + Union { members: Set }, + Table { fields: Vec }, + Token { kind_id: usize }, } #[derive(Debug, Ord, PartialOrd, Eq, PartialEq)] @@ -27,22 +28,48 @@ pub struct TypeName { pub named: bool, } +#[derive(Debug)] +pub enum FieldTypeInfo { + /// The field has a single type. + Single(TypeName), + + /// The field can take one of several types, so we also provide the name of + /// the database union type that wraps them, and the corresponding QL class + /// name. + Multiple { + types: Set, + dbscheme_union: String, + ql_class: String, + }, +} + #[derive(Debug)] pub struct Field { pub parent: TypeName, - pub types: Set, + pub type_info: FieldTypeInfo, /// The name of the field or None for the anonymous 'children' /// entry from node_types.json pub name: Option, pub storage: Storage, } +fn name_for_field_or_child(name: &Option) -> String { + match name { + Some(name) => name.clone(), + None => "child".to_owned(), + } +} + impl Field { pub fn get_name(&self) -> String { - match &self.name { - Some(name) => name.clone(), - None => "child".to_owned(), - } + name_for_field_or_child(&self.name) + } + + pub fn get_getter_name(&self) -> String { + format!( + "get{}", + dbscheme_name_to_class_name(&escape_name(&name_for_field_or_child(&self.name))) + ) } } @@ -55,13 +82,13 @@ pub enum Storage { Table(bool), } -pub fn read_node_types(node_types_path: &Path) -> std::io::Result> { +pub fn read_node_types(node_types_path: &Path) -> std::io::Result { let file = fs::File::open(node_types_path)?; let node_types = serde_json::from_reader(file)?; Ok(convert_nodes(node_types)) } -pub fn read_node_types_str(node_types_json: &str) -> std::io::Result> { +pub fn read_node_types_str(node_types_json: &str) -> std::io::Result { let node_types = serde_json::from_str(node_types_json)?; Ok(convert_nodes(node_types)) } @@ -77,20 +104,29 @@ fn convert_types(node_types: &Vec) -> Set { let iter = node_types.iter().map(convert_type).collect(); std::collections::BTreeSet::from(iter) } -pub fn convert_nodes(nodes: Vec) -> Vec { - let mut entries: Vec = Vec::new(); + +pub fn convert_nodes(nodes: Vec) -> NodeTypeMap { + let mut entries = NodeTypeMap::new(); let mut token_kinds = Set::new(); for node in nodes { + let flattened_name = node_type_name(&node.kind, node.named); + let ql_class_name = dbscheme_name_to_class_name(&escape_name(&flattened_name)); if let Some(subtypes) = &node.subtypes { // It's a tree-sitter supertype node, for which we create a union // type. - entries.push(Entry::Union { - type_name: TypeName { + entries.insert( + TypeName { kind: node.kind, named: node.named, }, - members: convert_types(&subtypes), - }); + Entry { + flattened_name, + ql_class_name, + kind: EntryKind::Union { + members: convert_types(&subtypes), + }, + }, + ); } else if node.fields.as_ref().map_or(0, |x| x.len()) == 0 && node.children.is_none() { let type_name = TypeName { kind: node.kind, @@ -121,18 +157,34 @@ pub fn convert_nodes(nodes: Vec) -> Vec { // Treat children as if they were a field called 'child'. add_field(&type_name, None, children, &mut fields); } - entries.push(Entry::Table { type_name, fields }); + entries.insert( + type_name, + Entry { + flattened_name, + ql_class_name, + kind: EntryKind::Table { fields }, + }, + ); } } let mut counter = 0; for type_name in token_kinds { - let kind_id = if type_name.named { + let entry = if type_name.named { counter += 1; - counter + let unprefixed_name = node_type_name(&type_name.kind, true); + Entry { + flattened_name: format!("token_{}", &unprefixed_name), + ql_class_name: dbscheme_name_to_class_name(&escape_name(&unprefixed_name)), + kind: EntryKind::Token { kind_id: counter }, + } } else { - 0 + Entry { + flattened_name: "reserved_word".to_owned(), + ql_class_name: "ReservedWord".to_owned(), + kind: EntryKind::Token { kind_id: 0 }, + } }; - entries.push(Entry::Token { type_name, kind_id }); + entries.insert(type_name, entry); } entries } @@ -156,12 +208,26 @@ fn add_field( // with an associated index. Storage::Table(true) }; + let type_info = if field_info.types.len() == 1 { + FieldTypeInfo::Single(convert_type(field_info.types.iter().next().unwrap())) + } else { + // The dbscheme type for this field will be a union. In QL, it'll just be AstNode. + FieldTypeInfo::Multiple { + types: convert_types(&field_info.types), + dbscheme_union: format!( + "{}_{}_type", + &node_type_name(&parent_type_name.kind, parent_type_name.named), + &name_for_field_or_child(&field_name) + ), + ql_class: "AstNode".to_owned(), + } + }; fields.push(Field { parent: TypeName { kind: parent_type_name.kind.to_string(), named: parent_type_name.named, }, - types: convert_types(&field_info.types), + type_info, name: field_name, storage, }); @@ -196,7 +262,7 @@ pub struct FieldInfo { /// Given a tree-sitter node type's (kind, named) pair, returns a single string /// representing the (unescaped) name we'll use to refer to corresponding QL /// type. -pub fn node_type_name(kind: &str, named: bool) -> String { +fn node_type_name(kind: &str, named: bool) -> String { if named { kind.to_string() } else { @@ -267,3 +333,26 @@ pub fn escape_name(name: &str) -> String { result } + +/// Given a valid dbscheme name (i.e. in snake case), produces the equivalent QL +/// name (i.e. in CamelCase). For example, "foo_bar_baz" becomes "FooBarBaz". +fn dbscheme_name_to_class_name(dbscheme_name: &str) -> String { + fn to_title_case(word: &str) -> String { + let mut first = true; + let mut result = String::new(); + for c in word.chars() { + if first { + first = false; + result.push(c.to_ascii_uppercase()); + } else { + result.push(c); + } + } + result + } + dbscheme_name + .split('_') + .map(|word| to_title_case(word)) + .collect::>() + .join("") +} diff --git a/ql/src/codeql_ruby/ast.qll b/ql/src/codeql_ruby/ast.qll index 618cc4c9d97..b66e15806ae 100644 --- a/ql/src/codeql_ruby/ast.qll +++ b/ql/src/codeql_ruby/ast.qll @@ -26,7 +26,7 @@ class Token extends @token, AstNode { override string describeQlClass() { result = "Token" } } -class ReservedWord extends Token, @reserved_word { +class ReservedWord extends @reserved_word, Token { override string describeQlClass() { result = "ReservedWord" } } @@ -173,7 +173,7 @@ class BlockParameter extends @block_parameter, AstNode { override Location getLocation() { block_parameter_def(this, _, result) } - Token getName() { block_parameter_def(this, result, _) } + Identifier getName() { block_parameter_def(this, result, _) } override AstNode getAFieldOrChild() { block_parameter_def(this, result, _) } } @@ -234,6 +234,10 @@ class ChainedString extends @chained_string, AstNode { override AstNode getAFieldOrChild() { chained_string_child(this, _, result) } } +class Character extends @token_character, Token { + override string describeQlClass() { result = "Character" } +} + class Class extends @class, AstNode { override string describeQlClass() { result = "Class" } @@ -246,6 +250,18 @@ class Class extends @class, AstNode { override AstNode getAFieldOrChild() { class_def(this, result, _) or class_child(this, _, result) } } +class ClassVariable extends @token_class_variable, Token { + override string describeQlClass() { result = "ClassVariable" } +} + +class Comment extends @token_comment, Token { + override string describeQlClass() { result = "Comment" } +} + +class Complex extends @token_complex, Token { + override string describeQlClass() { result = "Complex" } +} + class Conditional extends @conditional, AstNode { override string describeQlClass() { result = "Conditional" } @@ -264,6 +280,10 @@ class Conditional extends @conditional, AstNode { } } +class Constant extends @token_constant, Token { + override string describeQlClass() { result = "Constant" } +} + class DestructuredLeftAssignment extends @destructured_left_assignment, AstNode { override string describeQlClass() { result = "DestructuredLeftAssignment" } @@ -344,6 +364,10 @@ class Elsif extends @elsif, AstNode { } } +class EmptyStatement extends @token_empty_statement, Token { + override string describeQlClass() { result = "EmptyStatement" } +} + class EndBlock extends @end_block, AstNode { override string describeQlClass() { result = "EndBlock" } @@ -364,6 +388,10 @@ class Ensure extends @ensure, AstNode { override AstNode getAFieldOrChild() { ensure_child(this, _, result) } } +class EscapeSequence extends @token_escape_sequence, Token { + override string describeQlClass() { result = "EscapeSequence" } +} + class ExceptionVariable extends @exception_variable, AstNode { override string describeQlClass() { result = "ExceptionVariable" } @@ -384,6 +412,14 @@ class Exceptions extends @exceptions, AstNode { override AstNode getAFieldOrChild() { exceptions_child(this, _, result) } } +class False extends @token_false, Token { + override string describeQlClass() { result = "False" } +} + +class Float extends @token_float, Token { + override string describeQlClass() { result = "Float" } +} + class For extends @for, AstNode { override string describeQlClass() { result = "For" } @@ -400,6 +436,10 @@ class For extends @for, AstNode { } } +class GlobalVariable extends @token_global_variable, Token { + override string describeQlClass() { result = "GlobalVariable" } +} + class Hash extends @hash, AstNode { override string describeQlClass() { result = "Hash" } @@ -425,11 +465,15 @@ class HashSplatParameter extends @hash_splat_parameter, AstNode { override Location getLocation() { hash_splat_parameter_def(this, result) } - Token getName() { hash_splat_parameter_name(this, result) } + Identifier getName() { hash_splat_parameter_name(this, result) } override AstNode getAFieldOrChild() { hash_splat_parameter_name(this, result) } } +class HeredocBeginning extends @token_heredoc_beginning, Token { + override string describeQlClass() { result = "HeredocBeginning" } +} + class HeredocBody extends @heredoc_body, AstNode { override string describeQlClass() { result = "HeredocBody" } @@ -440,6 +484,18 @@ class HeredocBody extends @heredoc_body, AstNode { override AstNode getAFieldOrChild() { heredoc_body_child(this, _, result) } } +class HeredocContent extends @token_heredoc_content, Token { + override string describeQlClass() { result = "HeredocContent" } +} + +class HeredocEnd extends @token_heredoc_end, Token { + override string describeQlClass() { result = "HeredocEnd" } +} + +class Identifier extends @token_identifier, Token { + override string describeQlClass() { result = "Identifier" } +} + class If extends @if, AstNode { override string describeQlClass() { result = "If" } @@ -480,6 +536,14 @@ class In extends @in, AstNode { override AstNode getAFieldOrChild() { in_def(this, result, _) } } +class InstanceVariable extends @token_instance_variable, Token { + override string describeQlClass() { result = "InstanceVariable" } +} + +class Integer extends @token_integer, Token { + override string describeQlClass() { result = "Integer" } +} + class Interpolation extends @interpolation, AstNode { override string describeQlClass() { result = "Interpolation" } @@ -495,7 +559,7 @@ class KeywordParameter extends @keyword_parameter, AstNode { override Location getLocation() { keyword_parameter_def(this, _, result) } - Token getName() { keyword_parameter_def(this, result, _) } + Identifier getName() { keyword_parameter_def(this, result, _) } UnderscoreArg getValue() { keyword_parameter_value(this, result) } @@ -606,6 +670,14 @@ class Next extends @next, AstNode { override AstNode getAFieldOrChild() { next_child(this, result) } } +class Nil extends @token_nil, Token { + override string describeQlClass() { result = "Nil" } +} + +class Operator extends @token_operator, Token { + override string describeQlClass() { result = "Operator" } +} + class OperatorAssignment extends @operator_assignment, AstNode { override string describeQlClass() { result = "OperatorAssignment" } @@ -625,7 +697,7 @@ class OptionalParameter extends @optional_parameter, AstNode { override Location getLocation() { optional_parameter_def(this, _, _, result) } - Token getName() { optional_parameter_def(this, result, _, _) } + Identifier getName() { optional_parameter_def(this, result, _, _) } UnderscoreArg getValue() { optional_parameter_def(this, _, result, _) } @@ -693,7 +765,7 @@ class Rational extends @rational, AstNode { override Location getLocation() { rational_def(this, _, result) } - Token getChild() { rational_def(this, result, _) } + Integer getChild() { rational_def(this, result, _) } override AstNode getAFieldOrChild() { rational_def(this, result, _) } } @@ -802,12 +874,16 @@ class ScopeResolution extends @scope_resolution, AstNode { } } +class Self extends @token_self, Token { + override string describeQlClass() { result = "Self" } +} + class Setter extends @setter, AstNode { override string describeQlClass() { result = "Setter" } override Location getLocation() { setter_def(this, _, result) } - Token getChild() { setter_def(this, result, _) } + Identifier getChild() { setter_def(this, result, _) } override AstNode getAFieldOrChild() { setter_def(this, result, _) } } @@ -862,7 +938,7 @@ class SplatParameter extends @splat_parameter, AstNode { override Location getLocation() { splat_parameter_def(this, result) } - Token getName() { splat_parameter_name(this, result) } + Identifier getName() { splat_parameter_name(this, result) } override AstNode getAFieldOrChild() { splat_parameter_name(this, result) } } @@ -887,6 +963,10 @@ class StringArray extends @string_array, AstNode { override AstNode getAFieldOrChild() { string_array_child(this, _, result) } } +class StringContent extends @token_string_content, Token { + override string describeQlClass() { result = "StringContent" } +} + class Subshell extends @subshell, AstNode { override string describeQlClass() { result = "Subshell" } @@ -897,6 +977,10 @@ class Subshell extends @subshell, AstNode { override AstNode getAFieldOrChild() { subshell_child(this, _, result) } } +class Super extends @token_super, Token { + override string describeQlClass() { result = "Super" } +} + class Superclass extends @superclass, AstNode { override string describeQlClass() { result = "Superclass" } @@ -937,6 +1021,10 @@ class Then extends @then, AstNode { override AstNode getAFieldOrChild() { then_child(this, _, result) } } +class True extends @token_true, Token { + override string describeQlClass() { result = "True" } +} + class Unary extends @unary, AstNode { override string describeQlClass() { result = "Unary" } @@ -961,6 +1049,10 @@ class Undef extends @undef, AstNode { override AstNode getAFieldOrChild() { undef_child(this, _, result) } } +class Uninterpreted extends @token_uninterpreted, Token { + override string describeQlClass() { result = "Uninterpreted" } +} + class Unless extends @unless, AstNode { override string describeQlClass() { result = "Unless" } @@ -1070,95 +1162,3 @@ class Yield extends @yield, AstNode { override AstNode getAFieldOrChild() { yield_child(this, result) } } - -class Character extends Token, @token_character { - override string describeQlClass() { result = "Character" } -} - -class ClassVariable extends Token, @token_class_variable { - override string describeQlClass() { result = "ClassVariable" } -} - -class Comment extends Token, @token_comment { - override string describeQlClass() { result = "Comment" } -} - -class Complex extends Token, @token_complex { - override string describeQlClass() { result = "Complex" } -} - -class Constant extends Token, @token_constant { - override string describeQlClass() { result = "Constant" } -} - -class EmptyStatement extends Token, @token_empty_statement { - override string describeQlClass() { result = "EmptyStatement" } -} - -class EscapeSequence extends Token, @token_escape_sequence { - override string describeQlClass() { result = "EscapeSequence" } -} - -class False extends Token, @token_false { - override string describeQlClass() { result = "False" } -} - -class Float extends Token, @token_float { - override string describeQlClass() { result = "Float" } -} - -class GlobalVariable extends Token, @token_global_variable { - override string describeQlClass() { result = "GlobalVariable" } -} - -class HeredocBeginning extends Token, @token_heredoc_beginning { - override string describeQlClass() { result = "HeredocBeginning" } -} - -class HeredocContent extends Token, @token_heredoc_content { - override string describeQlClass() { result = "HeredocContent" } -} - -class HeredocEnd extends Token, @token_heredoc_end { - override string describeQlClass() { result = "HeredocEnd" } -} - -class Identifier extends Token, @token_identifier { - override string describeQlClass() { result = "Identifier" } -} - -class InstanceVariable extends Token, @token_instance_variable { - override string describeQlClass() { result = "InstanceVariable" } -} - -class Integer extends Token, @token_integer { - override string describeQlClass() { result = "Integer" } -} - -class Nil extends Token, @token_nil { - override string describeQlClass() { result = "Nil" } -} - -class Operator extends Token, @token_operator { - override string describeQlClass() { result = "Operator" } -} - -class Self extends Token, @token_self { - override string describeQlClass() { result = "Self" } -} - -class StringContent extends Token, @token_string_content { - override string describeQlClass() { result = "StringContent" } -} - -class Super extends Token, @token_super { - override string describeQlClass() { result = "Super" } -} - -class True extends Token, @token_true { - override string describeQlClass() { result = "True" } -} - -class Uninterpreted extends Token, @token_uninterpreted { - override string describeQlClass() { result = "Uninterpreted" } -}