From ca6c05425696caa5c9ee0cabdf3e3bc2686d09ae Mon Sep 17 00:00:00 2001 From: Simon Friis Vindum Date: Wed, 17 Dec 2025 11:06:26 +0100 Subject: [PATCH] Rust: Rename `Adt` class and lift common predicates to it --- rust/ast-generator/src/main.rs | 1 + rust/extractor/src/translate/base.rs | 4 ++-- rust/schema/annotations.py | 32 +++++++++++++++++++++++----- rust/schema/ast.py | 8 +++---- 4 files changed, 34 insertions(+), 11 deletions(-) diff --git a/rust/ast-generator/src/main.rs b/rust/ast-generator/src/main.rs index 4b68c1d42a6..ddacc0d913e 100644 --- a/rust/ast-generator/src/main.rs +++ b/rust/ast-generator/src/main.rs @@ -15,6 +15,7 @@ use ungrammar::Grammar; fn class_name(type_name: &str) -> String { match type_name { + "Adt" => "TypeItem".to_owned(), "BinExpr" => "BinaryExpr".to_owned(), "ElseBranch" => "Expr".to_owned(), "Fn" => "Function".to_owned(), diff --git a/rust/extractor/src/translate/base.rs b/rust/extractor/src/translate/base.rs index ee26da665b2..a3a0b3c9133 100644 --- a/rust/extractor/src/translate/base.rs +++ b/rust/extractor/src/translate/base.rs @@ -674,7 +674,7 @@ impl<'a> Translator<'a> { pub(crate) fn emit_derive_expansion( &mut self, node: &(impl Into + Clone), - label: impl Into> + Copy, + label: impl Into> + Copy, ) { let Some(semantics) = self.semantics else { return; @@ -686,7 +686,7 @@ impl<'a> Translator<'a> { .flatten() .filter_map(|expanded| self.process_item_macro_expansion(&node, expanded)) .collect::>(); - generated::Adt::emit_derive_macro_expansions( + generated::TypeItem::emit_derive_macro_expansions( label.into(), expansions, &mut self.trap.writer, diff --git a/rust/schema/annotations.py b/rust/schema/annotations.py index bad273419d9..8896e5809f2 100644 --- a/rust/schema/annotations.py +++ b/rust/schema/annotations.py @@ -18,13 +18,18 @@ class LoopingExpr(LabelableExpr): loop_body: optional["BlockExpr"] | child -@annotate(Adt, replace_bases={AstNode: Item}) +@annotate(TypeItem, replace_bases={AstNode: Item}) class _: """ - An ADT (Abstract Data Type) definition, such as `Struct`, `Enum`, or `Union`. + An item that defines a type. Either a `Struct`, `Enum`, or `Union`. """ derive_macro_expansions: list[MacroItems] | child | rust.detach + attrs: list["Attr"] | child + generic_param_list: optional["GenericParamList"] | child + name: optional["Name"] | child + visibility: optional["Visibility"] | child + where_clause: optional["WhereClause"] | child @annotate(Module) @@ -1063,7 +1068,7 @@ class _: """ -@annotate(Enum, replace_bases={Item: None}) # still an Item via Adt +@annotate(Enum, replace_bases={Item: None}) class _: """ An enum declaration. @@ -1074,6 +1079,12 @@ class _: ``` """ + attrs: drop + generic_param_list: drop + name: drop + visibility: drop + where_clause: drop + @annotate(ExternBlock) class _: @@ -1893,7 +1904,7 @@ class _: ) -@annotate(Struct, replace_bases={Item: None}) # still an Item via Adt +@annotate(Struct, replace_bases={Item: None}) class _: """ A Struct. For example: @@ -1906,6 +1917,11 @@ class _: """ field_list: _ | ql.db_table_name("struct_field_lists_") + attrs: drop + generic_param_list: drop + name: drop + visibility: drop + where_clause: drop @annotate(TokenTree) @@ -2075,7 +2091,7 @@ class _: """ -@annotate(Union, replace_bases={Item: None}) # still an Item via Adt +@annotate(Union, replace_bases={Item: None}) class _: """ A union declaration. @@ -2086,6 +2102,12 @@ class _: ``` """ + attrs: drop + generic_param_list: drop + name: drop + visibility: drop + where_clause: drop + @annotate(Use) class _: diff --git a/rust/schema/ast.py b/rust/schema/ast.py index d338c7a1636..5d8a7393ea6 100644 --- a/rust/schema/ast.py +++ b/rust/schema/ast.py @@ -2,7 +2,7 @@ from .prelude import * -class Adt(AstNode, ): +class TypeItem(AstNode, ): pass class AsmOperand(AstNode, ): @@ -206,7 +206,7 @@ class ContinueExpr(Expr, ): class DynTraitTypeRepr(TypeRepr, ): type_bound_list: optional["TypeBoundList"] | child -class Enum(Adt, Item, ): +class Enum(TypeItem, Item, ): attrs: list["Attr"] | child generic_param_list: optional["GenericParamList"] | child name: optional["Name"] | child @@ -623,7 +623,7 @@ class StmtList(AstNode, ): statements: list["Stmt"] | child tail_expr: optional["Expr"] | child -class Struct(Adt, Item, ): +class Struct(TypeItem, Item, ): attrs: list["Attr"] | child field_list: optional["FieldList"] | child generic_param_list: optional["GenericParamList"] | child @@ -712,7 +712,7 @@ class TypeParam(GenericParam, ): class UnderscoreExpr(Expr, ): attrs: list["Attr"] | child -class Union(Adt, Item, ): +class Union(TypeItem, Item, ): attrs: list["Attr"] | child generic_param_list: optional["GenericParamList"] | child name: optional["Name"] | child