Rust: address review

This commit is contained in:
Paolo Tranquilli
2024-10-15 14:21:11 +02:00
parent 248eb7f00c
commit bd08bc7923
4 changed files with 41 additions and 10 deletions

View File

@@ -1,7 +1,7 @@
from typing import ( from typing import (
Callable as _Callable, Callable as _Callable,
Dict as _Dict, Dict as _Dict,
List as _List, Iterable as _Iterable,
ClassVar as _ClassVar, ClassVar as _ClassVar,
) )
from misc.codegen.lib import schema as _schema from misc.codegen.lib import schema as _schema
@@ -279,7 +279,7 @@ _ = _PropertyAnnotation()
drop = object() drop = object()
def annotate(annotated_cls: type, add_bases: _List[type] | None = None, replace_bases: _Dict[type, type] | None = None) -> _Callable[[type], _PropertyAnnotation]: def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, replace_bases: _Dict[type, type] | None = None) -> _Callable[[type], _PropertyAnnotation]:
""" """
Add or modify schema annotations after a class has been defined previously. Add or modify schema annotations after a class has been defined previously.
@@ -297,7 +297,7 @@ def annotate(annotated_cls: type, add_bases: _List[type] | None = None, replace_
if replace_bases: if replace_bases:
annotated_cls.__bases__ = tuple(replace_bases.get(b, b) for b in annotated_cls.__bases__) annotated_cls.__bases__ = tuple(replace_bases.get(b, b) for b in annotated_cls.__bases__)
if add_bases: if add_bases:
annotated_cls.__bases__ = tuple(annotated_cls.__bases__) + tuple(add_bases) annotated_cls.__bases__ += tuple(add_bases)
for a in dir(cls): for a in dir(cls):
if a.startswith(_schema.inheritable_pragma_prefix): if a.startswith(_schema.inheritable_pragma_prefix):
setattr(annotated_cls, a, getattr(cls, a)) setattr(annotated_cls, a, getattr(cls, a))

View File

@@ -914,6 +914,36 @@ def test_annotate_replace_bases():
} }
def test_annotate_add_bases():
@load
class data:
class Root:
pass
class A(Root):
pass
class B(Root):
pass
class C(Root):
pass
class Derived(A):
pass
@defs.annotate(Derived, add_bases=(B, C))
class _:
pass
assert data.classes == {
"Root": schema.Class("Root", derived={"A", "B", "C"}),
"A": schema.Class("A", bases=["Root"], derived={"Derived"}),
"B": schema.Class("B", bases=["Root"], derived={"Derived"}),
"C": schema.Class("C", bases=["Root"], derived={"Derived"}),
"Derived": schema.Class("Derived", bases=["A", "B", "C"]),
}
def test_annotate_drop_field(): def test_annotate_drop_field():
@load @load
class data: class data:

View File

@@ -1741,13 +1741,6 @@ class _:
``` ```
""" """
class Callable(AstNode):
"""
A callable. Either a `Function` or a `ClosureExpr`.
"""
param_list: optional["ParamList"] | child
attrs: list["Attr"] | child
@annotate(Function, add_bases=[Callable]) @annotate(Function, add_bases=[Callable])
class _: class _:
param_list: drop param_list: drop

View File

@@ -63,3 +63,11 @@ class Unimplemented(Unextracted):
The base class for unimplemented nodes. This is used to mark nodes that are not yet extracted. The base class for unimplemented nodes. This is used to mark nodes that are not yet extracted.
""" """
pass pass
class Callable(AstNode):
"""
A callable. Either a `Function` or a `ClosureExpr`.
"""
param_list: optional["ParamList"] | child
attrs: list["Attr"] | child