Skip to content

Commit

Permalink
Fix left multiplying operators with numpy arrays (#5361)
Browse files Browse the repository at this point in the history
[sc-58675]

**Context:**

`np.array([0.5, 0.6]) * qml.S(0)` should give a `SProd` op with a batch
dimension.

**Description of the Change:**

Setting the `__array_priority__` property tells numpy we should use
`Operator.__mul__` and `Operator.__rmul__` instead of the corresponding
numpy methods that try and stick an operator into an array.

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**
  • Loading branch information
albi3ro authored Mar 12, 2024
1 parent 5157192 commit 13afa38
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 4 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@

<h3>Improvements 🛠</h3>

* Operators can now be left multiplied `x * op` by numpy arrays.
[(#5361)](https://github.com/PennyLaneAI/pennylane/pull/5361)

* Create the `qml.Reflection` operator, useful for amplitude amplification and its variants.
[(##5159)](https://github.com/PennyLaneAI/pennylane/pull/5159)

Expand Down
4 changes: 4 additions & 0 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,10 @@ def compute_decomposition(theta, wires):

# pylint: disable=too-many-public-methods, too-many-instance-attributes

# this allows scalar multiplication from left with numpy arrays np.array(0.5) * ps1
# taken from [stackexchange](https://stackoverflow.com/questions/40694380/forcing-multiplication-to-use-rmul-instead-of-numpy-array-mul-or-byp/44634634#44634634)
__array_priority__ = 1000

def __init_subclass__(cls, **_):
register_pytree(cls, cls._flatten, cls._unflatten)

Expand Down
30 changes: 30 additions & 0 deletions tests/test_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,24 @@ def test_mul_scalar_tensor(self):
assert isinstance(prod_op, SProd)
assert prod_op.scalar is scalar

prod_op2 = scalar * qml.RX(1.23, 0)
assert isinstance(prod_op2, SProd)
assert prod_op.scalar is scalar

def test_mul_array_numpy(self):
"""Test that the __mul__ dunder works with a batched scalar."""

scalar = np.array([0.5, 0.6, 0.7])
prod_op = scalar * qml.S(0)
assert isinstance(prod_op, SProd)
assert prod_op.scalar is scalar
assert prod_op.batch_size == 3

prod_op = qml.S(0) * scalar
assert isinstance(prod_op, SProd)
assert prod_op.scalar is scalar
assert prod_op.batch_size == 3

def test_divide_with_scalar(self):
"""Test the __truediv__ dunder method with a scalar value."""
sprod_op = qml.RX(1, 0) / 4
Expand Down Expand Up @@ -1255,6 +1273,10 @@ def test_mul_scalar_torch_tensor(self):
assert isinstance(prod_op, SProd)
assert prod_op.scalar is scalar

prod_op = scalar * qml.RX(1.23, 0)
assert isinstance(prod_op, SProd)
assert prod_op.scalar is scalar

@pytest.mark.tf
def test_mul_scalar_tf_tensor(self):
"""Test the __mul__ dunder method with a scalar tf tensor."""
Expand All @@ -1265,6 +1287,10 @@ def test_mul_scalar_tf_tensor(self):
assert isinstance(prod_op, SProd)
assert prod_op.scalar is scalar

prod_op = scalar * qml.RX(1.23, 0)
assert isinstance(prod_op, SProd)
assert prod_op.scalar is scalar

@pytest.mark.jax
def test_mul_scalar_jax_tensor(self):
"""Test the __mul__ dunder method with a scalar jax tensor."""
Expand All @@ -1275,6 +1301,10 @@ def test_mul_scalar_jax_tensor(self):
assert isinstance(prod_op, SProd)
assert prod_op.scalar is scalar

prod_op = scalar * qml.RX(1.23, 0)
assert isinstance(prod_op, SProd)
assert prod_op.scalar is scalar

def test_mul_with_operator(self):
"""Test the __matmul__ dunder method with an operator."""
prod_op = qml.RX(1, 0) @ qml.PauliX(0)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def test_non_analytic_expval(self, mock_qubit_device_with_original_statistics, m
m.setattr("numpy.mean", lambda obs, axis=None: obs)
res = dev.expval(obs)

assert res == obs
assert res == np.array(obs) # no idea what is trying to cast obs to an array now.

def test_no_eigval_error(self, mock_qubit_device_with_original_statistics):
"""Tests that an error is thrown if expval is called with an observable that does
Expand Down Expand Up @@ -675,7 +675,7 @@ def test_non_analytic_var(self, mock_qubit_device_with_original_statistics, monk
m.setattr("numpy.var", lambda obs, axis=None: obs)
res = dev.var(obs)

assert res == obs
assert res == np.array(obs)

def test_no_eigval_error(self, mock_qubit_device_with_original_statistics):
"""Tests that an error is thrown if var is called with an observable that does not have eigenvalues defined."""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_qutrit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def test_non_analytic_expval(self, mock_qutrit_device_with_original_statistics,
m.setattr("numpy.mean", lambda obs, axis=None: obs)
res = dev.expval(obs)

assert res == obs
assert res == np.array(obs)

def test_no_eigval_error(self, mock_qutrit_device_with_original_statistics):
"""Tests that an error is thrown if expval is called with an observable that does
Expand Down Expand Up @@ -762,7 +762,7 @@ def test_non_analytic_var(self, mock_qutrit_device_with_original_statistics, mon
m.setattr("numpy.var", lambda obs, axis=None: obs)
res = dev.var(obs)

assert res == obs
assert res == np.array(obs)

def test_no_eigval_error(self, mock_qutrit_device_with_original_statistics):
"""Tests that an error is thrown if var is called with an observable that does not have eigenvalues defined."""
Expand Down

0 comments on commit 13afa38

Please sign in to comment.