Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] Switch from binding qnode_kwargs to execution_config #6991

Merged
merged 13 commits into from
Feb 24, 2025
Merged
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@
`jnp.arange`, and `jnp.full`.
[#6865)](https://github.com/PennyLaneAI/pennylane/pull/6865)

* The qnode primitive now stores the `ExecutionConfig` instead `qnode_kwargs`.

* `Device.eval_jaxpr` now accepts an `execution_config` keyword argument.

<h3>Breaking changes 💔</h3>

* `MultiControlledX` no longer accepts strings as control values.
Expand Down
4 changes: 2 additions & 2 deletions pennylane/capture/base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def handle_while_loop(

# pylint: disable=unused-argument, too-many-arguments
@PlxprInterpreter.register_primitive(qnode_prim)
def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, n_consts):
def handle_qnode(self, *invals, shots, qnode, device, execution_config, qfunc_jaxpr, n_consts):
"""Handle a qnode primitive."""
consts = invals[:n_consts]
args = invals[n_consts:]
Expand All @@ -604,7 +604,7 @@ def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr,
shots=shots,
qnode=qnode,
device=device,
qnode_kwargs=qnode_kwargs,
execution_config=execution_config,
qfunc_jaxpr=new_qfunc_jaxpr.jaxpr,
n_consts=len(new_qfunc_jaxpr.consts),
)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ def execute_and_compute_vjp(

# pylint: disable=import-outside-toplevel
def eval_jaxpr(
self, jaxpr: "jax.core.Jaxpr", consts: list[TensorLike], *args
self, jaxpr: "jax.core.Jaxpr", consts: list[TensorLike], *args, execution_config=None
) -> list[TensorLike]:
from .qubit.dq_interpreter import DefaultQubitInterpreter

Expand Down
9 changes: 8 additions & 1 deletion pennylane/devices/device_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,11 @@ def supports_vjp(
return type(self).compute_vjp != Device.compute_vjp

def eval_jaxpr(
self, jaxpr: "jax.core.Jaxpr", consts: list[TensorLike], *args
self,
jaxpr: "jax.core.Jaxpr",
consts: list[TensorLike],
*args,
execution_config: Optional[ExecutionConfig] = None,
) -> list[TensorLike]:
"""An **experimental** method for natively evaluating PLXPR. See the ``capture`` module for more details.

Expand All @@ -979,6 +983,9 @@ def eval_jaxpr(
consts (list[TensorLike]): the closure variables ``consts`` corresponding to the jaxpr
*args (TensorLike): the variables to use with the jaxpr'.

Keyword Args:
execution_config (Optional[ExecutionConfig]): a datastructure with additional information required for execution

Returns:
list[TensorLike]: the result of evaluating the jaxpr with the given parameters.

Expand Down
1 change: 1 addition & 0 deletions pennylane/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def __getattr__(name):
"is_independent",
"iscomplex",
"jacobian",
"Interface",
"marginal_prob",
"max_entropy",
"min_entropy",
Expand Down
62 changes: 31 additions & 31 deletions pennylane/workflow/_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@
features is non-exhaustive.

"""
from copy import copy
from functools import partial
from numbers import Number
from warnings import warn
Expand Down Expand Up @@ -184,7 +183,7 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0, batch_shape=(

# pylint: disable=too-many-arguments, unused-argument
@qnode_prim.def_impl
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None):
def _(*args, qnode, shots, device, execution_config, qfunc_jaxpr, n_consts, batch_dims=None):
if shots != device.shots:
raise NotImplementedError(
"Overriding shots is not yet supported with the program capture execution."
Expand All @@ -193,16 +192,17 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_di
consts = args[:n_consts]
non_const_args = args[n_consts:]

if batch_dims is None:
return device.eval_jaxpr(qfunc_jaxpr, consts, *non_const_args)
return jax.vmap(partial(device.eval_jaxpr, qfunc_jaxpr, consts), batch_dims[n_consts:])(
*non_const_args
partial_eval = partial(
device.eval_jaxpr, qfunc_jaxpr, consts, execution_config=execution_config
)
if batch_dims is None:
return partial_eval(*non_const_args)
return jax.vmap(partial_eval, batch_dims[n_consts:])(*non_const_args)


# pylint: disable=unused-argument
@qnode_prim.def_abstract_eval
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None):
def _(*args, qnode, shots, device, execution_config, qfunc_jaxpr, n_consts, batch_dims=None):

mps = qfunc_jaxpr.outvars

Expand All @@ -223,7 +223,7 @@ def _qnode_batching_rule(
qnode,
shots,
device,
qnode_kwargs,
execution_config,
qfunc_jaxpr,
n_consts,
):
Expand Down Expand Up @@ -265,7 +265,7 @@ def _qnode_batching_rule(
shots=shots,
qnode=qnode,
device=device,
qnode_kwargs=qnode_kwargs,
execution_config=execution_config,
qfunc_jaxpr=qfunc_jaxpr,
n_consts=n_consts,
batch_dims=batch_dims,
Expand All @@ -292,29 +292,21 @@ def _backprop(args, tangents, **impl_kwargs):
def _finite_diff(args, tangents, **impl_kwargs):
f = partial(qnode_prim.bind, **impl_kwargs)
return qml.gradients.finite_diff_jvp(
f, args, tangents, **impl_kwargs["qnode_kwargs"]["gradient_kwargs"]
f, args, tangents, **impl_kwargs["execution_config"].gradient_keyword_arguments
)


diff_method_map = {"backprop": _backprop, "finite-diff": _finite_diff}


def _resolve_diff_method(diff_method: str, device) -> str:
# check if best is backprop
if diff_method == "best":
config = qml.devices.ExecutionConfig(gradient_method=diff_method, interface="jax")
diff_method = device.setup_execution_config(config).gradient_method

if diff_method not in diff_method_map:
raise NotImplementedError(f"diff_method {diff_method} not yet implemented.")
def _qnode_jvp(args, tangents, *, execution_config, device, **impl_kwargs):
config = device.setup_execution_config(execution_config)

return diff_method
if config.gradient_method not in diff_method_map:
raise NotImplementedError(f"diff_method {config.gradient_method} not yet implemented.")


def _qnode_jvp(args, tangents, *, qnode_kwargs, device, **impl_kwargs):
diff_method = _resolve_diff_method(qnode_kwargs["diff_method"], device)
return diff_method_map[diff_method](
args, tangents, qnode_kwargs=qnode_kwargs, device=device, **impl_kwargs
return diff_method_map[config.gradient_method](
args, tangents, execution_config=config, device=device, **impl_kwargs
)


Expand Down Expand Up @@ -517,12 +509,20 @@ def f(x):
"flow functions like for_loop, while_loop, etc."
) from exc

execute_kwargs = copy(qnode.execute_kwargs)
qnode_kwargs = {
"diff_method": qnode.diff_method,
**execute_kwargs,
"gradient_kwargs": qnode.gradient_kwargs,
}
kwargs = qnode.execute_kwargs
mcm_config = qml.devices.MCMConfig(
mcm_method=kwargs["mcm_method"], postselect_mode=kwargs["postselect_mode"]
)
g_on_ex = kwargs["grad_on_execution"]
config = qml.devices.ExecutionConfig(
grad_on_execution=None if g_on_ex == "best" else g_on_ex,
use_device_jacobian_product=kwargs["device_vjp"],
derivative_order=kwargs["max_diff"],
gradient_method=qnode.diff_method,
gradient_keyword_arguments=qnode.gradient_kwargs,
interface=qml.math.Interface.JAX,
mcm_config=mcm_config,
)

res = qnode_prim.bind(
*qfunc_jaxpr.consts,
Expand All @@ -531,7 +531,7 @@ def f(x):
shots=shots,
qnode=qnode,
device=qnode.device,
qnode_kwargs=qnode_kwargs,
execution_config=config,
qfunc_jaxpr=qfunc_jaxpr.jaxpr,
n_consts=len(qfunc_jaxpr.consts),
)
Expand Down
4 changes: 2 additions & 2 deletions tests/capture/test_base_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,8 @@ def f():
assert inner_jaxpr.eqns[1].primitive == qml.RX._primitive
assert inner_jaxpr.eqns[3].primitive == qml.RX._primitive

assert jaxpr.eqns[0].params["qnode_kwargs"]["diff_method"] == "backprop"
assert jaxpr.eqns[0].params["qnode_kwargs"]["grad_on_execution"] is False
assert jaxpr.eqns[0].params["execution_config"].gradient_method == "backprop"
assert jaxpr.eqns[0].params["execution_config"].grad_on_execution is False
assert jaxpr.eqns[0].params["device"] == dev

res1 = f()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def true_fn(phi):
jaxpr = jax.make_jaxpr(f)(x)
assert jaxpr.eqns[0].primitive == qnode_prim
assert jaxpr.eqns[0].params["device"] == dev
assert jaxpr.eqns[0].params["qnode_kwargs"]["diff_method"] == "parameter-shift"
assert jaxpr.eqns[0].params["execution_config"].gradient_method == "parameter-shift"

inner_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"]
collector = CollectOpsandMeas()
Expand Down
31 changes: 16 additions & 15 deletions tests/capture/workflow/test_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,13 @@ def circuit(x):
assert eqn0.params["device"] == dev
assert eqn0.params["qnode"] == circuit
assert eqn0.params["shots"] == qml.measurements.Shots(None)
expected_kwargs = {"diff_method": "best", "gradient_kwargs": {}}
expected_kwargs.update(circuit.execute_kwargs)
assert eqn0.params["qnode_kwargs"] == expected_kwargs
expected_config = qml.devices.ExecutionConfig(
gradient_method="best",
gradient_keyword_arguments={},
use_device_jacobian_product=False,
interface="jax",
)
assert eqn0.params["execution_config"] == expected_config

qfunc_jaxpr = eqn0.params["qfunc_jaxpr"]
assert len(qfunc_jaxpr.eqns) == 3
Expand Down Expand Up @@ -279,18 +283,15 @@ def circuit():
jaxpr = jax.make_jaxpr(circuit)()

assert jaxpr.eqns[0].primitive == qnode_prim
expected = {
"diff_method": "parameter-shift",
"grad_on_execution": False,
"cache": True,
"cachesize": 10,
"max_diff": 2,
"device_vjp": False,
"mcm_method": None,
"postselect_mode": None,
"gradient_kwargs": {},
}
assert jaxpr.eqns[0].params["qnode_kwargs"] == expected
expected_config = qml.devices.ExecutionConfig(
gradient_method="parameter-shift",
grad_on_execution=False,
derivative_order=2,
use_device_jacobian_product=False,
mcm_config=qml.devices.MCMConfig(mcm_method=None, postselect_mode=None),
interface=qml.math.Interface.JAX,
)
assert jaxpr.eqns[0].params["execution_config"] == expected_config


def test_qnode_closure_variables():
Expand Down
Loading