Skip to content

Commit

Permalink
[Capture] Switch from binding qnode_kwargs to execution_config (#…
Browse files Browse the repository at this point in the history
…6991)

**Context:**

As the capture workflow is getting more complicated, we should start
using the well-defined `ExecutionConfig` object instead of the ambiguous
and unspecified `qnode_kwargs`.

We also need to start passing the `execution_config` to the device, as
that information is needed for handling mid circuit measurements.

**Description of the Change:**

Switches from binding a `qnode_kwargs` dictionary to an
`execution_config` object.

**Benefits:**

Easier to manage the configuration of a workflow. Can specify MCM
configuration info for device execution.

**Possible Drawbacks:**

Technically a breaking change, but a breaking change to an experimental
project.

**Related GitHub Issues:**

[sc-84916]

---------

Co-authored-by: Mudit Pandey <[email protected]>
Co-authored-by: Pietropaolo Frisoni <[email protected]>
  • Loading branch information
3 people authored Feb 24, 2025
1 parent d57655c commit 76b585d
Show file tree
Hide file tree
Showing 13 changed files with 77 additions and 66 deletions.
6 changes: 6 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@
`jnp.arange`, and `jnp.full`.
[#6865)](https://github.com/PennyLaneAI/pennylane/pull/6865)

* The qnode primitive now stores the `ExecutionConfig` instead of `qnode_kwargs`.
[(#6991)](https://github.com/PennyLaneAI/pennylane/pull/6991)

* `Device.eval_jaxpr` now accepts an `execution_config` keyword argument.
[(#6991)](https://github.com/PennyLaneAI/pennylane/pull/6991)

* The adjoint jvp of a jaxpr can be computed using default.qubit tooling.
[(#6875)](https://github.com/PennyLaneAI/pennylane/pull/6875)

Expand Down
4 changes: 2 additions & 2 deletions pennylane/capture/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def handle_while_loop(

# pylint: disable=unused-argument, too-many-arguments
@PlxprInterpreter.register_primitive(qnode_prim)
def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, n_consts):
def handle_qnode(self, *invals, shots, qnode, device, execution_config, qfunc_jaxpr, n_consts):
"""Handle a qnode primitive."""
consts = invals[:n_consts]
args = invals[n_consts:]
Expand All @@ -604,7 +604,7 @@ def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr,
shots=shots,
qnode=qnode,
device=device,
qnode_kwargs=qnode_kwargs,
execution_config=execution_config,
qfunc_jaxpr=new_qfunc_jaxpr.jaxpr,
n_consts=len(new_qfunc_jaxpr.consts),
)
Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,9 +935,9 @@ def execute_and_compute_vjp(

return tuple(zip(*results))

# pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel, unused-argument
def eval_jaxpr(
self, jaxpr: "jax.core.Jaxpr", consts: list[TensorLike], *args
self, jaxpr: "jax.core.Jaxpr", consts: list[TensorLike], *args, execution_config=None
) -> list[TensorLike]:
from .qubit.dq_interpreter import DefaultQubitInterpreter

Expand Down
11 changes: 9 additions & 2 deletions pennylane/devices/device_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,14 +970,21 @@ def supports_vjp(
return type(self).compute_vjp != Device.compute_vjp

def eval_jaxpr(
self, jaxpr: "jax.core.Jaxpr", consts: list[TensorLike], *args
self,
jaxpr: "jax.core.Jaxpr",
consts: list[TensorLike],
*args,
execution_config: Optional[ExecutionConfig] = None,
) -> list[TensorLike]:
"""An **experimental** method for natively evaluating PLXPR. See the ``capture`` module for more details.
Args:
jaxpr (jax.core.Jaxpr): Pennylane variant jaxpr containing quantum operations and measurements
consts (list[TensorLike]): the closure variables ``consts`` corresponding to the jaxpr
*args (TensorLike): the variables to use with the jaxpr'.
*args (TensorLike): the variables to use with the jaxpr.
Keyword Args:
execution_config (Optional[ExecutionConfig]): a data structure with additional information required for execution
Returns:
list[TensorLike]: the result of evaluating the jaxpr with the given parameters.
Expand Down
5 changes: 4 additions & 1 deletion pennylane/devices/null_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,10 @@ def execute_and_compute_vjp(
vjps = tuple(self._vjp(c, _interface(execution_config)) for c in circuits)
return results, vjps

def eval_jaxpr(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list:
# pylint: disable= unused-argument
def eval_jaxpr(
self, jaxpr: "jax.core.Jaxpr", consts: list, *args, execution_config=None
) -> list:
from pennylane.capture.primitives import ( # pylint: disable=import-outside-toplevel
AbstractMeasurement,
)
Expand Down
1 change: 1 addition & 0 deletions pennylane/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def __getattr__(name):
"is_independent",
"iscomplex",
"jacobian",
"Interface",
"marginal_prob",
"max_entropy",
"min_entropy",
Expand Down
14 changes: 7 additions & 7 deletions pennylane/tape/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
from .qscript import QuantumScript, QuantumScriptBatch, QuantumScriptOrBatch, make_qscript
from .tape import QuantumTape, QuantumTapeBatch, TapeError, expand_tape_state_prep

try:
from .plxpr_conversion import plxpr_to_tape
except ImportError: # pragma: no cover

# pragma: no cover
def plxpr_to_tape(jaxpr: "jax.core.Jaxpr", consts, *args, shots=None): # pragma: no cover
"""A dummy version of ``plxpr_to_tape`` when jax is not installed on the system."""
raise ImportError("plxpr_to_tape requires jax to be installed") # pragma: no cover
# pylint: disable=import-outside-toplevel
def __getattr__(key):
if key == "plxpr_to_tape":
from .plxpr_conversion import plxpr_to_tape

return plxpr_to_tape
raise AttributeError(f"module 'pennylane.tape' has no attribute '{key}'") # pragma: no cover
53 changes: 22 additions & 31 deletions pennylane/workflow/_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@
features is non-exhaustive.
"""
from copy import copy
from functools import partial
from numbers import Number
from warnings import warn
Expand All @@ -119,6 +118,8 @@
from pennylane.capture.custom_primitives import QmlPrimitive
from pennylane.typing import TensorLike

from .construct_execution_config import construct_execution_config


def _is_scalar_tensor(arg) -> bool:
"""Check if an argument is a scalar tensor-like object or a numeric scalar."""
Expand Down Expand Up @@ -184,7 +185,7 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0, batch_shape=(

# pylint: disable=too-many-arguments, unused-argument
@qnode_prim.def_impl
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None):
def _(*args, qnode, shots, device, execution_config, qfunc_jaxpr, n_consts, batch_dims=None):
if shots != device.shots:
raise NotImplementedError(
"Overriding shots is not yet supported with the program capture execution."
Expand All @@ -193,16 +194,17 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_di
consts = args[:n_consts]
non_const_args = args[n_consts:]

if batch_dims is None:
return device.eval_jaxpr(qfunc_jaxpr, consts, *non_const_args)
return jax.vmap(partial(device.eval_jaxpr, qfunc_jaxpr, consts), batch_dims[n_consts:])(
*non_const_args
partial_eval = partial(
device.eval_jaxpr, qfunc_jaxpr, consts, execution_config=execution_config
)
if batch_dims is None:
return partial_eval(*non_const_args)
return jax.vmap(partial_eval, batch_dims[n_consts:])(*non_const_args)


# pylint: disable=unused-argument
@qnode_prim.def_abstract_eval
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None):
def _(*args, qnode, shots, device, execution_config, qfunc_jaxpr, n_consts, batch_dims=None):

mps = qfunc_jaxpr.outvars

Expand All @@ -223,7 +225,7 @@ def _qnode_batching_rule(
qnode,
shots,
device,
qnode_kwargs,
execution_config,
qfunc_jaxpr,
n_consts,
):
Expand Down Expand Up @@ -265,7 +267,7 @@ def _qnode_batching_rule(
shots=shots,
qnode=qnode,
device=device,
qnode_kwargs=qnode_kwargs,
execution_config=execution_config,
qfunc_jaxpr=qfunc_jaxpr,
n_consts=n_consts,
batch_dims=batch_dims,
Expand All @@ -292,29 +294,21 @@ def _backprop(args, tangents, **impl_kwargs):
def _finite_diff(args, tangents, **impl_kwargs):
f = partial(qnode_prim.bind, **impl_kwargs)
return qml.gradients.finite_diff_jvp(
f, args, tangents, **impl_kwargs["qnode_kwargs"]["gradient_kwargs"]
f, args, tangents, **impl_kwargs["execution_config"].gradient_keyword_arguments
)


diff_method_map = {"backprop": _backprop, "finite-diff": _finite_diff}


def _resolve_diff_method(diff_method: str, device) -> str:
# check if best is backprop
if diff_method == "best":
config = qml.devices.ExecutionConfig(gradient_method=diff_method, interface="jax")
diff_method = device.setup_execution_config(config).gradient_method

if diff_method not in diff_method_map:
raise NotImplementedError(f"diff_method {diff_method} not yet implemented.")

return diff_method
def _qnode_jvp(args, tangents, *, execution_config, device, **impl_kwargs):
config = device.setup_execution_config(execution_config)

if config.gradient_method not in diff_method_map:
raise NotImplementedError(f"diff_method {config.gradient_method} not yet implemented.")

def _qnode_jvp(args, tangents, *, qnode_kwargs, device, **impl_kwargs):
diff_method = _resolve_diff_method(qnode_kwargs["diff_method"], device)
return diff_method_map[diff_method](
args, tangents, qnode_kwargs=qnode_kwargs, device=device, **impl_kwargs
return diff_method_map[config.gradient_method](
args, tangents, execution_config=config, device=device, **impl_kwargs
)


Expand Down Expand Up @@ -517,12 +511,9 @@ def f(x):
"flow functions like for_loop, while_loop, etc."
) from exc

execute_kwargs = copy(qnode.execute_kwargs)
qnode_kwargs = {
"diff_method": qnode.diff_method,
**execute_kwargs,
"gradient_kwargs": qnode.gradient_kwargs,
}
config = construct_execution_config(
qnode, resolve=False
)() # no need for args and kwargs as not resolving

res = qnode_prim.bind(
*qfunc_jaxpr.consts,
Expand All @@ -531,7 +522,7 @@ def f(x):
shots=shots,
qnode=qnode,
device=qnode.device,
qnode_kwargs=qnode_kwargs,
execution_config=config,
qfunc_jaxpr=qfunc_jaxpr.jaxpr,
n_consts=len(qfunc_jaxpr.consts),
)
Expand Down
5 changes: 3 additions & 2 deletions pennylane/workflow/construct_execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

import pennylane as qml
from pennylane.math import Interface
from pennylane.workflow import construct_tape
from pennylane.workflow.resolution import _resolve_execution_config

from .construct_tape import construct_tape
from .resolution import _resolve_execution_config


def construct_execution_config(qnode: "qml.QNode", resolve: bool = True):
Expand Down
2 changes: 1 addition & 1 deletion pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def __repr__(self) -> str:
@property
def interface(self) -> str:
"""The interface used by the QNode"""
return self._interface.value
return "jax" if qml.capture.enabled() else self._interface.value

@interface.setter
def interface(self, value: SupportedInterfaceUserInput):
Expand Down
4 changes: 2 additions & 2 deletions tests/capture/test_base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,8 @@ def f():
assert inner_jaxpr.eqns[1].primitive == qml.RX._primitive
assert inner_jaxpr.eqns[3].primitive == qml.RX._primitive

assert jaxpr.eqns[0].params["qnode_kwargs"]["diff_method"] == "backprop"
assert jaxpr.eqns[0].params["qnode_kwargs"]["grad_on_execution"] is False
assert jaxpr.eqns[0].params["execution_config"].gradient_method == "backprop"
assert jaxpr.eqns[0].params["execution_config"].grad_on_execution is False
assert jaxpr.eqns[0].params["device"] == dev

res1 = f()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def true_fn(phi):
jaxpr = jax.make_jaxpr(f)(x)
assert jaxpr.eqns[0].primitive == qnode_prim
assert jaxpr.eqns[0].params["device"] == dev
assert jaxpr.eqns[0].params["qnode_kwargs"]["diff_method"] == "parameter-shift"
assert jaxpr.eqns[0].params["execution_config"].gradient_method == "parameter-shift"

inner_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"]
collector = CollectOpsandMeas()
Expand Down
32 changes: 17 additions & 15 deletions tests/capture/workflow/test_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,14 @@ def circuit(x):
assert eqn0.params["device"] == dev
assert eqn0.params["qnode"] == circuit
assert eqn0.params["shots"] == qml.measurements.Shots(None)
expected_kwargs = {"diff_method": "best", "gradient_kwargs": {}}
expected_kwargs.update(circuit.execute_kwargs)
assert eqn0.params["qnode_kwargs"] == expected_kwargs
expected_config = qml.devices.ExecutionConfig(
gradient_method="best",
gradient_keyword_arguments={},
use_device_jacobian_product=False,
interface="jax",
grad_on_execution=False,
)
assert eqn0.params["execution_config"] == expected_config

qfunc_jaxpr = eqn0.params["qfunc_jaxpr"]
assert len(qfunc_jaxpr.eqns) == 3
Expand Down Expand Up @@ -279,18 +284,15 @@ def circuit():
jaxpr = jax.make_jaxpr(circuit)()

assert jaxpr.eqns[0].primitive == qnode_prim
expected = {
"diff_method": "parameter-shift",
"grad_on_execution": False,
"cache": True,
"cachesize": 10,
"max_diff": 2,
"device_vjp": False,
"mcm_method": None,
"postselect_mode": None,
"gradient_kwargs": {},
}
assert jaxpr.eqns[0].params["qnode_kwargs"] == expected
expected_config = qml.devices.ExecutionConfig(
gradient_method="parameter-shift",
grad_on_execution=False,
derivative_order=2,
use_device_jacobian_product=False,
mcm_config=qml.devices.MCMConfig(mcm_method=None, postselect_mode=None),
interface=qml.math.Interface.JAX,
)
assert jaxpr.eqns[0].params["execution_config"] == expected_config


def test_qnode_closure_variables():
Expand Down

0 comments on commit 76b585d

Please sign in to comment.