Skip to content

Commit

Permalink
Persist circuits and tasks in batch execute (#260)
Browse files Browse the repository at this point in the history
* feature: persist circuits in batch execute

* feature: Add test for batch_execute persistance

* fix: circuits attribute

* documentation: add doc for printing circuits from batch_execute

* fix: linting

* Add tasks, circuits

---------

Co-authored-by: Ryan Shaffer <[email protected]>
Co-authored-by: Tim (Yi-Ting) <[email protected]>
  • Loading branch information
3 people authored Jun 12, 2024
1 parent 7cfecb8 commit 03fe023
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 0 deletions.
4 changes: 4 additions & 0 deletions doc/devices/braket_remote.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ You can set a timeout by using the ``poll_timeout_seconds`` argument;
the device will retry circuits that do not complete within the timeout.
A timeout of 30 to 60 seconds is recommended for circuits with fewer than 25 qubits.

Each of the submitted circuit can be visualised using the attribute ``circuits`` on the device

>> print(remote_device.circuits[0])

Device options
~~~~~~~~~~~~~~

Expand Down
18 changes: 18 additions & 0 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def __init__(
super().__init__(wires, shots=shots or None)
self._device = device
self._circuit = None
self._circuits = []
self._task = None
self._tasks = []
self._noise_model = noise_model
self._parametrize_differentiable = parametrize_differentiable
self._run_kwargs = run_kwargs
Expand All @@ -153,7 +155,9 @@ def __init__(
def reset(self):
super().reset()
self._circuit = None
self._circuits = []
self._task = None
self._tasks = []

@property
def operations(self) -> frozenset[str]:
Expand All @@ -173,11 +177,21 @@ def circuit(self) -> Circuit:
"""Circuit: The last circuit run on this device."""
return self._circuit

@property
def circuits(self) -> list[Circuit]:
"""Circuit: The circuits run on this device."""
return self._circuits

@property
def task(self) -> QuantumTask:
"""QuantumTask: The task corresponding to the last run circuit."""
return self._task

@property
def tasks(self) -> list[QuantumTask]:
"""The tasks corresponding to the circuits run on this device."""
return self._tasks

def _pl_to_braket_circuit(
self,
circuit: QuantumTape,
Expand Down Expand Up @@ -584,6 +598,8 @@ def __init__(
self._max_parallel = max_parallel
self._max_connections = max_connections
self._max_retries = max_retries
self._circuits = []
self._tasks = []

@property
def use_grouping(self) -> bool:
Expand Down Expand Up @@ -621,6 +637,7 @@ def batch_execute(self, circuits, **run_kwargs):
**run_kwargs,
)
)
self._circuits = braket_circuits

batch_shots = 0 if self.analytic else self.shots

Expand All @@ -639,6 +656,7 @@ def batch_execute(self, circuits, **run_kwargs):
),
**self._run_kwargs,
)
self._tasks = task_batch.tasks
# Call results() to retrieve the Braket results in parallel.
try:
braket_results_batch = task_batch.results(
Expand Down
25 changes: 25 additions & 0 deletions test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,15 @@ def test_reset():
"""Tests that the members of the device are cleared on reset."""
dev = _aws_device(wires=2)
dev._circuit = CIRCUIT
dev._circuits = [CIRCUIT, CIRCUIT]
dev._task = TASK
dev._tasks = [TASK, TASK]

dev.reset()
assert dev.circuit is None
assert dev.circuits == []
assert dev.task is None
assert dev.tasks == []


def test_apply():
Expand Down Expand Up @@ -910,6 +914,25 @@ def test_batch_execute_non_parallel_tracker(mock_run):
callback.assert_called_with(latest=latest, history=history, totals=totals)


@patch.object(AwsDevice, "run_batch")
def test_batch_execute_parallel_circuits_persistance(mock_run_batch):
mock_run_batch.return_value = TASK_BATCH
dev = _aws_device(wires=4, foo="bar", parallel=True)
assert dev.parallel is True

with QuantumTape() as circuit:
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
qml.probs(wires=[0])
qml.expval(qml.PauliX(1))
qml.var(qml.PauliY(2))
qml.sample(qml.PauliZ(3))

circuits = [circuit, circuit]
dev.batch_execute(circuits)
assert dev.circuits[1]


@patch.object(AwsDevice, "run_batch")
def test_batch_execute_parallel(mock_run_batch):
"""Test batch_execute(parallel=True) correctly calls batch execution methods in Braket SDK"""
Expand All @@ -927,6 +950,8 @@ def test_batch_execute_parallel(mock_run_batch):

circuits = [circuit, circuit]
batch_results = dev.batch_execute(circuits)

assert dev.tasks[0]
for results in batch_results:
assert np.allclose(
results[0], RESULT.get_value_by_result_type(result_types.Probability(target=[0]))
Expand Down

0 comments on commit 03fe023

Please sign in to comment.