Skip to content

Commit

Permalink
run_autograph is now idempotent (#7001)
Browse files Browse the repository at this point in the history
**Context:**

When support for control flow with `autograph` (through the `QNode`
execution pipeline) was introduced in
#6837 a minor bug was
discovered where if you try to apply `autograph` to an already
`autograph`'d function, an _ambiguous_ error would show up,
```python
qml.capture.enable()
from pennylane.capture import run_autograph

dev = qml.device("default.qubit", wires=1)

@qml.qnode(qml.device("default.qubit", wires=1), autograph=True)
def circ(x: float):
    qml.RY(x, wires=0)
    return qml.expval(qml.PauliZ(0))

>>> ag_circ = run_autograph(circ)
>>> ag_ag_circ = run_autograph(ag_circ)
ValueError: closure mismatch, requested ('ag__',), but source function had ()
```

There are a few routes to solve this issue,
1. Raise a more helpful error
2. Throw a warning and proceed with converted function (Catalyst does
this)
3. *Silently* proceed with converted function

**Description of the Change:**

Option (2) was preferred and so if an already converted function is
detected (by the presence of the `ag_uncoverted` attribute), we throw an
`AutoGraphWarning` and just return early. Now we have the behaviour,
```python
qml.capture.enable()
from pennylane.capture import run_autograph

dev = qml.device("default.qubit", wires=1)

@qml.qnode(qml.device("default.qubit", wires=1), autograph=True)
def circ(x: float):
    qml.RY(x, wires=0)
    return qml.expval(qml.PauliZ(0))

ag_circ = run_autograph(circ)
>>> ag_ag_circ = run_autograph(ag_circ)
AutoGraphWarning: AutoGraph will not transform the function <function ...> as it has already been transformed.
>>> ag_ag_circ.func is ag_circ.func
True
>>> ag_ag_circ(-1)
Array(0.5403023, dtype=float32)
```

**Benefits:** Enables `autograph` specific tests to run with the default
`autograph=True` argument through the `QNode`.

**Possible Drawbacks:** 

Suppressed the warnings in the `autograph` test folder. This means that
a future test that emits this warning could be missed if we don't pay
attention to the WAE action that is run weekly.

Fixes [sc-83366]
  • Loading branch information
andrijapau authored Feb 27, 2025
1 parent 8e30920 commit 2c3a11f
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 99 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,10 @@

<h3>Internal changes ⚙️</h3>

* `qml.capture.run_autograph` is now idempotent.
This means `run_autograph(fn) = run_autograph(run_autograph(fn))`.
[(#7001)](https://github.com/PennyLaneAI/pennylane/pull/7001)

* Minor changes to `DQInterpreter` for speedups with program capture execution.
[(#6984)](https://github.com/PennyLaneAI/pennylane/pull/6984)

Expand Down
2 changes: 2 additions & 0 deletions pennylane/capture/autograph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
run_autograph,
)

from .ag_primitives import AutoGraphWarning

__all__ = (
"autograph_source",
"run_autograph",
Expand Down
4 changes: 4 additions & 0 deletions pennylane/capture/autograph/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
]


class AutoGraphWarning(Warning):
"""Warnings related to PennyLane's AutoGraph submodule."""


class AutoGraphError(Exception):
"""Errors related to PennyLane's AutoGraph submodule."""

Expand Down
18 changes: 16 additions & 2 deletions pennylane/capture/autograph/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
"""
import copy
import inspect
import warnings

from malt.core import converter
from malt.impl.api import PyToPy

import pennylane as qml

from . import ag_primitives
from .ag_primitives import AutoGraphError
from .ag_primitives import AutoGraphError, AutoGraphWarning


class PennyLaneTransformer(PyToPy):
Expand Down Expand Up @@ -60,7 +61,20 @@ def transform(self, obj, user_context):
else:
raise AutoGraphError(f"Unsupported object for transformation: {type(fn)}")

new_fn, module, source_map = self.transform_function(fn, user_context)
# Check if the function has already been converted.
if hasattr(fn, "ag_unconverted"):
warnings.warn(
f"AutoGraph will not transform the function {fn} as it has already been transformed.",
AutoGraphWarning,
)
new_fn, module, source_map = (
fn,
getattr(fn, "ag_module", None),
getattr(fn, "ag_source_map", None),
)
else:
new_fn, module, source_map = self.transform_function(fn, user_context)

new_obj = new_fn

if isinstance(obj, qml.QNode):
Expand Down
36 changes: 36 additions & 0 deletions tests/capture/autograph/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2018-2020 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Pytest configuration file for AutoGraph test folder.
"""
import warnings

import pytest

from pennylane.capture.autograph import AutoGraphWarning


# pylint: disable=unused-import
# This is intended to suppress the *expected* warnings that arise when
# testing AutoGraph transformation functions with a `QNode` (which by default
# has AutoGraph transformations applied to it due to the `autograph` argument).
@pytest.fixture(autouse=True)
def filter_expected_warnings():
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=AutoGraphWarning,
message=r"AutoGraph will not transform the function .* as it has already been transformed\.",
)
yield
46 changes: 21 additions & 25 deletions tests/capture/autograph/test_autograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def test_transform_on_lambda(self):
@pytest.mark.parametrize("autograph", [True, False])
def test_transform_on_qnode(self, autograph):
"""Test the transform method on a QNode updates the qnode.func"""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")
transformer = PennyLaneTransformer()
user_context = converter.ProgramContext(TOPLEVEL_OPTIONS)

Expand Down Expand Up @@ -138,6 +136,18 @@ def fn(x):
class TestIntegration:
"""Test that the autograph transformations trigger correctly in different settings."""

def test_run_autograph_on_converted_function(self):
"""Test that running run_autograph on a function that has already been converted
does not trigger the transformation again."""

def fn(x):
return x**2

ag_fn = run_autograph(fn)
ag_ag_fn = run_autograph(ag_fn)
assert ag_ag_fn is ag_fn
assert ag_ag_fn(4) == 16

def test_unsupported_object(self):
"""Check the error produced when attempting to convert an unsupported object (neither of
QNode, function, method or callable)."""
Expand Down Expand Up @@ -207,8 +217,6 @@ def fn(x: int):
@pytest.mark.parametrize("autograph", [True, False])
def test_qnode(self, autograph):
"""Test autograph on a QNode."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def circ(x: float):
Expand Down Expand Up @@ -266,8 +274,6 @@ def fn(x: float):
@pytest.mark.parametrize("autograph", [True, False])
def test_adjoint_op(self, autograph):
"""Test that the adjoint of an operator successfully passes through autograph"""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=2), autograph=autograph)
def circ():
Expand All @@ -280,8 +286,6 @@ def circ():
@pytest.mark.parametrize("autograph", [True, False])
def test_ctrl_op(self, autograph):
"""Test that controlled operators successfully pass through autograph"""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=2), autograph=autograph)
def circ():
Expand All @@ -295,8 +299,6 @@ def circ():
@pytest.mark.parametrize("autograph", [True, False])
def test_adjoint_wrapper(self, autograph):
"""Test conversion is happening successfully on functions wrapped with 'adjoint'."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

def inner(x):
qml.RY(x, wires=0)
Expand All @@ -321,8 +323,6 @@ def circ(x: float):
@pytest.mark.parametrize("autograph", [True, False])
def test_ctrl_wrapper(self, autograph):
"""Test conversion is happening successfully on functions wrapped with 'ctrl'."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

def inner(x):
qml.RY(x, wires=0)
Expand Down Expand Up @@ -375,25 +375,21 @@ def fn(x: float):
@pytest.mark.parametrize("autograph", [True, False])
def test_tape_transform(self, autograph):
"""Test if tape transform is applied when autograph is on."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")
dev = dev = qml.device("default.qubit", wires=1)

dev = qml.device("default.qubit", wires=1)

@qml.transform
def my_quantum_transform(tape):
raise NotImplementedError

def fn(x):
@my_quantum_transform
@qml.qnode(dev, autograph=autograph)
def circuit(x):
qml.RY(x, wires=0)
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))

