Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run_autograph is now idempotent #7001

Merged
merged 10 commits into from
Feb 27, 2025
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@

<h3>Internal changes ⚙️</h3>

* `qml.capture.run_autograph` is now idempotent.
This means `run_autograph(fn) = run_autograph(run_autograph(fn))`.
[(#7001)](https://github.com/PennyLaneAI/pennylane/pull/7001)

* Minor changes to `DQInterpreter` for speedups with program capture execution.
[(#6984)](https://github.com/PennyLaneAI/pennylane/pull/6984)

Expand Down
11 changes: 10 additions & 1 deletion pennylane/capture/autograph/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,16 @@ def transform(self, obj, user_context):
else:
raise AutoGraphError(f"Unsupported object for transformation: {type(fn)}")

new_fn, module, source_map = self.transform_function(fn, user_context)
# Check if the function has already been converted.

if hasattr(fn, "ag_unconverted"):
new_fn, module, source_map = (
fn,
getattr(fn, "ag_module", None),
getattr(fn, "ag_source_map", None),
)
else:
new_fn, module, source_map = self.transform_function(fn, user_context)
new_obj = new_fn

if isinstance(obj, qml.QNode):
Expand Down
46 changes: 21 additions & 25 deletions tests/capture/autograph/test_autograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def test_transform_on_lambda(self):
@pytest.mark.parametrize("autograph", [True, False])
def test_transform_on_qnode(self, autograph):
"""Test the transform method on a QNode updates the qnode.func"""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")
transformer = PennyLaneTransformer()
user_context = converter.ProgramContext(TOPLEVEL_OPTIONS)

Expand Down Expand Up @@ -138,6 +136,18 @@ def fn(x):
class TestIntegration:
"""Test that the autograph transformations trigger correctly in different settings."""

def test_run_autograph_on_converted_function(self):
"""Test that running run_autograph on a function that has already been converted
does not trigger the transformation again."""

def fn(x):
return x**2

ag_fn = run_autograph(fn)
ag_ag_fn = run_autograph(ag_fn)
assert ag_ag_fn == ag_fn
assert ag_ag_fn(4) == 16

def test_unsupported_object(self):
"""Check the error produced when attempting to convert an unsupported object (neither of
QNode, function, method or callable)."""
Expand Down Expand Up @@ -207,8 +217,6 @@ def fn(x: int):
@pytest.mark.parametrize("autograph", [True, False])
def test_qnode(self, autograph):
"""Test autograph on a QNode."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def circ(x: float):
Expand Down Expand Up @@ -266,8 +274,6 @@ def fn(x: float):
@pytest.mark.parametrize("autograph", [True, False])
def test_adjoint_op(self, autograph):
"""Test that the adjoint of an operator successfully passes through autograph"""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=2), autograph=autograph)
def circ():
Expand All @@ -280,8 +286,6 @@ def circ():
@pytest.mark.parametrize("autograph", [True, False])
def test_ctrl_op(self, autograph):
"""Test that controlled operators successfully pass through autograph"""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=2), autograph=autograph)
def circ():
Expand All @@ -295,8 +299,6 @@ def circ():
@pytest.mark.parametrize("autograph", [True, False])
def test_adjoint_wrapper(self, autograph):
"""Test conversion is happening successfully on functions wrapped with 'adjoint'."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

def inner(x):
qml.RY(x, wires=0)
Expand All @@ -321,8 +323,6 @@ def circ(x: float):
@pytest.mark.parametrize("autograph", [True, False])
def test_ctrl_wrapper(self, autograph):
"""Test conversion is happening successfully on functions wrapped with 'ctrl'."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

def inner(x):
qml.RY(x, wires=0)
Expand Down Expand Up @@ -375,25 +375,21 @@ def fn(x: float):
@pytest.mark.parametrize("autograph", [True, False])
def test_tape_transform(self, autograph):
"""Test if tape transform is applied when autograph is on."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")
dev = dev = qml.device("default.qubit", wires=1)

dev = qml.device("default.qubit", wires=1)

@qml.transform
def my_quantum_transform(tape):
raise NotImplementedError

def fn(x):
@my_quantum_transform
@qml.qnode(dev, autograph=autograph)
def circuit(x):
qml.RY(x, wires=0)
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))

return circuit(x)
@my_quantum_transform
@qml.qnode(dev, autograph=autograph)
def circuit(x):
qml.RY(x, wires=0)
qml.RX(x, wires=0)
return qml.expval(qml.PauliZ(0))

ag_fn = run_autograph(fn)
ag_fn = run_autograph(circuit)

with pytest.raises(NotImplementedError):
ag_fn(0.5)
Expand Down
40 changes: 0 additions & 40 deletions tests/capture/autograph/test_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ class TestForLoops:
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_array(self, autograph):
"""Test for loop over JAX array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f(params):
Expand All @@ -138,8 +136,6 @@ def res(params):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_array_unpack(self, autograph):
"""Test for loop over a 2D JAX array unpacking the inner dimension."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f(params):
Expand All @@ -159,8 +155,6 @@ def f(params):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_numeric_list(self, autograph):
"""Test for loop over a Python list that is convertible to an array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f():
Expand All @@ -179,8 +173,6 @@ def f():
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_numeric_list_of_list(self, autograph):
"""Test for loop over a nested Python list that is convertible to an array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f():
Expand All @@ -203,8 +195,6 @@ def f():
def test_for_in_object_list(self, autograph):
"""Test for loop over a Python list that is *not* convertible to an array.
The behaviour should fall back to standard Python."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f():
Expand All @@ -222,8 +212,6 @@ def f():
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_static_range(self, autograph):
"""Test for loop over a Python range with static bounds."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=3), autograph=autograph)
def f():
Expand All @@ -240,8 +228,6 @@ def f():
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_static_range_indexing_array(self, autograph):
"""Test for loop over a Python range with static bounds that is used to index an array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f():
Expand All @@ -259,8 +245,6 @@ def f():
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_dynamic_range(self, autograph):
"""Test for loop over a Python range with dynamic bounds."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=3), autograph=autograph)
def f(n: int):
Expand All @@ -277,8 +261,6 @@ def f(n: int):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_dynamic_range_indexing_array(self, autograph):
"""Test for loop over a Python range with dynamic bounds that is used to index an array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f(n: int):
Expand All @@ -296,8 +278,6 @@ def f(n: int):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_enumerate_array(self, autograph):
"""Test for loop over a Python enumeration on an array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=3), autograph=autograph)
def f(params):
Expand All @@ -316,8 +296,6 @@ def f(params):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_enumerate_array_no_unpack(self, autograph):
"""Test for loop over a Python enumeration with delayed unpacking."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=3), autograph=autograph)
def f(params):
Expand All @@ -336,8 +314,6 @@ def f(params):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_enumerate_nested_unpack(self, autograph):
"""Test for loop over a Python enumeration with nested unpacking."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=3), autograph=autograph)
def f(params):
Expand All @@ -359,8 +335,6 @@ def f(params):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_enumerate_start(self, autograph):
"""Test for loop over a Python enumeration with offset indices."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=5), autograph=autograph)
def f(params):
Expand All @@ -379,8 +353,6 @@ def f(params):
@pytest.mark.parametrize("autograph", [True, False])
def test_for_in_enumerate_numeric_list(self, autograph):
"""Test for loop over a Python enumeration on a list that is convertible to an array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=3), autograph=autograph)
def f():
Expand Down Expand Up @@ -417,8 +389,6 @@ def f():
def test_for_in_enumerate_object_list(self, autograph):
"""Test for loop over a Python enumeration on a list that is *not* convertible to an array.
The behaviour should fall back to standard Python."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=3), autograph=autograph)
def f():
Expand All @@ -440,8 +410,6 @@ def f():
def test_for_in_other_iterable_object(self, autograph):
"""Test for loop over arbitrary iterable Python objects.
The behaviour should fall back to standard Python."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f():
Expand Down Expand Up @@ -606,8 +574,6 @@ class TestErrors:
def test_for_in_object_list(self, autograph):
"""Check the error raised when a for loop iterates over a Python list that
is *not* convertible to an array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f():
Expand All @@ -624,8 +590,6 @@ def test_for_in_static_range_indexing_numeric_list(self, autograph):
"""Test an informative error is raised when using a for loop with a static range
to index through an array-compatible Python list. This can be fixed by wrapping the
list in a jax array, so the error raised here is actionable."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f():
Expand All @@ -645,8 +609,6 @@ def test_for_in_dynamic_range_indexing_numeric_list(self, autograph):
"""Test an informative error is raised when using a for loop with a dynamic range
to index through an array-compatible Python list. This can be fixed by wrapping the
list in a jax array, so the error raised here is actionable."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f(n: int):
Expand All @@ -666,8 +628,6 @@ def test_for_in_dynamic_range_indexing_object_list(self, autograph):
"""Test that an error is raised for a for loop over a Python range with dynamic bounds
that is used to index an array-incompatible Python list. This use-case is never possible,
even with AutoGraph, because the list can't be wrapped in a jax array."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f(n: int):
Expand Down
4 changes: 0 additions & 4 deletions tests/capture/autograph/test_if_else.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ def res(x):
@pytest.mark.parametrize("autograph", [True, False])
def test_qubit_manipulation_cond(self, autograph):
"""Test conditional with quantum operation."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def circuit(x):
Expand Down Expand Up @@ -291,8 +289,6 @@ def circuit():
def test_multiple_return_different_measurements(self, autograph):
"""Test that different measurements be used in the return in different branches, as
they are all represented by the AbstractMeasurement class."""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=1), autograph=autograph)
def f(switch: bool):
Expand Down
2 changes: 0 additions & 2 deletions tests/capture/autograph/test_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ def f(param):
@pytest.mark.parametrize("autograph", [True, False])
def test_whileloop_qnode(self, autograph):
"""Test while-loop used with a qnode"""
if autograph:
pytest.xfail(reason="Autograph cannot be applied twice in a row. See sc-83366")

@qml.qnode(qml.device("default.qubit", wires=4), autograph=autograph)
def f(p):
Expand Down
Loading