Python: CG trace: Apply better_compare_for_dataclass to all

This commit is contained in:
Rasmus Wriedt Larsen
2020-07-21 23:37:33 +02:00
parent 9bff615fad
commit b86ca19264
3 changed files with 33 additions and 22 deletions

View File

@@ -5,6 +5,8 @@ from dis import Instruction
from types import FrameType
from typing import Any, List
from cg_trace.utils import better_compare_for_dataclass
LOGGER = logging.getLogger(__name__)
# See https://docs.python.org/3/library/dis.html#python-bytecode-instructions for
@@ -18,6 +20,7 @@ class BytecodeExpr:
"""
@better_compare_for_dataclass
@dataclasses.dataclass(frozen=True, eq=True, order=True)
class BytecodeConst(BytecodeExpr):
"""FOR LOAD_CONST"""
@@ -28,6 +31,7 @@ class BytecodeConst(BytecodeExpr):
return repr(self.value)
@better_compare_for_dataclass
@dataclasses.dataclass(frozen=True, eq=True, order=True)
class BytecodeVariableName(BytecodeExpr):
name: str
@@ -36,6 +40,7 @@ class BytecodeVariableName(BytecodeExpr):
return self.name
@better_compare_for_dataclass
@dataclasses.dataclass(frozen=True, eq=True, order=True)
class BytecodeAttribute(BytecodeExpr):
attr_name: str
@@ -45,6 +50,7 @@ class BytecodeAttribute(BytecodeExpr):
return f"{self.object}.{self.attr_name}"
@better_compare_for_dataclass
@dataclasses.dataclass(frozen=True, eq=True, order=True)
class BytecodeSubscript(BytecodeExpr):
key: BytecodeExpr
@@ -54,6 +60,7 @@ class BytecodeSubscript(BytecodeExpr):
return f"{self.object}[{self.key}]"
@better_compare_for_dataclass
@dataclasses.dataclass(frozen=True, eq=True, order=True)
class BytecodeTuple(BytecodeExpr):
elements: List[BytecodeExpr]
@@ -67,6 +74,7 @@ class BytecodeTuple(BytecodeExpr):
return f"({elements_formatted})"
@better_compare_for_dataclass
@dataclasses.dataclass(frozen=True, eq=True, order=True)
class BytecodeList(BytecodeExpr):
elements: List[BytecodeExpr]
@@ -80,6 +88,7 @@ class BytecodeList(BytecodeExpr):
return f"[{elements_formatted}]"
@better_compare_for_dataclass
@dataclasses.dataclass(frozen=True, eq=True, order=True)
class BytecodeCall(BytecodeExpr):
function: BytecodeExpr
@@ -88,6 +97,7 @@ class BytecodeCall(BytecodeExpr):
return f"{self.function}()"
@better_compare_for_dataclass
@dataclasses.dataclass(frozen=True, eq=True, order=True)
class BytecodeUnknown(BytecodeExpr):
opname: str
@@ -96,6 +106,7 @@ class BytecodeUnknown(BytecodeExpr):
return f"<{self.opname}>"
@better_compare_for_dataclass
@dataclasses.dataclass(frozen=True, eq=True, order=True)
class BytecodeMakeFunction(BytecodeExpr):
"""For MAKE_FUNCTION opcode"""
@@ -106,6 +117,7 @@ class BytecodeMakeFunction(BytecodeExpr):
return f"<MAKE_FUNCTION>(qualified_name={self.qualified_name})>"
@better_compare_for_dataclass
@dataclasses.dataclass(frozen=True, eq=True, order=True)
class SomethingInvolvingScaryBytecodeJump(BytecodeExpr):
opname: str

View File

@@ -6,6 +6,7 @@ from types import FrameType
from typing import Any, Optional, Tuple
from cg_trace.bytecode_reconstructor import BytecodeExpr, expr_from_frame
from cg_trace.utils import better_compare_for_dataclass
LOGGER = logging.getLogger(__name__)
@@ -75,28 +76,6 @@ class Call:
)
def better_compare_for_dataclass(cls):
"""When dataclass is used with `order=True`, the comparison methods is only implemented for
objects of the same class. This decorator extends the functionality to compare class
name if used against other objects.
"""
for op in [
"__lt__",
"__le__",
"__gt__",
"__ge__",
]:
old = getattr(cls, op)
def new(self, other):
if type(self) == type(other):
return old(self, other)
return getattr(str, op)(self.__class__.__name__, other.__class__.__name__)
setattr(cls, op, new)
return cls
@dataclasses.dataclass(frozen=True, eq=True, order=True)
class Callee:
pass

View File

@@ -0,0 +1,20 @@
def better_compare_for_dataclass(cls):
"""When dataclass is used with `order=True`, the comparison methods is only implemented for
objects of the same class. This decorator extends the functionality to compare class
name if used against other objects.
"""
for op in [
"__lt__",
"__le__",
"__gt__",
"__ge__",
]:
old = getattr(cls, op)
def new(self, other):
if type(self) == type(other):
return old(self, other)
return getattr(str, op)(self.__class__.__name__, other.__class__.__name__)
setattr(cls, op, new)
return cls