Skip to content

Commit

Permalink
[Bugfix] Unary mid-circuit measurement expressions lack support (#5480)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [x] Ensure that the test suite passes, by running `make test`.

- [x] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [x] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
@dwierichs reported the following bug
```
import pennylane as qml
import numpy as np

dev = qml.device("default.qubit", shots=100) # Shot-based device -> dynamic_one_shot transform will be used.

@qml.qnode(dev)
def node(x):
    [qml.RX(np.pi/2, w) for w in [0, 1]]
    mcm0 = qml.measure(0)
    mcm1 = qml.measure(1)
    # Not working
    return qml.expval(mcm0 * 2)
    return qml.expval(mcm0 + 1)
    return qml.expval(mcm0 & 3)

node(0.4)
```

**Description of the Change:**
`gather_mcm` uses the none-`use_as_is` branch if the measurement values
branches are not 0 and 1.

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**
Unary mid circuit measurement expressions [sc-60856]

---------

Co-authored-by: Mudit Pandey <[email protected]>
  • Loading branch information
vincentmr and mudit2812 authored Apr 9, 2024
1 parent 4289473 commit f10e98f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 3 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@

<h3>Bug fixes 🐛</h3>

* Fix a bug where certain unary mid-circuit measurement expressions would raise an uncaught error.
[(#5480)](https://github.com/PennyLaneAI/pennylane/pull/5480)

* The probabilities now sum to one using the `torch` interface with `default_dtype` set to `torch.float32`.
[(#5462)](https://github.com/PennyLaneAI/pennylane/pull/5462)

Expand Down
9 changes: 7 additions & 2 deletions pennylane/transforms/dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def gather_mcm(measurement, samples):
TensorLike: The combined measurement outcome
"""
mv = measurement.mv
# The following block handles measurement value lists, like ``qml.counts(op=[mcm0, mcm1, mcm2])``.
if isinstance(measurement, (CountsMP, ProbabilityMP, SampleMP)) and isinstance(mv, Sequence):
wires = qml.wires.Wires(range(len(mv)))
mcm_samples = list(
Expand All @@ -284,8 +285,12 @@ def gather_mcm(measurement, samples):
mcm_samples = np.concatenate(mcm_samples, axis=1)
meas_tmp = measurement.__class__(wires=wires)
return meas_tmp.process_samples(mcm_samples, wire_order=wires)
mcm_samples = np.array([mv.concretize(dct) for dct in samples]).reshape((-1, 1))
use_as_is = len(mv.measurements) == 1
if isinstance(measurement, ProbabilityMP):
mcm_samples = np.array([dct[mv.measurements[0]] for dct in samples]).reshape((-1, 1))
use_as_is = True
else:
mcm_samples = np.array([mv.concretize(dct) for dct in samples]).reshape((-1, 1))
use_as_is = mv.branches == {(0,): 0, (1,): 1}
if use_as_is:
wires, meas_tmp = mv.wires, measurement
else:
Expand Down
51 changes: 50 additions & 1 deletion tests/devices/default_qubit/test_default_qubit_native_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,52 @@ def func(x, y, z):
validate_measurements(measure_f, shots, r1, r2)


@flaky(max_runs=5)
@pytest.mark.parametrize(
"mcm_f",
[
lambda x: x * -1,
lambda x: x * 1,
lambda x: x * 2,
lambda x: 1 - x,
lambda x: x + 1,
lambda x: x & 3,
],
)
@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var])
def test_simple_composite_mcm(mcm_f, measure_f):
"""Tests that DefaultQubit handles a circuit with a composite mid-circuit measurement and a
conditional gate. A single measurement of a composite mid-circuit measurement is performed
at the end."""
shots = 5000

dev = qml.device("default.qubit", shots=shots)
param = np.pi / 3

@qml.qnode(dev)
def func(x):
qml.RX(x, 0)
m0 = qml.measure(0)
qml.RX(0.5 * x, 1)
m1 = qml.measure(1)
qml.cond((m0 + m1) == 2, qml.RY)(2.0 * x, 0)
m2 = qml.measure(0)
return measure_f(op=mcm_f(m2))

func1 = func
func2 = qml.defer_measurements(func)

results1 = func1(param)
results2 = func2(param)

validate_measurements(measure_f, shots, results1, results2)


@flaky(max_runs=5)
@pytest.mark.parametrize("shots", [None, 5000, [5000, 5001]])
@pytest.mark.parametrize("postselect", [None, 0, 1])
@pytest.mark.parametrize("reset", [False, True])
@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.sample, qml.var])
@pytest.mark.parametrize("measure_f", [qml.counts, qml.expval, qml.probs, qml.sample, qml.var])
def test_composite_mcm_measure_composite_mcm(shots, postselect, reset, measure_f):
"""Tests that DefaultQubit handles a circuit with a composite mid-circuit measurement and a
conditional gate. A single measurement of a composite mid-circuit measurement is performed
Expand All @@ -378,6 +419,14 @@ def func(x):
if shots is None and measure_f in (qml.counts, qml.sample):
return

if measure_f == qml.probs:
with pytest.raises(
ValueError,
match=r"Cannot use qml.probs\(\) when measuring multiple mid-circuit measurements collected",
):
_ = func1(param)
return

results1 = func1(param)
results2 = func2(param)

Expand Down

0 comments on commit f10e98f

Please sign in to comment.