Skip to content

Commit

Permalink
Fix argnum kwarg of qml.gradients.stoch_pulse_grad (#5458)
Browse files Browse the repository at this point in the history
**Context:**
The `argnum` kwarg of gradient transforms applied to a tape is supposed
to index into the set of _trainable_ parameters, not all tape
parameters.
`stoch_pulse_grad` does not do this properly, as reported in #5457.

**Description of the Change:**
Changes the indexing to be consistent with other gradient transforms

**Benefits:**
Consistency across `gradients` module; Extend support of differentiable
parameters

**Possible Drawbacks:**
N/A

**Related GitHub Issues:**
Fixes #5457.


[sc-60286]
  • Loading branch information
dwierichs authored Apr 5, 2024
1 parent 3bd3677 commit 68f4762
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@

<h3>Bug fixes 🐛</h3>

* Fix a bug where the `argnum` kwarg of `qml.gradients.stoch_pulse_grad` references the wrong parameters in a tape,
creating an inconsistency with other differentiation methods and preventing some use cases.
[(#5458)](https://github.com/PennyLaneAI/pennylane/pull/5458)

* Avoid bounded value failures due to numerical noise with calls to `np.random.binomial`.
[(#5447)](https://github.com/PennyLaneAI/pennylane/pull/5447)

Expand Down
4 changes: 2 additions & 2 deletions pennylane/gradients/pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,8 +787,8 @@ def _expval_stoch_pulse_grad(tape, argnum, num_split_times, key, use_broadcastin
"""
tapes = []
gradient_data = []
for idx, trainable_idx in enumerate(tape.trainable_params):
if trainable_idx not in argnum:
for idx in range(tape.num_params):
if idx not in argnum:
# Only the number of tapes is needed to indicate a zero gradient entry
gradient_data.append((0, None, None, None))
continue
Expand Down
33 changes: 31 additions & 2 deletions tests/gradients/core/test_pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,14 +913,14 @@ def test_constant_ry(self, dev_name, num_split_times, t):

jax.config.update("jax_enable_x64", True)
params = [jnp.array(0.24)]
T = t if isinstance(t, tuple) else (0, t)
delta_t = t[-1] - t[0] if isinstance(t, tuple) else t
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
op = qml.evolve(ham_single_q_const)(params, t)
tape = qml.tape.QuantumScript([op], [qml.expval(qml.PauliZ(0))])

dev = qml.device(dev_name, wires=1)
# Effective rotation parameter
p = params[0] * (delta_t := T[-1] - T[0])
p = params[0] * delta_t
r = qml.execute([tape], dev, None)
assert qml.math.isclose(r, jnp.cos(2 * p), atol=1e-4)
tapes, fn = stoch_pulse_grad(tape, num_split_times=num_split_times)
Expand All @@ -930,6 +930,35 @@ def test_constant_ry(self, dev_name, num_split_times, t):
assert qml.math.isclose(res, -2 * jnp.sin(2 * p) * delta_t)
jax.clear_caches()

def test_constant_ry_argnum(self, dev_name):
"""Test that the derivative of a pulse generated by a constant Hamiltonian,
which is a Pauli word, is computed correctly if it is not the only
operation in a tape but selected via `argnum`."""
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)
params = [jnp.array(0.04)]
t = 0.1
y = 0.3
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)
op = qml.evolve(ham_single_q_const)(params, t)
tape = qml.tape.QuantumScript([qml.RY(y, 0), op], [qml.expval(qml.PauliZ(0))])
tape.trainable_params = [1]

dev = qml.device(dev_name, wires=1)
# Effective rotation parameter
p = params[0] * t
r = qml.execute([tape], dev, None)
assert qml.math.isclose(r, jnp.cos(2 * p + y), atol=1e-4)
num_split_times = 1
tapes, fn = stoch_pulse_grad(tape, num_split_times=num_split_times, argnum=0)
assert len(tapes) == num_split_times * 2

res = fn(qml.execute(tapes, dev, None))
assert qml.math.isclose(res, -2 * jnp.sin(2 * p + y) * t)
jax.clear_caches()

@pytest.mark.parametrize("num_split_times", [1, 3])
@pytest.mark.parametrize("t", [2.0, 3, (0.5, 0.6), (0.1, 0.9, 1.2)])
def test_constant_ry_rescaled(self, dev_name, num_split_times, t):
Expand Down

0 comments on commit 68f4762

Please sign in to comment.