Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Cirq Transforms for Gate Decomposition (#93) #184

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 63 additions & 21 deletions qbraid_qir/cirq/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,70 @@

"""
import itertools
from typing import Iterable
from typing import Iterable, List, Sequence, Type, Union

import cirq
from cirq.protocols.decompose_protocol import DecomposeResult

from .exceptions import CirqConversionError
from .opsets import map_cirq_op_to_pyqir_callable


class QirTargetGateSet(cirq.TwoQubitCompilationTargetGateset):
def __init__(
self,
*,
atol: float = 1e-8,
allow_partial_czs: bool = False,
additional_gates: Sequence[
Union[Type["cirq.Gate"], "cirq.Gate", "cirq.GateFamily"]
] = (),
preserve_moment_structure: bool = True,
) -> None:
super().__init__(
cirq.IdentityGate,
cirq.HPowGate,
cirq.XPowGate,
cirq.YPowGate,
cirq.ZPowGate,
cirq.SWAP,
cirq.CNOT,
cirq.CZ,
cirq.TOFFOLI,
cirq.ResetChannel,
cirq.MeasurementGate,
cirq.PauliMeasurementGate,
*additional_gates,
name="QirTargetGateset",
preserve_moment_structure=preserve_moment_structure,
)
self.allow_partial_czs = allow_partial_czs
self.atol = atol

@property
def postprocess_transformers(self) -> List["cirq.TRANSFORMER"]:
return []

def _decompose_single_qubit_operation(
self, op: "cirq.Operation", moment_idx: int
) -> DecomposeResult:
qubit = op.qubits[0]
mat = cirq.unitary(op)
for gate in cirq.single_qubit_matrix_to_gates(mat, self.atol):
yield gate(qubit)

def _decompose_two_qubit_operation(self, op: "cirq.Operation", _) -> "cirq.OP_TREE":
if not cirq.has_unitary(op):
return NotImplemented
return cirq.two_qubit_matrix_to_cz_operations(
op.qubits[0],
op.qubits[1],
cirq.unitary(op),
allow_partial_czs=self.allow_partial_czs,
atol=self.atol,
)


def _decompose_gate_op(operation: cirq.Operation) -> Iterable[cirq.OP_TREE]:
"""Decomposes a single Cirq gate operation into a sequence of operations
that are directly supported by PyQIR.
Expand All @@ -36,12 +92,10 @@ def _decompose_gate_op(operation: cirq.Operation) -> Iterable[cirq.OP_TREE]:
_ = map_cirq_op_to_pyqir_callable(operation)
return [operation]
except CirqConversionError:
pass
new_ops = cirq.decompose_once(operation, flatten=True, default=[operation])
if len(new_ops) == 1 and new_ops[0] == operation:
raise CirqConversionError("Couldn't convert circuit to QIR gate set.")
return list(itertools.chain.from_iterable(map(_decompose_gate_op, new_ops)))

new_ops = cirq.decompose_once(operation, flatten=True, default=[operation])
if len(new_ops) == 1 and new_ops[0] == operation:
raise CirqConversionError("Couldn't convert circuit to QIR gate set.")
return list(itertools.chain.from_iterable(map(_decompose_gate_op, new_ops)))

def _decompose_unsupported_gates(circuit: cirq.Circuit) -> cirq.Circuit:
"""
Expand All @@ -53,21 +107,9 @@ def _decompose_unsupported_gates(circuit: cirq.Circuit) -> cirq.Circuit:
Returns:
cirq.Circuit: A new circuit with unsupported gates decomposed.
"""
new_circuit = cirq.Circuit()
for moment in circuit:
new_ops = []
for operation in moment:
if isinstance(operation, cirq.GateOperation):
decomposed_ops = list(_decompose_gate_op(operation))
new_ops.extend(decomposed_ops)
elif isinstance(operation, cirq.ClassicallyControlledOperation):
new_ops.append(operation)
else:
new_ops.append(operation)

new_circuit.append(new_ops)
return new_circuit
circuit = cirq.optimize_for_target_gateset(circuit, gateset=QirTargetGateSet(), ignore_failures=True, max_num_passes=1)

return circuit

def preprocess_circuit(circuit: cirq.Circuit) -> cirq.Circuit:
"""
Expand Down
17 changes: 13 additions & 4 deletions qbraid_qir/cirq/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
from abc import ABCMeta, abstractmethod

import numpy as np
import cirq
import pyqir
import pyqir._native
Expand Down Expand Up @@ -108,6 +109,13 @@ def handle_measurement(pyqir_func):
for qubit, result in zip(qubits, results):
self._measured_qubits[pyqir.qubit_id(qubit)] = True
pyqir_func(self._builder, qubit, result)

def get_rot_gate_angle(operation: cirq.Operation):
if isinstance(operation.gate, (cirq.ops.XPowGate, cirq.ops.YPowGate, cirq.ops.ZPowGate)):
angle = operation.gate.exponent * np.pi
else:
angle = operation.gate._rads
return angle

# dealing with conditional gates
if isinstance(operation, cirq.ClassicallyControlledOperation):
Expand All @@ -121,9 +129,10 @@ def handle_measurement(pyqir_func):

# pylint: disable=unnecessary-lambda-assignment
if op_str in ["Rx", "Ry", "Rz"]:
angle = get_rot_gate_angle(operation._sub_operation)
pyqir_func = lambda: temp_pyqir_func(
self._builder,
operation._sub_operation.gate._rads, # type: ignore[union-attr]
angle, # type: ignore[union-attr]
*qubits,
)
else:
Expand All @@ -144,11 +153,11 @@ def _branch(conds, pyqir_func):
_branch(conditions, pyqir_func)
else:
pyqir_func, op_str = map_cirq_op_to_pyqir_callable(operation)

if op_str.startswith("measure"):
handle_measurement(pyqir_func)
elif op_str in ["Rx", "Ry", "Rz"]:
pyqir_func(self._builder, operation.gate._rads, *qubits) # type: ignore[union-attr]
elif op_str in ["Rx", "Ry", "Rz"]:
angle = get_rot_gate_angle(operation)
pyqir_func(self._builder, angle, *qubits) # type: ignore[union-attr]
else:
pyqir_func(self._builder, *qubits)

Expand Down
21 changes: 5 additions & 16 deletions tests/cirq_qir/test_cirq_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
import cirq
import numpy as np
import pytest
import qbraid

from qbraid_qir.cirq.exceptions import CirqConversionError
from qbraid_qir.cirq.passes import preprocess_circuit

# pylint: disable=redefined-outer-name


@pytest.fixture
def gridqubit_circuit():
qubits = [cirq.GridQubit(x, 0) for x in range(4)]
Expand All @@ -40,17 +38,17 @@ def test_convert_gridqubits_to_linequbits(gridqubit_circuit):
linequbit_circuit = preprocess_circuit(gridqubit_circuit)
for qubit in linequbit_circuit.all_qubits():
assert isinstance(qubit, cirq.LineQubit), "Qubit is not a LineQubit"
assert np.allclose(
linequbit_circuit.unitary(), gridqubit_circuit.unitary()
qbraid.interface.assert_allclose_up_to_global_phase(
linequbit_circuit.unitary(), gridqubit_circuit.unitary(), atol=1e-6
), "Circuits are not equal"


def test_convert_namedqubits_to_linequbits(namedqubit_circuit):
linequbit_circuit = preprocess_circuit(namedqubit_circuit)
for qubit in linequbit_circuit.all_qubits():
assert isinstance(qubit, cirq.LineQubit), "Qubit is not a LineQubit"
assert np.allclose(
linequbit_circuit.unitary(), namedqubit_circuit.unitary()
qbraid.interface.assert_allclose_up_to_global_phase(
linequbit_circuit.unitary(), namedqubit_circuit.unitary(), atol=1e-6
), "Circuits are not equal"


Expand All @@ -59,12 +57,3 @@ def test_empty_circuit_conversion():
converted_circuit = preprocess_circuit(circuit)
assert len(converted_circuit.all_qubits()) == 0, "Converted empty circuit should have no qubits"


def test_multi_qubit_measurement_error():
qubits = cirq.LineQubit.range(3)
circuit = cirq.Circuit()
ps = cirq.X(qubits[0]) * cirq.Y(qubits[1]) * cirq.X(qubits[2])
meas_gates = cirq.measure_single_paulistring(ps)
circuit.append(meas_gates)
with pytest.raises(CirqConversionError):
preprocess_circuit(circuit)
9 changes: 1 addition & 8 deletions tests/cirq_qir/test_cirq_to_qir.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import cirq
import pyqir
import pytest
import numpy as np

from qbraid_qir.cirq import CirqConversionError, cirq_to_qir
from tests.cirq_qir.fixtures.basic_gates import (
Expand Down Expand Up @@ -69,14 +70,6 @@ def test_cirq_qir_conversion_error():
cirq_to_qir(None)


def test_cirq_to_qir_conversion_error():
"""Test raising exception for conversion error."""
op = cirq.XPowGate(exponent=0.25).controlled().on(cirq.LineQubit(1), cirq.LineQubit(2))
circuit = cirq.Circuit(op)
with pytest.raises(CirqConversionError):
cirq_to_qir(circuit)


@pytest.mark.parametrize("circuit_name", single_op_tests)
def test_single_qubit_gates(circuit_name, request):
qir_op, circuit = request.getfixturevalue(circuit_name)
Expand Down
15 changes: 0 additions & 15 deletions tests/qasm3_qir/converter/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,3 @@ def test_nested_gate_modifiers():
check_single_qubit_gate_op(generated_qir, 2, [1, 1, 1], "z")


def test_unsupported_modifiers():
# TO DO : add implementations, but till then we have tests
for modifier in ["ctrl", "negctrl"]:
with pytest.raises(
NotImplementedError,
match=r"Controlled modifier gates not yet supported .*",
):
_ = qasm3_to_qir(
f"""
OPENQASM 3;
include "stdgates.inc";
qubit[2] q;
{modifier} @ h q[0], q[1];
"""
)
Loading