Skip to content

Commit

Permalink
testing max_bond_dimension error
Browse files Browse the repository at this point in the history
  • Loading branch information
KetpuntoG committed Feb 7, 2025
1 parent 03e9ba1 commit df0c52e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
18 changes: 13 additions & 5 deletions pennylane/templates/state_preparations/state_prep_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class MPSPrep(Operation):
product of site matrices. See the usage details section for more information.
wires (Sequence[int]): wires that the template acts on
work_wires (Sequence[int]): list of extra qubits needed in the decomposition. The bond dimension of the mps
is defined as ``2^len(work_wires)``. Default is ``None``.
work_wires (Sequence[int]): list of extra qubits needed in the decomposition. The maximum permissible bond
dimension of the provided MPS is defined as ``2^len(work_wires)``. Default is ``None``.
The decomposition follows Eq. (23) in `[arXiv:2310.18410] <https://arxiv.org/pdf/2310.18410>`_.
Expand Down Expand Up @@ -233,14 +233,22 @@ def compute_decomposition(mps, wires, work_wires): # pylint: disable=arguments-
product of site matrices.
wires (Sequence[int]): wires that the template acts on
work_wires (Sequence[int]): list of extra qubits needed. The bond dimension of the mps
is defined as ``2^len(work_wires)``
work_wires (Sequence[int]): list of extra qubits needed in the decomposition. The maximum permissible bond
dimension of the provided MPS is defined as ``2^len(work_wires)``. Default is ``None``.
Returns:
list[.Operator]: Decomposition of the operator
"""

if work_wires is None:
bond_dimensions = []

for i in range(len(mps) - 1):
bond_dim = mps[i].shape[-1]
bond_dimensions.append(bond_dim)

max_bond_dimension = max(bond_dimensions)

if work_wires is None or 2 ** len(work_wires) < max_bond_dimension:
raise ValueError(
"The qml.MPSPrep decomposition requires `work_wires` to be specified, "
"and the bond dimension cannot exceed `2**len(work_wires)`."
Expand Down
24 changes: 16 additions & 8 deletions tests/templates/test_state_preparations/test_state_prep_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,21 +1546,29 @@ def test_decomposition(self):
assert op.wires == qml.wires.Wires([2 + ind] + [0, 1])
assert op.name == "QubitUnitary"

def test_wires_decomposition(self):
@pytest.mark.parametrize(("work_wires"), [None, 1])
def test_wires_decomposition(self, work_wires):
"""Checks that error is shown if no `work_wires` are given in decomposition"""

mps = [
np.array([[0.0, 0.107j], [0.994, 0.0]], dtype=complex),
np.array([[0.70710678, 0.0], [0.0, 0.70710678]]),
np.array(
[
[[0.0, 0.0], [1.0, 0.0]],
[[0.0, 1.0], [0.0, 0.0]],
],
dtype=complex,
[[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, -0.0, 0.0], [-1.0, 0.0, 0.0, 0.0]],
]
),
np.array(
[
[[0.00000000e00, 1.74315280e-32], [-7.07106781e-01, -7.07106781e-01]],
[[7.07106781e-01, 7.07106781e-01], [0.00000000e00, 0.00000000e00]],
[[0.00000000e00, 0.00000000e00], [-7.07106781e-01, 7.07106781e-01]],
[[-7.07106781e-01, 7.07106781e-01], [0.00000000e00, 0.00000000e00]],
]
),
np.array([[-1.0, -0.0], [-0.0, -1.0]], dtype=complex),
np.array([[1.0, 0.0], [0.0, 1.0]]),
]

op = qml.MPSPrep(mps, wires=range(2, 5))
op = qml.MPSPrep(mps, wires=range(2, 5), work_wires=work_wires)
with pytest.raises(ValueError, match="The qml.MPSPrep decomposition requires"):
op.decomposition()

0 comments on commit df0c52e

Please sign in to comment.