diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 445064a369d..e8eb872565d 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -60,6 +60,9 @@
Improvements ðŸ›
+* Operators can now be left multiplied `x * op` by numpy arrays.
+ [(#5361)](https://github.com/PennyLaneAI/pennylane/pull/5361)
+
* Create the `qml.Reflection` operator, useful for amplitude amplification and its variants.
[(##5159)](https://github.com/PennyLaneAI/pennylane/pull/5159)
diff --git a/pennylane/operation.py b/pennylane/operation.py
index bb9c634265a..26808341cda 100644
--- a/pennylane/operation.py
+++ b/pennylane/operation.py
@@ -691,6 +691,10 @@ def compute_decomposition(theta, wires):
# pylint: disable=too-many-public-methods, too-many-instance-attributes
+ # this allows scalar multiplication from left with numpy arrays np.array(0.5) * ps1
+ # taken from [stackexchange](https://stackoverflow.com/questions/40694380/forcing-multiplication-to-use-rmul-instead-of-numpy-array-mul-or-byp/44634634#44634634)
+ __array_priority__ = 1000
+
def __init_subclass__(cls, **_):
register_pytree(cls, cls._flatten, cls._unflatten)
diff --git a/tests/test_operation.py b/tests/test_operation.py
index ffabfeea358..8a880b915d1 100644
--- a/tests/test_operation.py
+++ b/tests/test_operation.py
@@ -1197,6 +1197,24 @@ def test_mul_scalar_tensor(self):
assert isinstance(prod_op, SProd)
assert prod_op.scalar is scalar
+ prod_op2 = scalar * qml.RX(1.23, 0)
+ assert isinstance(prod_op2, SProd)
+ assert prod_op.scalar is scalar
+
+ def test_mul_array_numpy(self):
+ """Test that the __mul__ dunder works with a batched scalar."""
+
+ scalar = np.array([0.5, 0.6, 0.7])
+ prod_op = scalar * qml.S(0)
+ assert isinstance(prod_op, SProd)
+ assert prod_op.scalar is scalar
+ assert prod_op.batch_size == 3
+
+ prod_op = qml.S(0) * scalar
+ assert isinstance(prod_op, SProd)
+ assert prod_op.scalar is scalar
+ assert prod_op.batch_size == 3
+
def test_divide_with_scalar(self):
"""Test the __truediv__ dunder method with a scalar value."""
sprod_op = qml.RX(1, 0) / 4
@@ -1255,6 +1273,10 @@ def test_mul_scalar_torch_tensor(self):
assert isinstance(prod_op, SProd)
assert prod_op.scalar is scalar
+ prod_op = scalar * qml.RX(1.23, 0)
+ assert isinstance(prod_op, SProd)
+ assert prod_op.scalar is scalar
+
@pytest.mark.tf
def test_mul_scalar_tf_tensor(self):
"""Test the __mul__ dunder method with a scalar tf tensor."""
@@ -1265,6 +1287,10 @@ def test_mul_scalar_tf_tensor(self):
assert isinstance(prod_op, SProd)
assert prod_op.scalar is scalar
+ prod_op = scalar * qml.RX(1.23, 0)
+ assert isinstance(prod_op, SProd)
+ assert prod_op.scalar is scalar
+
@pytest.mark.jax
def test_mul_scalar_jax_tensor(self):
"""Test the __mul__ dunder method with a scalar jax tensor."""
@@ -1275,6 +1301,10 @@ def test_mul_scalar_jax_tensor(self):
assert isinstance(prod_op, SProd)
assert prod_op.scalar is scalar
+ prod_op = scalar * qml.RX(1.23, 0)
+ assert isinstance(prod_op, SProd)
+ assert prod_op.scalar is scalar
+
def test_mul_with_operator(self):
"""Test the __matmul__ dunder method with an operator."""
prod_op = qml.RX(1, 0) @ qml.PauliX(0)
diff --git a/tests/test_qubit_device.py b/tests/test_qubit_device.py
index 452deaed98a..b27f1c67bc7 100644
--- a/tests/test_qubit_device.py
+++ b/tests/test_qubit_device.py
@@ -595,7 +595,7 @@ def test_non_analytic_expval(self, mock_qubit_device_with_original_statistics, m
m.setattr("numpy.mean", lambda obs, axis=None: obs)
res = dev.expval(obs)
- assert res == obs
+ assert res == np.array(obs) # no idea what is trying to cast obs to an array now.
def test_no_eigval_error(self, mock_qubit_device_with_original_statistics):
"""Tests that an error is thrown if expval is called with an observable that does
@@ -675,7 +675,7 @@ def test_non_analytic_var(self, mock_qubit_device_with_original_statistics, monk
m.setattr("numpy.var", lambda obs, axis=None: obs)
res = dev.var(obs)
- assert res == obs
+ assert res == np.array(obs)
def test_no_eigval_error(self, mock_qubit_device_with_original_statistics):
"""Tests that an error is thrown if var is called with an observable that does not have eigenvalues defined."""
diff --git a/tests/test_qutrit_device.py b/tests/test_qutrit_device.py
index 9613e76d748..74cd9fec886 100644
--- a/tests/test_qutrit_device.py
+++ b/tests/test_qutrit_device.py
@@ -700,7 +700,7 @@ def test_non_analytic_expval(self, mock_qutrit_device_with_original_statistics,
m.setattr("numpy.mean", lambda obs, axis=None: obs)
res = dev.expval(obs)
- assert res == obs
+ assert res == np.array(obs)
def test_no_eigval_error(self, mock_qutrit_device_with_original_statistics):
"""Tests that an error is thrown if expval is called with an observable that does
@@ -762,7 +762,7 @@ def test_non_analytic_var(self, mock_qutrit_device_with_original_statistics, mon
m.setattr("numpy.var", lambda obs, axis=None: obs)
res = dev.var(obs)
- assert res == obs
+ assert res == np.array(obs)
def test_no_eigval_error(self, mock_qutrit_device_with_original_statistics):
"""Tests that an error is thrown if var is called with an observable that does not have eigenvalues defined."""