Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Autograph not working with adjoint and ctrl #6992

Open
1 task done
mudit2812 opened this issue Feb 21, 2025 · 0 comments
Open
1 task done

[BUG] Autograph not working with adjoint and ctrl #6992

mudit2812 opened this issue Feb 21, 2025 · 0 comments
Labels
bug 🐛 Something isn't working

Comments

@mudit2812
Copy link
Contributor

mudit2812 commented Feb 21, 2025

Expected behavior

I expect autograph to work with the inner functions of qml.adjoint and qml.ctrl as it would for "normal" functions with program capture.

Actual behavior

When using qml.capture.enable(), autograph does not work on the wrapped functions of qml.adjoint or qml.ctrl. I tried manually running autograph on a QNode/qfunc as well as setting the autograph=True argument of QNode, and got the same result both times.

Additional information

After investigating with @lillian542, we suspect this to be the cause of the issue. It seems that we skip autograph for pennylane functions.

Source code

import pennylane as qml

qml.capture.enable()

dev = qml.device("default.qubit", wires=5)

def adjoint_fn(y):
    if y < 2:
        qml.RX(y, 0)
        qml.RY(y, 0)
    else:
        qml.RY(y, 0)
        qml.RZ(y, 0)

    qml.Hadamard(0)

@qml.qnode(dev, autograph=True)
def f(x):
    qml.adjoint(adjoint_fn)(x)
    return qml.expval(qml.Z(0))

print(f(1.5))

ag_fn = qml.capture.run_autograph(f.func)
print(ag_fn(1.5))

Tracebacks

Using QNode (same error happens when using autograph manually, so I am not including it):

---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
File ~/repos/pennylane/pennylane/workflow/_capture_qnode.py:506, in capture_qnode(qnode, *args, **kwargs)
    505         abstracted_axes = jax.tree_util.tree_unflatten(dynamic_args_struct, abstracted_axes)
--> 506     qfunc_jaxpr = jax.make_jaxpr(
    507         flat_fn, abstracted_axes=abstracted_axes, static_argnums=qnode.static_argnums
    508     )(*args)
    509 except (
    510     jax.errors.TracerArrayConversionError,
    511     jax.errors.TracerIntegerConversionError,
    512     jax.errors.TracerBoolConversionError,
    513 ) as exc:

    [... skipping hidden 6 frame]

File ~/repos/pennylane/pennylane/capture/flatfn.py:74, in FlatFn.__call__(self, *args, **kwargs)
     73     args = jax.tree_util.tree_unflatten(self.in_tree, args)
---> 74 out = self.f(*args, **kwargs)
     75 out_flat, out_tree = jax.tree_util.tree_flatten(out)

File /var/folders/wb/mp5pns5s4b15rv4xqbmgw1nr0000gs/T/__autograph_generated_fileh5nzvmsl.py:10, in outer_factory.<locals>.inner_factory.<locals>.ag__f(x)
      9 retval_ = ag__.UndefinedReturnValue()
---> 10 ag__.converted_call(ag__.converted_call(ag__.ld(qml).adjoint, (ag__.ld(adjoint_fn),), None, fscope), (ag__.ld(x),), None, fscope)
     11 try:

File ~/repos/pennylane/pennylane/capture/autograph/ag_primitives.py:424, in converted_call(fn, args, kwargs, caller_fn_scope, options)
    422     return new_qnode()
--> 424 return ag_converted_call(fn, args, kwargs, caller_fn_scope, options)

File ~/.pyenv/versions/pennylane/lib/python3.11/site-packages/malt/impl/api.py:382, in converted_call(f, args, kwargs, caller_fn_scope, options)
    381   else:
--> 382     result = converted_f(*effective_args)
    383 except Exception as e:

File /var/folders/wb/mp5pns5s4b15rv4xqbmgw1nr0000gs/T/__autograph_generated_filefixlyzi0.py:13, in outer_factory.<locals>.inner_factory.<locals>.ag__new_qfunc(*args, **kwargs)
     12 abstracted_axes, abstract_shapes = ag__.converted_call(ag__.ld(qml).capture.determine_abstracted_axes, (ag__.ld(args),), None, fscope)
