Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persist circuits in batch execute #260

4 changes: 4 additions & 0 deletions doc/devices/braket_remote.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ You can set a timeout by using the ``poll_timeout_seconds`` argument;
the device will retry circuits that do not complete within the timeout.
A timeout of 30 to 60 seconds is recommended for circuits with fewer than 25 qubits.

Each of the submitted circuit can be visualised using the attribute ``circuits`` on the device

>> print(remote_device.circuits[0])

Device options
~~~~~~~~~~~~~~

Expand Down
4 changes: 4 additions & 0 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
super().__init__(wires, shots=shots or None)
self._device = device
self._circuit = None
self._circuits = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._circuits here is a private variable. that users should not access. Instead, users access _circuits through properties. So, there need to be a new property of this class called circuits defined as a method of the class, similar to

.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And yes, it's a bit repetitive now that we have both circuits and circuit, but I think it's okay for now. In long term, circuit may be deprecated.

self._task = None
self._noise_model = noise_model
self._parametrize_differentiable = parametrize_differentiable
Expand All @@ -153,6 +154,7 @@ def __init__(
def reset(self):
super().reset()
self._circuit = None
self._circuits = []
self._task = None

@property
Expand Down Expand Up @@ -584,6 +586,7 @@ def __init__(
self._max_parallel = max_parallel
self._max_connections = max_connections
self._max_retries = max_retries
self.circuits = []

@property
def use_grouping(self) -> bool:
Expand Down Expand Up @@ -621,6 +624,7 @@ def batch_execute(self, circuits, **run_kwargs):
**run_kwargs,
)
)
self.circuits.append(circuit)

batch_shots = 0 if self.analytic else self.shots

Expand Down
19 changes: 19 additions & 0 deletions test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,25 @@ def test_batch_execute_non_parallel_tracker(mock_run):
callback.assert_called_with(latest=latest, history=history, totals=totals)


@patch.object(AwsDevice, "run_batch")
def test_batch_execute_parallel_circuits_persistance(mock_run_batch):
mock_run_batch.return_value = TASK_BATCH
dev = _aws_device(wires=4, foo="bar", parallel=True)
assert dev.parallel is True

with QuantumTape() as circuit:
qml.Hadamard(wires=0)
qml.CNOT(wires=[0, 1])
qml.probs(wires=[0])
qml.expval(qml.PauliX(1))
qml.var(qml.PauliY(2))
qml.sample(qml.PauliZ(3))

circuits = [circuit, circuit]
dev.batch_execute(circuits)
assert dev.circuits[1]


@patch.object(AwsDevice, "run_batch")
def test_batch_execute_parallel(mock_run_batch):
"""Test batch_execute(parallel=True) correctly calls batch execution methods in Braket SDK"""
Expand Down