Skip to content

Commit

Permalink
Use a default dtype obtained form the architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Mar 25, 2024
1 parent 1aeb743 commit 0e18391
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 8 deletions.
7 changes: 4 additions & 3 deletions docs/src/getting-started/advanced_base_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ be adjusted. They should be written without indentation in the ``options.yaml``

:param device: The device in which the training should be run. Takes two possible
values: ``cpu`` and ``gpu``. Default: ``cpu``
:param base_precision: Override the base precision of all floats during training. By
default an optimal precision is obtained from the architecture. Changing this will
have an effect on the memory consumption during training and maybe also on the
accuracy of the model. Possible values: ``64``, ``32`` or ``16``.
:param seed: Seed used to start the training. Set all the seeds
of ``numpy.random``, ``random``, ``torch`` and ``torch.cuda`` (if available)
to the same value ``seed``.
If ``seed=None`` all the seeds are set to a random number. Default: ``None``
Note: in a ``.yaml`` file ``None`` is ``null``.
:param base_precision: This may increase the accuracy but will increase the
memory consumption during training. Possible values:
``64``, ``32`` or ``16``. Default: ``64``

In the next tutorials we show how to override the default parameters of an architecture.
2 changes: 1 addition & 1 deletion src/metatensor/models/cli/conf/base.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
device: "cpu"
base_precision: 32
base_precision: ${default_precision:}
seed: null
35 changes: 33 additions & 2 deletions src/metatensor/models/utils/omegaconf.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,50 @@
import importlib
from pathlib import Path
from typing import Union

from omegaconf import DictConfig, ListConfig, OmegaConf
import torch
from omegaconf import Container, DictConfig, ListConfig, OmegaConf
from omegaconf.basecontainer import BaseContainer


def file_format(_parent_: DictConfig) -> str:
def file_format(_parent_: Container) -> str:
"""Custom OmegaConf resolver to find the file format.
File format is obtained based on the suffix of the ``read_from`` field in the same
section."""
return Path(_parent_["read_from"]).suffix


def default_precision(_root_: BaseContainer) -> int:
"""Custom OmegaConf resolver to find the default precision of an architecture.
File format is obtained based on the architecture name and its first entry in the
``supported_dtypes`` list."""

architecture_name = _root_["architecture"]["name"]
architecture = importlib.import_module(f"metatensor.models.{architecture_name}")
architecture_capabilities = architecture.__ARCHITECTURE_CAPABILITIES__

# desired `dtype` is the first entry
default_dtype = architecture_capabilities["supported_dtypes"][0]

# base_precision has to be a integere and not a torch dtype
if default_dtype in [torch.float64, torch.double]:
return 64
elif default_dtype == torch.float32:
return 32
elif default_dtype == torch.float16:
return 16
else:
raise ValueError(
f"architectures `default_dtype` ({default_dtype}) refers to an unknown "
"torch dtype. This should not happen."
)


# Register custom resolvers
OmegaConf.register_new_resolver("file_format", file_format)
OmegaConf.register_new_resolver("default_precision", default_precision)


def _resolve_single_str(config: str) -> DictConfig:
Expand Down
6 changes: 4 additions & 2 deletions tests/resources/generate-outputs.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#!/bin/bash
set -e

echo "Generate data for testing..."

ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)

cd $ROOT_DIR

metatensor-models train options.yaml -o model-32-bit.pt -y base_precision=32
metatensor-models train options.yaml -o model-64-bit.pt -y base_precision=64
metatensor-models train options.yaml -o model-32-bit.pt -y base_precision=32 > /dev/null
metatensor-models train options.yaml -o model-64-bit.pt -y base_precision=64 > /dev/null
43 changes: 43 additions & 0 deletions tests/utils/test_omegaconf.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import re

import pytest
import torch
from omegaconf import ListConfig, OmegaConf

from metatensor.models.experimental import soap_bpnn
from metatensor.models.utils.omegaconf import (
check_options_list,
check_units,
Expand All @@ -16,6 +18,47 @@ def test_file_format_resolver():
assert (conf["file_format"]) == ".xyz"


@pytest.mark.parametrize(
"dtype, precision",
[(torch.float64, 64), (torch.double, 64), (torch.float32, 32), (torch.float16, 16)],
)
def test_default_precision_resolver(dtype, precision, monkeypatch):
patched_capabilities = {"supported_dtypes": [dtype]}
monkeypatch.setattr(
soap_bpnn, "__ARCHITECTURE_CAPABILITIES__", patched_capabilities, raising=True
)

conf = OmegaConf.create(
{
"base_precision": "${default_precision:}",
"architecture": {"name": "experimental.soap_bpnn"},
}
)

assert conf["base_precision"] == precision


def test_default_precision_resolver_unknown_dtype(monkeypatch):
patched_capabilities = {"supported_dtypes": [torch.int64]}
monkeypatch.setattr(
soap_bpnn, "__ARCHITECTURE_CAPABILITIES__", patched_capabilities, raising=True
)

conf = OmegaConf.create(
{
"base_precision": "${default_precision:}",
"architecture": {"name": "experimental.soap_bpnn"},
}
)

match = (
r"architectures `default_dtype` \(torch.int64\) refers to an unknown torch "
"dtype. This should not happen."
)
with pytest.raises(ValueError, match=match):
conf["base_precision"]


@pytest.mark.parametrize("n_datasets", [1, 2])
def test_expand_dataset_config(n_datasets):
"""Test dataset expansion for a list of n_datasets times the same config"""
Expand Down

0 comments on commit 0e18391

Please sign in to comment.