Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] Improve unitary_to_rot tests #6977

Merged
merged 6 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
* Added class `qml.capture.transforms.UnitaryToRotInterpreter` that decomposes `qml.QubitUnitary` operators
following the same API as `qml.transforms.unitary_to_rot` when experimental program capture is enabled.
[(#6916)](https://github.com/PennyLaneAI/pennylane/pull/6916)
[(#6977)](https://github.com/PennyLaneAI/pennylane/pull/6977)

<h3>Improvements 🛠</h3>

Expand Down
201 changes: 78 additions & 123 deletions tests/capture/transforms/test_capture_unitary_to_rot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
qnode_prim,
while_loop_prim,
)
from pennylane.tape.plxpr_conversion import CollectOpsandMeas
from pennylane.transforms.unitary_to_rot import (
UnitaryToRotInterpreter,
one_qubit_decomposition,
Expand Down Expand Up @@ -96,7 +97,7 @@ def f(U):
assert jaxpr.eqns[-2].primitive == qml.PauliZ._primitive
assert jaxpr.eqns[-1].primitive == qml.measurements.ExpectationMP._obs_primitive

def test_three_qubit_example(self):
def test_three_qubit_conversion(self):
"""Tests that no decomposition occurs since num_qubits > 2"""

@UnitaryToRotInterpreter()
Expand All @@ -113,143 +114,62 @@ def f(U):
assert jaxpr.eqns[-2].primitive == qml.PauliZ._primitive
assert jaxpr.eqns[-1].primitive == qml.measurements.ExpectationMP._obs_primitive


class TestQNodeIntegration:
"""Test that transform works at the QNode level."""

def test_one_qubit_conversion_qnode(self):
"""Test that you can integrate the transform at the QNode level."""
dev = qml.device("default.qubit", wires=1)
def test_traced_arguments(self):
"""Test that traced arguments are correctly handled."""

@UnitaryToRotInterpreter()
@qml.qnode(dev)
def f(U):
qml.QubitUnitary(U, 0)
qml.X(0)
def f(U, wire):
qml.QubitUnitary(U, wire)
return qml.expval(qml.Z(0))

U = qml.Rot(jax.numpy.pi, 0, 0, wires=0)

jaxpr = jax.make_jaxpr(f)(U.matrix())
assert jaxpr.eqns[0].primitive == qnode_prim
qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"]

# Qubit Unitary decomposition
with qml.capture.pause():
QU = qml.QubitUnitary(U.matrix(), 0)
decomp = jax.jit(one_qubit_decomposition)(QU.parameters[0], QU.wires[0])
assert len(decomp) > 1
for i, eqn in enumerate(qfunc_jaxpr.eqns[-len(decomp) - 3 : -3]):
assert eqn.primitive == decomp[i]._primitive

# X gate
assert qfunc_jaxpr.eqns[-3].primitive == qml.PauliX._primitive

# Measurement
assert qfunc_jaxpr.eqns[-2].primitive == qml.PauliZ._primitive
assert qfunc_jaxpr.eqns[-1].primitive == qml.measurements.ExpectationMP._obs_primitive

res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, U.matrix())
assert qml.math.allclose(res, -1.0)

# two_qubit_decomposition only supports decomps with
# three CNOTs for abstract matrices
def test_two_qubit_three_cnot_conversion_qnode(self):
"""Test that a two qubit unitary can be decomposed correctly."""
dev = qml.device("default.qubit", wires=2)

U1 = qml.Rot(jax.numpy.pi, 0, 0, wires=0)
U2 = qml.Rot(jax.numpy.pi, 0, 0, wires=1)

U = qml.prod(U1, U2)
U = qml.Rot(1.0, 2.0, 3.0, wires=0)
args = (U.matrix(), 0)
jaxpr = jax.make_jaxpr(f)(*args)
collector = CollectOpsandMeas()
collector.eval(jaxpr.jaxpr, jaxpr.consts, *args)

@UnitaryToRotInterpreter()
@qml.qnode(dev)
def f(U):
qml.QubitUnitary(U, [0, 1])
return qml.expval(qml.Z(0)), qml.expval(qml.Z(1))
expected_ops = [
qml.RZ(jax.numpy.array(1.0), wires=[0]),
qml.RY(jax.numpy.array(2.0), wires=[0]),
qml.RZ(jax.numpy.array(3.0), wires=[0]),
]

jaxpr = jax.make_jaxpr(f)(U.matrix())
assert jaxpr.eqns[0].primitive == qnode_prim
qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"]
ops = collector.state["ops"]
assert ops == expected_ops

# Theoretical decomposition based on,
# https://docs.pennylane.ai/en/stable/code/api/pennylane.ops.two_qubit_decomposition.html
with qml.capture.pause():
QU = qml.QubitUnitary(U.matrix(), [0, 1])
decomp = jax.jit(two_qubit_decomposition)(QU.parameters[0], QU.wires)
assert len(decomp) > 1
for i, eqn in enumerate(qfunc_jaxpr.eqns[-len(decomp) - 4 : -4]):
assert eqn.primitive == decomp[i]._primitive
expected_meas = [
qml.expval(qml.PauliZ(0)),
]
meas = collector.state["measurements"]
assert meas == expected_meas

# Measurement 1
assert qfunc_jaxpr.eqns[-4].primitive == qml.PauliZ._primitive
assert qfunc_jaxpr.eqns[-3].primitive == qml.measurements.ExpectationMP._obs_primitive

# Measurement 2
assert qfunc_jaxpr.eqns[-2].primitive == qml.PauliZ._primitive
assert qfunc_jaxpr.eqns[-1].primitive == qml.measurements.ExpectationMP._obs_primitive
def test_plxpr_to_plxpr():
"""Test that transforming plxpr works correctly."""

res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, U.matrix())
assert qml.math.allclose(res, (1.0, 1.0))
def circuit(U):
qml.QubitUnitary(U, 0)
return qml.expval(qml.Z(0))

U = qml.Rot(1.0, 2.0, 3.0, wires=0)
args = (U.matrix(),)
jaxpr = jax.make_jaxpr(circuit)(*args)
transformed_jaxpr = unitary_to_rot_plxpr_to_plxpr(jaxpr.jaxpr, jaxpr.consts, [], {}, *args)

class TestUnitaryToRotPlxprTransform:
"""Tests that transforming plxpr works correctly."""
assert isinstance(transformed_jaxpr, jax.core.ClosedJaxpr)

def test_one_qubit_plxpr_transform(self):
"""Test that transforming plxpr works correctly."""
# Qubit Unitary decomposition
with qml.capture.pause():
QU = qml.QubitUnitary(U.matrix(), 0)
decomp = jax.jit(one_qubit_decomposition)(QU.parameters[0], QU.wires[0])
assert len(decomp) > 1

def circuit(U):
qml.QubitUnitary(U, 0)
return qml.expval(qml.Z(0))
for i, eqn in enumerate(transformed_jaxpr.eqns[-len(decomp) : -2]):
assert eqn.primitive == decomp[i]._primitive

U = qml.Rot(1.0, 2.0, 3.0, wires=0)
args = (U.matrix(),)
jaxpr = jax.make_jaxpr(circuit)(*args)
transformed_jaxpr = unitary_to_rot_plxpr_to_plxpr(jaxpr.jaxpr, jaxpr.consts, [], {}, *args)

assert isinstance(transformed_jaxpr, jax.core.ClosedJaxpr)

# Qubit Unitary decomposition
with qml.capture.pause():
QU = qml.QubitUnitary(U.matrix(), 0)
decomp = jax.jit(one_qubit_decomposition)(QU.parameters[0], QU.wires[0])
assert len(decomp) > 1
for i, eqn in enumerate(transformed_jaxpr.eqns[-len(decomp) : -2]):
assert eqn.primitive == decomp[i]._primitive

# Measurement
assert transformed_jaxpr.eqns[-2].primitive == qml.PauliZ._primitive
assert transformed_jaxpr.eqns[-1].primitive == qml.measurements.ExpectationMP._obs_primitive

# two_qubit_decomposition only supports decomps with
# three CNOTs for abstract matrices
def test_two_qubit_three_cnot_plxpr_transform(self):
"""Test that a two qubit unitary can be decomposed correctly."""

def circuit(U):
qml.QubitUnitary(U, [0, 1])
return qml.expval(qml.Z(0))

U1 = qml.Rot(1.0, 2.0, 3.0, wires=0)
U2 = qml.Rot(1.0, 2.0, 3.0, wires=1)
U = qml.prod(U1, U2)
args = (U.matrix(),)
jaxpr = jax.make_jaxpr(circuit)(*args)
transformed_jaxpr = unitary_to_rot_plxpr_to_plxpr(jaxpr.jaxpr, jaxpr.consts, [], {}, *args)

# Theoretical decomposition based on,
# https://docs.pennylane.ai/en/stable/code/api/pennylane.ops.two_qubit_decomposition.html
with qml.capture.pause():
QU = qml.QubitUnitary(U.matrix(), [0, 1])
decomp = jax.jit(two_qubit_decomposition)(QU.parameters[0], QU.wires)
assert len(decomp) > 1
for i, eqn in enumerate(transformed_jaxpr.eqns[-len(decomp) - 2 : -2]):
assert eqn.primitive == decomp[i]._primitive
# Measurement
assert transformed_jaxpr.eqns[-2].primitive == qml.PauliZ._primitive
assert transformed_jaxpr.eqns[-1].primitive == qml.measurements.ExpectationMP._obs_primitive
# Measurement
assert transformed_jaxpr.eqns[-2].primitive == qml.PauliZ._primitive
assert transformed_jaxpr.eqns[-1].primitive == qml.measurements.ExpectationMP._obs_primitive


class TestHigherOrderPrimitiveIntegration:
Expand Down Expand Up @@ -474,6 +394,41 @@ def circuit(a, b, c):
assert qfunc_jaxpr.eqns[-2].primitive == qml.PauliZ._primitive
assert qfunc_jaxpr.eqns[-1].primitive == qml.measurements.ExpectationMP._obs_primitive

def test_qnode_higher_order_primitive(self):
"""Test that you can integrate the transform at the QNode level."""
dev = qml.device("default.qubit", wires=1)

@UnitaryToRotInterpreter()
@qml.qnode(dev)
def f(U):
qml.QubitUnitary(U, 0)
qml.X(0)
return qml.expval(qml.Z(0))

U = qml.Rot(jax.numpy.pi, 0, 0, wires=0)

jaxpr = jax.make_jaxpr(f)(U.matrix())
assert jaxpr.eqns[0].primitive == qnode_prim
qfunc_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"]

# Qubit Unitary decomposition
with qml.capture.pause():
QU = qml.QubitUnitary(U.matrix(), 0)
decomp = jax.jit(one_qubit_decomposition)(QU.parameters[0], QU.wires[0])
assert len(decomp) > 1
for i, eqn in enumerate(qfunc_jaxpr.eqns[-len(decomp) - 3 : -3]):
assert eqn.primitive == decomp[i]._primitive

# X gate
assert qfunc_jaxpr.eqns[-3].primitive == qml.PauliX._primitive

# Measurement
assert qfunc_jaxpr.eqns[-2].primitive == qml.PauliZ._primitive
assert qfunc_jaxpr.eqns[-1].primitive == qml.measurements.ExpectationMP._obs_primitive

res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, U.matrix())
assert qml.math.allclose(res, -1.0)


class TestExpandPlxprTransformIntegration:
"""Test that the transform works with expand_plxpr_transform"""
Expand Down