Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decompose Z gates in bell state corrections when compiling for hardware #71

Merged
merged 9 commits into from
Sep 12, 2024
1 change: 1 addition & 0 deletions netqasm/sdk/build_epr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class EntRequestParams:
time_unit: TimeUnit = TimeUnit.MICRO_SECONDS
max_time: int = 0
expect_phi_plus: bool = True
expect_psi_plus: bool = False
min_fidelity_all_at_end: Optional[int] = None
max_tries: Optional[int] = None
random_basis_local: Optional[RandomBasis] = None
Expand Down
78 changes: 54 additions & 24 deletions netqasm/sdk/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from netqasm.lang.subroutine import Subroutine
from netqasm.lang.version import NETQASM_VERSION
from netqasm.qlink_compat import BellState, EPRRole, EPRType, LinkLayerOKTypeK
from netqasm.runtime.settings import get_is_using_hardware
from netqasm.sdk.build_epr import (
SER_RESPONSE_KEEP_IDX_BELL_STATE,
SER_RESPONSE_KEEP_LEN,
Expand Down Expand Up @@ -373,7 +374,9 @@ def post_loop(conn: BaseNetQASMConnection, loop_reg: RegFuture):
pair=pair,
)

if params.expect_phi_plus and role == EPRRole.RECV:
if (
params.expect_phi_plus or params.expect_psi_plus
) and role == EPRRole.RECV:
# Perform Bell corrections
bell_state = self._get_raw_bell_state(
ent_results_array, loop_reg, bell_state_reg
Expand All @@ -382,7 +385,9 @@ def post_loop(conn: BaseNetQASMConnection, loop_reg: RegFuture):
instruction=GenericInstr.SET, operands=[qubit_reg, 0]
)
self.subrt_add_pending_command(set_qubit_reg_cmd) # type: ignore
self._build_cmds_epr_keep_corrections_single_pair(bell_state, qubit_reg)
self._build_cmds_epr_keep_corrections_single_pair(
bell_state, qubit_reg, params
)

# If it's the last pair, don't move it to a mem qubit
with loop_reg.if_ne(params.number - 1):
Expand Down Expand Up @@ -444,15 +449,19 @@ def post_loop(conn: BaseNetQASMConnection, loop_reg: RegFuture):
)
assert tp == EPRType.K or tp == EPRType.R

if params.expect_phi_plus and role == EPRRole.RECV:
if (
params.expect_phi_plus or params.expect_psi_plus
) and role == EPRRole.RECV:
bell_state = self._get_raw_bell_state(
ent_results_array, loop_reg, bell_state_reg
)
set_qubit_reg_cmd = ICmd(
instruction=GenericInstr.SET, operands=[qubit_reg, 0]
)
self.subrt_add_pending_command(set_qubit_reg_cmd) # type: ignore
self._build_cmds_epr_keep_corrections_single_pair(bell_state, qubit_reg)
self._build_cmds_epr_keep_corrections_single_pair(
bell_state, qubit_reg, params
)

q_id = qubit_ids.get_future_index(loop_register)
q = FutureQubit(conn=conn, future_id=q_id)
Expand Down Expand Up @@ -1353,26 +1362,43 @@ def undef_result_element(conn: BaseNetQASMConnection, _: RegFuture):
)

