Skip to content

Commit

Permalink
Add check_eq for StateDump in Python (#1372)
Browse files Browse the repository at this point in the history
This adds a utility to the `StateDump` object in Python to help with
writing tests that verify quantum state. The check ignores global phase,
so allows for passing in any dictionary where the states differ from the
dump by a constant factor, including unnormalized states.
  • Loading branch information
swernli authored Apr 23, 2024
1 parent 4d891c1 commit 3e85d6d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
35 changes: 34 additions & 1 deletion pip/qsharp/_qsharp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Circuit,
)
from warnings import warn
from typing import Any, Callable, Dict, Optional, TypedDict, Union, List
from typing import Any, Callable, Dict, Optional, Tuple, TypedDict, Union, List
from .estimator._estimator import EstimatorResult, EstimatorParams
import json

Expand Down Expand Up @@ -349,6 +349,39 @@ def __str__(self) -> str:
def _repr_html_(self) -> str:
return self.__data._repr_html_()

def check_eq(
self, state: Union[Dict[int, complex], List[complex]], tolerance: float = 1e-10
) -> bool:
"""
Checks if the state dump is equal to the given state. This is not mathematical equality,
as the check ignores global phase.
:param state: The state to check against, provided either as a dictionary of state indices to complex amplitudes,
or as a list of real amplitudes.
:param tolerance: The tolerance for the check. Defaults to 1e-10.
"""
phase = None
# Convert a dense list of real amplitudes to a dictionary of state indices to complex amplitudes
if isinstance(state, list):
state = {i: state[i] for i in range(len(state))}
# Filter out zero states from the state dump and the given state based on tolerance
state = {k: v for k, v in state.items() if abs(v) > tolerance}
inner_state = {k: v for k, v in self.__inner.items() if abs(v) > tolerance}
if len(state) != len(inner_state):
return False
for key in state:
if key not in inner_state:
return False
if phase is None:
# Calculate the phase based on the first state pair encountered.
# Every pair of states after this must have the same phase for the states to be equivalent.
phase = inner_state[key] / state[key]
elif abs(phase - inner_state[key] / state[key]) > tolerance:
# This pair of states does not have the same phase,
# within tolerance, so the equivalence check fails.
return False
return True


def dump_machine() -> StateDump:
"""
Expand Down
28 changes: 28 additions & 0 deletions pip/tests/test_qsharp.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,34 @@ def test_dump_machine() -> None:
# Check that the state dump correctly supports iteration and membership checks
for idx in state_dump:
assert idx in state_dump
# Check that the state dump is correct and equivalence check ignores global phase, allowing passing
# in of different, potentially unnormalized states. The state should be
# |01⟩: 0.7071+0.0000𝑖, |11⟩: −0.7071+0.0000𝑖
assert state_dump.check_eq({1: complex(0.7071, 0.0), 3: complex(-0.7071, 0.0)})
assert state_dump.check_eq({1: complex(0.0, 0.7071), 3: complex(0.0, -0.7071)})
assert state_dump.check_eq({1: complex(0.5, 0.0), 3: complex(-0.5, 0.0)})
assert state_dump.check_eq(
{1: complex(0.7071, 0.0), 3: complex(-0.7071, 0.0), 0: complex(0.0, 0.0)}
)
assert state_dump.check_eq([0.0, 0.5, 0.0, -0.5])
assert state_dump.check_eq([0.0, 0.5001, 0.00001, -0.5], tolerance=1e-3)
assert state_dump.check_eq(
[complex(0.0, 0.0), complex(0.0, -0.5), complex(0.0, 0.0), complex(0.0, 0.5)]
)
assert not state_dump.check_eq({1: complex(0.7071, 0.0), 3: complex(0.7071, 0.0)})
assert not state_dump.check_eq({1: complex(0.5, 0.0), 3: complex(0.0, 0.5)})
assert not state_dump.check_eq({2: complex(0.5, 0.0), 3: complex(-0.5, 0.0)})
assert not state_dump.check_eq([0.0, 0.5001, 0.0, -0.5], tolerance=1e-6)
# Reset the qubits and apply a small rotation to q1, to confirm that tolerance applies to the dump
# itself and not just the state.
qsharp.eval("ResetAll([q1, q2]);")
qsharp.eval("Ry(0.0001, q1);")
state_dump = qsharp.dump_machine()
assert state_dump.qubit_count == 2
assert len(state_dump) == 2
assert not state_dump.check_eq([1.0])
assert state_dump.check_eq([0.99999999875, 0.0, 4.999999997916667e-05])
assert state_dump.check_eq([1.0], tolerance=1e-4)


def test_dump_operation() -> None:
Expand Down

0 comments on commit 3e85d6d

Please sign in to comment.