Skip to content

Commit

Permalink
Use default device obtained from architecture (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri authored Apr 11, 2024
1 parent 6f02805 commit 5548548
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 179 deletions.
2 changes: 1 addition & 1 deletion src/metatensor/models/cli/conf/base.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
device: "cpu"
device: ${default_device:}
base_precision: ${default_precision:}
seed: null
5 changes: 2 additions & 3 deletions src/metatensor/models/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .. import CONFIG_PATH
from ..utils.data import DatasetInfo, TargetInfo, read_systems, read_targets
from ..utils.data.dataset import _train_test_random_split
from ..utils.devices import get_available_devices, pick_devices
from ..utils.devices import pick_devices
from ..utils.errors import ArchitectureError
from ..utils.io import export, save
from ..utils.omegaconf import check_options_list, check_units, expand_dataset_config
Expand Down Expand Up @@ -215,9 +215,8 @@ def _train_model_hydra(options: DictConfig) -> None:
# PROCESS BASE PARAMETERS #
###########################
devices = pick_devices(
requested_device=options["device"],
available_devices=get_available_devices(),
architecture_devices=architecture_capabilities["supported_devices"],
desired_device=options["device"],
)

# process dtypes
Expand Down
162 changes: 61 additions & 101 deletions src/metatensor/models/utils/devices.py
Original file line number Diff line number Diff line change
@@ -1,127 +1,87 @@
import warnings
from typing import List
from typing import List, Optional

import torch


def get_available_devices() -> List[torch.device]:
"""Returns a list of available torch devices.
This function returns a list of available torch devices, which can
be used to specify the devices on which to run a model.
:return: The list of available torch devices.
"""
devices = [torch.device("cpu")]
def _get_available_devices() -> List[str]:
available_devices = ["cpu"]
if torch.cuda.is_available():
device_count = torch.cuda.device_count()
for i in range(device_count):
devices.append(torch.device(f"cuda:{i}"))
available_devices.append("cuda")
if torch.cuda.device_count() > 1:
available_devices.append("multi-cuda")
# for torch<2.0 `torch.backends.mps.is_available()` is required for a reasonable
# check.
if torch.backends.mps.is_built() and torch.backends.mps.is_available():
devices.append(torch.device("mps"))
return devices
available_devices.append("mps")

return available_devices


def pick_devices(
requested_device: str,
available_devices: List[torch.device],
architecture_devices: List[str],
desired_device: Optional[str] = None,
) -> List[torch.device]:
"""Picks the devices to use for training.
This function picks the devices to use for training based on the
requested device, the available devices, and the list of devices
supported by the architecture.
The choice is based on the following logic. First, the requested
device is checked to see if it is supported (i.e., one of "cpu",
"cuda", "mps", "gpu", "multi-gpu", or "multi-cuda"). Then, the
requested device is checked to see if it is available on the system.
Finally, the requested device is checked to see if it is supported
by the architecture. If the requested device is not supported by the
architecture, a ValueError is raised. If the requested device is
supported by the architecture, but a different device is preferred
by the architecture and present on the system, a warning is issued.
:param requested_device: The requested device.
:param available_devices: The available devices.
:param architecture_devices: The devices supported by the architecture.
"""Pick (best) devices for training.
The choice is made on the intersection of the ``architecture_devices`` and the
available devices on the current system. If no ``desired_device`` is provided the
first device of this intersection will be returned.
:param architecture_devices: Devices supported by the architecture. The list should
be sorted by the preference of the architecture while the most prefferred device
should be first and the least one last.
:param desired_device: desired device by the user
"""

requested_device = requested_device.lower()

# first, we check that the requested device is supported
if requested_device not in ["cpu", "cuda", "multi-cuda", "mps", "gpu", "multi-gpu"]:
raise ValueError(
f"Unsupported device: `{requested_device}`. Please choose from "
"cpu, cuda, mps, gpu, multi-gpu, multi-cuda"
)

# we convert "gpu" and "multi-gpu" to "cuda" or "mps" if available
if requested_device == "gpu":
if torch.cuda.is_available():
requested_device = "cuda"
elif torch.backends.mps.is_built() and torch.backends.mps.is_available():
requested_device = "mps"
else:
raise ValueError(
"Requested `gpu` device, but found no GPU (CUDA or MPS) devices"
)
available_devices = _get_available_devices()

# we convert "multi-gpu" to "multi-cuda"
if requested_device == "multi-gpu":
requested_device = "multi-cuda"

# check that the requested device is available
available_device_types = [device.type for device in available_devices]
available_device_strings = ["cpu"] # always available
if "cuda" in available_device_types:
available_device_strings.append("cuda")
if "mps" in available_device_types:
available_device_strings.append("mps")
if available_device_strings.count("cuda") > 1:
available_device_strings.append("multi-cuda")

if requested_device not in available_device_strings:
if requested_device == "multi-cuda":
if available_device_strings.count("cuda") == 0:
raise ValueError(
"Requested device `multi-gpu` or `multi-cuda`, "
"but found no cuda devices"
)
# intersect between available and architecture's devices. keep order of architecture
possible_devices = [d for d in architecture_devices if d in available_devices]

# cpu device should always be available
assert "cpu" in possible_devices

# If desired device given compare the possible devices and try to find a match
if desired_device is None:
desired_device = possible_devices[0]
else:
desired_device = desired_device.lower()

# convert "gpu" and "multi-gpu" to "cuda" or "mps" if available
if desired_device == "gpu":
if torch.cuda.is_available():
desired_device = "cuda"
elif torch.backends.mps.is_built() and torch.backends.mps.is_available():
desired_device = "mps"
else:
raise ValueError(
"Requested device `multi-gpu` or `multi-cuda`, "
"but found only one cuda device. If you want to run on a "
"single GPU, please use `gpu` or `cuda` instead."
"Requested 'gpu' device, but found no GPU (CUDA or MPS) devices."
)
else:
if desired_device == "multi-gpu":
desired_device = "multi-cuda"

if desired_device not in possible_devices:
raise ValueError(
f"Requested device `{requested_device}` is not available on this system"
f"Unsupported desired device {desired_device!r}. "
f"Please choose from {', '.join(possible_devices)}."
)
if desired_device == "multi-cuda" and torch.cuda.device_count() < 2:
raise ValueError(
"Requested device 'multi-gpu' or 'multi-cuda', but found only one CUDA "
"device. If you want to run on a single GPU, please use 'gpu' or "
"'cuda' instead."
)

# if the requested device is available, check it against the architecture's devices
if requested_device not in architecture_devices:
raise ValueError(
f"The requested device `{requested_device}` is not supported by the chosen "
f"architecture. Supported devices are {architecture_devices}."
)

# we check all the devices that come before the requested one in the
# list of architecture devices. If any of them are available, we warn

requested_device_index = architecture_devices.index(requested_device)
for device in architecture_devices[:requested_device_index]:
if device in available_device_strings:
if possible_devices.index(desired_device) > 0:
warnings.warn(
f"Device `{requested_device}` was requested, but the chosen "
f"architecture prefers `{device}`, which was also found on your "
f"system. Consider using the `{device}` device.",
f"Device {desired_device!r} requested, but {possible_devices[0]!r} is "
"prefferred by the architecture and available on current system.",
stacklevel=2,
)

# finally, we convert the requested device to a list of devices
if requested_device == "multi-cuda":
return [device for device in available_devices if device.type == "cuda"]
# convert the requested device to a list of torch devices
if desired_device == "multi-cuda":
return [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())]
else:
return [torch.device(requested_device)]
return [torch.device(desired_device)]
30 changes: 26 additions & 4 deletions src/metatensor/models/utils/omegaconf.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import importlib
from pathlib import Path
from typing import Union
from typing import Dict, List, Union

import torch
from omegaconf import Container, DictConfig, ListConfig, OmegaConf
from omegaconf.basecontainer import BaseContainer

from .devices import pick_devices


def file_format(_parent_: Container) -> str:
"""Custom OmegaConf resolver to find the file format.
Expand All @@ -15,15 +17,34 @@ def file_format(_parent_: Container) -> str:
return Path(_parent_["read_from"]).suffix


def _get_architecture_capabilities(conf: BaseContainer) -> Dict[str, List[str]]:
architecture_name = conf["architecture"]["name"]
architecture = importlib.import_module(f"metatensor.models.{architecture_name}")
return architecture.__ARCHITECTURE_CAPABILITIES__


def default_device(_root_: BaseContainer) -> str:
"""Custom OmegaConf resolver to find the default device of an architecture.
Device is found using the :py:func:metatensor.models.utils.devices.pick_devices`
function."""

architecture_capabilities = _get_architecture_capabilities(_root_)
desired_device = pick_devices(architecture_capabilities["supported_devices"])

if len(desired_device) > 1:
return "multi-cuda"
else:
return desired_device[0].type


def default_precision(_root_: BaseContainer) -> int:
"""Custom OmegaConf resolver to find the default precision of an architecture.
File format is obtained based on the architecture name and its first entry in the
``supported_dtypes`` list."""

architecture_name = _root_["architecture"]["name"]
architecture = importlib.import_module(f"metatensor.models.{architecture_name}")
architecture_capabilities = architecture.__ARCHITECTURE_CAPABILITIES__
architecture_capabilities = _get_architecture_capabilities(_root_)

# desired `dtype` is the first entry
default_dtype = architecture_capabilities["supported_dtypes"][0]
Expand All @@ -44,6 +65,7 @@ def default_precision(_root_: BaseContainer) -> int:

# Register custom resolvers
OmegaConf.register_new_resolver("file_format", file_format)
OmegaConf.register_new_resolver("default_device", default_device)
OmegaConf.register_new_resolver("default_precision", default_precision)


Expand Down
Loading

0 comments on commit 5548548

Please sign in to comment.