return circuit(x)
@my_quantum_transform
@qml.qnode(dev, autograph=autograph)
def circuit(x):
qml.RY(x, wires=0)
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))

ag_fn = run_autograph(fn)
ag_fn = run_autograph(circuit)

with pytest.raises(NotImplementedError):
ag_fn(0.5)
Expand Down
40 changes: 0 additions & 40 deletions tests/capture/autograph/test_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ class TestForLoops:
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_array(self, autograph):
"""Test for loop over JAX array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f(params):
Expand All @@ -138,8 +136,6 @@ def res(params):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_array_unpack(self, autograph):
"""Test for loop over a 2D JAX array unpacking the inner dimension."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f(params):
Expand All @@ -159,8 +155,6 @@ def f(params):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_numeric_list(self, autograph):
"""Test for loop over a Python list that is convertible to an array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f():
Expand All @@ -179,8 +173,6 @@ def f():
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_numeric_list_of_list(self, autograph):
"""Test for loop over a nested Python list that is convertible to an array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f():
Expand All @@ -203,8 +195,6 @@ def f():
def test_for_in_object_list(self, autograph):
"""Test for loop over a Python list that is *not* convertible to an array.
The behaviour should fall back to standard Python."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f():
Expand All @@ -222,8 +212,6 @@ def f():
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_static_range(self, autograph):
"""Test for loop over a Python range with static bounds."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=3), autograph=autograph)
def f():
Expand All @@ -240,8 +228,6 @@ def f():
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_static_range_indexing_array(self, autograph):
"""Test for loop over a Python range with static bounds that is used to index an array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f():
Expand All @@ -259,8 +245,6 @@ def f():
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_dynamic_range(self, autograph):
"""Test for loop over a Python range with dynamic bounds."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=3), autograph=autograph)
def f(n: int):
Expand All @@ -277,8 +261,6 @@ def f(n: int):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_dynamic_range_indexing_array(self, autograph):
"""Test for loop over a Python range with dynamic bounds that is used to index an array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f(n: int):
Expand All @@ -296,8 +278,6 @@ def f(n: int):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_enumerate_array(self, autograph):
"""Test for loop over a Python enumeration on an array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=3), autograph=autograph)
def f(params):
Expand All @@ -316,8 +296,6 @@ def f(params):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_enumerate_array_no_unpack(self, autograph):
"""Test for loop over a Python enumeration with delayed unpacking."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=3), autograph=autograph)
def f(params):
Expand All @@ -336,8 +314,6 @@ def f(params):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_enumerate_nested_unpack(self, autograph):
"""Test for loop over a Python enumeration with nested unpacking."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=3), autograph=autograph)
def f(params):
Expand All @@ -359,8 +335,6 @@ def f(params):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_enumerate_start(self, autograph):
"""Test for loop over a Python enumeration with offset indices."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=5), autograph=autograph)
def f(params):
Expand All @@ -379,8 +353,6 @@ def f(params):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_enumerate_numeric_list(self, autograph):
"""Test for loop over a Python enumeration on a list that is convertible to an array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=3), autograph=autograph)
def f():
Expand Down Expand Up @@ -417,8 +389,6 @@ def f():
def test_for_in_enumerate_object_list(self, autograph):
"""Test for loop over a Python enumeration on a list that is *not* convertible to an array.
The behaviour should fall back to standard Python."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=3), autograph=autograph)
def f():
Expand All @@ -440,8 +410,6 @@ def f():
def test_for_in_other_iterable_object(self, autograph):
"""Test for loop over arbitrary iterable Python objects.
The behaviour should fall back to standard Python."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f():
Expand Down Expand Up @@ -606,8 +574,6 @@ class TestErrors:
def test_for_in_object_list(self, autograph):
"""Check the error raised when a for loop iterates over a Python list that
is *not* convertible to an array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f():
Expand All @@ -624,8 +590,6 @@ def test_for_in_static_range_indexing_numeric_list(self, autograph):
"""Test an informative error is raised when using a for loop with a static range
to index through an array-compatible Python list. This can be fixed by wrapping the
list in a jax array, so the error raised here is actionable."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f():
Expand All @@ -645,8 +609,6 @@ def test_for_in_dynamic_range_indexing_numeric_list(self, autograph):
"""Test an informative error is raised when using a for loop with a dynamic range
to index through an array-compatible Python list. This can be fixed by wrapping the
list in a jax array, so the error raised here is actionable."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f(n: int):
Expand All @@ -666,8 +628,6 @@ def test_for_in_dynamic_range_indexing_object_list(self, autograph):
"""Test that an error is raised for a for loop over a Python range with dynamic bounds
that is used to index an array-incompatible Python list. This use-case is never possible,
even with AutoGraph, because the list can't be wrapped in a jax array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f(n: int):
Expand Down
Loading

0 comments on commit 2c3a11f

Please sign in to comment.