-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use default device obtained from architecture (#165)
- Loading branch information
1 parent
6f02805
commit 5548548
Showing
6 changed files
with
207 additions
and
179 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
device: "cpu" | ||
device: ${default_device:} | ||
base_precision: ${default_precision:} | ||
seed: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,127 +1,87 @@ | ||
import warnings | ||
from typing import List | ||
from typing import List, Optional | ||
|
||
import torch | ||
|
||
|
||
def get_available_devices() -> List[torch.device]: | ||
"""Returns a list of available torch devices. | ||
This function returns a list of available torch devices, which can | ||
be used to specify the devices on which to run a model. | ||
:return: The list of available torch devices. | ||
""" | ||
devices = [torch.device("cpu")] | ||
def _get_available_devices() -> List[str]: | ||
available_devices = ["cpu"] | ||
if torch.cuda.is_available(): | ||
device_count = torch.cuda.device_count() | ||
for i in range(device_count): | ||
devices.append(torch.device(f"cuda:{i}")) | ||
available_devices.append("cuda") | ||
if torch.cuda.device_count() > 1: | ||
available_devices.append("multi-cuda") | ||
# for torch<2.0 `torch.backends.mps.is_available()` is required for a reasonable | ||
# check. | ||
if torch.backends.mps.is_built() and torch.backends.mps.is_available(): | ||
devices.append(torch.device("mps")) | ||
return devices | ||
available_devices.append("mps") | ||
|
||
return available_devices | ||
|
||
|
||
def pick_devices( | ||
requested_device: str, | ||
available_devices: List[torch.device], | ||
architecture_devices: List[str], | ||
desired_device: Optional[str] = None, | ||
) -> List[torch.device]: | ||
"""Picks the devices to use for training. | ||
This function picks the devices to use for training based on the | ||
requested device, the available devices, and the list of devices | ||
supported by the architecture. | ||
The choice is based on the following logic. First, the requested | ||
device is checked to see if it is supported (i.e., one of "cpu", | ||
"cuda", "mps", "gpu", "multi-gpu", or "multi-cuda"). Then, the | ||
requested device is checked to see if it is available on the system. | ||
Finally, the requested device is checked to see if it is supported | ||
by the architecture. If the requested device is not supported by the | ||
architecture, a ValueError is raised. If the requested device is | ||
supported by the architecture, but a different device is preferred | ||
by the architecture and present on the system, a warning is issued. | ||
:param requested_device: The requested device. | ||
:param available_devices: The available devices. | ||
:param architecture_devices: The devices supported by the architecture. | ||
"""Pick (best) devices for training. | ||
The choice is made on the intersection of the ``architecture_devices`` and the | ||
available devices on the current system. If no ``desired_device`` is provided the | ||
first device of this intersection will be returned. | ||
:param architecture_devices: Devices supported by the architecture. The list should | ||
be sorted by the preference of the architecture while the most prefferred device | ||
should be first and the least one last. | ||
:param desired_device: desired device by the user | ||
""" | ||
|
||
requested_device = requested_device.lower() | ||
|
||
# first, we check that the requested device is supported | ||
if requested_device not in ["cpu", "cuda", "multi-cuda", "mps", "gpu", "multi-gpu"]: | ||
raise ValueError( | ||
f"Unsupported device: `{requested_device}`. Please choose from " | ||
"cpu, cuda, mps, gpu, multi-gpu, multi-cuda" | ||
) | ||
|
||
# we convert "gpu" and "multi-gpu" to "cuda" or "mps" if available | ||
if requested_device == "gpu": | ||
if torch.cuda.is_available(): | ||
requested_device = "cuda" | ||
elif torch.backends.mps.is_built() and torch.backends.mps.is_available(): | ||
requested_device = "mps" | ||
else: | ||
raise ValueError( | ||
"Requested `gpu` device, but found no GPU (CUDA or MPS) devices" | ||
) | ||
available_devices = _get_available_devices() | ||
|
||
# we convert "multi-gpu" to "multi-cuda" | ||
if requested_device == "multi-gpu": | ||
requested_device = "multi-cuda" | ||
|
||
# check that the requested device is available | ||
available_device_types = [device.type for device in available_devices] | ||
available_device_strings = ["cpu"] # always available | ||
if "cuda" in available_device_types: | ||
available_device_strings.append("cuda") | ||
if "mps" in available_device_types: | ||
available_device_strings.append("mps") | ||
if available_device_strings.count("cuda") > 1: | ||
available_device_strings.append("multi-cuda") | ||
|
||
if requested_device not in available_device_strings: | ||
if requested_device == "multi-cuda": | ||
if available_device_strings.count("cuda") == 0: | ||
raise ValueError( | ||
"Requested device `multi-gpu` or `multi-cuda`, " | ||
"but found no cuda devices" | ||
) | ||
# intersect between available and architecture's devices. keep order of architecture | ||
possible_devices = [d for d in architecture_devices if d in available_devices] | ||
|
||
# cpu device should always be available | ||
assert "cpu" in possible_devices | ||
|
||
# If desired device given compare the possible devices and try to find a match | ||
if desired_device is None: | ||
desired_device = possible_devices[0] | ||
else: | ||
desired_device = desired_device.lower() | ||
|
||
# convert "gpu" and "multi-gpu" to "cuda" or "mps" if available | ||
if desired_device == "gpu": | ||
if torch.cuda.is_available(): | ||
desired_device = "cuda" | ||
elif torch.backends.mps.is_built() and torch.backends.mps.is_available(): | ||
desired_device = "mps" | ||
else: | ||
raise ValueError( | ||
"Requested device `multi-gpu` or `multi-cuda`, " | ||
"but found only one cuda device. If you want to run on a " | ||
"single GPU, please use `gpu` or `cuda` instead." | ||
"Requested 'gpu' device, but found no GPU (CUDA or MPS) devices." | ||
) | ||
else: | ||
if desired_device == "multi-gpu": | ||
desired_device = "multi-cuda" | ||
|
||
if desired_device not in possible_devices: | ||
raise ValueError( | ||
f"Requested device `{requested_device}` is not available on this system" | ||
f"Unsupported desired device {desired_device!r}. " | ||
f"Please choose from {', '.join(possible_devices)}." | ||
) | ||
if desired_device == "multi-cuda" and torch.cuda.device_count() < 2: | ||
raise ValueError( | ||
"Requested device 'multi-gpu' or 'multi-cuda', but found only one CUDA " | ||
"device. If you want to run on a single GPU, please use 'gpu' or " | ||
"'cuda' instead." | ||
) | ||
|
||
# if the requested device is available, check it against the architecture's devices | ||
if requested_device not in architecture_devices: | ||
raise ValueError( | ||
f"The requested device `{requested_device}` is not supported by the chosen " | ||
f"architecture. Supported devices are {architecture_devices}." | ||
) | ||
|
||
# we check all the devices that come before the requested one in the | ||
# list of architecture devices. If any of them are available, we warn | ||
|
||
requested_device_index = architecture_devices.index(requested_device) | ||
for device in architecture_devices[:requested_device_index]: | ||
if device in available_device_strings: | ||
if possible_devices.index(desired_device) > 0: | ||
warnings.warn( | ||
f"Device `{requested_device}` was requested, but the chosen " | ||
f"architecture prefers `{device}`, which was also found on your " | ||
f"system. Consider using the `{device}` device.", | ||
f"Device {desired_device!r} requested, but {possible_devices[0]!r} is " | ||
"prefferred by the architecture and available on current system.", | ||
stacklevel=2, | ||
) | ||
|
||
# finally, we convert the requested device to a list of devices | ||
if requested_device == "multi-cuda": | ||
return [device for device in available_devices if device.type == "cuda"] | ||
# convert the requested device to a list of torch devices | ||
if desired_device == "multi-cuda": | ||
return [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())] | ||
else: | ||
return [torch.device(requested_device)] | ||
return [torch.device(desired_device)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.