diff --git a/tests/utils/test_device.py b/tests/utils/test_device.py index 0097eec83..ee515ec23 100644 --- a/tests/utils/test_device.py +++ b/tests/utils/test_device.py @@ -46,18 +46,6 @@ def _get_available_devices() -> List[str]: assert picked_devices == [torch.device("mps")] -@pytest.mark.parametrize("desired_device", ["multi-cuda", None]) -def test_pick_devices__multi_cuda(desired_device, monkeypatch): - def _get_available_devices() -> List[str]: - return ["cuda:0", "cuda:1", "cpu"] - - monkeypatch.setattr(devices, "_get_available_devices", _get_available_devices) - - picked_devices = pick_devices(["cuda", "cpu"], desired_device) - - assert picked_devices == [torch.device("cuda:0"), torch.device("cuda:1")] - - def test_pick_devices_unsoprted(): match = "Unsupported desired device 'cuda'. Please choose from cpu." with pytest.raises(ValueError, match=match):