Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] Switch from binding qnode_kwargs to execution_config #6991

Merged
merged 13 commits into from
Feb 24, 2025
Merged
6 changes: 6 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@
`jnp.arange`, and `jnp.full`.
[#6865)](https://github.com/PennyLaneAI/pennylane/pull/6865)

* The qnode primitive now stores the `ExecutionConfig` instead `qnode_kwargs`.
[(#6991)](https://github.com/PennyLaneAI/pennylane/pull/6991)

* `Device.eval_jaxpr` now accepts an `execution_config` keyword argument.
[(#6991)](https://github.com/PennyLaneAI/pennylane/pull/6991)

* The adjoint jvp of a jaxpr can be computed using default.qubit tooling.
[(#6875)](https://github.com/PennyLaneAI/pennylane/pull/6875)

Expand Down
4 changes: 2 additions & 2 deletions pennylane/capture/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def handle_while_loop(

# pylint: disable=unused-argument, too-many-arguments
@PlxprInterpreter.register_primitive(qnode_prim)
def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, n_consts):
def handle_qnode(self, *invals, shots, qnode, device, execution_config, qfunc_jaxpr, n_consts):
"""Handle a qnode primitive."""
consts = invals[:n_consts]
args = invals[n_consts:]
Expand All @@ -604,7 +604,7 @@ def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr,
shots=shots,
qnode=qnode,
device=device,
qnode_kwargs=qnode_kwargs,
execution_config=execution_config,
qfunc_jaxpr=new_qfunc_jaxpr.jaxpr,
n_consts=len(new_qfunc_jaxpr.consts),
)
Expand Down
4 changes: 2 additions & 2 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,9 +935,9 @@ def execute_and_compute_vjp(

return tuple(zip(*results))

# pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel, unused-argument
def eval_jaxpr(
self, jaxpr: "jax.core.Jaxpr", consts: list[TensorLike], *args
self, jaxpr: "jax.core.Jaxpr", consts: list[TensorLike], *args, execution_config=None
) -> list[TensorLike]:
from .qubit.dq_interpreter import DefaultQubitInterpreter

Expand Down
11 changes: 9 additions & 2 deletions pennylane/devices/device_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,14 +970,21 @@ def supports_vjp(
return type(self).compute_vjp != Device.compute_vjp

def eval_jaxpr(
self, jaxpr: "jax.core.Jaxpr", consts: list[TensorLike], *args
self,
jaxpr: "jax.core.Jaxpr",
consts: list[TensorLike],
*args,
execution_config: Optional[ExecutionConfig] = None,
) -> list[TensorLike]:
"""An **experimental** method for natively evaluating PLXPR. See the ``capture`` module for more details.

Args:
jaxpr (jax.core.Jaxpr): Pennylane variant jaxpr containing quantum operations and measurements
consts (list[TensorLike]): the closure variables ``consts`` corresponding to the jaxpr
*args (TensorLike): the variables to use with the jaxpr'.
*args (TensorLike): the variables to use with the jaxpr.

Keyword Args:
execution_config (Optional[ExecutionConfig]): a datastructure with additional information required for execution

Returns:
list[TensorLike]: the result of evaluating the jaxpr with the given parameters.
Expand Down
5 changes: 4 additions & 1 deletion pennylane/devices/null_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,10 @@ def execute_and_compute_vjp(
vjps = tuple(self._vjp(c, _interface(execution_config)) for c in circuits)
return results, vjps

def eval_jaxpr(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list:
# pylint: disable= unused-argument
def eval_jaxpr(
self, jaxpr: "jax.core.Jaxpr", consts: list, *args, execution_config=None
) -> list:
from pennylane.capture.primitives import ( # pylint: disable=import-outside-toplevel
AbstractMeasurement,
)
Expand Down
1 change: 1 addition & 0 deletions pennylane/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def __getattr__(name):
"is_independent",
"iscomplex",
"jacobian",
"Interface",
"marginal_prob",
"max_entropy",
"min_entropy",
Expand Down
14 changes: 7 additions & 7 deletions pennylane/tape/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
from .qscript import QuantumScript, QuantumScriptBatch, QuantumScriptOrBatch, make_qscript
from .tape import QuantumTape, QuantumTapeBatch, TapeError, expand_tape_state_prep

try:
from .plxpr_conversion import plxpr_to_tape
except ImportError: # pragma: no cover

# pragma: no cover
def plxpr_to_tape(jaxpr: "jax.core.Jaxpr", consts, *args, shots=None): # pragma: no cover
"""A dummy version of ``plxpr_to_tape`` when jax is not installed on the system."""
raise ImportError("plxpr_to_tape requires jax to be installed") # pragma: no cover
# pylint: disable=import-outside-toplevel
def __getattr__(key):
if key == "plxpr_to_tape":
from .plxpr_conversion import plxpr_to_tape

return plxpr_to_tape
raise AttributeError(f"module 'pennylane.tape' has no attribute '{key}'") # pragma: no cover
53 changes: 22 additions & 31 deletions pennylane/workflow/_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@
features is non-exhaustive.

"""
from copy import copy
from functools import partial
from numbers import Number
from warnings import warn
Expand All @@ -119,6 +118,8 @@
from pennylane.capture.custom_primitives import QmlPrimitive
from pennylane.typing import TensorLike

from .construct_execution_config import construct_execution_config


def _is_scalar_tensor(arg) -> bool:
"""Check if an argument is a scalar tensor-like object or a numeric scalar."""
Expand Down Expand Up @@ -184,7 +185,7 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0, batch_shape=(

# pylint: disable=too-many-arguments, unused-argument
@qnode_prim.def_impl
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None):
def _(*args, qnode, shots, device, execution_config, qfunc_jaxpr, n_consts, batch_dims=None):
if shots != device.shots:
raise NotImplementedError(
"Overriding shots is not yet supported with the program capture execution."
Expand All @@ -193,16 +194,17 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_di
consts = args[:n_consts]
non_const_args = args[n_consts:]

if batch_dims is None:
return device.eval_jaxpr(qfunc_jaxpr, consts, *non_const_args)
return jax.vmap(partial(device.eval_jaxpr, qfunc_jaxpr, consts), batch_dims[n_consts:])(
*non_const_args
partial_eval = partial(
device.eval_jaxpr, qfunc_jaxpr, consts, execution_config=execution_config
)
if batch_dims is None:
return partial_eval(*non_const_args)
return jax.vmap(partial_eval, batch_dims[n_consts:])(*non_const_args)


# pylint: disable=unused-argument
@qnode_prim.def_abstract_eval
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None):
def _(*args, qnode, shots, device, execution_config, qfunc_jaxpr, n_consts, batch_dims=None):

mps = qfunc_jaxpr.outvars

Expand All @@ -223,7 +225,7 @@ def _qnode_batching_rule(
qnode,
shots,
device,
qnode_kwargs,
execution_config,
qfunc_jaxpr,
n_consts,
):
Expand Down Expand Up @@ -265,7 +267,7 @@ def _qnode_batching_rule(
shots=shots,
qnode=qnode,
device=device,
qnode_kwargs=qnode_kwargs,
execution_config=execution_config,
qfunc_jaxpr=qfunc_jaxpr,
n_consts=n_consts,
batch_dims=batch_dims,
Expand All @@ -292,29 +294,21 @@ def _backprop(args, tangents, **impl_kwargs):
def _finite_diff(args, tangents, **impl_kwargs):
f = partial(qnode_prim.bind, **impl_kwargs)
return qml.gradients.finite_diff_jvp(
f, args, tangents, **impl_kwargs["qnode_kwargs"]["gradient_kwargs"]
f, args, tangents, **impl_kwargs["execution_config"].gradient_keyword_arguments
)


diff_method_map = {"backprop": _backprop, "finite-diff": _finite_diff}


def _resolve_diff_method(diff_method: str, device) -> str:
# check if best is backprop
if diff_method == "best":
config = qml.devices.ExecutionConfig(gradient_method=diff_method, interface="jax")
diff_method = device.setup_execution_config(config).gradient_method

if diff_method not in diff_method_map:
raise NotImplementedError(f"diff_method {diff_method} not yet implemented.")

return diff_method
def _qnode_jvp(args, tangents, *, execution_config, device, **impl_kwargs):
config = device.setup_execution_config(execution_config)

if config.gradient_method not in diff_method_map:
raise NotImplementedError(f"diff_method {config.gradient_method} not yet implemented.")

def _qnode_jvp(args, tangents, *, qnode_kwargs, device, **impl_kwargs):
diff_method = _resolve_diff_method(qnode_kwargs["diff_method"], device)
return diff_method_map[diff_method](
args, tangents, qnode_kwargs=qnode_kwargs, device=device, **impl_kwargs
return diff_method_map[config.gradient_method](
args, tangents, execution_config=config, device=device, **impl_kwargs
)


Expand Down Expand Up @@ -517,12 +511,9 @@ def f(x):
"flow functions like for_loop, while_loop, etc."
) from exc

execute_kwargs = copy(qnode.execute_kwargs)
qnode_kwargs = {
"diff_method": qnode.diff_method,
**execute_kwargs,
"gradient_kwargs": qnode.gradient_kwargs,
}
config = construct_execution_config(
qnode, resolve=False
)() # no need for args and kwargs as not resolving

res = qnode_prim.bind(
*qfunc_jaxpr.consts,
Expand All @@ -531,7 +522,7 @@ def f(x):
shots=shots,
qnode=qnode,
device=qnode.device,
qnode_kwargs=qnode_kwargs,
execution_config=config,
qfunc_jaxpr=qfunc_jaxpr.jaxpr,
n_consts=len(qfunc_jaxpr.consts),
)
Expand Down
5 changes: 3 additions & 2 deletions pennylane/workflow/construct_execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

import pennylane as qml
from pennylane.math import Interface
from pennylane.workflow import construct_tape
from pennylane.workflow.resolution import _resolve_execution_config

from .construct_tape import construct_tape
from .resolution import _resolve_execution_config


def construct_execution_config(qnode: "qml.QNode", resolve: bool = True):
Expand Down
2 changes: 1 addition & 1 deletion pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def __repr__(self) -> str:
@property
def interface(self) -> str:
"""The interface used by the QNode"""
return self._interface.value
return "jax" if qml.capture.enabled() else self._interface.value

@interface.setter
def interface(self, value: SupportedInterfaceUserInput):
Expand Down
4 changes: 2 additions & 2 deletions tests/capture/test_base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,8 @@ def f():
assert inner_jaxpr.eqns[1].primitive == qml.RX._primitive
assert inner_jaxpr.eqns[3].primitive == qml.RX._primitive

assert jaxpr.eqns[0].params["qnode_kwargs"]["diff_method"] == "backprop"
assert jaxpr.eqns[0].params["qnode_kwargs"]["grad_on_execution"] is False
assert jaxpr.eqns[0].params["execution_config"].gradient_method == "backprop"
assert jaxpr.eqns[0].params["execution_config"].grad_on_execution is False
assert jaxpr.eqns[0].params["device"] == dev

res1 = f()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def true_fn(phi):
jaxpr = jax.make_jaxpr(f)(x)
assert jaxpr.eqns[0].primitive == qnode_prim
assert jaxpr.eqns[0].params["device"] == dev
assert jaxpr.eqns[0].params["qnode_kwargs"]["diff_method"] == "parameter-shift"
assert jaxpr.eqns[0].params["execution_config"].gradient_method == "parameter-shift"

inner_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"]
collector = CollectOpsandMeas()
Expand Down
32 changes: 17 additions & 15 deletions tests/capture/workflow/test_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,14 @@ def circuit(x):
assert eqn0.params["device"] == dev
assert eqn0.params["qnode"] == circuit
assert eqn0.params["shots"] == qml.measurements.Shots(None)
expected_kwargs = {"diff_method": "best", "gradient_kwargs": {}}
expected_kwargs.update(circuit.execute_kwargs)
assert eqn0.params["qnode_kwargs"] == expected_kwargs
expected_config = qml.devices.ExecutionConfig(
gradient_method="best",
gradient_keyword_arguments={},
use_device_jacobian_product=False,
interface="jax",
grad_on_execution=False,
)
assert eqn0.params["execution_config"] == expected_config

qfunc_jaxpr = eqn0.params["qfunc_jaxpr"]
assert len(qfunc_jaxpr.eqns) == 3
Expand Down Expand Up @@ -279,18 +284,15 @@ def circuit():
jaxpr = jax.make_jaxpr(circuit)()

assert jaxpr.eqns[0].primitive == qnode_prim
expected = {
"diff_method": "parameter-shift",
"grad_on_execution": False,
"cache": True,
"cachesize": 10,
"max_diff": 2,
"device_vjp": False,
"mcm_method": None,
"postselect_mode": None,
"gradient_kwargs": {},
}
assert jaxpr.eqns[0].params["qnode_kwargs"] == expected
expected_config = qml.devices.ExecutionConfig(
gradient_method="parameter-shift",
grad_on_execution=False,
derivative_order=2,
use_device_jacobian_product=False,
mcm_config=qml.devices.MCMConfig(mcm_method=None, postselect_mode=None),
interface=qml.math.Interface.JAX,
)
assert jaxpr.eqns[0].params["execution_config"] == expected_config


def test_qnode_closure_variables():
Expand Down