diff --git a/pip/qsharp/_qsharp.py b/pip/qsharp/_qsharp.py index 75d148bb30..48eb06e984 100644 --- a/pip/qsharp/_qsharp.py +++ b/pip/qsharp/_qsharp.py @@ -10,7 +10,7 @@ Circuit, ) from warnings import warn -from typing import Any, Callable, Dict, Optional, TypedDict, Union, List +from typing import Any, Callable, Dict, Optional, Tuple, TypedDict, Union, List from .estimator._estimator import EstimatorResult, EstimatorParams import json @@ -349,6 +349,39 @@ def __str__(self) -> str: def _repr_html_(self) -> str: return self.__data._repr_html_() + def check_eq( + self, state: Union[Dict[int, complex], List[complex]], tolerance: float = 1e-10 + ) -> bool: + """ + Checks if the state dump is equal to the given state. This is not mathematical equality, + as the check ignores global phase. + + :param state: The state to check against, provided either as a dictionary of state indices to complex amplitudes, + or as a list of real amplitudes. + :param tolerance: The tolerance for the check. Defaults to 1e-10. + """ + phase = None + # Convert a dense list of real amplitudes to a dictionary of state indices to complex amplitudes + if isinstance(state, list): + state = {i: state[i] for i in range(len(state))} + # Filter out zero states from the state dump and the given state based on tolerance + state = {k: v for k, v in state.items() if abs(v) > tolerance} + inner_state = {k: v for k, v in self.__inner.items() if abs(v) > tolerance} + if len(state) != len(inner_state): + return False + for key in state: + if key not in inner_state: + return False + if phase is None: + # Calculate the phase based on the first state pair encountered. + # Every pair of states after this must have the same phase for the states to be equivalent. + phase = inner_state[key] / state[key] + elif abs(phase - inner_state[key] / state[key]) > tolerance: + # This pair of states does not have the same phase, + # within tolerance, so the equivalence check fails. + return False + return True + def dump_machine() -> StateDump: """ diff --git a/pip/tests/test_qsharp.py b/pip/tests/test_qsharp.py index 4cd2316d84..719ac290d5 100644 --- a/pip/tests/test_qsharp.py +++ b/pip/tests/test_qsharp.py @@ -98,6 +98,34 @@ def test_dump_machine() -> None: # Check that the state dump correctly supports iteration and membership checks for idx in state_dump: assert idx in state_dump + # Check that the state dump is correct and equivalence check ignores global phase, allowing passing + # in of different, potentially unnormalized states. The state should be + # |01⟩: 0.7071+0.0000𝑖, |11⟩: −0.7071+0.0000𝑖 + assert state_dump.check_eq({1: complex(0.7071, 0.0), 3: complex(-0.7071, 0.0)}) + assert state_dump.check_eq({1: complex(0.0, 0.7071), 3: complex(0.0, -0.7071)}) + assert state_dump.check_eq({1: complex(0.5, 0.0), 3: complex(-0.5, 0.0)}) + assert state_dump.check_eq( + {1: complex(0.7071, 0.0), 3: complex(-0.7071, 0.0), 0: complex(0.0, 0.0)} + ) + assert state_dump.check_eq([0.0, 0.5, 0.0, -0.5]) + assert state_dump.check_eq([0.0, 0.5001, 0.00001, -0.5], tolerance=1e-3) + assert state_dump.check_eq( + [complex(0.0, 0.0), complex(0.0, -0.5), complex(0.0, 0.0), complex(0.0, 0.5)] + ) + assert not state_dump.check_eq({1: complex(0.7071, 0.0), 3: complex(0.7071, 0.0)}) + assert not state_dump.check_eq({1: complex(0.5, 0.0), 3: complex(0.0, 0.5)}) + assert not state_dump.check_eq({2: complex(0.5, 0.0), 3: complex(-0.5, 0.0)}) + assert not state_dump.check_eq([0.0, 0.5001, 0.0, -0.5], tolerance=1e-6) + # Reset the qubits and apply a small rotation to q1, to confirm that tolerance applies to the dump + # itself and not just the state. + qsharp.eval("ResetAll([q1, q2]);") + qsharp.eval("Ry(0.0001, q1);") + state_dump = qsharp.dump_machine() + assert state_dump.qubit_count == 2 + assert len(state_dump) == 2 + assert not state_dump.check_eq([1.0]) + assert state_dump.check_eq([0.99999999875, 0.0, 4.999999997916667e-05]) + assert state_dump.check_eq([1.0], tolerance=1e-4) def test_dump_operation() -> None: