diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 445064a369d..e8eb872565d 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -60,6 +60,9 @@

Improvements 🛠

+* Operators can now be left multiplied `x * op` by numpy arrays. + [(#5361)](https://github.com/PennyLaneAI/pennylane/pull/5361) + * Create the `qml.Reflection` operator, useful for amplitude amplification and its variants. [(##5159)](https://github.com/PennyLaneAI/pennylane/pull/5159) diff --git a/pennylane/operation.py b/pennylane/operation.py index bb9c634265a..26808341cda 100644 --- a/pennylane/operation.py +++ b/pennylane/operation.py @@ -691,6 +691,10 @@ def compute_decomposition(theta, wires): # pylint: disable=too-many-public-methods, too-many-instance-attributes + # this allows scalar multiplication from left with numpy arrays np.array(0.5) * ps1 + # taken from [stackexchange](https://stackoverflow.com/questions/40694380/forcing-multiplication-to-use-rmul-instead-of-numpy-array-mul-or-byp/44634634#44634634) + __array_priority__ = 1000 + def __init_subclass__(cls, **_): register_pytree(cls, cls._flatten, cls._unflatten) diff --git a/tests/test_operation.py b/tests/test_operation.py index ffabfeea358..8a880b915d1 100644 --- a/tests/test_operation.py +++ b/tests/test_operation.py @@ -1197,6 +1197,24 @@ def test_mul_scalar_tensor(self): assert isinstance(prod_op, SProd) assert prod_op.scalar is scalar + prod_op2 = scalar * qml.RX(1.23, 0) + assert isinstance(prod_op2, SProd) + assert prod_op.scalar is scalar + + def test_mul_array_numpy(self): + """Test that the __mul__ dunder works with a batched scalar.""" + + scalar = np.array([0.5, 0.6, 0.7]) + prod_op = scalar * qml.S(0) + assert isinstance(prod_op, SProd) + assert prod_op.scalar is scalar + assert prod_op.batch_size == 3 + + prod_op = qml.S(0) * scalar + assert isinstance(prod_op, SProd) + assert prod_op.scalar is scalar + assert prod_op.batch_size == 3 + def test_divide_with_scalar(self): """Test the __truediv__ dunder method with a scalar value.""" sprod_op = qml.RX(1, 0) / 4 @@ -1255,6 +1273,10 @@ def test_mul_scalar_torch_tensor(self): assert isinstance(prod_op, SProd) assert prod_op.scalar is scalar + prod_op = scalar * qml.RX(1.23, 0) + assert isinstance(prod_op, SProd) + assert prod_op.scalar is scalar + @pytest.mark.tf def test_mul_scalar_tf_tensor(self): """Test the __mul__ dunder method with a scalar tf tensor.""" @@ -1265,6 +1287,10 @@ def test_mul_scalar_tf_tensor(self): assert isinstance(prod_op, SProd) assert prod_op.scalar is scalar + prod_op = scalar * qml.RX(1.23, 0) + assert isinstance(prod_op, SProd) + assert prod_op.scalar is scalar + @pytest.mark.jax def test_mul_scalar_jax_tensor(self): """Test the __mul__ dunder method with a scalar jax tensor.""" @@ -1275,6 +1301,10 @@ def test_mul_scalar_jax_tensor(self): assert isinstance(prod_op, SProd) assert prod_op.scalar is scalar + prod_op = scalar * qml.RX(1.23, 0) + assert isinstance(prod_op, SProd) + assert prod_op.scalar is scalar + def test_mul_with_operator(self): """Test the __matmul__ dunder method with an operator.""" prod_op = qml.RX(1, 0) @ qml.PauliX(0) diff --git a/tests/test_qubit_device.py b/tests/test_qubit_device.py index 452deaed98a..b27f1c67bc7 100644 --- a/tests/test_qubit_device.py +++ b/tests/test_qubit_device.py @@ -595,7 +595,7 @@ def test_non_analytic_expval(self, mock_qubit_device_with_original_statistics, m m.setattr("numpy.mean", lambda obs, axis=None: obs) res = dev.expval(obs) - assert res == obs + assert res == np.array(obs) # no idea what is trying to cast obs to an array now. def test_no_eigval_error(self, mock_qubit_device_with_original_statistics): """Tests that an error is thrown if expval is called with an observable that does @@ -675,7 +675,7 @@ def test_non_analytic_var(self, mock_qubit_device_with_original_statistics, monk m.setattr("numpy.var", lambda obs, axis=None: obs) res = dev.var(obs) - assert res == obs + assert res == np.array(obs) def test_no_eigval_error(self, mock_qubit_device_with_original_statistics): """Tests that an error is thrown if var is called with an observable that does not have eigenvalues defined.""" diff --git a/tests/test_qutrit_device.py b/tests/test_qutrit_device.py index 9613e76d748..74cd9fec886 100644 --- a/tests/test_qutrit_device.py +++ b/tests/test_qutrit_device.py @@ -700,7 +700,7 @@ def test_non_analytic_expval(self, mock_qutrit_device_with_original_statistics, m.setattr("numpy.mean", lambda obs, axis=None: obs) res = dev.expval(obs) - assert res == obs + assert res == np.array(obs) def test_no_eigval_error(self, mock_qutrit_device_with_original_statistics): """Tests that an error is thrown if expval is called with an observable that does @@ -762,7 +762,7 @@ def test_non_analytic_var(self, mock_qutrit_device_with_original_statistics, mon m.setattr("numpy.var", lambda obs, axis=None: obs) res = dev.var(obs) - assert res == obs + assert res == np.array(obs) def test_no_eigval_error(self, mock_qutrit_device_with_original_statistics): """Tests that an error is thrown if var is called with an observable that does not have eigenvalues defined."""