Skip to content

Commit

Permalink
Fix test_device.py
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed May 19, 2024
1 parent 9184b1a commit 1ca5a84
Showing 1 changed file with 2 additions and 11 deletions.
13 changes: 2 additions & 11 deletions tests/utils/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def _get_available_devices() -> List[str]:


@pytest.mark.parametrize("desired_device", ["multi-cuda", None])
def test_pick_devices__multi_cuda(desired_device, monkeypatch):
def test_pick_devices_multi_cuda(desired_device, monkeypatch):
def _get_available_devices() -> List[str]:
return ["cuda:0", "cuda:1", "cpu"]

monkeypatch.setattr(devices, "_get_available_devices", _get_available_devices)

picked_devices = pick_devices(["cuda", "cpu"], desired_device)
picked_devices = pick_devices(["multi-cuda", "cuda", "cpu"], desired_device)

assert picked_devices == [torch.device("cuda:0"), torch.device("cuda:1")]

Expand Down Expand Up @@ -102,15 +102,6 @@ def test_pick_devices_gpu_mps_map():
assert picked_devices == [torch.device("mps")]


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="less than 2 CUDA devices")
@pytest.mark.parametrize("desired_device", ["multi-cuda", "multi-gpu"])
def test_pick_devices_multi_cuda(desired_device):
picked_devices = pick_devices(["cpu", "cuda", "multi-cuda"], desired_device)
assert picked_devices == [
torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())
]


@pytest.mark.skipif(
torch.cuda.is_available()
or (torch.backends.mps.is_built() and torch.backends.mps.is_available()),
Expand Down

0 comments on commit 1ca5a84

Please sign in to comment.