Skip to content

Commit

Permalink
Use default device obtained from architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Apr 3, 2024
1 parent e250a5a commit bd30dd9
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/metatensor/models/cli/conf/base.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
device: "cpu"
device: ${default_device:}
base_precision: ${default_precision:}
seed: null
30 changes: 26 additions & 4 deletions src/metatensor/models/utils/omegaconf.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import importlib
from pathlib import Path
from typing import Union
from typing import Dict, List, Union

import torch
from omegaconf import Container, DictConfig, ListConfig, OmegaConf
from omegaconf.basecontainer import BaseContainer

from .devices import pick_devices


def file_format(_parent_: Container) -> str:
"""Custom OmegaConf resolver to find the file format.
Expand All @@ -15,15 +17,34 @@ def file_format(_parent_: Container) -> str:
return Path(_parent_["read_from"]).suffix


def _get_architecture_capabilities(conf: BaseContainer) -> Dict[str, List[str]]:
architecture_name = conf["architecture"]["name"]
architecture = importlib.import_module(f"metatensor.models.{architecture_name}")
return architecture.__ARCHITECTURE_CAPABILITIES__


def default_device(_root_: BaseContainer) -> str:
"""Custom OmegaConf resolver to find the default precision of an architecture.
File format is obtained based on the architecture name and its first entry in the
``supported_dtypes`` list."""

architecture_capabilities = _get_architecture_capabilities(_root_)
desired_device = pick_devices(architecture_capabilities["supported_devices"])

if len(desired_device) > 1:
return "multi-cuda"
else:
return desired_device[0].type


def default_precision(_root_: BaseContainer) -> int:
"""Custom OmegaConf resolver to find the default precision of an architecture.
File format is obtained based on the architecture name and its first entry in the
``supported_dtypes`` list."""

architecture_name = _root_["architecture"]["name"]
architecture = importlib.import_module(f"metatensor.models.{architecture_name}")
architecture_capabilities = architecture.__ARCHITECTURE_CAPABILITIES__
architecture_capabilities = _get_architecture_capabilities(_root_)

# desired `dtype` is the first entry
default_dtype = architecture_capabilities["supported_dtypes"][0]
Expand All @@ -44,6 +65,7 @@ def default_precision(_root_: BaseContainer) -> int:

# Register custom resolvers
OmegaConf.register_new_resolver("file_format", file_format)
OmegaConf.register_new_resolver("default_device", default_device)
OmegaConf.register_new_resolver("default_precision", default_precision)


Expand Down
32 changes: 30 additions & 2 deletions tests/utils/test_omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from omegaconf import ListConfig, OmegaConf

from metatensor.models.experimental import soap_bpnn
from metatensor.models.utils import omegaconf
from metatensor.models.utils.omegaconf import (
check_options_list,
check_units,
Expand All @@ -18,14 +19,41 @@ def test_file_format_resolver():
assert (conf["file_format"]) == ".xyz"


def test_default_device_resolver():
conf = OmegaConf.create(
{
"device": "${default_device:}",
"architecture": {"name": "experimental.soap_bpnn"},
}
)

assert conf["device"] == "cpu"


def test_default_device_resolver_multi(monkeypatch):
def pick_devices(architecture_devices):
return [torch.device("cuda:0"), torch.device("cuda:1")]

monkeypatch.setattr(omegaconf, "pick_devices", pick_devices)

conf = OmegaConf.create(
{
"device": "${default_device:}",
"architecture": {"name": "experimental.soap_bpnn"},
}
)

assert conf["device"] == "multi-cuda"


@pytest.mark.parametrize(
"dtype, precision",
[(torch.float64, 64), (torch.double, 64), (torch.float32, 32), (torch.float16, 16)],
)
def test_default_precision_resolver(dtype, precision, monkeypatch):
patched_capabilities = {"supported_dtypes": [dtype]}
monkeypatch.setattr(
soap_bpnn, "__ARCHITECTURE_CAPABILITIES__", patched_capabilities, raising=True
soap_bpnn, "__ARCHITECTURE_CAPABILITIES__", patched_capabilities
)

conf = OmegaConf.create(
Expand All @@ -41,7 +69,7 @@ def test_default_precision_resolver(dtype, precision, monkeypatch):
def test_default_precision_resolver_unknown_dtype(monkeypatch):
patched_capabilities = {"supported_dtypes": [torch.int64]}
monkeypatch.setattr(
soap_bpnn, "__ARCHITECTURE_CAPABILITIES__", patched_capabilities, raising=True
soap_bpnn, "__ARCHITECTURE_CAPABILITIES__", patched_capabilities
)

conf = OmegaConf.create(
Expand Down

0 comments on commit bd30dd9

Please sign in to comment.