Skip to content

Commit

Permalink
Do not try to use MPS backend on Github Actions
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Sep 2, 2024
1 parent 6fb85c6 commit b9c5d56
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
12 changes: 11 additions & 1 deletion python/rascaline-torch/tests/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,23 @@ def compute(weights):
)


def can_use_mps_backend():
return (
# Github Actions M1 runners don't have a GPU accessible
os.environ.get("GITHUB_ACTIONS") is None
and hasattr(torch.backends, "mps")
and torch.backends.mps.is_built()
and torch.backends.mps.is_available()
)


def test_different_device_dtype():
# check autograd if the data is on different devices/dtypes as well
options = [
(torch.device("cpu"), torch.float32),
(torch.device("cpu"), torch.float64),
]
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
if can_use_mps_backend():
options.append((torch.device("mps:0"), torch.float32))

if torch.cuda.is_available():
Expand Down
12 changes: 11 additions & 1 deletion python/rascaline-torch/tests/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_different_device_dtype_errors(system):

# Different devices
custom_device = None
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
if can_use_mps_backend():
custom_device = torch.device("mps:0")

if torch.cuda.is_available():
Expand Down Expand Up @@ -234,3 +234,13 @@ def forward(
with tmpdir.as_cwd():
torch.jit.save(module, "test-save.torch")
module = torch.jit.load("test-save.torch")


def can_use_mps_backend():
return (
# Github Actions M1 runners don't have a GPU accessible
os.environ.get("GITHUB_ACTIONS") is None
and hasattr(torch.backends, "mps")
and torch.backends.mps.is_built()
and torch.backends.mps.is_available()
)

0 comments on commit b9c5d56

Please sign in to comment.