Skip to content

Commit

Permalink
[UnitaryHack] Mock provider, backend (#107)
Browse files Browse the repository at this point in the history
* Mock provider, backend

* Fix lint

* Fix tests to deterministic circuit creation

* Add gate support

* Fix test

* Add coverage braket_backend.run

* Fix mypy

* Add test for braket job status

* Drop Literal typing for python3.7
  • Loading branch information
WingCode authored Jun 29, 2023
1 parent a811b5e commit 984101c
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ ignored-modules=
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local,QuantumCircuit
ignored-classes=optparse.Values,thread._local,_thread._local,QuantumCircuit,Circuit

# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
"braket.ir.jaqcd.program": {
"actionType": "braket.ir.jaqcd.program",
"version": ["1"],
"supportedOperations": ["H"],
"supportedOperations": ["H", "CNOT"],
}
},
"paradigm": {"qubitCount": 30},
Expand Down
40 changes: 40 additions & 0 deletions tests/providers/test_braket_job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Tests for AWS Braket job."""

from unittest import TestCase
from unittest.mock import Mock

import pytest
from braket.aws.aws_quantum_task import AwsQuantumTask
from qiskit.providers import JobStatus

from qiskit_braket_provider.providers import (
Expand Down Expand Up @@ -75,3 +78,40 @@ def test_AWS_result(self):
self.assertEqual(job.result().results[0].status, "COMPLETED")
self.assertEqual(job.result().results[0].shots, 3)
self.assertEqual(job.result().get_memory(), ["10", "10", "01"])


class TestBracketJobStatus:
"""Tests for AWS Braket job status."""

def _get_mock_aws_quantum_task(self, status: str) -> AwsQuantumTask:
"""
Creates a mock AwsQuantumTask with the given status.
Status can be one of "CREATED", "QUEUED", "RUNNING", "COMPLETED",
"FAILED", "CANCELLING", "CANCELLED"
"""
task = Mock(spec=AwsQuantumTask)
task.state.return_value = status
return task

@pytest.mark.parametrize(
"task_states, expected_status",
[
(["COMPLETED", "FAILED"], JobStatus.ERROR),
(["COMPLETED", "CANCELLED"], JobStatus.CANCELLED),
(["COMPLETED", "COMPLETED"], JobStatus.DONE),
(["RUNNING", "RUNNING"], JobStatus.RUNNING),
(["QUEUED", "QUEUED"], JobStatus.QUEUED),
],
)
def test_status(self, task_states, expected_status):
"""Tests job status when multiple task status is present."""
job = AWSBraketJob(
backend=BraketLocalBackend(name="default"),
job_id="MockId",
tasks=[MOCK_LOCAL_QUANTUM_TASK],
shots=100,
)
job._tasks = Mock(spec=AmazonBraketTask)
job._tasks = [self._get_mock_aws_quantum_task(state) for state in task_states]

assert job.status() == expected_status
101 changes: 72 additions & 29 deletions tests/providers/test_braket_provider.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Tests for AWS Braket provider."""
import unittest
from unittest import TestCase
from unittest.mock import Mock, patch
import uuid

from braket.aws import AwsDeviceType
from braket.circuits import Circuit
from braket.aws import AwsSession, AwsQuantumTaskBatch
from braket.aws import AwsDevice, AwsDeviceType
from qiskit import circuit as qiskit_circuit
from qiskit.circuit.random import random_circuit
from qiskit.compiler import transpile

