diff --git a/misc/codegen/generators/rustgen.py b/misc/codegen/generators/rustgen.py index 9b850c3cf1b..6fa5d4a05ec 100644 --- a/misc/codegen/generators/rustgen.py +++ b/misc/codegen/generators/rustgen.py @@ -20,7 +20,7 @@ def _get_type(t: str) -> str: case "int": return "usize" case _ if t[0].isupper(): - return "trap::Label" + return f"{t}TrapLabel" case "boolean": assert False, "boolean unsupported" case _: @@ -57,6 +57,15 @@ def _get_properties( yield cls, p +def _get_ancestors( + cls: schema.Class, lookup: dict[str, schema.Class] +) -> typing.Iterable[schema.Class]: + for b in cls.bases: + base = lookup[b] + yield base + yield from _get_ancestors(base, lookup) + + class Processor: def __init__(self, data: schema.Schema): self._classmap = data.classes @@ -69,14 +78,15 @@ class Processor: _get_field(c, p) for c, p in _get_properties(cls, self._classmap) if "rust_skip" not in p.pragmas and not p.synth - ], + ] if not cls.derived else [], + ancestors=sorted(set(a.name for a in _get_ancestors(cls, self._classmap))), table_name=inflection.tableize(cls.name), ) def get_classes(self): ret = {"": []} for k, cls in self._classmap.items(): - if not cls.synth and not cls.derived: + if not cls.synth: ret.setdefault(cls.group, []).append(self._get_class(cls.name)) return ret diff --git a/misc/codegen/lib/rust.py b/misc/codegen/lib/rust.py index ac7bf4313d3..0f4b410db70 100644 --- a/misc/codegen/lib/rust.py +++ b/misc/codegen/lib/rust.py @@ -110,8 +110,9 @@ class Field: @dataclasses.dataclass class Class: name: str - table_name: str + table_name: str | None = None fields: list[Field] = dataclasses.field(default_factory=list) + ancestors: list[str] = dataclasses.field(default_factory=list) @property def single_field_entries(self): diff --git a/misc/codegen/templates/rust_classes.mustache b/misc/codegen/templates/rust_classes.mustache index 3b415683d5f..f749733b2bd 100644 --- a/misc/codegen/templates/rust_classes.mustache +++ b/misc/codegen/templates/rust_classes.mustache @@ -2,48 +2,77 @@ #![cfg_attr(any(), rustfmt::skip)] -use crate::trap::{TrapId, TrapEntry}; -use codeql_extractor::trap; +use crate::trap; {{#classes}} +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct {{name}}TrapLabel(trap::UntypedLabel); + +impl From for {{name}}TrapLabel { + fn from(value: trap::UntypedLabel) -> Self { + Self(value) + } +} + +impl From<{{name}}TrapLabel> for trap::TrapId<{{name}}> { + fn from(value: {{name}}TrapLabel) -> Self { + Self::Label(value) + } +} + +impl trap::Label for {{name}}TrapLabel { + fn as_untyped(&self) -> trap::UntypedLabel { + self.0 + } +} + +impl From<{{name}}TrapLabel> for trap::Arg { + fn from(value: {{name}}TrapLabel) -> Self { + value.0.into() + } +} + +{{#table_name}} #[derive(Debug)] pub struct {{name}} { - pub id: TrapId, + pub id: trap::TrapId<{{name}}>, {{#fields}} pub {{field_name}}: {{type}}, {{/fields}} } -impl TrapEntry for {{name}} { - fn extract_id(&mut self) -> TrapId { - std::mem::replace(&mut self.id, TrapId::Star) +impl trap::TrapEntry for {{name}} { + fn class_name() -> &'static str { "{{name}}" } + + fn extract_id(&mut self) -> trap::TrapId { + std::mem::replace(&mut self.id, trap::TrapId::Star) } - fn emit(self, id: trap::Label, out: &mut trap::Writer) { + fn emit(self, id: Self::Label, out: &mut trap::Writer) { {{#single_field_entries}} - out.add_tuple("{{table_name}}", vec![trap::Arg::Label(id){{#fields}}, self.{{field_name}}.into(){{/fields}}]); + out.add_tuple("{{table_name}}", vec![id.into(){{#fields}}, self.{{field_name}}.into(){{/fields}}]); {{/single_field_entries}} {{#fields}} {{#is_predicate}} if self.{{field_name}} { - out.add_tuple("{{table_name}}", vec![trap::Arg::Label(id)]); + out.add_tuple("{{table_name}}", vec![id.into()]); } {{/is_predicate}} {{#is_optional}} {{^is_repeated}} if let Some(v) = self.{{field_name}} { - out.add_tuple("{{table_name}}", vec![trap::Arg::Label(id), v.into()]); + out.add_tuple("{{table_name}}", vec![id.into(), v.into()]); } {{/is_repeated}} {{/is_optional}} {{#is_repeated}} for (i, v) in self.{{field_name}}.into_iter().enumerate() { {{^is_optional}} - out.add_tuple("{{table_name}}", vec![trap::Arg::Label(id){{^is_unordered}}, i.into(){{/is_unordered}}, v.into()]); + out.add_tuple("{{table_name}}", vec![id.into(){{^is_unordered}}, i.into(){{/is_unordered}}, v.into()]); {{/is_optional}} {{#is_optional}} if let Some(v) = v { - out.add_tuple("{{table_name}}", vec![trap::Arg::Label(id){{^is_unordered}}, i.into(){{/is_unordered}}, v.into()]); + out.add_tuple("{{table_name}}", vec![id.into(){{^is_unordered}}, i.into(){{/is_unordered}}, v.into()]); } {{/is_optional}} } @@ -51,4 +80,26 @@ impl TrapEntry for {{name}} { {{/fields}} } } +{{/table_name}} +{{^table_name}} +{{! virtual class, make it unbuildable }} +pub struct {{name}} { + unused: () +} +{{/table_name}} + +impl trap::TrapClass for {{name}} { + type Label = {{name}}TrapLabel; +} +{{/classes}} + +// Conversions +{{#classes}} +{{#ancestors}} +impl From<{{name}}TrapLabel> for {{.}}TrapLabel { + fn from(value: {{name}}TrapLabel) -> Self { + value.0.into() + } +} +{{/ancestors}} {{/classes}} diff --git a/rust/extractor/src/generated/.generated.list b/rust/extractor/src/generated/.generated.list index b404b2e7541..fd7cc5c6614 100644 --- a/rust/extractor/src/generated/.generated.list +++ b/rust/extractor/src/generated/.generated.list @@ -1,2 +1,2 @@ mod.rs 7cdfedcd68cf8e41134daf810c1af78624082b0c3e8be6570339b1a69a5d457e 7cdfedcd68cf8e41134daf810c1af78624082b0c3e8be6570339b1a69a5d457e -top.rs 7150acaeab0b57039ca9f2ed20311229aab5fd48b533f13410ecc34fd8e3bda0 7150acaeab0b57039ca9f2ed20311229aab5fd48b533f13410ecc34fd8e3bda0 +top.rs e06dc90de4abd57719786fd5e49e6ea3089ec3ec167c64446e25d95a16b1714c e06dc90de4abd57719786fd5e49e6ea3089ec3ec167c64446e25d95a16b1714c diff --git a/rust/extractor/src/generated/top.rs b/rust/extractor/src/generated/top.rs index de6d2106c1d..084fdbb61f0 100644 --- a/rust/extractor/src/generated/top.rs +++ b/rust/extractor/src/generated/top.rs @@ -2,1222 +2,5008 @@ #![cfg_attr(any(), rustfmt::skip)] -use crate::trap::{TrapId, TrapEntry}; -use codeql_extractor::trap; +use crate::trap; + +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct ElementTrapLabel(trap::UntypedLabel); + +impl From for ElementTrapLabel { + fn from(value: trap::UntypedLabel) -> Self { + Self(value) + } +} + +impl From for trap::TrapId { + fn from(value: ElementTrapLabel) -> Self { + Self::Label(value) + } +} + +impl trap::Label for ElementTrapLabel { + fn as_untyped(&self) -> trap::UntypedLabel { + self.0 + } +} + +impl From for trap::Arg { + fn from(value: ElementTrapLabel) -> Self { + value.0.into() + } +} + +#[derive(Debug)] +pub struct Element { + pub id: trap::TrapId, +} + +impl trap::TrapEntry for Element { + fn class_name() -> &'static str { "Element" } + + fn extract_id(&mut self) -> trap::TrapId { + std::mem::replace(&mut self.id, trap::TrapId::Star) + } + + fn emit(self, id: Self::Label, out: &mut trap::Writer) { + out.add_tuple("elements", vec![id.into()]); + } +} + +impl trap::TrapClass for Element { + type Label = ElementTrapLabel; +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct LocatableTrapLabel(trap::UntypedLabel); + +impl From for LocatableTrapLabel { + fn from(value: trap::UntypedLabel) -> Self { + Self(value) + } +} + +impl From for trap::TrapId { + fn from(value: LocatableTrapLabel) -> Self { + Self::Label(value) + } +} + +impl trap::Label for LocatableTrapLabel { + fn as_untyped(&self) -> trap::UntypedLabel { + self.0 + } +} + +impl From for trap::Arg { + fn from(value: LocatableTrapLabel) -> Self { + value.0.into() + } +} + +#[derive(Debug)] +pub struct Locatable { + pub id: trap::TrapId, +} + +impl trap::TrapEntry for Locatable { + fn class_name() -> &'static str { "Locatable" } + + fn extract_id(&mut self) -> trap::TrapId { + std::mem::replace(&mut self.id, trap::TrapId::Star) + } + + fn emit(self, id: Self::Label, out: &mut trap::Writer) { + out.add_tuple("locatables", vec![id.into()]); + } +} + +impl trap::TrapClass for Locatable { + type Label = LocatableTrapLabel; +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct AstNodeTrapLabel(trap::UntypedLabel); + +impl From for AstNodeTrapLabel { + fn from(value: trap::UntypedLabel) -> Self { + Self(value) + } +} + +impl From for trap::TrapId { + fn from(value: AstNodeTrapLabel) -> Self { + Self::Label(value) + } +} + +impl trap::Label for AstNodeTrapLabel { + fn as_untyped(&self) -> trap::UntypedLabel { + self.0 + } +} + +impl From for trap::Arg { + fn from(value: AstNodeTrapLabel) -> Self { + value.0.into() + } +} + +#[derive(Debug)] +pub struct AstNode { + pub id: trap::TrapId, +} + +impl trap::TrapEntry for AstNode { + fn class_name() -> &'static str { "AstNode" } + + fn extract_id(&mut self) -> trap::TrapId { + std::mem::replace(&mut self.id, trap::TrapId::Star) + } + + fn emit(self, id: Self::Label, out: &mut trap::Writer) { + out.add_tuple("ast_nodes", vec![id.into()]); + } +} + +impl trap::TrapClass for AstNode { + type Label = AstNodeTrapLabel; +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct DeclarationTrapLabel(trap::UntypedLabel); + +impl From for DeclarationTrapLabel { + fn from(value: trap::UntypedLabel) -> Self { + Self(value) + } +} + +impl From for trap::TrapId { + fn from(value: DeclarationTrapLabel) -> Self { + Self::Label(value) + } +} + +impl trap::Label for DeclarationTrapLabel { + fn as_untyped(&self) -> trap::UntypedLabel { + self.0 + } +} + +impl From for trap::Arg { + fn from(value: DeclarationTrapLabel) -> Self { + value.0.into() + } +} + +#[derive(Debug)] +pub struct Declaration { + pub id: trap::TrapId, +} + +impl trap::TrapEntry for Declaration { + fn class_name() -> &'static str { "Declaration" } + + fn extract_id(&mut self) -> trap::TrapId { + std::mem::replace(&mut self.id, trap::TrapId::Star) + } + + fn emit(self, id: Self::Label, out: &mut trap::Writer) { + out.add_tuple("declarations", vec![id.into()]); + } +} + +impl trap::TrapClass for Declaration { + type Label = DeclarationTrapLabel; +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct ExprTrapLabel(trap::UntypedLabel); + +impl From for ExprTrapLabel { + fn from(value: trap::UntypedLabel) -> Self { + Self(value) + } +} + +impl From for trap::TrapId { + fn from(value: ExprTrapLabel) -> Self { + Self::Label(value) + } +} + +impl trap::Label for ExprTrapLabel { + fn as_untyped(&self) -> trap::UntypedLabel { + self.0 + } +} + +impl From for trap::Arg { + fn from(value: ExprTrapLabel) -> Self { + value.0.into() + } +} + +#[derive(Debug)] +pub struct Expr { + pub id: trap::TrapId, +} + +impl trap::TrapEntry for Expr { + fn class_name() -> &'static str { "Expr" } + + fn extract_id(&mut self) -> trap::TrapId { + std::mem::replace(&mut self.id, trap::TrapId::Star) + } + + fn emit(self, id: Self::Label, out: &mut trap::Writer) { + out.add_tuple("exprs", vec![id.into()]); + } +} + +impl trap::TrapClass for Expr { + type Label = ExprTrapLabel; +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct LabelTrapLabel(trap::UntypedLabel); + +impl From for LabelTrapLabel { + fn from(value: trap::UntypedLabel) -> Self { + Self(value) + } +} + +impl From for trap::TrapId