diff --git a/shared/tree-sitter-extractor/src/generator/mod.rs b/shared/tree-sitter-extractor/src/generator/mod.rs index d3880a74579..6c5fbfabda6 100644 --- a/shared/tree-sitter-extractor/src/generator/mod.rs +++ b/shared/tree-sitter-extractor/src/generator/mod.rs @@ -159,6 +159,7 @@ pub fn generate( )); body.append(&mut ql_gen::convert_nodes(&nodes)); + body.push(ql_gen::create_print_ast_module(&nodes)); ql::write( &mut ql_writer, &[ql::TopLevel::Module(ql::Module { diff --git a/shared/tree-sitter-extractor/src/generator/ql.rs b/shared/tree-sitter-extractor/src/generator/ql.rs index cdfe5d8c639..24ae25d854b 100644 --- a/shared/tree-sitter-extractor/src/generator/ql.rs +++ b/shared/tree-sitter-extractor/src/generator/ql.rs @@ -150,12 +150,14 @@ impl fmt::Display for Type<'_> { pub enum Expression<'a> { Var(&'a str), String(&'a str), - Integer(usize), + Integer(i64), Pred(&'a str, Vec>), And(Vec>), Or(Vec>), Equals(Box>, Box>), Dot(Box>, &'a str, Vec>), + /// A type cast, rendered as `x.(Type)`. + Cast(Box>, &'a str), Aggregate { name: &'a str, vars: Vec>, @@ -219,6 +221,7 @@ impl fmt::Display for Expression<'_> { } write!(f, ")") } + Expression::Cast(x, type_name) => write!(f, "{x}.({type_name})"), Expression::Aggregate { name, vars, diff --git a/shared/tree-sitter-extractor/src/generator/ql_gen.rs b/shared/tree-sitter-extractor/src/generator/ql_gen.rs index f827b12580e..c6feb58ba3f 100644 --- a/shared/tree-sitter-extractor/src/generator/ql_gen.rs +++ b/shared/tree-sitter-extractor/src/generator/ql_gen.rs @@ -705,7 +705,7 @@ fn create_field_getters<'a>( ), ql::Expression::Equals( Box::new(ql::Expression::Var("value")), - Box::new(ql::Expression::Integer(*value)), + Box::new(ql::Expression::Integer(*value as i64)), ), ]) }) @@ -874,3 +874,96 @@ pub fn convert_nodes(nodes: &node_types::NodeTypeMap) -> Vec> { classes } + +/// Creates a `PrintAst` module containing a `getChild` predicate that maps each +/// AST node to its children together with the name of the member predicate that +/// produced them (and, for indexed fields, the index). This mirrors the +/// information exposed by `getAFieldOrChild`, but keeps the member predicate +/// name and index so that an AST printer can render labelled edges. +pub fn create_print_ast_module(nodes: &node_types::NodeTypeMap) -> ql::TopLevel<'_> { + let mut disjuncts: Vec = Vec::new(); + for node in nodes.values() { + if let node_types::EntryKind::Table { name: _, fields } = &node.kind { + for field in fields { + // `ReservedWordInt` fields have string-valued getters, so they + // are not children and are excluded (just as they are from + // `getAFieldOrChild`). + if matches!(field.type_info, node_types::FieldTypeInfo::ReservedWordInt(_)) { + continue; + } + let has_index = matches!( + field.storage, + node_types::Storage::Table { + has_index: true, + .. + } + ); + let getter_call = ql::Expression::Dot( + Box::new(ql::Expression::Cast( + Box::new(ql::Expression::Var("node")), + &node.ql_class_name, + )), + &field.getter_name, + if has_index { + vec![ql::Expression::Var("i")] + } else { + vec![] + }, + ); + let mut conjuncts = vec![ql::Expression::Equals( + Box::new(ql::Expression::Var("result")), + Box::new(getter_call), + )]; + if !has_index { + conjuncts.push(ql::Expression::Equals( + Box::new(ql::Expression::Var("i")), + Box::new(ql::Expression::Integer(-1)), + )); + } + conjuncts.push(ql::Expression::Equals( + Box::new(ql::Expression::Var("name")), + Box::new(ql::Expression::String(&field.getter_name)), + )); + disjuncts.push(ql::Expression::And(conjuncts)); + } + } + } + + let get_child = ql::Predicate { + qldoc: Some(String::from( + "Gets a child of `node` returned by the member predicate with the given `name`. \ + If the predicate takes an index argument, `i` is bound to that index, otherwise \ + `i` is `-1` (which is never a valid index).", + )), + name: "getChild", + overridden: false, + is_private: false, + is_final: false, + return_type: Some(ql::Type::Normal("AstNode")), + formal_parameters: vec![ + ql::FormalParameter { + name: "node", + param_type: ql::Type::Normal("AstNode"), + }, + ql::FormalParameter { + name: "name", + param_type: ql::Type::String, + }, + ql::FormalParameter { + name: "i", + param_type: ql::Type::Int, + }, + ], + body: ql::Expression::Or(disjuncts), + overlay: None, + }; + + ql::TopLevel::Module(ql::Module { + qldoc: Some(String::from( + "Provides predicates for mapping AST nodes to their named children.", + )), + name: "PrintAst", + body: vec![ql::TopLevel::Predicate(get_child)], + overlay: None, + }) +}