diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 87233f8af85..35b651aa29a 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -255,6 +255,12 @@ `jnp.arange`, and `jnp.full`. [#6865)](https://github.com/PennyLaneAI/pennylane/pull/6865) +* The qnode primitive now stores the `ExecutionConfig` instead of `qnode_kwargs`. + [(#6991)](https://github.com/PennyLaneAI/pennylane/pull/6991) + +* `Device.eval_jaxpr` now accepts an `execution_config` keyword argument. + [(#6991)](https://github.com/PennyLaneAI/pennylane/pull/6991) + * The adjoint jvp of a jaxpr can be computed using default.qubit tooling. [(#6875)](https://github.com/PennyLaneAI/pennylane/pull/6875) diff --git a/pennylane/capture/base_interpreter.py b/pennylane/capture/base_interpreter.py index c99c1ee46d0..ae99407da29 100644 --- a/pennylane/capture/base_interpreter.py +++ b/pennylane/capture/base_interpreter.py @@ -591,7 +591,7 @@ def handle_while_loop( # pylint: disable=unused-argument, too-many-arguments @PlxprInterpreter.register_primitive(qnode_prim) -def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, n_consts): +def handle_qnode(self, *invals, shots, qnode, device, execution_config, qfunc_jaxpr, n_consts): """Handle a qnode primitive.""" consts = invals[:n_consts] args = invals[n_consts:] @@ -604,7 +604,7 @@ def handle_qnode(self, *invals, shots, qnode, device, qnode_kwargs, qfunc_jaxpr, shots=shots, qnode=qnode, device=device, - qnode_kwargs=qnode_kwargs, + execution_config=execution_config, qfunc_jaxpr=new_qfunc_jaxpr.jaxpr, n_consts=len(new_qfunc_jaxpr.consts), ) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 6f71e0bb6ee..3191da1c186 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -935,9 +935,9 @@ def execute_and_compute_vjp( return tuple(zip(*results)) - # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel, unused-argument def eval_jaxpr( - self, jaxpr: "jax.core.Jaxpr", consts: list[TensorLike], *args + self, jaxpr: "jax.core.Jaxpr", consts: list[TensorLike], *args, execution_config=None ) -> list[TensorLike]: from .qubit.dq_interpreter import DefaultQubitInterpreter diff --git a/pennylane/devices/device_api.py b/pennylane/devices/device_api.py index de9eaad6732..b99f138d104 100644 --- a/pennylane/devices/device_api.py +++ b/pennylane/devices/device_api.py @@ -970,14 +970,21 @@ def supports_vjp( return type(self).compute_vjp != Device.compute_vjp def eval_jaxpr( - self, jaxpr: "jax.core.Jaxpr", consts: list[TensorLike], *args + self, + jaxpr: "jax.core.Jaxpr", + consts: list[TensorLike], + *args, + execution_config: Optional[ExecutionConfig] = None, ) -> list[TensorLike]: """An **experimental** method for natively evaluating PLXPR. See the ``capture`` module for more details. Args: jaxpr (jax.core.Jaxpr): Pennylane variant jaxpr containing quantum operations and measurements consts (list[TensorLike]): the closure variables ``consts`` corresponding to the jaxpr - *args (TensorLike): the variables to use with the jaxpr'. + *args (TensorLike): the variables to use with the jaxpr. + + Keyword Args: + execution_config (Optional[ExecutionConfig]): a data structure with additional information required for execution Returns: list[TensorLike]: the result of evaluating the jaxpr with the given parameters. diff --git a/pennylane/devices/null_qubit.py b/pennylane/devices/null_qubit.py index 7c00899a99c..d20783c0e49 100644 --- a/pennylane/devices/null_qubit.py +++ b/pennylane/devices/null_qubit.py @@ -405,7 +405,10 @@ def execute_and_compute_vjp( vjps = tuple(self._vjp(c, _interface(execution_config)) for c in circuits) return results, vjps - def eval_jaxpr(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: + # pylint: disable= unused-argument + def eval_jaxpr( + self, jaxpr: "jax.core.Jaxpr", consts: list, *args, execution_config=None + ) -> list: from pennylane.capture.primitives import ( # pylint: disable=import-outside-toplevel AbstractMeasurement, ) diff --git a/pennylane/math/__init__.py b/pennylane/math/__init__.py index d5009ad3617..930560eabd6 100644 --- a/pennylane/math/__init__.py +++ b/pennylane/math/__init__.py @@ -176,6 +176,7 @@ def __getattr__(name): "is_independent", "iscomplex", "jacobian", + "Interface", "marginal_prob", "max_entropy", "min_entropy", diff --git a/pennylane/tape/__init__.py b/pennylane/tape/__init__.py index daf03cd9395..d4ac21f8956 100644 --- a/pennylane/tape/__init__.py +++ b/pennylane/tape/__init__.py @@ -20,11 +20,11 @@ from .qscript import QuantumScript, QuantumScriptBatch, QuantumScriptOrBatch, make_qscript from .tape import QuantumTape, QuantumTapeBatch, TapeError, expand_tape_state_prep -try: - from .plxpr_conversion import plxpr_to_tape -except ImportError: # pragma: no cover - # pragma: no cover - def plxpr_to_tape(jaxpr: "jax.core.Jaxpr", consts, *args, shots=None): # pragma: no cover - """A dummy version of ``plxpr_to_tape`` when jax is not installed on the system.""" - raise ImportError("plxpr_to_tape requires jax to be installed") # pragma: no cover +# pylint: disable=import-outside-toplevel +def __getattr__(key): + if key == "plxpr_to_tape": + from .plxpr_conversion import plxpr_to_tape + + return plxpr_to_tape + raise AttributeError(f"module 'pennylane.tape' has no attribute '{key}'") # pragma: no cover diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py index 2f859707963..d779343086f 100644 --- a/pennylane/workflow/_capture_qnode.py +++ b/pennylane/workflow/_capture_qnode.py @@ -106,7 +106,6 @@ features is non-exhaustive. """ -from copy import copy from functools import partial from numbers import Number from warnings import warn @@ -119,6 +118,8 @@ from pennylane.capture.custom_primitives import QmlPrimitive from pennylane.typing import TensorLike +from .construct_execution_config import construct_execution_config + def _is_scalar_tensor(arg) -> bool: """Check if an argument is a scalar tensor-like object or a numeric scalar.""" @@ -184,7 +185,7 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0, batch_shape=( # pylint: disable=too-many-arguments, unused-argument @qnode_prim.def_impl -def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None): +def _(*args, qnode, shots, device, execution_config, qfunc_jaxpr, n_consts, batch_dims=None): if shots != device.shots: raise NotImplementedError( "Overriding shots is not yet supported with the program capture execution." @@ -193,16 +194,17 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_di consts = args[:n_consts] non_const_args = args[n_consts:] - if batch_dims is None: - return device.eval_jaxpr(qfunc_jaxpr, consts, *non_const_args) - return jax.vmap(partial(device.eval_jaxpr, qfunc_jaxpr, consts), batch_dims[n_consts:])( - *non_const_args + partial_eval = partial( + device.eval_jaxpr, qfunc_jaxpr, consts, execution_config=execution_config ) + if batch_dims is None: + return partial_eval(*non_const_args) + return jax.vmap(partial_eval, batch_dims[n_consts:])(*non_const_args) # pylint: disable=unused-argument @qnode_prim.def_abstract_eval -def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None): +def _(*args, qnode, shots, device, execution_config, qfunc_jaxpr, n_consts, batch_dims=None): mps = qfunc_jaxpr.outvars @@ -223,7 +225,7 @@ def _qnode_batching_rule( qnode, shots, device, - qnode_kwargs, + execution_config, qfunc_jaxpr, n_consts, ): @@ -265,7 +267,7 @@ def _qnode_batching_rule( shots=shots, qnode=qnode, device=device, - qnode_kwargs=qnode_kwargs, + execution_config=execution_config, qfunc_jaxpr=qfunc_jaxpr, n_consts=n_consts, batch_dims=batch_dims, @@ -292,29 +294,21 @@ def _backprop(args, tangents, **impl_kwargs): def _finite_diff(args, tangents, **impl_kwargs): f = partial(qnode_prim.bind, **impl_kwargs) return qml.gradients.finite_diff_jvp( - f, args, tangents, **impl_kwargs["qnode_kwargs"]["gradient_kwargs"] + f, args, tangents, **impl_kwargs["execution_config"].gradient_keyword_arguments ) diff_method_map = {"backprop": _backprop, "finite-diff": _finite_diff} -def _resolve_diff_method(diff_method: str, device) -> str: - # check if best is backprop - if diff_method == "best": - config = qml.devices.ExecutionConfig(gradient_method=diff_method, interface="jax") - diff_method = device.setup_execution_config(config).gradient_method - - if diff_method not in diff_method_map: - raise NotImplementedError(f"diff_method {diff_method} not yet implemented.") - - return diff_method +def _qnode_jvp(args, tangents, *, execution_config, device, **impl_kwargs): + config = device.setup_execution_config(execution_config) + if config.gradient_method not in diff_method_map: + raise NotImplementedError(f"diff_method {config.gradient_method} not yet implemented.") -def _qnode_jvp(args, tangents, *, qnode_kwargs, device, **impl_kwargs): - diff_method = _resolve_diff_method(qnode_kwargs["diff_method"], device) - return diff_method_map[diff_method]( - args, tangents, qnode_kwargs=qnode_kwargs, device=device, **impl_kwargs + return diff_method_map[config.gradient_method]( + args, tangents, execution_config=config, device=device, **impl_kwargs ) @@ -517,12 +511,9 @@ def f(x): "flow functions like for_loop, while_loop, etc." ) from exc - execute_kwargs = copy(qnode.execute_kwargs) - qnode_kwargs = { - "diff_method": qnode.diff_method, - **execute_kwargs, - "gradient_kwargs": qnode.gradient_kwargs, - } + config = construct_execution_config( + qnode, resolve=False + )() # no need for args and kwargs as not resolving res = qnode_prim.bind( *qfunc_jaxpr.consts, @@ -531,7 +522,7 @@ def f(x): shots=shots, qnode=qnode, device=qnode.device, - qnode_kwargs=qnode_kwargs, + execution_config=config, qfunc_jaxpr=qfunc_jaxpr.jaxpr, n_consts=len(qfunc_jaxpr.consts), ) diff --git a/pennylane/workflow/construct_execution_config.py b/pennylane/workflow/construct_execution_config.py index 1dfba42f736..4b0577501d3 100644 --- a/pennylane/workflow/construct_execution_config.py +++ b/pennylane/workflow/construct_execution_config.py @@ -16,8 +16,9 @@ import pennylane as qml from pennylane.math import Interface -from pennylane.workflow import construct_tape -from pennylane.workflow.resolution import _resolve_execution_config + +from .construct_tape import construct_tape +from .resolution import _resolve_execution_config def construct_execution_config(qnode: "qml.QNode", resolve: bool = True): diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index c732ba82c78..11ef70c57ce 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -662,7 +662,7 @@ def __repr__(self) -> str: @property def interface(self) -> str: """The interface used by the QNode""" - return self._interface.value + return "jax" if qml.capture.enabled() else self._interface.value @interface.setter def interface(self, value: SupportedInterfaceUserInput): diff --git a/tests/capture/test_base_interpreter.py b/tests/capture/test_base_interpreter.py index b061a8b2b99..30a7679e162 100644 --- a/tests/capture/test_base_interpreter.py +++ b/tests/capture/test_base_interpreter.py @@ -669,8 +669,8 @@ def f(): assert inner_jaxpr.eqns[1].primitive == qml.RX._primitive assert inner_jaxpr.eqns[3].primitive == qml.RX._primitive - assert jaxpr.eqns[0].params["qnode_kwargs"]["diff_method"] == "backprop" - assert jaxpr.eqns[0].params["qnode_kwargs"]["grad_on_execution"] is False + assert jaxpr.eqns[0].params["execution_config"].gradient_method == "backprop" + assert jaxpr.eqns[0].params["execution_config"].grad_on_execution is False assert jaxpr.eqns[0].params["device"] == dev res1 = f() diff --git a/tests/capture/transforms/test_capture_defer_measurements.py b/tests/capture/transforms/test_capture_defer_measurements.py index 6023cef612f..3930d4f6559 100644 --- a/tests/capture/transforms/test_capture_defer_measurements.py +++ b/tests/capture/transforms/test_capture_defer_measurements.py @@ -643,7 +643,7 @@ def true_fn(phi): jaxpr = jax.make_jaxpr(f)(x) assert jaxpr.eqns[0].primitive == qnode_prim assert jaxpr.eqns[0].params["device"] == dev - assert jaxpr.eqns[0].params["qnode_kwargs"]["diff_method"] == "parameter-shift" + assert jaxpr.eqns[0].params["execution_config"].gradient_method == "parameter-shift" inner_jaxpr = jaxpr.eqns[0].params["qfunc_jaxpr"] collector = CollectOpsandMeas() diff --git a/tests/capture/workflow/test_capture_qnode.py b/tests/capture/workflow/test_capture_qnode.py index 88dd37734de..c567cecb99b 100644 --- a/tests/capture/workflow/test_capture_qnode.py +++ b/tests/capture/workflow/test_capture_qnode.py @@ -130,9 +130,14 @@ def circuit(x): assert eqn0.params["device"] == dev assert eqn0.params["qnode"] == circuit assert eqn0.params["shots"] == qml.measurements.Shots(None) - expected_kwargs = {"diff_method": "best", "gradient_kwargs": {}} - expected_kwargs.update(circuit.execute_kwargs) - assert eqn0.params["qnode_kwargs"] == expected_kwargs + expected_config = qml.devices.ExecutionConfig( + gradient_method="best", + gradient_keyword_arguments={}, + use_device_jacobian_product=False, + interface="jax", + grad_on_execution=False, + ) + assert eqn0.params["execution_config"] == expected_config qfunc_jaxpr = eqn0.params["qfunc_jaxpr"] assert len(qfunc_jaxpr.eqns) == 3 @@ -279,18 +284,15 @@ def circuit(): jaxpr = jax.make_jaxpr(circuit)() assert jaxpr.eqns[0].primitive == qnode_prim - expected = { - "diff_method": "parameter-shift", - "grad_on_execution": False, - "cache": True, - "cachesize": 10, - "max_diff": 2, - "device_vjp": False, - "mcm_method": None, - "postselect_mode": None, - "gradient_kwargs": {}, - } - assert jaxpr.eqns[0].params["qnode_kwargs"] == expected + expected_config = qml.devices.ExecutionConfig( + gradient_method="parameter-shift", + grad_on_execution=False, + derivative_order=2, + use_device_jacobian_product=False, + mcm_config=qml.devices.MCMConfig(mcm_method=None, postselect_mode=None), + interface=qml.math.Interface.JAX, + ) + assert jaxpr.eqns[0].params["execution_config"] == expected_config def test_qnode_closure_variables():