Skip to content

Commit

Permalink
Cast samples to int64 instead of int8 in qml.counts (#5544)
Browse files Browse the repository at this point in the history
**Context:**
[sc-61314] Issue #5513 was being caused by samples that were converted
to decimals to be cast to `int8`. This caused overflow with 8 or more
wires, resulting in negative values being present in the counts
dictionary.

**Description of the Change:**
* Update `qml.counts.process_samples` and
`QubitDevice._samples_to_counts` to cast to `int64` instead of `int8`.

**Benefits:**
No more overflow issues with `qml.counts`.

**Possible Drawbacks:**

**Related GitHub Issues:**
#5513
  • Loading branch information
mudit2812 authored Apr 18, 2024
1 parent c2de427 commit ca9637a
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 26 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@

<h3>Bug fixes 🐛</h3>

* `qml.counts` no longer returns negative samples when measuring 8 or more wires.
[(#5544)](https://github.com/PennyLaneAI/pennylane/pull/5544)

* The `dynamic_one_shot` transform now works with broadcasting.
[(#5473)](https://github.com/PennyLaneAI/pennylane/pull/5473)

Expand Down
2 changes: 1 addition & 1 deletion pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,7 +1459,7 @@ def circuit(x):
if mp.obs is None and not isinstance(mp.mv, MeasurementValue):
# convert samples and outcomes (if using) from arrays to str for dict keys
samples = np.array([sample for sample in samples if not np.any(np.isnan(sample))])
samples = qml.math.cast_like(samples, qml.math.int8(0))
samples = qml.math.cast_like(samples, qml.math.int64(0))
samples = np.apply_along_axis(_sample_to_str, -1, samples)
batched_ndims = 3 # no observable was provided, batched samples will have shape (batch_size, shots, len(wires))
if mp.all_outcomes:
Expand Down
2 changes: 1 addition & 1 deletion pennylane/measurements/counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def convert(x):
exp2 = 2 ** np.arange(num_wires - 1, -1, -1)
samples = np.einsum("...i,i", samples, exp2)
new_shape = samples.shape
samples = qml.math.cast_like(samples, qml.math.int8(0))
samples = qml.math.cast_like(samples, qml.math.int64(0))
samples = list(map(convert, samples.ravel()))
samples = np.array(samples).reshape(new_shape)

Expand Down
14 changes: 14 additions & 0 deletions tests/measurements/test_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,20 @@ def test_counts_with_nan_samples(self):
total_counts = sum(count for count in result.values())
assert total_counts == 997

@pytest.mark.parametrize("n_wires", [5, 8, 10])
@pytest.mark.parametrize("all_outcomes", [True, False])
def test_counts_multi_wires_no_overflow(self, n_wires, all_outcomes):
"""Test that binary strings for wire samples are not negative due to overflow."""
shots = 1000
total_wires = 10
samples = np.random.choice([0, 1], size=(shots, total_wires)).astype(np.float64)
result = qml.counts(wires=list(range(n_wires)), all_outcomes=all_outcomes).process_samples(
samples, wire_order=list(range(total_wires))
)

assert sum(result.values()) == shots
assert all(0 <= int(sample, 2) <= 2**n_wires for sample in result.keys())

def test_counts_obs(self):
"""Test that the counts function outputs counts of the right size for observables"""
shots = 1000
Expand Down
75 changes: 51 additions & 24 deletions tests/test_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,27 +1476,54 @@ def circuit(x):
]


def test_samples_to_counts_with_nan():
"""Test that the counts function disregards failed measurements (samples including
NaN values) when totalling counts"""
# generate 1000 samples for 2 wires, randomly distributed between 0 and 1
device = qml.device("default.qubit.legacy", wires=2, shots=1000)
device._state = [0.5 + 0.0j, 0.5 + 0.0j, 0.5 + 0.0j, 0.5 + 0.0j]
device._samples = device.generate_samples()
samples = device.sample(qml.measurements.CountsMP())

# imitate hardware return with NaNs (requires dtype float)
samples = qml.math.cast_like(samples, np.array([1.2]))
samples[0][0] = np.NaN
samples[17][1] = np.NaN
samples[850][0] = np.NaN

result = device._samples_to_counts(samples, mp=qml.measurements.CountsMP(), num_wires=2)

# no keys with NaNs
assert len(result) == 4
assert set(result.keys()) == {"00", "01", "10", "11"}

# # NaNs were not converted into "0", but were excluded from the counts
total_counts = sum(count for count in result.values())
assert total_counts == 997
class TestSamplesToCounts:
"""Tests for correctness of QubitDevice._samples_to_counts"""

def test_samples_to_counts_with_nan(self):
"""Test that the counts function disregards failed measurements (samples including
NaN values) when totalling counts"""
# generate 1000 samples for 2 wires, randomly distributed between 0 and 1
device = qml.device("default.qubit.legacy", wires=2, shots=1000)
device._state = [0.5 + 0.0j, 0.5 + 0.0j, 0.5 + 0.0j, 0.5 + 0.0j]
device._samples = device.generate_samples()
samples = device.sample(qml.measurements.CountsMP())

# imitate hardware return with NaNs (requires dtype float)
samples = qml.math.cast_like(samples, np.array([1.2]))
samples[0][0] = np.NaN
samples[17][1] = np.NaN
samples[850][0] = np.NaN

result = device._samples_to_counts(samples, mp=qml.measurements.CountsMP(), num_wires=2)

# no keys with NaNs
assert len(result) == 4
assert set(result.keys()) == {"00", "01", "10", "11"}

# # NaNs were not converted into "0", but were excluded from the counts
total_counts = sum(result.values())
assert total_counts == 997

@pytest.mark.parametrize("all_outcomes", [True, False])
def test_samples_to_counts_with_many_wires(self, all_outcomes):
"""Test that the counts function correctly converts wire samples to strings when
the number of wires is 8 or more."""
# generate 1000 samples for 10 wires, randomly distributed between 0 and 1
n_wires = 10
shots = 100
device = qml.device("default.qubit.legacy", wires=n_wires, shots=shots)
state = np.random.rand(*([2] * n_wires))
device._state = state / np.linalg.norm(state)
device._samples = device.generate_samples()
samples = device.sample(qml.measurements.CountsMP(all_outcomes=all_outcomes))

result = device._samples_to_counts(
samples, mp=qml.measurements.CountsMP(), num_wires=n_wires
)

# Check that keys are correct binary strings
assert all(0 <= int(sample, 2) <= 2**n_wires for sample in result.keys())

# # NaNs were not converted into "0", but were excluded from the counts
total_counts = sum(result.values())
assert total_counts == shots

0 comments on commit ca9637a

Please sign in to comment.