You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I expect autograph to work with the inner functions of qml.adjoint and qml.ctrl as it would for "normal" functions with program capture.
Actual behavior
When using qml.capture.enable(), autograph does not work on the wrapped functions of qml.adjoint or qml.ctrl. I tried manually running autograph on a QNode/qfunc as well as setting the autograph=True argument of QNode, and got the same result both times.
Additional information
After investigating with @lillian542, we suspect this to be the cause of the issue. It seems that we skip autograph for pennylane functions.
Using QNode (same error happens when using autograph manually, so I am not including it):
---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
File ~/repos/pennylane/pennylane/workflow/_capture_qnode.py:506, in capture_qnode(qnode, *args, **kwargs)
505 abstracted_axes = jax.tree_util.tree_unflatten(dynamic_args_struct, abstracted_axes)
--> 506 qfunc_jaxpr = jax.make_jaxpr(
507 flat_fn, abstracted_axes=abstracted_axes, static_argnums=qnode.static_argnums
508 )(*args)
509 except (
510 jax.errors.TracerArrayConversionError,
511 jax.errors.TracerIntegerConversionError,
512 jax.errors.TracerBoolConversionError,
513 ) as exc:
[... skipping hidden 6 frame]
File ~/repos/pennylane/pennylane/capture/flatfn.py:74, in FlatFn.__call__(self, *args, **kwargs)
73 args = jax.tree_util.tree_unflatten(self.in_tree, args)
---> 74 out = self.f(*args, **kwargs)
75 out_flat, out_tree = jax.tree_util.tree_flatten(out)
File /var/folders/wb/mp5pns5s4b15rv4xqbmgw1nr0000gs/T/__autograph_generated_fileh5nzvmsl.py:10, in outer_factory.<locals>.inner_factory.<locals>.ag__f(x)
9 retval_ = ag__.UndefinedReturnValue()
---> 10 ag__.converted_call(ag__.converted_call(ag__.ld(qml).adjoint, (ag__.ld(adjoint_fn),), None, fscope), (ag__.ld(x),), None, fscope)
11 try:
File ~/repos/pennylane/pennylane/capture/autograph/ag_primitives.py:424, in converted_call(fn, args, kwargs, caller_fn_scope, options)
422 returnnew_qnode()
--> 424 return ag_converted_call(fn, args, kwargs, caller_fn_scope, options)
File ~/.pyenv/versions/pennylane/lib/python3.11/site-packages/malt/impl/api.py:382, in converted_call(f, args, kwargs, caller_fn_scope, options)
381 else:
--> 382 result = converted_f(*effective_args)
383 except Exception as e:
File /var/folders/wb/mp5pns5s4b15rv4xqbmgw1nr0000gs/T/__autograph_generated_filefixlyzi0.py:13, in outer_factory.<locals>.inner_factory.<locals>.ag__new_qfunc(*args, **kwargs)
12 abstracted_axes, abstract_shapes = ag__.converted_call(ag__.ld(qml).capture.determine_abstracted_axes, (ag__.ld(args),), None, fscope)
---> 13 jaxpr = ag__.converted_call(ag__.converted_call(ag__.ld(jax).make_jaxpr, (ag__.converted_call(ag__.ld(partial), (ag__.ld(qfunc),), dict(**ag__.ld(kwargs)), fscope),), dict(abstracted_axes=ag__.ld(abstracted_axes)), fscope), tuple(ag__.ld(args)), None, fscope)
14 flat_args = ag__.converted_call(ag__.ld(jax).tree_util.tree_leaves, (ag__.ld(args),), None, fscope)
File ~/repos/pennylane/pennylane/capture/autograph/ag_primitives.py:424, in converted_call(fn, args, kwargs, caller_fn_scope, options)
422 returnnew_qnode()
--> 424 return ag_converted_call(fn, args, kwargs, caller_fn_scope, options)
File ~/.pyenv/versions/pennylane/lib/python3.11/site-packages/malt/impl/api.py:319, in converted_call(f, args, kwargs, caller_fn_scope, options)
318 if not options.user_requested and conversion.is_allowlisted(f):
--> 319 return _call_unconverted(f, args, kwargs, options)
321 # internal_convert_user_code is for example turned off when issuing a dynamic
322 # call conversion from generated code while in nonrecursive mode. In that
323 # case we evidently don't want to recurse, but we still have to convert
324 # things like builtins.
File ~/.pyenv/versions/pennylane/lib/python3.11/site-packages/malt/impl/api.py:399, in _call_unconverted(f, args, kwargs, options, update_cache)
398 return f(*args, **kwargs)
--> 399 return f(*args)
[... skipping hidden 6 frame]
Cell In[3], line 8, in adjoint_fn(y)
7 def adjoint_fn(y):
----> 8 if y < 2:
9 qml.RX(y, 0)
[... skipping hidden 1 frame]
File ~/.pyenv/versions/pennylane/lib/python3.11/site-packages/jax/_src/core.py:1475, in concretization_function_error.<locals>.error(self, arg)
1474 def error(self, arg):
-> 1475 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the functionadjoint_fn at /var/folders/wb/mp5pns5s4b15rv4xqbmgw1nr0000gs/T/ipykernel_82049/3673553018.py:7 formake_jaxpr. This concrete value was not availablein Python because it depends on the value of the argument y.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
The above exception was the direct cause of the following exception:
CaptureError Traceback (most recent call last)
Cell In[3], line 22
19 qml.adjoint(adjoint_fn)(x)
20 return qml.expval(qml.Z(0))
---> 22 f(1.5)
File ~/repos/pennylane/pennylane/workflow/qnode.py:879, in QNode.__call__(self, *args, **kwargs)
876 ifqml.capture.enabled():
877 from ._capture_qnode import capture_qnode # pylint: disable=import-outside-toplevel
--> 879 return capture_qnode(self, *args, **kwargs)
880 return self._impl_call(*args, **kwargs)
File ~/repos/pennylane/pennylane/workflow/_capture_qnode.py:514, in capture_qnode(qnode, *args, **kwargs)
506 qfunc_jaxpr = jax.make_jaxpr(
507 flat_fn, abstracted_axes=abstracted_axes, static_argnums=qnode.static_argnums
508 )(*args)
509 except (
510 jax.errors.TracerArrayConversionError,
511 jax.errors.TracerIntegerConversionError,
512 jax.errors.TracerBoolConversionError,
513 ) as exc:
--> 514 raise CaptureError(
515 "Autograph must be used when Python control flow is dependent on a dynamic "
516 "variable (a function input). Please ensure autograph=True or use native control "
517 "flow functions like for_loop, while_loop, etc."
518 ) from exc
520 execute_kwargs = copy(qnode.execute_kwargs)
521 qnode_kwargs = {
522 "diff_method": qnode.diff_method,
523 **execute_kwargs,
524 "gradient_kwargs": qnode.gradient_kwargs,
525 }
CaptureError: Autograph must be used when Python control flow is dependent on a dynamic variable (a functioninput). Please ensure autograph=True or use native control flow functions like for_loop, while_loop, etc.
System information
PL dev
Existing GitHub issues
I have searched existing GitHub issues to make sure the issue does not already exist.
The text was updated successfully, but these errors were encountered:
Expected behavior
I expect autograph to work with the inner functions of
qml.adjoint
andqml.ctrl
as it would for "normal" functions with program capture.Actual behavior
When using
qml.capture.enable()
, autograph does not work on the wrapped functions ofqml.adjoint
orqml.ctrl
. I tried manually running autograph on a QNode/qfunc as well as setting theautograph=True
argument ofQNode
, and got the same result both times.Additional information
After investigating with @lillian542, we suspect this to be the cause of the issue. It seems that we skip autograph for pennylane functions.
Source code
Tracebacks
System information
Existing GitHub issues
The text was updated successfully, but these errors were encountered: