Skip to content

Commit

Permalink
Improvements to hydra handling (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri authored Feb 6, 2024
1 parent c4720ad commit 8b39820
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 63 deletions.
4 changes: 2 additions & 2 deletions docs/src/getting-started/custom_dataset_conf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ format, which is also valid for initial input:
structures:
read_from: dataset.xyz
file_format: .xyz
unit: null
length_unit: null
targets:
energy:
quantity: energy
Expand Down Expand Up @@ -71,7 +71,7 @@ Describes the structure data like positions and cell information.
:param read_from: The file containing structure data.
:param file_format: The file format, guessed from the suffix if ``null`` or not
provided.
:param unit: The unit of lengths, optional but recommended for simulations.
:param length_unit: The unit of lengths, optional but recommended for simulations.

A single string in this section automatically expands, using the string as the
``read_from`` parameter.
Expand Down
18 changes: 1 addition & 17 deletions docs/src/getting-started/override.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,8 @@ hyperparameters. The adjustments for ``num_epochs`` and ``cutoff`` look like thi

.. code-block:: yaml
defaults:
- architecture: soap_bpnn
- _self_
architecture:
name: "soap_bpnn"
model:
soap:
cutoff: 7.0
Expand Down Expand Up @@ -65,16 +62,3 @@ syntax are available at https://hydra.cc/docs/advanced/override_grammar/basic/.
For your reference and reproducibility purposes `metatensor-models` always writes the
fully expanded options to the ``.hydra`` subdirectory inside the ``output``
directory of your current training run.


Understanding the Defaults Section
----------------------------------

You may have noticed the ``defaults`` section at the beginning of each file. This list
dictates which defaults should be loaded and how to compose the final config object and
is conventionally the first item in the config.

Append ``_self_`` to the end of the list to have your primary config override values
from the Defaults List. If you do not add a ``_self_`` entry still your primary config
Overrides values from the Defaults List, but Hydra will throw a warning. For more
background, visit https://hydra.cc/docs/tutorials/basic/your_first_app/defaults/.
5 changes: 1 addition & 4 deletions docs/static/ethanol/options.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
defaults:
- architecture: soap_bpnn # architecture used to train the model
- _self_

architecture:
name: soap_bpnn
training:
batch_size: 16
num_epochs: 100
Expand Down
13 changes: 5 additions & 8 deletions docs/static/qm9/options.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
defaults:
- architecture: soap_bpnn # architecture used to train the model
- _self_
# architecture used to train the model
architecture:
name: soap_bpnn

# Last position of the _self_ this entry defines that default options will be
# overwritten by this config.

# Mandatory section defining the parameters for structure and target data of the trainin
# set
# Mandatory section defining the parameters for structure and target data of the
# training set
training_set:
structures: "qm9_reduced_100.xyz" # file where the positions are stored
targets:
Expand Down
17 changes: 1 addition & 16 deletions src/metatensor/models/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import argparse
import sys
from pathlib import Path

from . import __version__
from .cli import eval_model, export_model, train_model
Expand Down Expand Up @@ -40,21 +39,7 @@ def main():
elif callable == "export_model":
export_model(**args.__dict__)
elif callable == "train_model":
# HACK: Hydra parses command line arguments directlty from `sys.argv`. We
# override `sys.argv` to be compatible with our CLI architecture.
argv = sys.argv[:1]

options = Path(args.options)
argv.append(f"--config-dir={options.parent}")
argv.append(f"--config-name={options.name}")
argv.append(f"+output_path={args.output}")

if args.hydra_paramters is not None:
argv += args.hydra_paramters

sys.argv = argv

train_model()
train_model(**args.__dict__)
else:
raise ValueError("internal error when selecting a sub-command.")

Expand Down
3 changes: 3 additions & 0 deletions src/metatensor/models/cli/conf/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
device: "cpu"
base_precision: 64
seed: -1
18 changes: 18 additions & 0 deletions src/metatensor/models/cli/conf/hydra/job_logging/custom.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
version: 1
formatters:
simple:
format: "[%(asctime)s][%(levelname)s] - %(message)s"
datefmt: "%Y-%m-%d %H:%M:%S"
handlers:
console:
class: logging.StreamHandler
formatter: simple
stream: ext://sys.stdout
file:
class: logging.FileHandler
formatter: simple
filename: "${hydra:runtime.output_dir}/train.log"
root:
handlers: [console, file]

disable_existing_loggers: false
80 changes: 68 additions & 12 deletions src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import argparse
import importlib
import logging
import sys
import tempfile
import warnings
from pathlib import Path
from typing import List, Optional

import hydra
import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput
from omegaconf import DictConfig, OmegaConf
from omegaconf.errors import ConfigKeyError

from metatensor.models.utils.data import Dataset
from metatensor.models.utils.data.readers import read_structures, read_targets
Expand Down Expand Up @@ -66,16 +70,20 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:
parser.add_argument(
"-y",
"--hydra",
dest="hydra_paramters",
dest="hydra_parameters",
nargs="+",
type=str,
help="Hydra's command line and override flags.",
)


@hydra.main(config_path=str(CONFIG_PATH), version_base=None)
def train_model(options: DictConfig) -> None:
"""Train an atomistic machine learning model using configurations provided by Hydra.
def train_model(
options: str,
output: str = "model.pt",
hydra_parameters: Optional[List[str]] = None,
) -> None:
"""
Train an atomistic machine learning model using configurations provided by Hydra.
This function sets up the dataset and model architecture, then runs the training
process. The dataset is prepared by reading structural data and target values from
Expand All @@ -88,17 +96,63 @@ def train_model(options: DictConfig) -> None:
https://hydra.cc/docs/advanced/hydra-command-line-flags/ and
https://hydra.cc/docs/advanced/override_grammar/basic/ for details.
:param options: Options file path
:param output: Path to save the final model
:param hydra_parameters: Hydra's command line and override flags
"""
conf = OmegaConf.load(options)

try:
architecture_name = conf["architecture"]["name"]
except ConfigKeyError as exc:
raise ConfigKeyError("Architecture name is not defined!") from exc

conf["defaults"] = [
"base",
{"architecture": architecture_name},
{"override hydra/job_logging": "custom"},
"_self_",
]

with tempfile.TemporaryDirectory() as tmpdirname:
options_new = Path(tmpdirname) / "options.yaml"
OmegaConf.save(config=conf, f=options_new)

# HACK: Hydra parses command line arguments directlty from `sys.argv`. We
# override `sys.argv` to be compatible with our CLI architecture.
argv = sys.argv[:1]

argv.append(f"--config-dir={options_new.parent}")
argv.append(f"--config-name={options_new.name}")
argv.append(f"+output_path={output}")

if hydra_parameters is not None:
argv += hydra_parameters

sys.argv = argv

_train_model_hydra()


@hydra.main(config_path=str(CONFIG_PATH), version_base=None)
def _train_model_hydra(options: DictConfig) -> None:
"""Actual fit function called in :func:`train_model`.
:param options: A dictionary-like object obtained from Hydra, containing all the
necessary options for dataset preparation, model hyperparameters, and training.
"""
if options["base_precision"] == 64:
torch.set_default_dtype(torch.float64)
elif options["base_precision"] == 32:
torch.set_default_dtype(torch.float32)
elif options["base_precision"] == 16:
torch.set_default_dtype(torch.float16)
else:
raise ValueError("Only 64, 32 or 16 are possible values for `base_precision`.")

# This gives some accuracy improvements. It is very likely that
# this is just due to the preliminary composition fit in the SOAP-BPNN.
# TODO: investigate
torch.set_default_dtype(torch.float64)

# TODO load seed from config
generator = torch.Generator()
if options["seed"] != -1:
generator.manual_seed(options["seed"])

logger.info("Setting up training set")
train_options = expand_dataset_config(options["training_set"])
Expand Down Expand Up @@ -173,12 +227,14 @@ def train_model(options: DictConfig) -> None:

test_dataset

output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
# Save fully expanded config
OmegaConf.save(config=options, f=Path(output_dir) / "options.yaml")

logger.info("Setting up model")
architetcure_name = options["architecture"]["name"]
architecture = importlib.import_module(f"metatensor.models.{architetcure_name}")

output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

all_species = []
for dataset in [train_dataset]: # HACK: only a single train_dataset for now
all_species += get_all_species(dataset)
Expand Down
23 changes: 23 additions & 0 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import glob
import shutil
import subprocess
from pathlib import Path
Expand Down Expand Up @@ -27,6 +28,12 @@ def test_train(monkeypatch, tmp_path, output):
subprocess.check_call(command)
assert Path(output).is_file()

# Test if fully expanded options.yaml file is written
assert len(glob.glob("outputs/*/*/options.yaml")) == 1

# Test if logfile is written
assert len(glob.glob("outputs/*/*/train.log")) == 1


@pytest.mark.parametrize("test_set_file", (True, False))
@pytest.mark.parametrize("validation_set_file", (True, False))
Expand Down Expand Up @@ -85,3 +92,19 @@ def test_hydra_arguments():
)
# Check that num_epochs is override is succesful
assert "num_epochs: 1" in str(out)


def test_no_architecture_name(monkeypatch, tmp_path):
"""Test error raise if architecture.name is not set."""
monkeypatch.chdir(tmp_path)

options = OmegaConf.load(RESOURCES_PATH / "options.yaml")
options["architecture"].pop("name")
OmegaConf.save(config=options, f="options.yaml")

try:
subprocess.check_output(
["metatensor-models", "train", "options.yaml"], stderr=subprocess.STDOUT
)
except subprocess.CalledProcessError as captured:
assert "Architecture name is not defined!" in str(captured.output)
5 changes: 1 addition & 4 deletions tests/resources/options.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
defaults:
- architecture: soap_bpnn # architecture used to train the model
- _self_

architecture:
name: soap_bpnn
training:
batch_size: 2
num_epochs: 1
Expand Down

0 comments on commit 8b39820

Please sign in to comment.