From b86ca19264bb1386809f4d51ceef48630d749d96 Mon Sep 17 00:00:00 2001 From: Rasmus Wriedt Larsen Date: Tue, 21 Jul 2020 23:37:33 +0200 Subject: [PATCH] Python: CG trace: Apply better_compare_for_dataclass to all --- .../src/cg_trace/bytecode_reconstructor.py | 12 ++++++++++ .../src/cg_trace/tracer.py | 23 +------------------ .../src/cg_trace/utils.py | 20 ++++++++++++++++ 3 files changed, 33 insertions(+), 22 deletions(-) create mode 100644 python/tools/recorded-call-graph-metrics/src/cg_trace/utils.py diff --git a/python/tools/recorded-call-graph-metrics/src/cg_trace/bytecode_reconstructor.py b/python/tools/recorded-call-graph-metrics/src/cg_trace/bytecode_reconstructor.py index cf63b996e8c..6944a42615f 100644 --- a/python/tools/recorded-call-graph-metrics/src/cg_trace/bytecode_reconstructor.py +++ b/python/tools/recorded-call-graph-metrics/src/cg_trace/bytecode_reconstructor.py @@ -5,6 +5,8 @@ from dis import Instruction from types import FrameType from typing import Any, List +from cg_trace.utils import better_compare_for_dataclass + LOGGER = logging.getLogger(__name__) # See https://docs.python.org/3/library/dis.html#python-bytecode-instructions for @@ -18,6 +20,7 @@ class BytecodeExpr: """ +@better_compare_for_dataclass @dataclasses.dataclass(frozen=True, eq=True, order=True) class BytecodeConst(BytecodeExpr): """FOR LOAD_CONST""" @@ -28,6 +31,7 @@ class BytecodeConst(BytecodeExpr): return repr(self.value) +@better_compare_for_dataclass @dataclasses.dataclass(frozen=True, eq=True, order=True) class BytecodeVariableName(BytecodeExpr): name: str @@ -36,6 +40,7 @@ class BytecodeVariableName(BytecodeExpr): return self.name +@better_compare_for_dataclass @dataclasses.dataclass(frozen=True, eq=True, order=True) class BytecodeAttribute(BytecodeExpr): attr_name: str @@ -45,6 +50,7 @@ class BytecodeAttribute(BytecodeExpr): return f"{self.object}.{self.attr_name}" +@better_compare_for_dataclass @dataclasses.dataclass(frozen=True, eq=True, order=True) class BytecodeSubscript(BytecodeExpr): key: BytecodeExpr @@ -54,6 +60,7 @@ class BytecodeSubscript(BytecodeExpr): return f"{self.object}[{self.key}]" +@better_compare_for_dataclass @dataclasses.dataclass(frozen=True, eq=True, order=True) class BytecodeTuple(BytecodeExpr): elements: List[BytecodeExpr] @@ -67,6 +74,7 @@ class BytecodeTuple(BytecodeExpr): return f"({elements_formatted})" +@better_compare_for_dataclass @dataclasses.dataclass(frozen=True, eq=True, order=True) class BytecodeList(BytecodeExpr): elements: List[BytecodeExpr] @@ -80,6 +88,7 @@ class BytecodeList(BytecodeExpr): return f"[{elements_formatted}]" +@better_compare_for_dataclass @dataclasses.dataclass(frozen=True, eq=True, order=True) class BytecodeCall(BytecodeExpr): function: BytecodeExpr @@ -88,6 +97,7 @@ class BytecodeCall(BytecodeExpr): return f"{self.function}()" +@better_compare_for_dataclass @dataclasses.dataclass(frozen=True, eq=True, order=True) class BytecodeUnknown(BytecodeExpr): opname: str @@ -96,6 +106,7 @@ class BytecodeUnknown(BytecodeExpr): return f"<{self.opname}>" +@better_compare_for_dataclass @dataclasses.dataclass(frozen=True, eq=True, order=True) class BytecodeMakeFunction(BytecodeExpr): """For MAKE_FUNCTION opcode""" @@ -106,6 +117,7 @@ class BytecodeMakeFunction(BytecodeExpr): return f"(qualified_name={self.qualified_name})>" +@better_compare_for_dataclass @dataclasses.dataclass(frozen=True, eq=True, order=True) class SomethingInvolvingScaryBytecodeJump(BytecodeExpr): opname: str diff --git a/python/tools/recorded-call-graph-metrics/src/cg_trace/tracer.py b/python/tools/recorded-call-graph-metrics/src/cg_trace/tracer.py index 4e923a36c52..daea79c816d 100644 --- a/python/tools/recorded-call-graph-metrics/src/cg_trace/tracer.py +++ b/python/tools/recorded-call-graph-metrics/src/cg_trace/tracer.py @@ -6,6 +6,7 @@ from types import FrameType from typing import Any, Optional, Tuple from cg_trace.bytecode_reconstructor import BytecodeExpr, expr_from_frame +from cg_trace.utils import better_compare_for_dataclass LOGGER = logging.getLogger(__name__) @@ -75,28 +76,6 @@ class Call: ) -def better_compare_for_dataclass(cls): - """When dataclass is used with `order=True`, the comparison methods is only implemented for - objects of the same class. This decorator extends the functionality to compare class - name if used against other objects. - """ - for op in [ - "__lt__", - "__le__", - "__gt__", - "__ge__", - ]: - old = getattr(cls, op) - - def new(self, other): - if type(self) == type(other): - return old(self, other) - return getattr(str, op)(self.__class__.__name__, other.__class__.__name__) - - setattr(cls, op, new) - return cls - - @dataclasses.dataclass(frozen=True, eq=True, order=True) class Callee: pass diff --git a/python/tools/recorded-call-graph-metrics/src/cg_trace/utils.py b/python/tools/recorded-call-graph-metrics/src/cg_trace/utils.py new file mode 100644 index 00000000000..f55d033916a --- /dev/null +++ b/python/tools/recorded-call-graph-metrics/src/cg_trace/utils.py @@ -0,0 +1,20 @@ +def better_compare_for_dataclass(cls): + """When dataclass is used with `order=True`, the comparison methods is only implemented for + objects of the same class. This decorator extends the functionality to compare class + name if used against other objects. + """ + for op in [ + "__lt__", + "__le__", + "__gt__", + "__ge__", + ]: + old = getattr(cls, op) + + def new(self, other): + if type(self) == type(other): + return old(self, other) + return getattr(str, op)(self.__class__.__name__, other.__class__.__name__) + + setattr(cls, op, new) + return cls