From 23967ca81c92b55053a46cc057bf16ca25e6b607 Mon Sep 17 00:00:00 2001 From: Marquess Valdez Date: Tue, 13 Aug 2024 16:14:47 -0700 Subject: [PATCH] fix: Unpickling an `AbstractInstruction` will result in an `AbstractInstruction` instead of a `quil` `Instruction` (#1801) * fix: The DefMeasureCalibration class returns `pyQuil` `AbstractInstrutcion`s instead of `quil` `Instruction`s * fix: Unpickling an AbstractInstruction will result in an AbstractInstruction instead of a `quil` Instruction * fix ruff checks * mypy is wrong * fix tests * fix typo * fix assertion --- pyquil/quilbase.py | 53 ++++++++++++++++++++++++++++++++++++++ test/unit/test_quilbase.py | 29 ++++++++++++++++++--- 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/pyquil/quilbase.py b/pyquil/quilbase.py index 3e5d981c5..6099f1cf9 100644 --- a/pyquil/quilbase.py +++ b/pyquil/quilbase.py @@ -23,6 +23,7 @@ Callable, ClassVar, Optional, + TypeVar, Union, ) @@ -104,6 +105,22 @@ def __hash__(self) -> int: return hash(str(self)) +_T = TypeVar("_T", bound=type) + + +def _add_reduce_method(cls: _T) -> _T: + def __reduce__(self: Any) -> tuple[Callable[[Any], AbstractInstruction], tuple[Any]]: + init_fn, args = super(cls, self).__reduce__() # type: ignore + obj = init_fn(*args) + return ( + _convert_to_py_instruction, + (obj,), + ) + + cls.__reduce__ = __reduce__ # type: ignore + return cls + + def _convert_to_rs_instruction(instr: Union[AbstractInstruction, quil_rs.Instruction]) -> quil_rs.Instruction: if isinstance(instr, quil_rs.Instruction): return instr @@ -319,6 +336,7 @@ def _convert_to_py_instructions(instrs: Iterable[quil_rs.Instruction]) -> list[A ] +@_add_reduce_method class Gate(quil_rs.Gate, AbstractInstruction): """A quantum gate instruction.""" @@ -488,6 +506,7 @@ def _strip_modifiers(gate: Gate, limit: Optional[int] = None) -> Gate: return stripped +@_add_reduce_method class Measurement(quil_rs.Measurement, AbstractInstruction): """A Quil measurement instruction.""" @@ -565,6 +584,7 @@ def __deepcopy__(self, memo: dict) -> "Measurement": return Measurement._from_rs_measurement(super().__deepcopy__(memo)) +@_add_reduce_method class Reset(quil_rs.Reset, AbstractInstruction): """The RESET instruction.""" @@ -644,6 +664,7 @@ def _from_rs_reset(cls, reset: quil_rs.Reset) -> "ResetQubit": raise ValueError("reset.qubit should not be None") +@_add_reduce_method class DefGate(quil_rs.GateDefinition, AbstractInstruction): """A DEFGATE directive.""" @@ -840,6 +861,7 @@ def __str__(self) -> str: return super().to_quil_or_debug() +@_add_reduce_method class JumpTarget(quil_rs.Label, AbstractInstruction): """Representation of a target that can be jumped to.""" @@ -872,6 +894,7 @@ def __deepcopy__(self, memo: dict) -> "JumpTarget": return JumpTarget._from_rs_label(super().__deepcopy__(memo)) +@_add_reduce_method class JumpWhen(quil_rs.JumpWhen, AbstractInstruction): """The JUMP-WHEN instruction.""" @@ -923,6 +946,7 @@ def __deepcopy__(self, memo: dict) -> "JumpWhen": return JumpWhen._from_rs_jump_when(super().__deepcopy__(memo)) +@_add_reduce_method class JumpUnless(quil_rs.JumpUnless, AbstractInstruction): """The JUMP-UNLESS instruction.""" @@ -1011,6 +1035,7 @@ class Nop(SimpleInstruction): instruction = quil_rs.Instruction.new_nop() +@_add_reduce_method class UnaryClassicalInstruction(quil_rs.UnaryLogic, AbstractInstruction): """Base class for unary classical instructions.""" @@ -1061,6 +1086,7 @@ class ClassicalNot(UnaryClassicalInstruction): op = quil_rs.UnaryOperator.Not +@_add_reduce_method class LogicalBinaryOp(quil_rs.BinaryLogic, AbstractInstruction): """Base class for binary logical classical instructions.""" @@ -1142,6 +1168,7 @@ class ClassicalExclusiveOr(LogicalBinaryOp): op = quil_rs.BinaryOperator.Xor +@_add_reduce_method class ArithmeticBinaryOp(quil_rs.Arithmetic, AbstractInstruction): """Base class for binary arithmetic classical instructions.""" @@ -1216,6 +1243,7 @@ class ClassicalDiv(ArithmeticBinaryOp): op = quil_rs.ArithmeticOperator.Divide +@_add_reduce_method class ClassicalMove(quil_rs.Move, AbstractInstruction): """The MOVE instruction.""" @@ -1259,6 +1287,7 @@ def __deepcopy__(self, memo: dict) -> "ClassicalMove": return ClassicalMove._from_rs_move(super().__deepcopy__(memo)) +@_add_reduce_method class ClassicalExchange(quil_rs.Exchange, AbstractInstruction): """The EXCHANGE instruction.""" @@ -1306,6 +1335,7 @@ def __deepcopy__(self, memo: dict) -> "ClassicalExchange": return ClassicalExchange._from_rs_exchange(super().__deepcopy__(memo)) +@_add_reduce_method class ClassicalConvert(quil_rs.Convert, AbstractInstruction): """The CONVERT instruction.""" @@ -1349,6 +1379,7 @@ def __deepcopy__(self, memo: dict) -> "ClassicalConvert": return ClassicalConvert._from_rs_convert(super().__deepcopy__(memo)) +@_add_reduce_method class ClassicalLoad(quil_rs.Load, AbstractInstruction): """The LOAD instruction.""" @@ -1420,6 +1451,7 @@ def _to_py_arithmetic_operand(operand: quil_rs.ArithmeticOperand) -> Union[Memor return inner +@_add_reduce_method class ClassicalStore(quil_rs.Store, AbstractInstruction): """The STORE instruction.""" @@ -1473,6 +1505,7 @@ def __deepcopy__(self, memo: dict) -> "ClassicalStore": return ClassicalStore._from_rs_store(super().__deepcopy__(memo)) +@_add_reduce_method class ClassicalComparison(quil_rs.Comparison, AbstractInstruction): """Base class for ternary comparison instructions.""" @@ -1588,6 +1621,7 @@ class ClassicalGreaterEqual(ClassicalComparison): op = quil_rs.ComparisonOperator.GreaterThanOrEqual +@_add_reduce_method class Jump(quil_rs.Jump, AbstractInstruction): """Representation of an unconditional jump instruction (JUMP).""" @@ -1624,6 +1658,7 @@ def __deepcopy__(self, memo: dict) -> "Jump": return Jump._from_rs_jump(super().__deepcopy__(memo)) +@_add_reduce_method class Pragma(quil_rs.Pragma, AbstractInstruction): """A PRAGMA instruction. @@ -1712,6 +1747,7 @@ def __deepcopy__(self, memo: dict) -> "Pragma": return Pragma._from_rs_pragma(super().__deepcopy__(memo)) +@_add_reduce_method class Declare(quil_rs.Declaration, AbstractInstruction): """A DECLARE directive. @@ -1838,6 +1874,7 @@ def __deepcopy__(self, memo: dict) -> "Declare": return Declare._from_rs_declaration(super().__deepcopy__(memo)) +@_add_reduce_method class Include(quil_rs.Include, AbstractInstruction): """An INCLUDE directive.""" @@ -1859,6 +1896,7 @@ def __deepcopy__(self, memo: dict) -> "Include": return Include._from_rs_include(super().__deepcopy__(memo)) +@_add_reduce_method class Pulse(quil_rs.Pulse, AbstractInstruction): """A PULSE instruction.""" @@ -1926,6 +1964,7 @@ def __deepcopy__(self, memo: dict) -> "Pulse": return Pulse._from_rs_pulse(super().__deepcopy__(memo)) +@_add_reduce_method class SetFrequency(quil_rs.SetFrequency, AbstractInstruction): """A SET-FREQUENCY instruction.""" @@ -1983,6 +2022,7 @@ def __deepcopy__(self, memo: dict) -> "SetFrequency": return SetFrequency._from_rs_set_frequency(super().__deepcopy__(memo)) +@_add_reduce_method class ShiftFrequency(quil_rs.ShiftFrequency, AbstractInstruction): """The SHIFT-FREQUENCY instruction.""" @@ -2040,6 +2080,7 @@ def __deepcopy__(self, memo: dict) -> "ShiftFrequency": return ShiftFrequency._from_rs_shift_frequency(super().__deepcopy__(memo)) +@_add_reduce_method class SetPhase(quil_rs.SetPhase, AbstractInstruction): """The SET-PHASE instruction.""" @@ -2097,6 +2138,7 @@ def __deepcopy__(self, memo: dict) -> "SetPhase": return SetPhase._from_rs_set_phase(super().__deepcopy__(memo)) +@_add_reduce_method class ShiftPhase(quil_rs.ShiftPhase, AbstractInstruction): """The SHIFT-PHASE instruction.""" @@ -2154,6 +2196,7 @@ def __deepcopy__(self, memo: dict) -> "ShiftPhase": return ShiftPhase._from_rs_shift_phase(super().__deepcopy__(memo)) +@_add_reduce_method class SwapPhases(quil_rs.SwapPhases, AbstractInstruction): """The SWAP-PHASES instruction.""" @@ -2211,6 +2254,7 @@ def __deepcopy__(self, memo: dict) -> "SwapPhases": return SwapPhases._from_rs_swap_phases(super().__deepcopy__(memo)) +@_add_reduce_method class SetScale(quil_rs.SetScale, AbstractInstruction): """The SET-SCALE instruction.""" @@ -2268,6 +2312,7 @@ def __deepcopy__(self, memo: dict) -> "SetScale": return SetScale._from_rs_set_scale(super().__deepcopy__(memo)) +@_add_reduce_method class Capture(quil_rs.Capture, AbstractInstruction): """The CAPTURE instruction.""" @@ -2352,6 +2397,7 @@ def __deepcopy__(self, memo: dict) -> "Capture": return Capture._from_rs_capture(super().__deepcopy__(memo)) +@_add_reduce_method class RawCapture(quil_rs.RawCapture, AbstractInstruction): """The RAW-CAPTURE instruction.""" @@ -2440,6 +2486,7 @@ def __deepcopy__(self, memo: dict) -> "RawCapture": return RawCapture._from_rs_raw_capture(super().__deepcopy__(memo)) +@_add_reduce_method class Delay(quil_rs.Delay, AbstractInstruction): """The DELAY instruction.""" @@ -2534,6 +2581,7 @@ def _from_rs_delay(cls, delay: quil_rs.Delay) -> "DelayQubits": return Delay._from_rs_delay.__func__(cls, delay) # type: ignore +@_add_reduce_method class Fence(quil_rs.Fence, AbstractInstruction): """The FENCE instruction.""" @@ -2579,6 +2627,7 @@ def __new__(cls) -> Self: return super().__new__(cls, []) +@_add_reduce_method class DefWaveform(quil_rs.WaveformDefinition, AbstractInstruction): """A waveform definition.""" @@ -2637,6 +2686,7 @@ def __deepcopy__(self, memo: dict) -> "DefWaveform": return DefWaveform._from_rs_waveform_definition(super().__deepcopy__(memo)) +@_add_reduce_method class DefCircuit(quil_rs.CircuitDefinition, AbstractInstruction): """A circuit definition.""" @@ -2707,6 +2757,7 @@ def __deepcopy__(self, memo: dict) -> "DefCircuit": return DefCircuit._from_rs_circuit_definition(super().__deepcopy__(memo)) +@_add_reduce_method class DefCalibration(quil_rs.Calibration, AbstractInstruction): """A calibration definition.""" @@ -2789,6 +2840,7 @@ def __deepcopy__(self, memo: dict) -> "DefCalibration": return DefCalibration._from_rs_calibration(super().__deepcopy__(memo)) +@_add_reduce_method class DefMeasureCalibration(quil_rs.MeasureCalibrationDefinition, AbstractInstruction): """A measure calibration definition.""" @@ -2866,6 +2918,7 @@ def __deepcopy__(self, memo: dict) -> "DefMeasureCalibration": return DefMeasureCalibration._from_rs_measure_calibration_definition(super().__deepcopy__(memo)) +@_add_reduce_method class DefFrame(quil_rs.FrameDefinition, AbstractInstruction): """A frame definition.""" diff --git a/test/unit/test_quilbase.py b/test/unit/test_quilbase.py index e3c781218..424e7dc3e 100644 --- a/test/unit/test_quilbase.py +++ b/test/unit/test_quilbase.py @@ -191,6 +191,7 @@ def test_compile(self, program: Program, compiler: QPUCompiler): def test_pickle(self, gate: Gate): pickled = pickle.dumps(gate) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, Gate) assert unpickled == gate @@ -270,6 +271,7 @@ def test_copy(self, def_gate: DefGate): def test_pickle(self, def_gate: DefGate, snapshot: SnapshotAssertion): pickled = pickle.dumps(def_gate) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, DefGate) assert unpickled == snapshot @@ -444,6 +446,7 @@ def test_convert(self, calibration: DefCalibration): def test_pickle(self, calibration: DefCalibration): pickled = pickle.dumps(calibration) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, DefCalibration) assert unpickled == calibration @@ -500,6 +503,7 @@ def test_convert(self, measure_calibration: DefMeasureCalibration): def test_pickle(self, measure_calibration: DefMeasureCalibration): pickled = pickle.dumps(measure_calibration) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, DefMeasureCalibration) assert unpickled == measure_calibration @@ -548,6 +552,7 @@ def test_convert(self, measurement: Measurement): def test_pickle(self, measurement: Measurement): pickled = pickle.dumps(measurement) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, Measurement) assert unpickled == measurement @@ -643,9 +648,9 @@ def test_convert(self, def_frame: DefFrame): assert def_frame == _convert_to_py_instruction(rs_def_frame) def test_pickle(self, def_frame: DefFrame): - print(def_frame.to_quil()) pickled = pickle.dumps(def_frame) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, DefFrame) assert unpickled == def_frame @@ -721,6 +726,7 @@ def test_convert(self, declare: Declare): def test_pickle(self, declare: Declare): pickled = pickle.dumps(declare) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, Declare) assert unpickled == declare @@ -771,6 +777,7 @@ def test_convert(self, pragma: Pragma): def test_pickle(self, pragma: Pragma): pickled = pickle.dumps(pragma) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, Pragma) assert unpickled == pragma @@ -823,6 +830,7 @@ def test_convert(self, reset_qubit: Reset): def test_pickle(self, reset_qubit: Reset): pickled = pickle.dumps(reset_qubit) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, (Reset, ResetQubit)) assert unpickled == reset_qubit @@ -861,6 +869,7 @@ def test_convert(self, delay_frames: DelayFrames): def test_pickle(self, delay_frames: DelayFrames): pickled = pickle.dumps(delay_frames) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, DelayFrames) assert unpickled == delay_frames @@ -901,6 +910,7 @@ def test_convert(self, delay_qubits: DelayQubits): def test_pickle(self, delay_qubits: DelayQubits): pickled = pickle.dumps(delay_qubits) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, DelayQubits) assert unpickled == delay_qubits @@ -936,6 +946,7 @@ def test_convert(self, fence: Fence): def test_pickle(self, fence: Fence): pickled = pickle.dumps(fence) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, Fence) assert unpickled == fence @@ -989,9 +1000,9 @@ def test_convert(self, def_waveform: DefWaveform): assert def_waveform == _convert_to_py_instruction(rs_def_waveform) def test_pickle(self, def_waveform: DefWaveform, snapshot: SnapshotAssertion): - print(def_waveform.to_quil()) pickled = pickle.dumps(def_waveform) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, DefWaveform) assert unpickled == snapshot @@ -1054,9 +1065,9 @@ def test_convert(self, def_circuit: DefCircuit): assert def_circuit == _convert_to_py_instruction(rs_def_circuit) def test_pickle(self, def_circuit: DefCircuit): - print(def_circuit.to_quil()) pickled = pickle.dumps(def_circuit) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, DefCircuit) assert unpickled == def_circuit @@ -1123,6 +1134,7 @@ def test_convert(self, capture: Capture): def test_pickle(self, capture: Capture, snapshot: SnapshotAssertion): pickled = pickle.dumps(capture) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, Capture) assert unpickled == snapshot @@ -1218,6 +1230,7 @@ def test_convert(self, pulse: Pulse): def test_pickle(self, pulse: Pulse, snapshot: SnapshotAssertion): pickled = pickle.dumps(pulse) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, Pulse) assert unpickled == snapshot @@ -1284,6 +1297,7 @@ def test_convert(self, raw_capture: RawCapture): def test_pickle(self, raw_capture: RawCapture): pickled = pickle.dumps(raw_capture) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, RawCapture) assert unpickled == raw_capture @@ -1390,6 +1404,7 @@ def test_convert(self, swap_phases: SwapPhases): def test_pickle(self, swap_phases: SwapPhases): pickled = pickle.dumps(swap_phases) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, SwapPhases) assert unpickled == swap_phases @@ -1430,6 +1445,7 @@ def test_convert(self, move: ClassicalMove): def test_pickle(self, move: ClassicalMove): pickled = pickle.dumps(move) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, ClassicalMove) assert unpickled == move @@ -1466,6 +1482,7 @@ def test_convert(self, exchange: ClassicalExchange): def test_pickle(self, exchange: ClassicalExchange): pickled = pickle.dumps(exchange) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, ClassicalExchange) assert unpickled == exchange @@ -1502,6 +1519,7 @@ def test_convert(self, convert: ClassicalConvert): def test_pickle(self, convert: ClassicalConvert): pickled = pickle.dumps(convert) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, ClassicalConvert) assert unpickled == convert @@ -1543,6 +1561,7 @@ def test_convert(self, load: ClassicalLoad): def test_pickle(self, load: ClassicalLoad): pickled = pickle.dumps(load) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, ClassicalLoad) assert unpickled == load @@ -1588,6 +1607,7 @@ def test_convert(self, store: ClassicalStore): def test_pickle(self, store: ClassicalStore): pickled = pickle.dumps(store) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, ClassicalStore) assert unpickled == store @@ -1645,6 +1665,7 @@ def test_convert(self, comparison: ClassicalComparison): def test_pickle(self, comparison: ClassicalComparison): pickled = pickle.dumps(comparison) unpickled = pickle.loads(pickled) + assert isinstance(unpickled, ClassicalComparison) assert unpickled == comparison @@ -1727,7 +1748,7 @@ def test_copy(self, arithmetic: ArithmeticBinaryOp): def valid_in_program(self, arithmetic): try: p = Program(arithmetic) - p[0] == arithmetic + p[0] = arithmetic except Exception: pytest.fail("ArithmeticBinaryOp not valid in Program")