You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
/usr/local/lib/python3.10/dist-packages/pennylane/gradients/gradient_transform.py in assert_no_trainable_tape_batching(tape, transform_name)
95 for idx in range(len(tape.trainable_params)):
96 if tape.get_operation(idx)[0].batch_size is not None:
---> 97 raise NotImplementedError(
98 "Computing the gradient of broadcasted tapes with respect to the broadcasted "
99 f"parameters using the {transform_name} gradient transform is currently not "
Would this imply there is some behavior not implemented for diff_method='parameter-shift?
@mews6 yep that's correct! We never updated the parameter-shift rule to support broadcasting, however it seems that the torch layer is providing the parameter-shift rule with a broadcasted tape, which is breaking. @dwierichs may have more technical details here.
As @josh146 and the error message say, if trainable parameters are batched/broadcasted, param_shift can not handle that.
However, oftentimes the training data is batched, whereas the trainable parameters are not batched. Could this be the case here, and the batched parameters can actually be marked as non-trainable? Non-trainable batched parameters are supported by param_shift.
Expected behavior
The TorchLayer demo should work with diff_method='parameter-shift'
Actual behavior
It throws an error.
Additional information
This question originated from Forum thread 4940.
Source code
Tracebacks
System information
Name: PennyLane Version: 0.37.0 Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network. Home-page: https://github.com/PennyLaneAI/pennylane Author: Author-email: License: Apache License 2.0 Location: /usr/local/lib/python3.10/dist-packages Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions Required-by: PennyLane_Lightning Platform info: Linux-6.1.85+-x86_64-with-glibc2.35 Python version: 3.10.12 Numpy version: 1.26.4 Scipy version: 1.13.1 Installed devices: - lightning.qubit (PennyLane_Lightning-0.37.0) - default.clifford (PennyLane-0.37.0) - default.gaussian (PennyLane-0.37.0) - default.mixed (PennyLane-0.37.0) - default.qubit (PennyLane-0.37.0) - default.qubit.autograd (PennyLane-0.37.0) - default.qubit.jax (PennyLane-0.37.0) - default.qubit.legacy (PennyLane-0.37.0) - default.qubit.tf (PennyLane-0.37.0) - default.qubit.torch (PennyLane-0.37.0) - default.qutrit (PennyLane-0.37.0) - default.qutrit.mixed (PennyLane-0.37.0) - default.tensor (PennyLane-0.37.0) - null.qubit (PennyLane-0.37.0)
Existing GitHub issues
The text was updated successfully, but these errors were encountered: