Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] Capture the single_qubit_fusion transform #6945

Merged
merged 44 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
0edac8d
E.C.
PietropaoloFrisoni Feb 10, 2025
c1cd09c
Tentativo penoso
PietropaoloFrisoni Feb 11, 2025
a06cfd2
Non e' molto, ma e' un lavoro onesto
PietropaoloFrisoni Feb 12, 2025
4c69b8a
Continuing tests
PietropaoloFrisoni Feb 13, 2025
6971a49
One more test
PietropaoloFrisoni Feb 14, 2025
5fa88ee
Updating logic
PietropaoloFrisoni Feb 17, 2025
a117fb6
One more tests, but many more to add tomorrow
PietropaoloFrisoni Feb 18, 2025
037850a
More tests
PietropaoloFrisoni Feb 18, 2025
ec62880
Test for qml.cond
PietropaoloFrisoni Feb 18, 2025
ca24558
Improving cond test
PietropaoloFrisoni Feb 18, 2025
9fa3b54
More tests
PietropaoloFrisoni Feb 18, 2025
a63d106
Removing debug messages
PietropaoloFrisoni Feb 19, 2025
2eb5875
Redundant code in tests
PietropaoloFrisoni Feb 19, 2025
8d3fbb6
More tests
PietropaoloFrisoni Feb 19, 2025
d281509
Testing CI coverance
PietropaoloFrisoni Feb 19, 2025
e1f57ac
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Feb 19, 2025
371193a
Pylint
PietropaoloFrisoni Feb 19, 2025
92db59e
Final parity arguments
PietropaoloFrisoni Feb 20, 2025
11ac15e
Improved logic
PietropaoloFrisoni Feb 20, 2025
2fb67fb
Better name
PietropaoloFrisoni Feb 20, 2025
3b5397b
Simplified logic even more
PietropaoloFrisoni Feb 20, 2025
b865533
Typo in test
PietropaoloFrisoni Feb 21, 2025
8b2c1cc
Test refactoring + code review suggestions
PietropaoloFrisoni Feb 21, 2025
47c5d6c
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Feb 21, 2025
097bd09
Test refactoring
PietropaoloFrisoni Feb 21, 2025
d48fc8f
Pylint maledetto
PietropaoloFrisoni Feb 21, 2025
93c6f9b
Removing init
PietropaoloFrisoni Feb 21, 2025
525e9c8
forgot args for plxpr
PietropaoloFrisoni Feb 24, 2025
554a9c0
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Feb 24, 2025
fe31bd1
Suggestions from code review Part 1 (tests)
PietropaoloFrisoni Feb 25, 2025
dc2da14
Update tests/capture/transforms/test_capture_single_qubit_fusion.py
PietropaoloFrisoni Feb 25, 2025
6558eb9
Test of grad with non-zero returned array
PietropaoloFrisoni Feb 25, 2025
680f308
Updating for_loop test
PietropaoloFrisoni Feb 25, 2025
e2d8dbb
Suggestions from code review (WIP, need to go to the office and chang…
PietropaoloFrisoni Feb 25, 2025
a319e80
Suggestions from code review
PietropaoloFrisoni Feb 25, 2025
256c37b
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Feb 25, 2025
8fda60e
Merge branch 'master' into capture_single_qubit_fusion
PietropaoloFrisoni Feb 25, 2025
083c147
Suggestions from code review
PietropaoloFrisoni Feb 25, 2025
dba1607
Apply suggestions from code review
PietropaoloFrisoni Feb 26, 2025
34cf21d
Removing unnecessary test for inner transform
PietropaoloFrisoni Feb 26, 2025
55ffc92
Merge branch 'master' into capture_single_qubit_fusion
PietropaoloFrisoni Feb 26, 2025
47d68cd
Adding test
PietropaoloFrisoni Feb 26, 2025
29a1017
Update pennylane/transforms/optimization/single_qubit_fusion.py
PietropaoloFrisoni Feb 27, 2025
3acbdaf
Merge branch 'master' into capture_single_qubit_fusion
PietropaoloFrisoni Feb 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@

<h4>Capturing and representing hybrid programs</h4>

* The `qml.transforms.single_qubit_fusion` quantum transform can now be applied with program capture enabled.
[(#6945)](https://github.com/PennyLaneAI/pennylane/pull/6945)

* `qml.QNode` can now cache plxpr. When executing a `QNode` for the first time, its plxpr representation will
be cached based on the abstract evaluation of the arguments. Later executions that have arguments with the
same shapes and data types will be able to use this cached plxpr instead of capturing the program again.
Expand Down
226 changes: 223 additions & 3 deletions pennylane/transforms/optimization/single_qubit_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,232 @@
"""Transform for fusing sequences of single-qubit gates."""
# pylint: disable=too-many-branches

from functools import lru_cache
from typing import Optional

import pennylane as qml
from pennylane.ops.qubit import Rot
from pennylane.queuing import QueuingManager
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.transforms import transform
from pennylane.typing import PostprocessingFn
from pennylane.typing import PostprocessingFn, TensorLike

from .optimization_utils import find_next_gate, fuse_rot_angles


@lru_cache
def _get_plxpr_single_qubit_fusion(): # pylint: disable=missing-function-docstring,too-many-statements
try:
# pylint: disable=import-outside-toplevel
from jax import make_jaxpr

from pennylane.capture import PlxprInterpreter
from pennylane.operation import Operator
except ImportError: # pragma: no cover
return None, None

# pylint: disable=redefined-outer-name

class SingleQubitFusionInterpreter(PlxprInterpreter):
"""Plxpr Interpreter for applying the ``single_qubit_fusion`` transform to callables or jaxpr
when program capture is enabled.

.. note::

In the process of transforming plxpr, this interpreter may reorder operations that do
not share any wires. This will not impact the correctness of the circuit.
"""

def __init__(self, atol: Optional[float] = 1e-8, exclude_gates: Optional[list[str]] = None):
"""Initialize the interpreter."""
self.atol = atol
self.exclude_gates = exclude_gates
self.previous_ops = {}
self._env = {}
super().__init__()

def setup(self) -> None:
"""Initialize the instance before interpreting equations."""
self.previous_ops = {}
self._env = {}

def cleanup(self) -> None:
"""Clean up the instance after interpreting equations."""
self.previous_ops.clear()
self._env.clear()

def _handle_non_fusible_op(self, op: Operator) -> list:
"""Handle an operation that cannot be fused into a Rot gate."""

# The order might not be deterministic if wires (the keys) are abstract.
# However, this only impacts operators without any shared wires,
# which does not affect the correctness of the result.
previous_ops_on_wires = list(
dict.fromkeys(
self.previous_ops.get(w)
for w in op.wires
if self.previous_ops.get(w) is not None
)
)

res = []
for prev_op in previous_ops_on_wires:
# pylint: disable=protected-access
rot = qml.Rot._primitive.impl(
*qml.math.stack(prev_op.single_qubit_rot_angles()), wires=prev_op.wires
)
res.append(super().interpret_operation(rot))

res.append(super().interpret_operation(op))

for w in op.wires:
self.previous_ops.pop(w, None)

return res

def _handle_fusible_op(self, op: Operator, cumulative_angles: TensorLike) -> list:
"""Handle an operation that can be potentially fused into a Rot gate."""

# Only single-qubit gates are considered for fusion
op_wire = op.wires[0]

prev_op = self.previous_ops.get(op_wire)
if prev_op is None:
self.previous_ops[op_wire] = op
return []

prev_op_angles = qml.math.stack(prev_op.single_qubit_rot_angles())
cumulative_angles = fuse_rot_angles(prev_op_angles, cumulative_angles)

if (
qml.math.is_abstract(cumulative_angles)
or qml.math.requires_grad(cumulative_angles)
or not qml.math.allclose(
qml.math.stack(
[cumulative_angles[0] + cumulative_angles[2], cumulative_angles[1]]
),
0.0,
atol=self.atol,
rtol=0,
)
):
# pylint: disable=protected-access
new_rot = qml.Rot._primitive.impl(*cumulative_angles, wires=op.wires)
self.previous_ops[op_wire] = new_rot
else:
self.previous_ops.pop(op_wire, None)

return []

def interpret_operation(self, op: Operator):
"""Interpret a PennyLane operation instance."""

# Operators like Identity() have no wires, so we interpret them directly
if len(op.wires) == 0:
return super().interpret_operation(op)

# We interpret directly if the gate is explicitly excluded
if self.exclude_gates is not None:
if op.name in self.exclude_gates:
return super().interpret_operation(op)

try:
cumulative_angles = qml.math.stack(op.single_qubit_rot_angles())
except (NotImplementedError, AttributeError):
return self._handle_non_fusible_op(op)

return self._handle_fusible_op(op, cumulative_angles)

def interpret_all_previous_ops(self) -> None:
"""Interpret all previous operations stored in the instance."""

# As above, the order might not be deterministic if wires (the keys) are abstract.
# However, this only impacts operators without any shared wires,
# which does not affect the correctness of the result.
ops_remaining = list(dict.fromkeys(self.previous_ops.values()))

for op in ops_remaining:
super().interpret_operation(op)

self.previous_ops.clear()

def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list:
"""Evaluate a jaxpr.

Args:
jaxpr (jax.core.Jaxpr): the jaxpr to evaluate
consts (list[TensorLike]): the constant variables for the jaxpr
*args (tuple[TensorLike]): The arguments for the jaxpr.

Returns:
list[TensorLike]: the results of the execution.
"""

self.setup()

for arg, invar in zip(args, jaxpr.invars, strict=True):
self._env[invar] = arg
for const, constvar in zip(consts, jaxpr.constvars, strict=True):
self._env[constvar] = const

for eqn in jaxpr.eqns:

prim_type = getattr(eqn.primitive, "prim_type", "")

custom_handler = self._primitive_registrations.get(eqn.primitive, None)
if custom_handler:
self.interpret_all_previous_ops()
invals = [self.read(invar) for invar in eqn.invars]
outvals = custom_handler(self, *invals, **eqn.params)
elif prim_type == "operator":
outvals = self.interpret_operation_eqn(eqn)
elif prim_type == "measurement":
self.interpret_all_previous_ops()
outvals = self.interpret_measurement_eqn(eqn)
else:
if prim_type == "transform":
self.interpret_all_previous_ops()
invals = [self.read(invar) for invar in eqn.invars]
subfuns, params = eqn.primitive.get_bind_params(eqn.params)
outvals = eqn.primitive.bind(*subfuns, *invals, **params)

if not eqn.primitive.multiple_results:
outvals = [outvals]
for outvar, outval in zip(eqn.outvars, outvals, strict=True):
self._env[outvar] = outval

self.interpret_all_previous_ops()

outvals = []
for var in jaxpr.outvars:
outval = self.read(var)
if isinstance(outval, Operator):
outvals.append(super().interpret_operation(outval))
else:
outvals.append(outval)

self.cleanup()
return outvals

def single_qubit_fusion_plxpr_to_plxpr(
jaxpr, consts, targs, tkwargs, *args
): # pylint: disable=unused-argument
interpreter = SingleQubitFusionInterpreter()

def wrapper(*inner_args):
return interpreter.eval(jaxpr, consts, *inner_args)

return make_jaxpr(wrapper)(*args)

return SingleQubitFusionInterpreter, single_qubit_fusion_plxpr_to_plxpr


SingleQubitFusionInterpreter, single_qubit_plxpr_to_plxpr = _get_plxpr_single_qubit_fusion()


@transform
def single_qubit_fusion(
tape: QuantumScript, atol=1e-8, exclude_gates=None
tape: QuantumScript, atol: Optional[float] = 1e-8, exclude_gates: Optional[list[str]] = None
) -> tuple[QuantumScriptBatch, PostprocessingFn]:
r"""Quantum function transform to fuse together groups of single-qubit
operations into a general single-qubit unitary operation (:class:`~.Rot`).
Expand Down Expand Up @@ -74,6 +287,12 @@ def qfunc(r1, r2):
because Euler angles are not unique for some rotations. ``single_qubit_fusion``
makes a particular choice in this case.

.. note::

The order of the gates resulting from the fusion may be different depending
on wether program capture is enabled or not. This only impacts the order of
operations that do not share any wires, so the correctness of the circuit is not affected.

.. warning::

This function is not differentiable everywhere. It has singularities for specific
Expand Down Expand Up @@ -263,7 +482,7 @@ def qfunc(r1, r2):
list_copy.pop(0)
continue

# Find the next gate that acts on the same wires
# Find the next gate that acts on at least one of the same wires
next_gate_idx = find_next_gate(current_gate.wires, list_copy[1:])

if next_gate_idx is None:
Expand Down Expand Up @@ -298,6 +517,7 @@ def qfunc(r1, r2):
next_gate_angles = qml.math.stack(next_gate.single_qubit_rot_angles())
except (NotImplementedError, AttributeError):
break

cumulative_angles = fuse_rot_angles(cumulative_angles, next_gate_angles)

list_copy.pop(next_gate_idx + 1)
Expand Down
Loading