def _build_cmds_epr_keep_corrections_single_pair(
self, bell_state: RegFuture, qubit_reg: operand.Register
self,
bell_state: RegFuture,
qubit_reg: operand.Register,
params: EntRequestParams,
) -> None:
with bell_state.if_eq(BellState.PHI_MINUS.value): # Phi- -> apply Z-gate
correction_cmds = [
ICmd(instruction=GenericInstr.ROT_Z, operands=[qubit_reg, 16, 4])
]
self.subrt_add_pending_commands(correction_cmds) # type: ignore
with bell_state.if_eq(BellState.PSI_PLUS.value): # Psi+ -> apply X-gate
correction_cmds = [
ICmd(instruction=GenericInstr.ROT_X, operands=[qubit_reg, 16, 4])
]
self.subrt_add_pending_commands(correction_cmds) # type: ignore
with bell_state.if_eq(
BellState.PSI_MINUS.value
): # Psi- -> apply X-gate and Z-gate
correction_cmds = [
x180 = [ICmd(instruction=GenericInstr.ROT_X, operands=[qubit_reg, 16, 4])]

if get_is_using_hardware():
# For hardware, don't use Z-gates.
# Decompose Z180 into Y90, X180, -Y90
z180 = [
ICmd(instruction=GenericInstr.ROT_Y, operands=[qubit_reg, 8, 4]),
ICmd(instruction=GenericInstr.ROT_X, operands=[qubit_reg, 16, 4]),
ICmd(instruction=GenericInstr.ROT_Z, operands=[qubit_reg, 16, 4]),
ICmd(instruction=GenericInstr.ROT_Y, operands=[qubit_reg, 24, 4]),
]
self.subrt_add_pending_commands(correction_cmds) # type: ignore
else:
z180 = [ICmd(instruction=GenericInstr.ROT_Z, operands=[qubit_reg, 16, 4])]

if params.expect_phi_plus:
with bell_state.if_eq(BellState.PHI_MINUS.value): # Phi- -> apply Z-gate
self.subrt_add_pending_commands(z180) # type: ignore
with bell_state.if_eq(BellState.PSI_PLUS.value): # Psi+ -> apply X-gate
self.subrt_add_pending_commands(x180) # type: ignore
with bell_state.if_eq(
BellState.PSI_MINUS.value
): # Psi- -> apply X-gate and Z-gate
self.subrt_add_pending_commands(x180 + z180) # type: ignore
else:
assert params.expect_psi_plus
with bell_state.if_eq(BellState.PHI_PLUS.value): # Phi+ -> apply X-gate
self.subrt_add_pending_commands(x180) # type: ignore
with bell_state.if_eq(
BellState.PHI_MINUS.value
): # Phi- -> apply X-gate and Z-gate
self.subrt_add_pending_commands(x180 + z180) # type: ignore
with bell_state.if_eq(BellState.PSI_MINUS.value): # Psi- -> apply Z-gate
self.subrt_add_pending_commands(z180) # type: ignore

def _build_cmds_epr_keep_corrections(
self, qubit_ids_array: Array, ent_results_array: Array, params: EntRequestParams
Expand All @@ -1394,7 +1420,9 @@ def loop(conn: BaseNetQASMConnection, index: RegFuture):
instruction=GenericInstr.SET, operands=[qubit_reg, 0]
)
self.subrt_add_pending_command(set_qubit_reg_cmd) # type: ignore
self._build_cmds_epr_keep_corrections_single_pair(bell_state, qubit_reg)
self._build_cmds_epr_keep_corrections_single_pair(
bell_state, qubit_reg, params
)

self._build_cmds_loop_body(
loop, stop=params.number, loop_register=loop_register
Expand Down Expand Up @@ -1501,7 +1529,7 @@ def _build_cmds_epr_recv_keep(

self.subrt_add_pending_commands(wait_cmds) # type: ignore

if wait_all and params.expect_phi_plus:
if wait_all and (params.expect_phi_plus or params.expect_psi_plus):
self._build_cmds_epr_keep_corrections(
qubit_ids_array, ent_results_array, params
)
Expand Down Expand Up @@ -1641,7 +1669,7 @@ def _build_cmds_epr_recv_rsp(

self.subrt_add_pending_commands(wait_cmds) # type: ignore

if wait_all and params.expect_phi_plus:
if wait_all and (params.expect_phi_plus or params.expect_psi_plus):
self._build_cmds_epr_keep_corrections(
qubit_ids_array, ent_results_array, params
)
Expand Down Expand Up @@ -1871,6 +1899,8 @@ def sdk_epr_keep(
else:
wait_all = True

self._connection._logger.info(f"wait_all = {wait_all}")

if reset_results_array:
self._build_cmds_undefine_array(ent_results_array)

Expand Down
6 changes: 6 additions & 0 deletions netqasm/sdk/epr_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,7 @@ def recv_keep(
post_routine: Optional[Callable] = None,
sequential: bool = False,
expect_phi_plus: bool = True,
expect_psi_plus: bool = False,
min_fidelity_all_at_end: Optional[int] = None,
max_tries: Optional[int] = None,
) -> List[Qubit]:
Expand Down Expand Up @@ -681,6 +682,10 @@ def recv_keep(
if self.conn is None:
raise RuntimeError("EPRSocket does not have an open connection")

assert not (
expect_phi_plus and expect_psi_plus
), "cannot ask for both phi+ and psi+"

qubits, _ = self.conn._builder.sdk_recv_epr_keep(
params=EntRequestParams(
remote_node_id=self.remote_node_id,
Expand All @@ -689,6 +694,7 @@ def recv_keep(
post_routine=post_routine,
sequential=sequential,
expect_phi_plus=expect_phi_plus,
expect_psi_plus=expect_psi_plus,
min_fidelity_all_at_end=min_fidelity_all_at_end,
max_tries=max_tries,
),
Expand Down
6 changes: 3 additions & 3 deletions netqasm/sdk/transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def _map_single_gate(
),
]
elif isinstance(instr, vanilla.RotZInstruction):
if get_is_using_hardware():
if get_is_using_hardware() and instr.angle_denom.value != 4:
imm0, imm1 = get_hardware_num_denom(instr)
else:
imm0, imm1 = instr.angle_num, instr.angle_denom
Expand All @@ -589,7 +589,7 @@ def _map_single_gate(
),
]
elif isinstance(instr, vanilla.RotXInstruction):
if get_is_using_hardware():
if get_is_using_hardware() and instr.angle_denom.value != 4:
imm0, imm1 = get_hardware_num_denom(instr)
else:
imm0, imm1 = instr.angle_num, instr.angle_denom
Expand All @@ -599,7 +599,7 @@ def _map_single_gate(
),
]
elif isinstance(instr, vanilla.RotYInstruction):
if get_is_using_hardware():
if get_is_using_hardware() and instr.angle_denom.value != 4:
imm0, imm1 = get_hardware_num_denom(instr)
else:
imm0, imm1 = instr.angle_num, instr.angle_denom
Expand Down
86 changes: 86 additions & 0 deletions tests/test_sdk/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from netqasm.lang.version import NETQASM_VERSION
from netqasm.logging.glob import set_log_level
from netqasm.qlink_compat import EPRType, TimeUnit
from netqasm.runtime.settings import set_is_using_hardware
from netqasm.sdk.connection import DebugConnection
from netqasm.sdk.epr_socket import EPRSocket
from netqasm.sdk.qubit import Qubit
Expand Down Expand Up @@ -297,6 +298,91 @@ def test_epr_k_recv():
print(expected)


def test_epr_k_recv_hardware():
set_is_using_hardware(True)

set_log_level(logging.DEBUG)

epr_socket = EPRSocket(remote_app_name="Bob")
with DebugConnection("Alice", epr_sockets=[epr_socket]) as alice:
q1 = epr_socket.recv_keep()[0]
q1.H()

# 5 messages: init, open_epr_socket, subroutine, stop app and stop backend
assert len(alice.storage) == 5
raw_subroutine = deserialize_message(raw=alice.storage[2]).subroutine
subroutine = deserialize_subroutine(raw_subroutine)
print(subroutine)

expected_text = """
# NETQASM 0.0
# APPID 0
set R5 10
array R5 @0
set R5 1
array R5 @1
set R5 0
set R6 0
store R5 @1[R6]
set R5 1
set R6 0
set R7 1
set R8 0
recv_epr R5 R6 R7 R8
set R5 0
set R6 10
wait_all @0[R5:R6]
set R2 0
set R5 1
beq R2 R5 46
load R0 @1[R2]
set R3 9
set R4 0
beq R4 R2 27
set R5 10
add R3 R3 R5
set R5 1
add R4 R4 R5
jmp 21
load R1 @0[R3]
set R0 0
set R5 3
bne R1 R5 34
rot_y R0 8 4
rot_x R0 16 4
rot_y R0 24 4
set R5 1
bne R1 R5 37
rot_x R0 16 4
set R5 2
bne R1 R5 43
rot_x R0 16 4
rot_y R0 8 4
rot_x R0 16 4
rot_y R0 24 4
set R5 1
add R2 R2 R5
jmp 16
set Q0 0
h Q0
ret_arr @0
ret_arr @1
"""

expected = parse_text_subroutine(expected_text)

for i, instr in enumerate(subroutine.instructions):
print(repr(instr))
expected_instr = expected.instructions[i]
print(repr(expected_instr))
print()
assert instr == expected_instr
print(subroutine)
print(expected)

set_is_using_hardware(False)


def test_two_epr_k_create():

set_log_level(logging.DEBUG)
Expand Down
Loading