Skip to content

Commit

Permalink
Multi-level dynamic decompositions (#6881)
Browse files Browse the repository at this point in the history
**Context:** After implementing the first minimum working prototype in
#6859, we consider the case where operators in the dynamic
decompositions have another dynamic decomposition available. That is,
another `compute_plxpr_decomposition` method.

**Description of the Change:** The interpreter can now handle both the
`max_expansion` and `gate_set` arguments.

**Benefits:** More flexible and powerful implementation

**Possible Drawbacks:** The idea at the current stage is that the
`compute_plxpr_decomposition` should *always* be preferred over the old
static `compute_decomposition` method if it has been defined for an
operator. However, there is one exception to this logic, which is a
limitation in the current implementation due to code structure and time
estimates.

In the dynamic decomposition of an operator with a
`compute_plxpr_decomposition` method, if we encounter an operator
without a `compute_plxpr_decomposition` method along the decomposition,
then we switch to the current decomposition logic, which **disables**
the program capture mechanism and decomposes recursively until we reach
a max depth or a stopping condition, then **re-actives program capture
again at the end of the decomposition**.

Because of this, if in the dynamic decomposition of an operator we
encounter an operator without a `compute_plxpr_decomposition` and
therefore we call the `compute_decomposition` method of that operator,
*if an operator with a `compute_plxpr_decomposition` is returned, we
ignore that and proceed with the usual `compute_decomposition` method*.
See the `test_nested_decomp_no_plxpr_decomp_max_exp` test for a concrete
example.

This is a convenient compromise that appears because:

- we decided to entangle this logic with the standard decomposition
pipeline
- we have a `compute_plxpr_decomposition` defined for just a few
operators, and a `compute_decomposition` is always present for such
operators.
- the current decomposition pipeline disables program capture globally
and acts recursively, making the integration harder

**In the future, this will not be an issue since we'll have a unique new
way of performing decompositions.**

**Related GitHub Issues:** None,

**Related ShortCut Stories:** [sc-83111]
  • Loading branch information
PietropaoloFrisoni authored Feb 25, 2025
1 parent af65460 commit 7c239b6
Show file tree
Hide file tree
Showing 4 changed files with 829 additions and 97 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@
* Implemented a `compute_plxpr_decomposition` method in the `qml.operation.Operator` class to apply dynamic decompositions
with program capture enabled.
[(#6859)](https://github.com/PennyLaneAI/pennylane/pull/6859)
[(#6881)](https://github.com/PennyLaneAI/pennylane/pull/6881)

* Autograph can now be used with custom operations defined outside of the pennylane namespace.
[(#6931)](https://github.com/PennyLaneAI/pennylane/pull/6931)
Expand Down
8 changes: 6 additions & 2 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,8 +1361,12 @@ def compute_plxpr_decomposition(*args, **hyperparameters) -> None:
When the program capture feature is enabled with ``qml.capture.enable()``, the decomposition of the operator
is computed with this method if it is defined. Otherwise, the :meth:`~.Operator.compute_decomposition` method is used.
If this method is defined, the control flow operations within the method are recorded in the JAX representation
of the operator's decomposition.
The exception to this rule is when the operator is returned from the :meth:`~.Operator.compute_decomposition` method
of another operator, in which case the decomposition is performed with :meth:`~.Operator.compute_decomposition`
(even if this method is defined), and not with this method.
When ``compute_plxpr_decomposition`` is defined for an operator, the control flow operations within the method
(specifying the decomposition of the operator) are recorded in the JAX representation.
This method is experimental and subject to change.
Expand Down
150 changes: 128 additions & 22 deletions pennylane/transforms/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
# pylint: disable=unnecessary-lambda-assignment

import warnings
from collections.abc import Callable, Generator, Iterable
from collections import ChainMap
from collections.abc import Generator, Iterable
from functools import lru_cache, partial
from typing import Optional
from typing import Callable, Optional, Sequence

import pennylane as qml
from pennylane.transforms.core import transform
Expand Down Expand Up @@ -62,12 +63,13 @@ def _operator_decomposition_gen(


@lru_cache
def _get_plxpr_decompose(): # pylint: disable=missing-docstring
def _get_plxpr_decompose(): # pylint: disable=missing-docstring, too-many-statements
try:
# pylint: disable=import-outside-toplevel
from jax import make_jaxpr
import jax

from pennylane.capture.primitives import ctrl_transform_prim

except ImportError: # pragma: no cover
return None, None

Expand All @@ -80,6 +82,14 @@ class DecomposeInterpreter(qml.capture.PlxprInterpreter):

def __init__(self, gate_set=None, max_expansion=None):
self.max_expansion = max_expansion
self._current_depth = 0

# We use a ChainMap to store the environment frames,
# which allows us to push and pop environments without copying
# the interpreter instance when we evaluate a jaxpr of a dynamic decomposition.

# The name is different from the _env in the parent class (a dictionary) to avoid confusion.
self._env_map = ChainMap()

if gate_set is None:
gate_set = set(qml.ops.__all__)
Expand All @@ -94,7 +104,23 @@ def __init__(self, gate_set=None, max_expansion=None):
else:
self.gate_set = gate_set

super().__init__()
def setup(self) -> None:
"""Setup the environment for the interpreter by pushing a new environment frame."""

# This is the local environment for the jaxpr evaluation, on the top of the stack,
# from which the interpreter reads and writes variables.
# ChainMap writes to the first dictionary in the chain by default.
self._env_map = self._env_map.new_child()

def cleanup(self) -> None:
"""Cleanup the environment by popping the top-most environment frame."""

# We delete the top-most environment frame after the evaluation is done.
self._env_map = self._env_map.parents

def read(self, var):
"""Extract the value corresponding to a variable."""
return var.val if isinstance(var, jax.core.Literal) else self._env_map[var]

def stopping_condition(self, op: qml.operation.Operator) -> bool:
"""Function to determine whether or not an operator needs to be decomposed or not.
Expand All @@ -106,6 +132,7 @@ def stopping_condition(self, op: qml.operation.Operator) -> bool:
bool: Whether or not ``op`` is valid or needs to be decomposed. ``True`` means
that the operator does not need to be decomposed.
"""

if not op.has_decomposition:
if not self.gate_set(op):
warnings.warn(
Expand All @@ -124,9 +151,6 @@ def decompose_operation(self, op: qml.operation.Operator):
Args:
op (Operator): a pennylane operator instance
Returns:
Any
This method is only called when the operator's output is a dropped variable,
so the output will not affect later equations in the circuit.
Expand All @@ -135,41 +159,123 @@ def decompose_operation(self, op: qml.operation.Operator):
if self.gate_set(op):
return self.interpret_operation(op)

max_expansion = (
self.max_expansion - self._current_depth if self.max_expansion is not None else None
)

with qml.capture.pause():
decomposition = list(
_operator_decomposition_gen(
op, self.stopping_condition, max_expansion=self.max_expansion
op,
self.stopping_condition,
max_expansion=max_expansion,
)
)

return [self.interpret_operation(decomp_op) for decomp_op in decomposition]

def interpret_operation_eqn(self, eqn):
def _evaluate_jaxpr_decomposition(self, op: qml.operation.Operator):
"""Creates and evaluates a Jaxpr of the plxpr decomposition of an operator."""

if self.gate_set(op):
return self.interpret_operation(op)

if self.max_expansion is not None and self._current_depth >= self.max_expansion:
return self.interpret_operation(op)

args = (*op.parameters, *op.wires)

jaxpr_decomp = qml.capture.make_plxpr(
partial(op.compute_plxpr_decomposition, **op.hyperparameters)
)(*args)

self._current_depth += 1
# We don't need to copy the interpreter here, as the jaxpr of the decomposition
# is evaluated with a new environment frame placed on top of the stack.
out = self.eval(jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args)
self._current_depth -= 1

return out

def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list:
"""
Evaluates a jaxpr, which can also be generated by a dynamic decomposition.
Args:
jaxpr_decomp (jax.core.Jaxpr): the Jaxpr to evaluate
consts (list[TensorLike]): the constant variables for the jaxpr
*args: the arguments to use in the evaluation
"""

self.setup()

for arg, invar in zip(args, jaxpr.invars, strict=True):
self._env_map[invar] = arg
for const, constvar in zip(consts, jaxpr.constvars, strict=True):
self._env_map[constvar] = const

for eq in jaxpr.eqns:

prim_type = getattr(eq.primitive, "prim_type", "")
custom_handler = self._primitive_registrations.get(eq.primitive, None)

if custom_handler:

invals = [self.read(invar) for invar in eq.invars]
outvals = custom_handler(self, *invals, **eq.params)

elif prim_type == "operator":
outvals = self.interpret_operation_eqn(eq)
elif prim_type == "measurement":
outvals = self.interpret_measurement_eqn(eq)
else:
invals = [self.read(invar) for invar in eq.invars]
subfuns, params = eq.primitive.get_bind_params(eq.params)
outvals = eq.primitive.bind(*subfuns, *invals, **params)

if not eq.primitive.multiple_results:
outvals = [outvals]

for outvar, outval in zip(eq.outvars, outvals, strict=True):
self._env_map[outvar] = outval

outvals = []
for var in jaxpr.outvars:
outval = self.read(var)
if isinstance(outval, qml.operation.Operator):
outvals.append(self.interpret_operation(outval))
else:
outvals.append(outval)

self.cleanup()

return outvals

def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"):
"""Interpret an equation corresponding to an operator.
If the operator has a dynamic decomposition defined, this method will
create and evaluate the jaxpr of the decomposition using the :meth:`~.eval` method.
Args:
eqn (jax.core.JaxprEqn): a jax equation for an operator.
See also: :meth:`~.interpret_operation`.
"""

invals = (self.read(invar) for invar in eqn.invars)

with qml.QueuingManager.stop_recording():
op = eqn.primitive.impl(*invals, **eqn.params)
if eqn.outvars[0].__class__.__name__ == "DropVar":

if op.has_plxpr_decomposition:

args = (*op.parameters, *op.wires)
qml.capture.run_autograph(op.compute_plxpr_decomposition)(
*args, **op.hyperparameters
)

else:
if not eqn.outvars[0].__class__.__name__ == "DropVar":
return op

return self.decompose_operation(op)
if not op.has_plxpr_decomposition:
return self.decompose_operation(op)

return op
return self._evaluate_jaxpr_decomposition(op)

# pylint: disable=unused-variable,missing-function-docstring
@DecomposeInterpreter.register_primitive(ctrl_transform_prim)
Expand All @@ -187,7 +293,7 @@ def decompose_plxpr_to_plxpr(
def wrapper(*inner_args):
return decomposer.eval(jaxpr, consts, *inner_args)

return make_jaxpr(wrapper)(*args)
return jax.make_jaxpr(wrapper)(*args)

return DecomposeInterpreter, decompose_plxpr_to_plxpr

Expand Down
Loading

0 comments on commit 7c239b6

Please sign in to comment.