from qiskit_braket_provider.providers import AWSBraketProvider
from qiskit_braket_provider.providers.braket_backend import (
BraketBackend,
Expand All @@ -24,50 +26,91 @@
class TestAWSBraketProvider(TestCase):
"""Tests AWSBraketProvider."""

def test_provider_backends(self):
"""Tests provider."""
mock_session = Mock()
def setUp(self):
self.mock_session = Mock()
simulators = [MOCK_GATE_MODEL_SIMULATOR_SV, MOCK_GATE_MODEL_SIMULATOR_TN]
mock_session.get_device.side_effect = simulators
mock_session.region = SIMULATOR_REGION
mock_session.boto_session.region_name = SIMULATOR_REGION
mock_session.search_devices.return_value = simulators
self.mock_session.get_device.side_effect = simulators
self.mock_session.region = SIMULATOR_REGION
self.mock_session.boto_session.region_name = SIMULATOR_REGION
self.mock_session.search_devices.return_value = simulators

def test_provider_backends(self):
"""Tests provider."""
provider = AWSBraketProvider()
backends = provider.backends(
aws_session=mock_session, types=[AwsDeviceType.SIMULATOR]
aws_session=self.mock_session, types=[AwsDeviceType.SIMULATOR]
)

self.assertTrue(len(backends) > 0)
for backend in backends:
with self.subTest(f"{backend.name}"):
self.assertIsInstance(backend, BraketBackend)

@unittest.skip("Call to external service")
def test_real_devices(self):
"""Tests real devices."""
provider = AWSBraketProvider()
backends = provider.backends()
self.assertTrue(len(backends) > 0)
for backend in backends:
with self.subTest(f"{backend.name}"):
self.assertIsInstance(backend, AWSBraketBackend)
with patch(
"qiskit_braket_provider.providers.braket_provider.AwsDevice"
) as mock_get_devices:
mock_get_devices.get_devices.return_value = [
AwsDevice(MOCK_GATE_MODEL_SIMULATOR_SV["deviceArn"], self.mock_session),
AwsDevice(MOCK_GATE_MODEL_SIMULATOR_TN["deviceArn"], self.mock_session),
]
provider = AWSBraketProvider()
backends = provider.backends()
self.assertTrue(len(backends) > 0)
for backend in backends:
with self.subTest(f"{backend.name}"):
self.assertIsInstance(backend, AWSBraketBackend)

online_simulators_backends = provider.backends(
statuses=["ONLINE"], types=["SIMULATOR"]
online_simulators_backends = provider.backends(
statuses=["ONLINE"], types=["SIMULATOR"]
)
for backend in online_simulators_backends:
with self.subTest(f"{backend.name}"):
self.assertIsInstance(backend, AWSBraketBackend)

@patch("qiskit_braket_provider.providers.braket_backend.AWSBraketBackend")
@patch("qiskit_braket_provider.providers.braket_backend.AwsDevice.get_devices")
def test_qiskit_circuit_transpilation_run(
self, mock_get_devices, mock_aws_braket_backend
):
"""Tests qiskit circuit transpilation."""
mock_get_devices.return_value = [
AwsDevice(MOCK_GATE_MODEL_SIMULATOR_SV["deviceArn"], self.mock_session)
]
s3_target = AwsSession.S3DestinationFolder("mock_bucket", "mock_key")
q_circuit = qiskit_circuit.QuantumCircuit(2)
q_circuit.h(0)
q_circuit.cx(0, 1)
braket_circuit = Circuit().h(0).cnot(0, 1)

mock_aws_braket_backend = Mock(spec=AWSBraketBackend)
mock_aws_braket_backend._device = Mock(spec=AwsDevice)
task = AwsQuantumTaskBatch(
Mock(),
MOCK_GATE_MODEL_SIMULATOR_SV["deviceArn"],
braket_circuit,
s3_target,
1000,
max_parallel=10,
)
for backend in online_simulators_backends:
with self.subTest(f"{backend.name}"):
self.assertIsInstance(backend, AWSBraketBackend)
task_mock = Mock()
task_mock.id = str(uuid.uuid4())
task_mock.state.return_value = "RUNNING"
task = Mock(spec=AwsQuantumTaskBatch, return_value=task)
task.tasks = [task_mock]

@unittest.skip("Call to external service")
def test_real_device_circuit_execution(self):
"""Tests circuit execution on real device."""
provider = AWSBraketProvider()
state_vector_backend = provider.get_backend("SV1")
circuit = random_circuit(3, 5, seed=42)
state_vector_backend = provider.get_backend(
"SV1", aws_session=self.mock_session
)

transpiled_circuit = transpile(
circuit, backend=state_vector_backend, seed_transpiler=42
q_circuit, backend=state_vector_backend, seed_transpiler=42
)

state_vector_backend._device.run_batch = Mock(
spec=AwsQuantumTaskBatch, return_value=task
)
result = state_vector_backend.run(transpiled_circuit, shots=10)
self.assertTrue(result)
Expand Down

0 comments on commit 984101c

Please sign in to comment.