Skip to content

Commit

Permalink
feat: Introduce run_multiple method
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 committed Jun 19, 2024
1 parent a4d7f98 commit 03b2dad
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 1 deletion.
38 changes: 37 additions & 1 deletion src/braket/simulator/braket_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# language governing permissions and limitations under the License.

from abc import ABC, abstractmethod
from typing import Union
from multiprocessing import Pool
from os import cpu_count
from typing import Optional, Union

from braket.device_schema import DeviceCapabilities
from braket.ir.ahs import Program as AHSProgram
Expand Down Expand Up @@ -59,6 +61,40 @@ def run(
representing the results of the simulation.
"""

def run_multiple(
self,
payloads: list[Union[OQ3Program, AHSProgram, JaqcdProgram]],
max_parallel: Optional[int] = None,
*args,
**kwargs,
) -> list[Union[GateModelTaskResult, AnalogHamiltonianSimulationTaskResult]]:
"""
Run the tasks specified by the given IR payloads.
Extra arguments will contain any additional information necessary to run the tasks,
such as number of qubits.
Args:
payloads (list[Union[OQ3Program, AHSProgram, JaqcdProgram]]): The IR representations
of the programs
max_parallel (Optional[int]): The maximum number of payloads to run in parallel.
Default is the number of CPUs.
Returns:
list[Union[GateModelTaskResult, AnalogHamiltonianSimulationTaskResult]]: A list of
result objects, with the ith object being the result of the ith program.
"""
max_parallel = max_parallel or cpu_count()
with Pool(min(max_parallel, len(payloads))) as pool:
param_list = [(task, args, kwargs) for task in payloads]
results = pool.starmap(self._run_wrapped, param_list)
return results

def _run_wrapped(
self, ir: Union[OQ3Program, AHSProgram, JaqcdProgram], args, kwargs
): # pragma: no cover
return self.run(ir, *args, **kwargs)

@property
@abstractmethod
def properties(self) -> DeviceCapabilities:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -831,3 +831,23 @@ def test_measure_with_qubits_not_used():
assert np.sum(measurements, axis=0)[3] == 0
assert len(measurements[0]) == 4
assert result.measuredQubits == [0, 1, 2, 3]


def test_run_multiple():
payloads = [
OpenQASMProgram(
source=f"""
OPENQASM 3.0;
bit[1] b;
qubit[1] q;
{gate} q[0];
#pragma braket result density_matrix
"""
)
for gate in ["h", "z", "x"]
]
simulator = DensityMatrixSimulator()
results = simulator.run_multiple(payloads, shots=0)
assert np.allclose(results[0].resultTypes[0].value, np.array([[0.5, 0.5], [0.5, 0.5]]))
assert np.allclose(results[1].resultTypes[0].value, np.array([[1, 0], [0, 0]]))
assert np.allclose(results[2].resultTypes[0].value, np.array([[0, 0], [0, 1]]))
Original file line number Diff line number Diff line change
Expand Up @@ -1363,3 +1363,23 @@ def test_rotation_parameter_expressions(operation, state_vector):
result = simulator.run(OpenQASMProgram(source=qasm), shots=0)
assert result.resultTypes[0].type == StateVector()
assert np.allclose(result.resultTypes[0].value, np.array(state_vector))


def test_run_multiple():
payloads = [
OpenQASMProgram(
source=f"""
OPENQASM 3.0;
bit[1] b;
qubit[1] q;
{gate} q[0];
#pragma braket result state_vector
"""
)
for gate in ["h", "z", "x"]
]
simulator = StateVectorSimulator()
results = simulator.run_multiple(payloads, shots=0)
assert np.allclose(results[0].resultTypes[0].value, np.array([1, 1]) / np.sqrt(2))
assert np.allclose(results[1].resultTypes[0].value, np.array([1, 0]))
assert np.allclose(results[2].resultTypes[0].value, np.array([0, 1]))

0 comments on commit 03b2dad

Please sign in to comment.