Skip to content

Commit

Permalink
gracefully handle when pytorch is not installed
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Feb 7, 2024
1 parent d2c8d38 commit 2435a50
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pygsti/forwardsims/torchfwdsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,14 @@ def jac_friendly_circuit_probs(self, *free_params: Tuple[torch.Tensor]):


class TorchForwardSimulator(ForwardSimulator):

ENABLED = TORCH_ENABLED

"""
A forward simulator that leverages automatic differentiation in PyTorch.
"""
def __init__(self, model : Optional[ExplicitOpModel] = None):
if not TORCH_ENABLED:
if not TorchForwardSimulator.ENABLED:
raise RuntimeError('PyTorch could not be imported.')
self.model = model
super(ForwardSimulator, self).__init__(model)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

# Extra requirements
extras = {
'pytorch' : ['torch'],
'diamond_norm': [
'cvxopt',
'cvxpy'
Expand Down
2 changes: 2 additions & 0 deletions test/unit/objects/test_forwardsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest import mock

import numpy as np
import pytest

from pygsti.models import modelconstruction as _setc
import pygsti.models as models
Expand Down Expand Up @@ -177,6 +178,7 @@ def test_simple_matrix_fwdsim(self):
def test_simple_map_fwdsim(self):
self._run(SimpleMapForwardSimulator)

@pytest.mark.skipif(not TorchForwardSimulator.ENABLED, reason="PyTorch is not installed.")
def test_torch_fwdsim(self):
self._run(TorchForwardSimulator)

Expand Down

0 comments on commit 2435a50

Please sign in to comment.