Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add phase RX gate #243

Merged
merged 1 commit into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/devices/braket_local.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,5 @@ from :mod:`braket.pennylane_plugin.ops <.ops>`:
braket.pennylane_plugin.GPi
braket.pennylane_plugin.GPi2
braket.pennylane_plugin.MS
braket.pennylane_plugin.PRx

1 change: 1 addition & 0 deletions doc/devices/braket_remote.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ from :mod:`braket.pennylane_plugin.ops <.ops>`:
braket.pennylane_plugin.GPi
braket.pennylane_plugin.GPi2
braket.pennylane_plugin.MS
braket.pennylane_plugin.PRx

Pulse Programming
~~~~~~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions src/braket/pennylane_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CPhaseShift10,
GPi,
GPi2,
PRx,
)

from ._version import __version__ # noqa: F401
51 changes: 51 additions & 0 deletions src/braket/pennylane_plugin/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,57 @@ def adjoint(self):
return CPhaseShift10(-phi, wires=self.wires)


class PRx(Operation):
r"""Phase Rx gate.

Unitary matrix:

.. math:: \mathtt{PRx}(\theta,\phi) = \begin{bmatrix}
\cos{(\theta / 2)} & -i e^{-i \phi} \sin{(\theta / 2)} \\
-i e^{i \phi} \sin{(\theta / 2)} & \cos{(\theta / 2)}
\end{bmatrix}.

**Details**

* Number of wires: 1
* Number of parameters: 2

Args:
theta (Union[FreeParameterExpression, float]): The first angle of the gate in
radians or expression representation.
phi (Union[FreeParameterExpression, float]): The second angle of the gate in
radians or expression representation.
"""

num_params = 2
num_wires = 1
grad_method = "F"

def __init__(self, theta, phi, wires, id=None):
super().__init__(theta, phi, wires=wires, id=id)

@staticmethod
def compute_matrix(theta, phi):
theta = _cast_to_tf(theta)
phi = _cast_to_tf(phi)
return np.array(
[
[
np.cos(theta / 2),
-1j * np.exp(-1j * phi) * np.sin(theta / 2),
],
[
-1j * np.exp(1j * phi) * np.sin(theta / 2),
np.cos(theta / 2),
],
]
)

def adjoint(self):
(theta, phi) = self.parameters
return PRx(-theta, phi, wires=self.wires)


class PSWAP(Operation):
r""" PSWAP(phi, wires)

Expand Down
9 changes: 9 additions & 0 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
CPhaseShift10,
GPi,
GPi2,
PRx,
)

_BRAKET_TO_PENNYLANE_OPERATIONS = {
Expand Down Expand Up @@ -88,6 +89,7 @@
"yy": "IsingYY",
"zz": "IsingZZ",
"ecr": "ECR",
"prx": "PRx",
"gpi": "GPi",
"gpi2": "GPi2",
"ms": "AAMS",
Expand Down Expand Up @@ -399,6 +401,13 @@ def _(zz: qml.IsingZZ, parameters, device=None):
return gates.ZZ(phi)


@_translate_operation.register
def _(_prx: PRx, parameters, device=None):
theta = parameters[0]
phi = parameters[1]
return gates.PRx(theta, phi)


@_translate_operation.register
def _(_gpi: GPi, parameters, device=None):
phi = parameters[0]
Expand Down
3 changes: 2 additions & 1 deletion test/unit_tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from numpy import float64

from braket.pennylane_plugin import PSWAP, CPhaseShift00, CPhaseShift01, CPhaseShift10
from braket.pennylane_plugin.ops import AAMS, MS, GPi, GPi2
from braket.pennylane_plugin.ops import AAMS, MS, GPi, GPi2, PRx

gates_1q_parametrized = [
(GPi, gates.GPi),
Expand All @@ -42,6 +42,7 @@

gates_2q_2p_parametrized = [
(MS, gates.MS),
(PRx, gates.PRx),
]

gates_2q_3p_parametrized = [
Expand Down
5 changes: 3 additions & 2 deletions test/unit_tests/test_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
CPhaseShift01,
CPhaseShift10,
)
from braket.pennylane_plugin.ops import AAMS, MS, GPi, GPi2
from braket.pennylane_plugin.ops import AAMS, MS, GPi, GPi2, PRx
from braket.pennylane_plugin.translation import (
_BRAKET_TO_PENNYLANE_OPERATIONS,
_translate_observable,
Expand Down Expand Up @@ -140,6 +140,7 @@ def _aws_device(
(GPi, gates.GPi, [0], [2]),
(GPi2, gates.GPi2, [0], [2]),
(MS, gates.MS, [0, 1], [2, 3]),
(PRx, gates.PRx, [0], [2, 3]),
(AAMS, gates.MS, [0, 1], [2, 3, 0.5]),
(qml.ECR, gates.ECR, [0, 1], []),
(qml.ISWAP, gates.ISwap, [0, 1], []),
Expand Down Expand Up @@ -339,7 +340,7 @@ def test_translate_operation(pl_cls, braket_cls, qubits, params):
pl_op = pl_cls(*params, wires=qubits)
braket_gate = braket_cls(*params)
assert translate_operation(pl_op) == braket_gate
if isinstance(pl_op, (GPi, GPi2, MS, AAMS)):
if isinstance(pl_op, (GPi, GPi2, MS, AAMS, PRx)):
translated_back = _braket_to_pl[
re.match("^[a-z0-2]+", braket_gate.to_ir(qubits, ir_type=IRType.OPENQASM)).group(0)
]
Expand Down