diff --git a/tests/utils/test_device.py b/tests/utils/test_device.py index 0097eec83..ce560fd87 100644 --- a/tests/utils/test_device.py +++ b/tests/utils/test_device.py @@ -47,13 +47,13 @@ def _get_available_devices() -> List[str]: @pytest.mark.parametrize("desired_device", ["multi-cuda", None]) -def test_pick_devices__multi_cuda(desired_device, monkeypatch): +def test_pick_devices_multi_cuda(desired_device, monkeypatch): def _get_available_devices() -> List[str]: return ["cuda:0", "cuda:1", "cpu"] monkeypatch.setattr(devices, "_get_available_devices", _get_available_devices) - picked_devices = pick_devices(["cuda", "cpu"], desired_device) + picked_devices = pick_devices(["multi-cuda", "cuda", "cpu"], desired_device) assert picked_devices == [torch.device("cuda:0"), torch.device("cuda:1")] @@ -102,15 +102,6 @@ def test_pick_devices_gpu_mps_map(): assert picked_devices == [torch.device("mps")] -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="less than 2 CUDA devices") -@pytest.mark.parametrize("desired_device", ["multi-cuda", "multi-gpu"]) -def test_pick_devices_multi_cuda(desired_device): - picked_devices = pick_devices(["cpu", "cuda", "multi-cuda"], desired_device) - assert picked_devices == [ - torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count()) - ] - - @pytest.mark.skipif( torch.cuda.is_available() or (torch.backends.mps.is_built() and torch.backends.mps.is_available()),