---> 13 jaxpr = ag__.converted_call(ag__.converted_call(ag__.ld(jax).make_jaxpr, (ag__.converted_call(ag__.ld(partial), (ag__.ld(qfunc),), dict(**ag__.ld(kwargs)), fscope),), dict(abstracted_axes=ag__.ld(abstracted_axes)), fscope), tuple(ag__.ld(args)), None, fscope)
     14 flat_args = ag__.converted_call(ag__.ld(jax).tree_util.tree_leaves, (ag__.ld(args),), None, fscope)

File ~/repos/pennylane/pennylane/capture/autograph/ag_primitives.py:424, in converted_call(fn, args, kwargs, caller_fn_scope, options)
    422     return new_qnode()
--> 424 return ag_converted_call(fn, args, kwargs, caller_fn_scope, options)

File ~/.pyenv/versions/pennylane/lib/python3.11/site-packages/malt/impl/api.py:319, in converted_call(f, args, kwargs, caller_fn_scope, options)
    318 if not options.user_requested and conversion.is_allowlisted(f):
--> 319   return _call_unconverted(f, args, kwargs, options)
    321 # internal_convert_user_code is for example turned off when issuing a dynamic
    322 # call conversion from generated code while in nonrecursive mode. In that
    323 # case we evidently don't want to recurse, but we still have to convert
    324 # things like builtins.

File ~/.pyenv/versions/pennylane/lib/python3.11/site-packages/malt/impl/api.py:399, in _call_unconverted(f, args, kwargs, options, update_cache)
    398   return f(*args, **kwargs)
--> 399 return f(*args)

    [... skipping hidden 6 frame]

Cell In[3], line 8, in adjoint_fn(y)
      7 def adjoint_fn(y):
----> 8     if y < 2:
      9         qml.RX(y, 0)

    [... skipping hidden 1 frame]

File ~/.pyenv/versions/pennylane/lib/python3.11/site-packages/jax/_src/core.py:1475, in concretization_function_error.<locals>.error(self, arg)
   1474 def error(self, arg):
-> 1475   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function adjoint_fn at /var/folders/wb/mp5pns5s4b15rv4xqbmgw1nr0000gs/T/ipykernel_82049/3673553018.py:7 for make_jaxpr. This concrete value was not available in Python because it depends on the value of the argument y.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

The above exception was the direct cause of the following exception:

CaptureError                              Traceback (most recent call last)
Cell In[3], line 22
     19     qml.adjoint(adjoint_fn)(x)
     20     return qml.expval(qml.Z(0))
---> 22 f(1.5)

File ~/repos/pennylane/pennylane/workflow/qnode.py:879, in QNode.__call__(self, *args, **kwargs)
    876 if qml.capture.enabled():
    877     from ._capture_qnode import capture_qnode  # pylint: disable=import-outside-toplevel
--> 879     return capture_qnode(self, *args, **kwargs)
    880 return self._impl_call(*args, **kwargs)

File ~/repos/pennylane/pennylane/workflow/_capture_qnode.py:514, in capture_qnode(qnode, *args, **kwargs)
    506         qfunc_jaxpr = jax.make_jaxpr(
    507             flat_fn, abstracted_axes=abstracted_axes, static_argnums=qnode.static_argnums
    508         )(*args)
    509     except (
    510         jax.errors.TracerArrayConversionError,
    511         jax.errors.TracerIntegerConversionError,
    512         jax.errors.TracerBoolConversionError,
    513     ) as exc:
--> 514         raise CaptureError(
    515             "Autograph must be used when Python control flow is dependent on a dynamic "
    516             "variable (a function input). Please ensure autograph=True or use native control "
    517             "flow functions like for_loop, while_loop, etc."
    518         ) from exc
    520 execute_kwargs = copy(qnode.execute_kwargs)
    521 qnode_kwargs = {
    522     "diff_method": qnode.diff_method,
    523     **execute_kwargs,
    524     "gradient_kwargs": qnode.gradient_kwargs,
    525 }

CaptureError: Autograph must be used when Python control flow is dependent on a dynamic variable (a function input). Please ensure autograph=True or use native control flow functions like for_loop, while_loop, etc.

System information

PL dev

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@mudit2812 mudit2812 added the bug 🐛 Something isn't working label Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant