Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-level dynamic decompositions #6881

Merged
merged 113 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from 95 commits
Commits
Show all changes
113 commits
Select commit Hold shift + click to select a range
696acf3
E.C.
PietropaoloFrisoni Jan 20, 2025
9cdfae9
Creating an empty `DynamicDecomposeInterpreter` c;ass
PietropaoloFrisoni Jan 21, 2025
1e138fa
Sbattendo la testa contro il muro tante volte
PietropaoloFrisoni Jan 21, 2025
76c9250
Current prototype version
PietropaoloFrisoni Jan 22, 2025
c821ac9
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 22, 2025
cee6ec4
Fixing one more problem
PietropaoloFrisoni Jan 22, 2025
a208dc5
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 22, 2025
d16a9e0
Moving tests to separate file
PietropaoloFrisoni Jan 23, 2025
a8b9283
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 23, 2025
18bc43c
Pylint fixes (although premature)
PietropaoloFrisoni Jan 23, 2025
e2e8fd0
Removing reundandt tuple calls
PietropaoloFrisoni Jan 23, 2025
0abd620
Tests with dynamic wires
PietropaoloFrisoni Jan 23, 2025
1e3ffb6
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 23, 2025
1ae399b
Adding Autograph test
PietropaoloFrisoni Jan 23, 2025
3fae553
First attempt (probably still not working)
PietropaoloFrisoni Jan 24, 2025
9a54c3e
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 24, 2025
c5f2ae5
Removing unused parameters and adding a few tests
PietropaoloFrisoni Jan 24, 2025
497440c
Adding a few more tests
PietropaoloFrisoni Jan 24, 2025
c7da133
Removing import
PietropaoloFrisoni Jan 24, 2025
2f0417c
Pylint
PietropaoloFrisoni Jan 24, 2025
3c8bc37
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 24, 2025
e9ff110
Adding test with hyperparameters
PietropaoloFrisoni Jan 27, 2025
b9f5d03
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 27, 2025
f2437b4
Black
PietropaoloFrisoni Jan 27, 2025
8c615c5
A few more tests
PietropaoloFrisoni Jan 27, 2025
9fef95b
Changelog
PietropaoloFrisoni Jan 27, 2025
4a56150
Removing redundant operations
PietropaoloFrisoni Jan 28, 2025
5ffcf90
Merge branch 'cond_dynamic_decomp' of https://github.com/PennyLaneAI/…
PietropaoloFrisoni Jan 28, 2025
99e5589
WIP on `max_expansion` parameter
PietropaoloFrisoni Jan 28, 2025
e733762
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 29, 2025
9792df6
Pre-binding hyperparameters [ci skip]
PietropaoloFrisoni Jan 29, 2025
aa422b3
Removing redundant method
PietropaoloFrisoni Jan 30, 2025
b7a18cd
Pylint
PietropaoloFrisoni Jan 30, 2025
97cba03
Testing CI failures (JAX imports)
PietropaoloFrisoni Jan 30, 2025
e05963e
isort
PietropaoloFrisoni Jan 30, 2025
ec73757
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 30, 2025
bc594ca
Support for consts and hyperparameters
PietropaoloFrisoni Jan 31, 2025
4734388
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 31, 2025
1ea4a14
Fixes neede after autograph PR merged on master
PietropaoloFrisoni Jan 31, 2025
f09937d
Merge branch 'cond_dynamic_decomp' of https://github.com/PennyLaneAI/…
PietropaoloFrisoni Feb 3, 2025
75b34e4
Fixing return stat
PietropaoloFrisoni Feb 3, 2025
0a0871d
Trying to remove reduntant (?) method
PietropaoloFrisoni Feb 4, 2025
f72f290
Soluzione ibrida poco convincente
PietropaoloFrisoni Feb 4, 2025
b4a2c76
Straightforward prototype
PietropaoloFrisoni Feb 4, 2025
64146f0
Removing references to nested decompositions
PietropaoloFrisoni Feb 4, 2025
c267be1
Removing references to nested decompositions
PietropaoloFrisoni Feb 4, 2025
f65ab4e
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Feb 5, 2025
f576486
[ci skip]
PietropaoloFrisoni Feb 5, 2025
abd9fe0
Merge branch 'cond_dynamic_decomp' of https://github.com/PennyLaneAI/…
PietropaoloFrisoni Feb 5, 2025
fd817cc
Removing `DynamicDecomposeInterpreter`
PietropaoloFrisoni Feb 5, 2025
90fc1e5
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Feb 5, 2025
e7a2775
pylint
PietropaoloFrisoni Feb 5, 2025
1897b55
Merge branch 'cond_dynamic_decomp' of https://github.com/PennyLaneAI/…
PietropaoloFrisoni Feb 6, 2025
d47fc6c
Sub-interpreter for nested decomposition
PietropaoloFrisoni Feb 6, 2025
1a17f95
TODO: add more tests
PietropaoloFrisoni Feb 6, 2025
f949897
Quagliato
PietropaoloFrisoni Feb 7, 2025
f60cb6c
simplified logic
PietropaoloFrisoni Feb 7, 2025
a41fe70
Improving test logic
PietropaoloFrisoni Feb 7, 2025
c3271b4
Improving test logic
PietropaoloFrisoni Feb 7, 2025
f9c9320
Improving test logic and special case
PietropaoloFrisoni Feb 7, 2025
b6e46c1
Adding more tests
PietropaoloFrisoni Feb 7, 2025
22e63d2
Testing coverage
PietropaoloFrisoni Feb 7, 2025
0959b31
Making the max_expansion parameter coherent
PietropaoloFrisoni Feb 8, 2025
0a892b1
Fixing existing tests
PietropaoloFrisoni Feb 8, 2025
c333bc8
Test coverage holes
PietropaoloFrisoni Feb 8, 2025
cf05139
Corr. tests
PietropaoloFrisoni Feb 8, 2025
43f9521
More tests
PietropaoloFrisoni Feb 9, 2025
a324e95
Cleaning comments
PietropaoloFrisoni Feb 9, 2025
c01521a
Typo
PietropaoloFrisoni Feb 9, 2025
1986b6a
Returning eval in `decompose_plxpr_to_plxpr`
PietropaoloFrisoni Feb 9, 2025
4510e7e
Tests with combined params
PietropaoloFrisoni Feb 9, 2025
f621156
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Feb 10, 2025
e9834af
Suggestions from code review (more tests)
PietropaoloFrisoni Feb 10, 2025
bd9473b
disabling wrong iimport order in test file (conflict between isort an…
PietropaoloFrisoni Feb 10, 2025
43a6838
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Feb 10, 2025
98fe17e
Suggestions from code review
PietropaoloFrisoni Feb 10, 2025
902e4f8
Refactoring changelog with program capture entries
PietropaoloFrisoni Feb 10, 2025
0da28e0
Merge branch 'cond_dynamic_decomp' of https://github.com/PennyLaneAI/…
PietropaoloFrisoni Feb 10, 2025
4f95b74
Merge branch 'master' into cond_dynamic_decomp
PietropaoloFrisoni Feb 10, 2025
8e34af7
Merging from base PR
PietropaoloFrisoni Feb 10, 2025
188d044
Merge branch 'cond_dynamic_decomp' of https://github.com/PennyLaneAI/…
PietropaoloFrisoni Feb 10, 2025
976bc16
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Feb 10, 2025
fa7ec7c
Pylint
PietropaoloFrisoni Feb 10, 2025
fbf7760
Sphynx maledetto
PietropaoloFrisoni Feb 10, 2025
4564e27
Fixing failure in SEL
PietropaoloFrisoni Feb 11, 2025
3d266b3
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Feb 11, 2025
828b9dd
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Feb 19, 2025
5279393
Suggestions from code review
PietropaoloFrisoni Feb 19, 2025
12f7242
Removing redundant code
PietropaoloFrisoni Feb 19, 2025
f0c5a07
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Feb 19, 2025
f032ade
Extended name
PietropaoloFrisoni Feb 19, 2025
a5a47f3
Triggering CI again
PietropaoloFrisoni Feb 19, 2025
a548d0c
Changelog entry
PietropaoloFrisoni Feb 21, 2025
fb7eb53
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Feb 21, 2025
f6d58e8
Merge branch 'master' into multi_level_decomp
PietropaoloFrisoni Feb 21, 2025
ae89262
Suggestion from code review
PietropaoloFrisoni Feb 21, 2025
bc54b24
Pylint
PietropaoloFrisoni Feb 21, 2025
61d0439
Merge branch 'master' into multi_level_decomp
PietropaoloFrisoni Feb 21, 2025
7362df1
Slight improvement in enviroment management
PietropaoloFrisoni Feb 22, 2025
acea027
pylint
PietropaoloFrisoni Feb 22, 2025
e6b31bb
Clarifying comment
PietropaoloFrisoni Feb 22, 2025
3c07579
making attribute private
PietropaoloFrisoni Feb 22, 2025
6adf6b1
Merge branch 'master' into multi_level_decomp
PietropaoloFrisoni Feb 22, 2025
7c4e81b
Replacing ccurrent data structure with a unique `ChainMap`
PietropaoloFrisoni Feb 22, 2025
914d53d
Merge branch 'multi_level_decomp' of https://github.com/PennyLaneAI/p…
PietropaoloFrisoni Feb 22, 2025
2c39162
Removing redundant line
PietropaoloFrisoni Feb 22, 2025
4096e02
Cleaner structure
PietropaoloFrisoni Feb 22, 2025
53d903d
CI
PietropaoloFrisoni Feb 22, 2025
ca1f3c6
Removing read function
PietropaoloFrisoni Feb 22, 2025
e940c17
Avoiding copying the interpreter
PietropaoloFrisoni Feb 23, 2025
8d65f14
Last suggestion from code review
PietropaoloFrisoni Feb 25, 2025
2b4a10c
Merge branch 'master' into multi_level_decomp
PietropaoloFrisoni Feb 25, 2025
30a568b
Merge branch 'master' into multi_level_decomp
PietropaoloFrisoni Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@
* Implemented a `compute_plxpr_decomposition` method in the `qml.operation.Operator` class to apply dynamic decompositions
with program capture enabled.
[(#6859)](https://github.com/PennyLaneAI/pennylane/pull/6859)
[(#6881)](https://github.com/PennyLaneAI/pennylane/pull/6881)

* Autograph can now be used with custom operations defined outside of the pennylane namespace.
[(#6931)](https://github.com/PennyLaneAI/pennylane/pull/6931)
Expand Down
8 changes: 6 additions & 2 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,8 +1361,12 @@ def compute_plxpr_decomposition(*args, **hyperparameters) -> None:
When the program capture feature is enabled with ``qml.capture.enable()``, the decomposition of the operator
is computed with this method if it is defined. Otherwise, the :meth:`~.Operator.compute_decomposition` method is used.

If this method is defined, the control flow operations within the method are recorded in the JAX representation
of the operator's decomposition.
The exception to this rule is when the operator is returned from the :meth:`~.Operator.compute_decomposition` method
of another operator, in which case the decomposition is performed with :meth:`~.Operator.compute_decomposition`
(even if this method is defined), and not with this method.

When ``compute_plxpr_decomposition`` is defined for an operator, the control flow operations within the method
(specifying the decomposition of the operator) are recorded in the JAX representation.

This method is experimental and subject to change.

Expand Down
277 changes: 256 additions & 21 deletions pennylane/transforms/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
# pylint: disable=unnecessary-lambda-assignment

import warnings
from collections.abc import Callable, Generator, Iterable
from collections.abc import Generator, Iterable
from copy import copy
from functools import lru_cache, partial
from typing import Optional
from typing import Callable, Optional, Sequence

import pennylane as qml
from pennylane.transforms.core import transform
Expand Down Expand Up @@ -62,12 +63,18 @@


@lru_cache
def _get_plxpr_decompose(): # pylint: disable=missing-docstring
def _get_plxpr_decompose(): # pylint: disable=missing-docstring, too-many-statements
try:
# pylint: disable=import-outside-toplevel
from jax import make_jaxpr
import jax

from pennylane.capture.primitives import (
cond_prim,
ctrl_transform_prim,
for_loop_prim,
while_loop_prim,
)

from pennylane.capture.primitives import ctrl_transform_prim
except ImportError: # pragma: no cover
return None, None

Expand Down Expand Up @@ -106,6 +113,7 @@
bool: Whether or not ``op`` is valid or needs to be decomposed. ``True`` means
that the operator does not need to be decomposed.
"""

if not op.has_decomposition:
if not self.gate_set(op):
warnings.warn(
Expand All @@ -117,15 +125,13 @@
return True
return self.gate_set(op)

def decompose_operation(self, op: qml.operation.Operator):
def decompose_operation(self, op: qml.operation.Operator, current_depth: int = 0):
"""Decompose a PennyLane operation instance if it does not satisfy the
provided gate set.

Args:
op (Operator): a pennylane operator instance

Returns:
Any
current_depth (int): the current depth of the decomposition

This method is only called when the operator's output is a dropped variable,
so the output will not affect later equations in the circuit.
Expand All @@ -135,47 +141,276 @@
if self.gate_set(op):
return self.interpret_operation(op)

max_expansion = (
self.max_expansion - current_depth if self.max_expansion is not None else None
)

with qml.capture.pause():
decomposition = list(
_operator_decomposition_gen(
op, self.stopping_condition, max_expansion=self.max_expansion
op,
self.stopping_condition,
max_expansion=max_expansion,
)
)

return [self.interpret_operation(decomp_op) for decomp_op in decomposition]

def interpret_operation_eqn(self, eqn):
def _evaluate_jaxpr_decomposition(self, op: qml.operation.Operator, current_depth: int = 0):
"""Creates and evaluates a Jaxpr of the plxpr decomposition of an operator."""

if self.gate_set(op):
return self.interpret_operation(op)

if self.max_expansion is not None and current_depth >= self.max_expansion:
return self.interpret_operation(op)

args = (*op.parameters, *op.wires)
jaxpr_decomp = qml.capture.make_plxpr(
partial(op.compute_plxpr_decomposition, **op.hyperparameters)
)(*args)
current_depth += 1

return self.eval(
jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args, current_depth=current_depth
)

def eval(
self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args, current_depth: int = 0
) -> list:
"""
Evaluates a jaxpr, which can also be generated by a dynamic decomposition.

Args:
jaxpr_decomp (jax.core.Jaxpr): the Jaxpr to evaluate
consts (list[TensorLike]): the constant variables for the jaxpr
*args: the arguments to use in the evaluation
current_depth (int): the current depth of the decomposition
"""

# We update the 'self._env' environment directly because the jaxpr of the decomposition
# can be called while evaluating another jaxpr of the previous decomposition.

for arg, invar in zip(args, jaxpr.invars, strict=True):
self._env[invar] = arg
for const, constvar in zip(consts, jaxpr.constvars, strict=True):
self._env[constvar] = const

for eq in jaxpr.eqns:

prim_type = getattr(eq.primitive, "prim_type", "")
custom_handler = self._primitive_registrations.get(eq.primitive, None)

if custom_handler:

invals = [self.read(invar) for invar in eq.invars]
control_flow_handlers = {"handle_for_loop", "handle_while_loop", "handle_cond"}
extra_kwargs = (
{"current_depth": current_depth}
if custom_handler.__name__ in control_flow_handlers
else {}
)
outvals = custom_handler(self, *invals, **eq.params, **extra_kwargs)

elif prim_type == "operator":
outvals = self.interpret_operation_eqn(eq, current_depth)
elif prim_type == "measurement":
outvals = self.interpret_measurement_eqn(eq)
else:
invals = [self.read(invar) for invar in eq.invars]
subfuns, params = eq.primitive.get_bind_params(eq.params)
outvals = eq.primitive.bind(*subfuns, *invals, **params)

if not eq.primitive.multiple_results:
outvals = [outvals]

for outvar, outval in zip(eq.outvars, outvals, strict=True):
self._env[outvar] = outval

outvals = []
for var in jaxpr.outvars:
outval = self.read(var)
if isinstance(outval, qml.operation.Operator):
outvals.append(self.interpret_operation(outval))
else:
outvals.append(outval)

return outvals

def interpret_operation_eqn(self, eqn, current_depth: int = 0):
"""Interpret an equation corresponding to an operator.

Args:
eqn (jax.core.JaxprEqn): a jax equation for an operator.
current_depth (int): the current depth of the decomposition.

See also: :meth:`~.interpret_operation`.

"""
invals = (self.read(invar) for invar in eqn.invars)

with qml.QueuingManager.stop_recording():
op = eqn.primitive.impl(*invals, **eqn.params)
if eqn.outvars[0].__class__.__name__ == "DropVar":

if op.has_plxpr_decomposition:
if not eqn.outvars[0].__class__.__name__ == "DropVar":
return op

args = (*op.parameters, *op.wires)
qml.capture.run_autograph(op.compute_plxpr_decomposition)(
*args, **op.hyperparameters
)
if not op.has_plxpr_decomposition:
return self.decompose_operation(op, current_depth)

else:
return self._evaluate_jaxpr_decomposition(op, current_depth)

return self.decompose_operation(op)
def jaxpr_to_jaxpr_decomp(
interpreter: DecomposeInterpreter, jaxpr: "jax.core.Jaxpr", consts, *args, current_depth
) -> "jax.core.Jaxpr":

return op
f = partial(interpreter.eval, jaxpr, consts, current_depth=current_depth)

return jax.make_jaxpr(f)(*args)

# pylint: disable=unused-variable,missing-function-docstring
@DecomposeInterpreter.register_primitive(ctrl_transform_prim)
def handle_ctrl_transform(*_, **__):
raise NotImplementedError

# We register the primitives to propagate the current depth
# in the dynamic decomposition evaluation with program capture enabled.

@DecomposeInterpreter.register_primitive(cond_prim)
def handle_cond(self, *invals, jaxpr_branches, consts_slices, args_slice, current_depth=0):
"""Handle a cond primitive."""

args = invals[args_slice]

new_jaxprs = []
new_consts = []
new_consts_slices = []
end_const_ind = len(jaxpr_branches)

for const_slice, jaxpr in zip(consts_slices, jaxpr_branches):
consts = invals[const_slice]
if jaxpr is None:
new_jaxprs.append(None)
new_consts_slices.append(slice(0, 0))
else:
new_jaxpr = jaxpr_to_jaxpr_decomp(
copy(self), jaxpr, consts, *args, current_depth=current_depth
)
new_jaxprs.append(new_jaxpr.jaxpr)
new_consts.extend(new_jaxpr.consts)
new_consts_slices.append(
slice(end_const_ind, end_const_ind + len(new_jaxpr.consts))
)
end_const_ind += len(new_jaxpr.consts)

new_args_slice = slice(end_const_ind, None)
return cond_prim.bind(
*invals[: len(jaxpr_branches)],
*new_consts,
*args,
jaxpr_branches=new_jaxprs,
consts_slices=new_consts_slices,
args_slice=new_args_slice,
)

@DecomposeInterpreter.register_primitive(for_loop_prim)
def handle_for_loop(
self,
start,
stop,
step,
*args,
jaxpr_body_fn,
consts_slice,
args_slice,
abstract_shapes_slice,
current_depth=0,
):
"""Handle a for loop primitive."""

consts = args[consts_slice]
init_state = args[args_slice]
abstract_shapes = args[abstract_shapes_slice]

new_jaxpr_body_fn = jaxpr_to_jaxpr_decomp(
copy(self),
jaxpr_body_fn,
consts,
*abstract_shapes,
start,
*init_state,
current_depth=current_depth,
)

Check notice on line 343 in pennylane/transforms/decompose.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/decompose.py#L343

Too many arguments (9/5) (too-many-arguments)

consts_slice = slice(0, len(new_jaxpr_body_fn.consts))
abstract_shapes_slice = slice(consts_slice.stop, consts_slice.stop + len(abstract_shapes))
args_slice = slice(abstract_shapes_slice.stop, None)
return for_loop_prim.bind(
start,
stop,
step,
*new_jaxpr_body_fn.consts,
*abstract_shapes,
*init_state,
jaxpr_body_fn=new_jaxpr_body_fn.jaxpr,
consts_slice=consts_slice,
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
)

@DecomposeInterpreter.register_primitive(while_loop_prim)
def handle_while_loop(
self,
*invals,
jaxpr_body_fn,
jaxpr_cond_fn,
body_slice,
cond_slice,
args_slice,
abstract_shapes_slice,
current_depth=0,
):
"""Handle a while loop primitive."""
consts_body = invals[body_slice]
consts_cond = invals[cond_slice]
init_state = invals[args_slice]
abstract_shapes = invals[abstract_shapes_slice]

new_jaxpr_body_fn = jaxpr_to_jaxpr_decomp(
copy(self),
jaxpr_body_fn,
consts_body,
*abstract_shapes,
*init_state,
current_depth=current_depth,
)
new_jaxpr_cond_fn = jaxpr_to_jaxpr_decomp(
copy(self),

Check notice on line 388 in pennylane/transforms/decompose.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/transforms/decompose.py#L388

Too many arguments (8/5) (too-many-arguments)
jaxpr_cond_fn,
consts_cond,
*abstract_shapes,
*init_state,
current_depth=current_depth,
)

body_consts = slice(0, len(new_jaxpr_body_fn.consts))
cond_consts = slice(body_consts.stop, body_consts.stop + len(new_jaxpr_cond_fn.consts))
abstract_shapes_slice = slice(cond_consts.stop, cond_consts.stop + len(abstract_shapes))
args_slice = slice(abstract_shapes_slice.stop, None)

return while_loop_prim.bind(
*new_jaxpr_body_fn.consts,
*new_jaxpr_cond_fn.consts,
*abstract_shapes,
*init_state,
jaxpr_body_fn=new_jaxpr_body_fn.jaxpr,
jaxpr_cond_fn=new_jaxpr_cond_fn.jaxpr,
body_slice=body_consts,
cond_slice=cond_consts,
args_slice=args_slice,
abstract_shapes_slice=abstract_shapes_slice,
)

def decompose_plxpr_to_plxpr(
jaxpr, consts, targs, tkwargs, *args
): # pylint: disable=unused-argument
Expand All @@ -187,7 +422,7 @@
def wrapper(*inner_args):
return decomposer.eval(jaxpr, consts, *inner_args)

return make_jaxpr(wrapper)(*args)
return jax.make_jaxpr(wrapper)(*args)

return DecomposeInterpreter, decompose_plxpr_to_plxpr

Expand Down
Loading