Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding test for target. #4

Merged
merged 4 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ AttributeIDSupressMenu = "AttributeIDSupressMenu"
Braket = "Braket"
mch = "mch"
IY = "IY"
ket = "ket"
bra = "bra"
4 changes: 2 additions & 2 deletions src/bloqade/pyqrack/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

import numpy as np

from .reg import SimQubitRef
from .reg import SimQubit

if TYPE_CHECKING:
from pyqrack.qrack_simulator import QrackSimulator # noqa: F401

QrackQubitId = SimQubitRef["QrackSimulator"]
QrackQubitId = SimQubit["QrackSimulator"]


class GateQrackRuntimeABC(abc.ABC):
Expand Down
4 changes: 2 additions & 2 deletions src/bloqade/pyqrack/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

import numpy as np

from .reg import SimQubitRef
from .reg import SimQubit

if TYPE_CHECKING:
from pyqrack.qrack_simulator import QrackSimulator # noqa: F401


QrackQubitId = SimQubitRef["QrackSimulator"]
QrackQubitId = SimQubit["QrackSimulator"]


@dataclasses.dataclass(frozen=True)
Expand Down
12 changes: 5 additions & 7 deletions src/bloqade/pyqrack/noise/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class PyQrackMethods(interp.MethodTable):
def apply_pauli_error(
self,
interp: PyQrackInterpreter,
qarg: reg.SimQubitRef,
qarg: reg.SimQubit,
px: float,
py: float,
pz: float,
Expand All @@ -39,7 +39,7 @@ def single_qubit_error_channel(
px: float = frame.get(stmt.px)
py: float = frame.get(stmt.py)
pz: float = frame.get(stmt.pz)
qarg: reg.SimQubitRef = frame.get(stmt.qarg)
qarg: reg.SimQubit = frame.get(stmt.qarg)

if qarg.is_active():
self.apply_pauli_error(interp, qarg, px, py, pz)
Expand All @@ -56,12 +56,12 @@ def cz_pauli_unpaired(
px_1: float = frame.get(stmt.px_1)
py_1: float = frame.get(stmt.py_1)
pz_1: float = frame.get(stmt.pz_1)
qarg1: reg.SimQubitRef = frame.get(stmt.qarg1)
qarg1: reg.SimQubit = frame.get(stmt.qarg1)

px_2: float = frame.get(stmt.px_2)
py_2: float = frame.get(stmt.py_2)
pz_2: float = frame.get(stmt.pz_2)
qarg2: reg.SimQubitRef = frame.get(stmt.qarg2)
qarg2: reg.SimQubit = frame.get(stmt.qarg2)

is_active_1 = qarg1.is_active()
is_active_2 = qarg2.is_active()
Expand All @@ -88,9 +88,7 @@ def atom_loss_channel(
stmt: native.AtomLossChannel,
):
prob: float = frame.get(stmt.prob)
qarg: reg.SimQubitRef["QrackSimulator"] = frame.get_typed(
stmt.qarg, reg.SimQubitRef
)
qarg: reg.SimQubit["QrackSimulator"] = frame.get_typed(stmt.qarg, reg.SimQubit)

if qarg.is_active() and interp.rng_state.uniform() > prob:
sim_reg = qarg.ref.sim_reg
Expand Down
14 changes: 4 additions & 10 deletions src/bloqade/pyqrack/qasm2/core.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from typing import TYPE_CHECKING

from kirin import interp
from bloqade.pyqrack.reg import (
CBitRef,
CRegister,
QubitState,
SimQubitRef,
SimQRegister,
)
from bloqade.pyqrack.reg import CBitRef, SimQReg, SimQubit, CRegister, QubitState
from bloqade.pyqrack.base import PyQrackInterpreter
from bloqade.qasm2.dialects import core

Expand All @@ -34,7 +28,7 @@
)

return (
SimQRegister(
SimQReg(
size=n_qubits,
sim_reg=interp.memory.sim_reg,
addrs=tuple(range(curr_allocated, curr_allocated + n_qubits)),
Expand All @@ -53,7 +47,7 @@
def qreg_get(
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegGet
):
return (SimQubitRef(ref=frame.get(stmt.reg), pos=frame.get(stmt.idx)),)
return (SimQubit(ref=frame.get(stmt.reg), pos=frame.get(stmt.idx)),)

@interp.impl(core.CRegGet)
def creg_get(
Expand All @@ -67,7 +61,7 @@
def measure(
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.Measure
):
qarg: SimQubitRef["QrackSimulator"] = frame.get(stmt.qarg)
qarg: SimQubit["QrackSimulator"] = frame.get(stmt.qarg)

Check warning on line 64 in src/bloqade/pyqrack/qasm2/core.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/pyqrack/qasm2/core.py#L64

Added line #L64 was not covered by tests
carg: CBitRef = frame.get(stmt.carg)
carg.set_value(bool(qarg.ref.sim_reg.m(qarg.addr)))

Expand Down
10 changes: 5 additions & 5 deletions src/bloqade/pyqrack/qasm2/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from kirin import interp
from kirin.dialects import ilist
from bloqade.pyqrack.reg import SimQubitRef
from bloqade.pyqrack.reg import SimQubit
from bloqade.pyqrack.base import PyQrackInterpreter
from bloqade.qasm2.dialects import parallel

Expand All @@ -16,8 +16,8 @@
@interp.impl(parallel.CZ)
def cz(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: parallel.CZ):

qargs: ilist.IList[SimQubitRef["QrackSimulator"], Any] = frame.get(stmt.qargs)
ctrls: ilist.IList[SimQubitRef["QrackSimulator"], Any] = frame.get(stmt.ctrls)
qargs: ilist.IList[SimQubit["QrackSimulator"], Any] = frame.get(stmt.qargs)
ctrls: ilist.IList[SimQubit["QrackSimulator"], Any] = frame.get(stmt.ctrls)

Check warning on line 20 in src/bloqade/pyqrack/qasm2/parallel.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/pyqrack/qasm2/parallel.py#L19-L20

Added lines #L19 - L20 were not covered by tests
for qarg, ctrl in zip(qargs, ctrls):
if qarg.is_active() and ctrl.is_active():
interp.memory.sim_reg.mcz(qarg, ctrl)
Expand All @@ -27,7 +27,7 @@
def ugate(
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: parallel.UGate
):
qargs: ilist.IList[SimQubitRef["QrackSimulator"], Any] = frame.get(stmt.qargs)
qargs: ilist.IList[SimQubit["QrackSimulator"], Any] = frame.get(stmt.qargs)

Check warning on line 30 in src/bloqade/pyqrack/qasm2/parallel.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/pyqrack/qasm2/parallel.py#L30

Added line #L30 was not covered by tests
theta, phi, lam = (
frame.get(stmt.theta),
frame.get(stmt.phi),
Expand All @@ -40,7 +40,7 @@

@interp.impl(parallel.RZ)
def rz(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: parallel.RZ):
qargs: ilist.IList[SimQubitRef["QrackSimulator"], Any] = frame.get(stmt.qargs)
qargs: ilist.IList[SimQubit["QrackSimulator"], Any] = frame.get(stmt.qargs)

Check warning on line 43 in src/bloqade/pyqrack/qasm2/parallel.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/pyqrack/qasm2/parallel.py#L43

Added line #L43 was not covered by tests
phi = frame.get(stmt.theta)
for qarg in qargs:
if qarg.is_active():
Expand Down
34 changes: 17 additions & 17 deletions src/bloqade/pyqrack/qasm2/uop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import TYPE_CHECKING

from kirin import interp
from bloqade.pyqrack.reg import SimQubitRef
from bloqade.pyqrack.reg import SimQubit
from bloqade.qasm2.dialects import uop

if TYPE_CHECKING:
Expand Down Expand Up @@ -51,14 +51,14 @@ def single_qubit_gate(
frame: interp.Frame,
stmt: uop.SingleQubitGate,
):
qarg: SimQubitRef["QrackSimulator"] = frame.get(stmt.qarg)
qarg: SimQubit["QrackSimulator"] = frame.get(stmt.qarg)
if qarg.is_active():
getattr(qarg.sim_reg, self.GATE_TO_METHOD[stmt.name])(qarg.addr)
return ()

@interp.impl(uop.UGate)
def ugate(self, interp: interp.Interpreter, frame: interp.Frame, stmt: uop.UGate):
qarg: SimQubitRef["QrackSimulator"] = frame.get(stmt.qarg)
qarg: SimQubit["QrackSimulator"] = frame.get(stmt.qarg)
if qarg.is_active():
qarg.sim_reg.u(
qarg.addr,
Expand All @@ -78,8 +78,8 @@ def control_gate(
frame: interp.Frame,
stmt: uop.CX | uop.CZ | uop.CY,
):
ctrl: SimQubitRef["QrackSimulator"] = frame.get(stmt.ctrl)
qarg: SimQubitRef["QrackSimulator"] = frame.get(stmt.qarg)
ctrl: SimQubit["QrackSimulator"] = frame.get(stmt.ctrl)
qarg: SimQubit["QrackSimulator"] = frame.get(stmt.qarg)
if ctrl.is_active() and qarg.is_active():
getattr(qarg.sim_reg, self.GATE_TO_METHOD[stmt.name])(
[ctrl.addr], qarg.addr
Expand All @@ -88,9 +88,9 @@ def control_gate(

@interp.impl(uop.CCX)
def ccx(self, interp: interp.Interpreter, frame: interp.Frame, stmt: uop.CCX):
ctrl1: SimQubitRef["QrackSimulator"] = frame.get(stmt.ctrl1)
ctrl2: SimQubitRef["QrackSimulator"] = frame.get(stmt.ctrl2)
qarg: SimQubitRef["QrackSimulator"] = frame.get(stmt.qarg)
ctrl1: SimQubit["QrackSimulator"] = frame.get(stmt.ctrl1)
ctrl2: SimQubit["QrackSimulator"] = frame.get(stmt.ctrl2)
qarg: SimQubit["QrackSimulator"] = frame.get(stmt.qarg)
if ctrl1.is_active() and ctrl2.is_active() and qarg.is_active():
qarg.sim_reg.mcx([ctrl1.addr, ctrl2.addr], qarg.addr)
return ()
Expand All @@ -104,21 +104,21 @@ def rotation(
frame: interp.Frame,
stmt: uop.RX | uop.RY | uop.RZ,
):
qarg: SimQubitRef["QrackSimulator"] = frame.get(stmt.qarg)
qarg: SimQubit["QrackSimulator"] = frame.get(stmt.qarg)
if qarg.is_active():
qarg.sim_reg.r(self.AXIS_MAP[stmt.name], frame.get(stmt.theta), qarg.addr)
return ()

@interp.impl(uop.U1)
def u1(self, interp: interp.Interpreter, frame: interp.Frame, stmt: uop.U1):
qarg: SimQubitRef["QrackSimulator"] = frame.get(stmt.qarg)
qarg: SimQubit["QrackSimulator"] = frame.get(stmt.qarg)
if qarg.is_active():
qarg.sim_reg.u(qarg.addr, 0, 0, frame.get(stmt.lam))
return ()

@interp.impl(uop.U2)
def u2(self, interp: interp.Interpreter, frame: interp.Frame, stmt: uop.U2):
qarg: SimQubitRef["QrackSimulator"] = frame.get(stmt.qarg)
qarg: SimQubit["QrackSimulator"] = frame.get(stmt.qarg)
if qarg.is_active():
qarg.sim_reg.u(
qarg.addr, math.pi / 2, frame.get(stmt.phi), frame.get(stmt.lam)
Expand All @@ -127,24 +127,24 @@ def u2(self, interp: interp.Interpreter, frame: interp.Frame, stmt: uop.U2):

@interp.impl(uop.CRX)
def crx(self, interp: interp.Interpreter, frame: interp.Frame, stmt: uop.CRX):
ctrl: SimQubitRef["QrackSimulator"] = frame.get(stmt.ctrl)
qarg: SimQubitRef["QrackSimulator"] = frame.get(stmt.qarg)
ctrl: SimQubit["QrackSimulator"] = frame.get(stmt.ctrl)
qarg: SimQubit["QrackSimulator"] = frame.get(stmt.qarg)
if qarg.is_active() and ctrl.is_active():
qarg.sim_reg.mcr(1, frame.get(stmt.theta), [ctrl.addr], qarg.addr)
return ()

@interp.impl(uop.CU1)
def cu1(self, interp: interp.Interpreter, frame: interp.Frame, stmt: uop.CU1):
ctrl: SimQubitRef["QrackSimulator"] = frame.get(stmt.ctrl)
qarg: SimQubitRef["QrackSimulator"] = frame.get(stmt.qarg)
ctrl: SimQubit["QrackSimulator"] = frame.get(stmt.ctrl)
qarg: SimQubit["QrackSimulator"] = frame.get(stmt.qarg)
if qarg.is_active() and ctrl.is_active():
qarg.sim_reg.mcu([ctrl.addr], qarg.addr, 0, 0, frame.get(stmt.lam))
return ()

@interp.impl(uop.CU3)
def cu3(self, interp: interp.Interpreter, frame: interp.Frame, stmt: uop.CU3):
ctrl: SimQubitRef["QrackSimulator"] = frame.get(stmt.ctrl)
qarg: SimQubitRef["QrackSimulator"] = frame.get(stmt.qarg)
ctrl: SimQubit["QrackSimulator"] = frame.get(stmt.ctrl)
qarg: SimQubit["QrackSimulator"] = frame.get(stmt.qarg)
if qarg.is_active() and ctrl.is_active():
qarg.sim_reg.mcu(
[ctrl.addr],
Expand Down
39 changes: 11 additions & 28 deletions src/bloqade/pyqrack/reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,7 @@
from typing import List, Generic, TypeVar
from dataclasses import dataclass


class QubitState(enum.Enum):
Active = enum.auto()
Lost = enum.auto()


@dataclass(frozen=True)
class QRegister:
size: int

def __hash__(self):
return id(self)

def __eq__(self, other):
return self is other

def __getitem__(self, pos: int):
return QubitRef(self, pos)


@dataclass(frozen=True)
class QubitRef:
ref: QRegister
pos: int
from bloqade.qasm2.types import QReg, Qubit


class CRegister(list[bool]):
Expand All @@ -45,11 +22,17 @@
return self.ref[self.pos]


class QubitState(enum.Enum):
Active = enum.auto()
Lost = enum.auto()


SimRegType = TypeVar("SimRegType")


@dataclass(frozen=True)
class SimQRegister(QRegister, Generic[SimRegType]):
class SimQReg(QReg, Generic[SimRegType]):
size: int
sim_reg: SimRegType
addrs: tuple[int, ...]
qubit_state: List[QubitState]
Expand All @@ -59,12 +42,12 @@
self.qubit_state[pos] = QubitState.Lost

def __getitem__(self, pos: int):
return SimQubitRef(self, pos)
return SimQubit(self, pos)

Check warning on line 45 in src/bloqade/pyqrack/reg.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/pyqrack/reg.py#L45

Added line #L45 was not covered by tests


@dataclass(frozen=True)
class SimQubitRef(QubitRef, Generic[SimRegType]):
ref: SimQRegister[SimRegType]
class SimQubit(Qubit, Generic[SimRegType]):
ref: SimQReg[SimRegType]
pos: int

@property
Expand Down
2 changes: 1 addition & 1 deletion test/runtime/noise/native/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_atom_loss():

memory = Memory(total=2, allocated=0, sim_reg=Mock())

result: reg.SimQRegister[Mock] = (
result: reg.SimQReg[Mock] = (
PyQrackInterpreter(simulation, memory=memory, rng_state=rng_state)
.run(test_atom_loss, ())
.expect()
Expand Down
46 changes: 46 additions & 0 deletions test/test_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import math

from bloqade import qasm2
from pyqrack import QrackSimulator
from bloqade.pyqrack import PyQrack, reg


def test_target():

@qasm2.main
def ghz():
q = qasm2.qreg(3)

qasm2.h(q[0])
qasm2.cx(q[0], q[1])
qasm2.cx(q[1], q[2])

return q

target = PyQrack(3)

q = target.run(ghz)

assert isinstance(q, reg.SimQReg)
assert isinstance(q.sim_reg, QrackSimulator)

out = q.sim_reg.out_ket()

norm = math.sqrt(sum(abs(ele) ** 2 for ele in out))
phase = out[0] / abs(out[0])

out = [ele / (phase * norm) for ele in out]

abs_tol = 2.2e-15

assert all(math.isclose(ele.imag, 0.0, abs_tol=abs_tol) for ele in out)

val = 1.0 / math.sqrt(2.0)

assert math.isclose(out[0].real, val, abs_tol=abs_tol)
assert math.isclose(out[-1].real, val, abs_tol=abs_tol)
assert all(math.isclose(ele.real, 0.0, abs_tol=abs_tol) for ele in out[1:-1])


if __name__ == "__main__":
test_target()