Skip to content

Commit

Permalink
Use runcard to select compute type as no other way to pass in para wo…
Browse files Browse the repository at this point in the history
…ut changing struct
  • Loading branch information
Tankya2 committed Jan 31, 2024
1 parent 6216a32 commit 1680185
Showing 1 changed file with 68 additions and 62 deletions.
130 changes: 68 additions & 62 deletions src/qibotn/backends/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,24 @@
class CuTensorNet(NumpyBackend): # pragma: no cover
# CI does not test for GPU

def __init__(self):
def __init__(self, runcard):
super().__init__()
import cuquantum # pylint: disable=import-error
from cuquantum import cutensornet as cutn # pylint: disable=import-error

if runcard is not None:
print("inside runcard")
# Parse the runcard or use its values to set flags
self.MPI_enabled = runcard.get("MPI_enabled", False)
self.MPS_enabled = runcard.get("MPS_enabled", False)
self.NCCL_enabled = runcard.get("NCCL_enabled", False)
self.expectation_enabled = runcard.get("expectation_enabled", False)
else:
self.MPI_enabled = False
self.MPS_enabled = False
self.NCCL_enabled = False
self.expectation_enabled = False

self.name = "qibotn"
self.cuquantum = cuquantum
self.cutn = cutn
Expand Down Expand Up @@ -53,7 +66,7 @@ def get_cuda_type(self, dtype="complex64"):
raise TypeError("Type can be either complex64 or complex128")

def execute_circuit(
self, circuit, MPI_enabled=False, MPS_enabled=False, NCCL_enabled=False, expectation_enabled=False, initial_state=None, nshots=None, return_array=False
self, circuit, initial_state=None, nshots=None, return_array=False
): # pragma: no cover
"""Executes a quantum circuit.
Expand All @@ -68,32 +81,31 @@ def execute_circuit(
"""

import qibotn.eval as eval
print("MPI_enabled", MPI_enabled)
print("MPS_enabled", MPS_enabled)
print("NCCL_enabled", NCCL_enabled)
print("expectation_enabled", expectation_enabled)


print("MPI_enabled", self.MPI_enabled)
print("MPS_enabled", self.MPS_enabled)
print("NCCL_enabled", self.NCCL_enabled)
print("expectation_enabled", self.expectation_enabled)

if (
MPI_enabled == False
and MPS_enabled == False
and NCCL_enabled == False
and expectation_enabled == False
self.MPI_enabled == False
and self.MPS_enabled == False
and self.NCCL_enabled == False
and self.expectation_enabled == False
):
if initial_state is not None:
raise_error(NotImplementedError,
"QiboTN cannot support initial state.")
raise_error(NotImplementedError, "QiboTN cannot support initial state.")

state = eval.dense_vector_tn(circuit, self.dtype)

if (
MPI_enabled == False
and MPS_enabled == True
and NCCL_enabled == False
and expectation_enabled == False
elif (
self.MPI_enabled == False
and self.MPS_enabled == True
and self.NCCL_enabled == False
and self.expectation_enabled == False
):
if initial_state is not None:
raise_error(NotImplementedError,
"QiboTN cannot support initial state.")
raise_error(NotImplementedError, "QiboTN cannot support initial state.")

gate_algo = {
"qr_method": False,
Expand All @@ -104,81 +116,75 @@ def execute_circuit(
} # make this user input
state = eval.dense_vector_mps(circuit, gate_algo, self.dtype)

if (
MPI_enabled == True
and MPS_enabled == False
and NCCL_enabled == False
and expectation_enabled == False
elif (
self.MPI_enabled == True
and self.MPS_enabled == False
and self.NCCL_enabled == False
and self.expectation_enabled == False
):
if initial_state is not None:
raise_error(NotImplementedError,
"QiboTN cannot support initial state.")
raise_error(NotImplementedError, "QiboTN cannot support initial state.")

state, rank = eval.dense_vector_tn_MPI(circuit, self.dtype, 32)
if rank > 0:
state = np.array(0)

if (
MPI_enabled == False
and MPS_enabled == False
and NCCL_enabled == True
and expectation_enabled == False
elif (
self.MPI_enabled == False
and self.MPS_enabled == False
and self.NCCL_enabled == True
and self.expectation_enabled == False
):
if initial_state is not None:
raise_error(NotImplementedError,
"QiboTN cannot support initial state.")
raise_error(NotImplementedError, "QiboTN cannot support initial state.")

state, rank = eval.dense_vector_tn_nccl(circuit, self.dtype, 32)
if rank > 0:
state = np.array(0)

if (
MPI_enabled == False
and MPS_enabled == False
and NCCL_enabled == False
and expectation_enabled == True
elif (
self.MPI_enabled == False
and self.MPS_enabled == False
and self.NCCL_enabled == False
and self.expectation_enabled == True
):
if initial_state is not None:
raise_error(NotImplementedError,
"QiboTN cannot support initial state.")
raise_error(NotImplementedError, "QiboTN cannot support initial state.")

state = eval.expectation_tn(circuit, self.dtype)
state = eval.expectation_pauli_tn(circuit, self.dtype)

if (
MPI_enabled == True
and MPS_enabled == False
and NCCL_enabled == False
and expectation_enabled == True
elif (
self.MPI_enabled == True
and self.MPS_enabled == False
and self.NCCL_enabled == False
and self.expectation_enabled == True
):
if initial_state is not None:
raise_error(NotImplementedError,
"QiboTN cannot support initial state.")
raise_error(NotImplementedError, "QiboTN cannot support initial state.")

state, rank = eval.expectation_pauli_tn_MPI(
circuit, self.dtype, 32)
state, rank = eval.expectation_pauli_tn_MPI(circuit, self.dtype, 32)

if rank > 0:
state = np.array(0)

if (
MPI_enabled == False
and MPS_enabled == False
and NCCL_enabled == True
and expectation_enabled == True
elif (
self.MPI_enabled == False
and self.MPS_enabled == False
and self.NCCL_enabled == True
and self.expectation_enabled == True
):
if initial_state is not None:
raise_error(NotImplementedError,
"QiboTN cannot support initial state.")
raise_error(NotImplementedError, "QiboTN cannot support initial state.")

state, rank = eval.expectation_pauli_tn_nccl(
circuit, self.dtype, 32)
state, rank = eval.expectation_pauli_tn_nccl(circuit, self.dtype, 32)

if rank > 0:
state = np.array(0)
else:
raise_error(NotImplementedError, "Backend not supported.")

if return_array:
return state.flatten()
else:
circuit._final_state = CircuitResult(
self, circuit, state.flatten(), nshots)
circuit._final_state = CircuitResult(self, circuit, state.flatten(), nshots)
return circuit._final_state

0 comments on commit 1680185

Please sign in to comment.