diff --git a/netqasm/sdk/build_epr.py b/netqasm/sdk/build_epr.py index 681d085..ebc45a2 100644 --- a/netqasm/sdk/build_epr.py +++ b/netqasm/sdk/build_epr.py @@ -28,6 +28,7 @@ class EntRequestParams: time_unit: TimeUnit = TimeUnit.MICRO_SECONDS max_time: int = 0 expect_phi_plus: bool = True + expect_psi_plus: bool = False min_fidelity_all_at_end: Optional[int] = None max_tries: Optional[int] = None random_basis_local: Optional[RandomBasis] = None diff --git a/netqasm/sdk/builder.py b/netqasm/sdk/builder.py index f6b5f8e..0465649 100644 --- a/netqasm/sdk/builder.py +++ b/netqasm/sdk/builder.py @@ -38,6 +38,7 @@ from netqasm.lang.subroutine import Subroutine from netqasm.lang.version import NETQASM_VERSION from netqasm.qlink_compat import BellState, EPRRole, EPRType, LinkLayerOKTypeK +from netqasm.runtime.settings import get_is_using_hardware from netqasm.sdk.build_epr import ( SER_RESPONSE_KEEP_IDX_BELL_STATE, SER_RESPONSE_KEEP_LEN, @@ -373,7 +374,9 @@ def post_loop(conn: BaseNetQASMConnection, loop_reg: RegFuture): pair=pair, ) - if params.expect_phi_plus and role == EPRRole.RECV: + if ( + params.expect_phi_plus or params.expect_psi_plus + ) and role == EPRRole.RECV: # Perform Bell corrections bell_state = self._get_raw_bell_state( ent_results_array, loop_reg, bell_state_reg @@ -382,7 +385,9 @@ def post_loop(conn: BaseNetQASMConnection, loop_reg: RegFuture): instruction=GenericInstr.SET, operands=[qubit_reg, 0] ) self.subrt_add_pending_command(set_qubit_reg_cmd) # type: ignore - self._build_cmds_epr_keep_corrections_single_pair(bell_state, qubit_reg) + self._build_cmds_epr_keep_corrections_single_pair( + bell_state, qubit_reg, params + ) # If it's the last pair, don't move it to a mem qubit with loop_reg.if_ne(params.number - 1): @@ -444,7 +449,9 @@ def post_loop(conn: BaseNetQASMConnection, loop_reg: RegFuture): ) assert tp == EPRType.K or tp == EPRType.R - if params.expect_phi_plus and role == EPRRole.RECV: + if ( + params.expect_phi_plus or params.expect_psi_plus + ) and role == EPRRole.RECV: bell_state = self._get_raw_bell_state( ent_results_array, loop_reg, bell_state_reg ) @@ -452,7 +459,9 @@ def post_loop(conn: BaseNetQASMConnection, loop_reg: RegFuture): instruction=GenericInstr.SET, operands=[qubit_reg, 0] ) self.subrt_add_pending_command(set_qubit_reg_cmd) # type: ignore - self._build_cmds_epr_keep_corrections_single_pair(bell_state, qubit_reg) + self._build_cmds_epr_keep_corrections_single_pair( + bell_state, qubit_reg, params + ) q_id = qubit_ids.get_future_index(loop_register) q = FutureQubit(conn=conn, future_id=q_id) @@ -1353,26 +1362,43 @@ def undef_result_element(conn: BaseNetQASMConnection, _: RegFuture): ) def _build_cmds_epr_keep_corrections_single_pair( - self, bell_state: RegFuture, qubit_reg: operand.Register + self, + bell_state: RegFuture, + qubit_reg: operand.Register, + params: EntRequestParams, ) -> None: - with bell_state.if_eq(BellState.PHI_MINUS.value): # Phi- -> apply Z-gate - correction_cmds = [ - ICmd(instruction=GenericInstr.ROT_Z, operands=[qubit_reg, 16, 4]) - ] - self.subrt_add_pending_commands(correction_cmds) # type: ignore - with bell_state.if_eq(BellState.PSI_PLUS.value): # Psi+ -> apply X-gate - correction_cmds = [ - ICmd(instruction=GenericInstr.ROT_X, operands=[qubit_reg, 16, 4]) - ] - self.subrt_add_pending_commands(correction_cmds) # type: ignore - with bell_state.if_eq( - BellState.PSI_MINUS.value - ): # Psi- -> apply X-gate and Z-gate - correction_cmds = [ + x180 = [ICmd(instruction=GenericInstr.ROT_X, operands=[qubit_reg, 16, 4])] + + if get_is_using_hardware(): + # For hardware, don't use Z-gates. + # Decompose Z180 into Y90, X180, -Y90 + z180 = [ + ICmd(instruction=GenericInstr.ROT_Y, operands=[qubit_reg, 8, 4]), ICmd(instruction=GenericInstr.ROT_X, operands=[qubit_reg, 16, 4]), - ICmd(instruction=GenericInstr.ROT_Z, operands=[qubit_reg, 16, 4]), + ICmd(instruction=GenericInstr.ROT_Y, operands=[qubit_reg, 24, 4]), ] - self.subrt_add_pending_commands(correction_cmds) # type: ignore + else: + z180 = [ICmd(instruction=GenericInstr.ROT_Z, operands=[qubit_reg, 16, 4])] + + if params.expect_phi_plus: + with bell_state.if_eq(BellState.PHI_MINUS.value): # Phi- -> apply Z-gate + self.subrt_add_pending_commands(z180) # type: ignore + with bell_state.if_eq(BellState.PSI_PLUS.value): # Psi+ -> apply X-gate + self.subrt_add_pending_commands(x180) # type: ignore + with bell_state.if_eq( + BellState.PSI_MINUS.value + ): # Psi- -> apply X-gate and Z-gate + self.subrt_add_pending_commands(x180 + z180) # type: ignore + else: + assert params.expect_psi_plus + with bell_state.if_eq(BellState.PHI_PLUS.value): # Phi+ -> apply X-gate + self.subrt_add_pending_commands(x180) # type: ignore + with bell_state.if_eq( + BellState.PHI_MINUS.value + ): # Phi- -> apply X-gate and Z-gate + self.subrt_add_pending_commands(x180 + z180) # type: ignore + with bell_state.if_eq(BellState.PSI_MINUS.value): # Psi- -> apply Z-gate + self.subrt_add_pending_commands(z180) # type: ignore def _build_cmds_epr_keep_corrections( self, qubit_ids_array: Array, ent_results_array: Array, params: EntRequestParams @@ -1394,7 +1420,9 @@ def loop(conn: BaseNetQASMConnection, index: RegFuture): instruction=GenericInstr.SET, operands=[qubit_reg, 0] ) self.subrt_add_pending_command(set_qubit_reg_cmd) # type: ignore - self._build_cmds_epr_keep_corrections_single_pair(bell_state, qubit_reg) + self._build_cmds_epr_keep_corrections_single_pair( + bell_state, qubit_reg, params + ) self._build_cmds_loop_body( loop, stop=params.number, loop_register=loop_register @@ -1501,7 +1529,7 @@ def _build_cmds_epr_recv_keep( self.subrt_add_pending_commands(wait_cmds) # type: ignore - if wait_all and params.expect_phi_plus: + if wait_all and (params.expect_phi_plus or params.expect_psi_plus): self._build_cmds_epr_keep_corrections( qubit_ids_array, ent_results_array, params ) @@ -1641,7 +1669,7 @@ def _build_cmds_epr_recv_rsp( self.subrt_add_pending_commands(wait_cmds) # type: ignore - if wait_all and params.expect_phi_plus: + if wait_all and (params.expect_phi_plus or params.expect_psi_plus): self._build_cmds_epr_keep_corrections( qubit_ids_array, ent_results_array, params ) @@ -1871,6 +1899,8 @@ def sdk_epr_keep( else: wait_all = True + self._connection._logger.info(f"wait_all = {wait_all}") + if reset_results_array: self._build_cmds_undefine_array(ent_results_array) diff --git a/netqasm/sdk/epr_socket.py b/netqasm/sdk/epr_socket.py index 8ecf1b2..7c137ee 100644 --- a/netqasm/sdk/epr_socket.py +++ b/netqasm/sdk/epr_socket.py @@ -644,6 +644,7 @@ def recv_keep( post_routine: Optional[Callable] = None, sequential: bool = False, expect_phi_plus: bool = True, + expect_psi_plus: bool = False, min_fidelity_all_at_end: Optional[int] = None, max_tries: Optional[int] = None, ) -> List[Qubit]: @@ -681,6 +682,10 @@ def recv_keep( if self.conn is None: raise RuntimeError("EPRSocket does not have an open connection") + assert not ( + expect_phi_plus and expect_psi_plus + ), "cannot ask for both phi+ and psi+" + qubits, _ = self.conn._builder.sdk_recv_epr_keep( params=EntRequestParams( remote_node_id=self.remote_node_id, @@ -689,6 +694,7 @@ def recv_keep( post_routine=post_routine, sequential=sequential, expect_phi_plus=expect_phi_plus, + expect_psi_plus=expect_psi_plus, min_fidelity_all_at_end=min_fidelity_all_at_end, max_tries=max_tries, ), diff --git a/netqasm/sdk/transpile.py b/netqasm/sdk/transpile.py index d389990..a98b0d3 100644 --- a/netqasm/sdk/transpile.py +++ b/netqasm/sdk/transpile.py @@ -579,7 +579,7 @@ def _map_single_gate( ), ] elif isinstance(instr, vanilla.RotZInstruction): - if get_is_using_hardware(): + if get_is_using_hardware() and instr.angle_denom.value != 4: imm0, imm1 = get_hardware_num_denom(instr) else: imm0, imm1 = instr.angle_num, instr.angle_denom @@ -589,7 +589,7 @@ def _map_single_gate( ), ] elif isinstance(instr, vanilla.RotXInstruction): - if get_is_using_hardware(): + if get_is_using_hardware() and instr.angle_denom.value != 4: imm0, imm1 = get_hardware_num_denom(instr) else: imm0, imm1 = instr.angle_num, instr.angle_denom @@ -599,7 +599,7 @@ def _map_single_gate( ), ] elif isinstance(instr, vanilla.RotYInstruction): - if get_is_using_hardware(): + if get_is_using_hardware() and instr.angle_denom.value != 4: imm0, imm1 = get_hardware_num_denom(instr) else: imm0, imm1 = instr.angle_num, instr.angle_denom diff --git a/tests/test_sdk/test_connection.py b/tests/test_sdk/test_connection.py index cf82382..577f5b8 100644 --- a/tests/test_sdk/test_connection.py +++ b/tests/test_sdk/test_connection.py @@ -12,6 +12,7 @@ from netqasm.lang.version import NETQASM_VERSION from netqasm.logging.glob import set_log_level from netqasm.qlink_compat import EPRType, TimeUnit +from netqasm.runtime.settings import set_is_using_hardware from netqasm.sdk.connection import DebugConnection from netqasm.sdk.epr_socket import EPRSocket from netqasm.sdk.qubit import Qubit @@ -297,6 +298,91 @@ def test_epr_k_recv(): print(expected) +def test_epr_k_recv_hardware(): + set_is_using_hardware(True) + + set_log_level(logging.DEBUG) + + epr_socket = EPRSocket(remote_app_name="Bob") + with DebugConnection("Alice", epr_sockets=[epr_socket]) as alice: + q1 = epr_socket.recv_keep()[0] + q1.H() + + # 5 messages: init, open_epr_socket, subroutine, stop app and stop backend + assert len(alice.storage) == 5 + raw_subroutine = deserialize_message(raw=alice.storage[2]).subroutine + subroutine = deserialize_subroutine(raw_subroutine) + print(subroutine) + + expected_text = """ +# NETQASM 0.0 +# APPID 0 +set R5 10 +array R5 @0 +set R5 1 +array R5 @1 +set R5 0 +set R6 0 +store R5 @1[R6] +set R5 1 +set R6 0 +set R7 1 +set R8 0 +recv_epr R5 R6 R7 R8 +set R5 0 +set R6 10 +wait_all @0[R5:R6] +set R2 0 +set R5 1 +beq R2 R5 46 +load R0 @1[R2] +set R3 9 +set R4 0 +beq R4 R2 27 +set R5 10 +add R3 R3 R5 +set R5 1 +add R4 R4 R5 +jmp 21 +load R1 @0[R3] +set R0 0 +set R5 3 +bne R1 R5 34 +rot_y R0 8 4 +rot_x R0 16 4 +rot_y R0 24 4 +set R5 1 +bne R1 R5 37 +rot_x R0 16 4 +set R5 2 +bne R1 R5 43 +rot_x R0 16 4 +rot_y R0 8 4 +rot_x R0 16 4 +rot_y R0 24 4 +set R5 1 +add R2 R2 R5 +jmp 16 +set Q0 0 +h Q0 +ret_arr @0 +ret_arr @1 +""" + + expected = parse_text_subroutine(expected_text) + + for i, instr in enumerate(subroutine.instructions): + print(repr(instr)) + expected_instr = expected.instructions[i] + print(repr(expected_instr)) + print() + assert instr == expected_instr + print(subroutine) + print(expected) + + set_is_using_hardware(False) + + def test_two_epr_k_create(): set_log_level(logging.DEBUG)