Skip to content

Commit

Permalink
Implement get_dataset (#297)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Jul 15, 2024
1 parent 26d6a0e commit 8fccf03
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 36 deletions.
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/data/get_dataset.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Reading a dataset
#################

.. automodule:: metatrain.utils.data.get_dataset
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/src/dev-docs/utils/data/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ API for handling data in ``metatrain``.

combine_dataloaders
dataset
get_dataset
readers
writers
systems_to_ase
45 changes: 9 additions & 36 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,7 @@

from .. import PACKAGE_ROOT
from ..utils.architectures import check_architecture_options, get_default_hypers
from ..utils.data import (
Dataset,
DatasetInfo,
TargetInfoDict,
get_atomic_types,
read_systems,
read_targets,
)
from ..utils.data import DatasetInfo, TargetInfoDict, get_atomic_types, get_dataset
from ..utils.data.dataset import _train_test_random_split
from ..utils.devices import pick_devices
from ..utils.distributed.logging import is_main_process
Expand Down Expand Up @@ -193,19 +186,9 @@ def train_model(
train_datasets = []
target_infos = TargetInfoDict()
for train_options in options["training_set"]:
train_systems = read_systems(
filename=train_options["systems"]["read_from"],
reader=train_options["systems"]["reader"],
)
train_targets, target_info_dictionary = read_targets(
conf=train_options["targets"]
)
train_targets, target_info_dictionary = read_targets(
conf=train_options["targets"]
)

target_infos.update(target_info_dictionary)
train_datasets.append(Dataset({"system": train_systems, **train_targets}))
dataset, target_info_dict = get_dataset(train_options)
train_datasets.append(dataset)
target_infos.update(target_info_dict)

train_size = 1.0

Expand Down Expand Up @@ -249,13 +232,8 @@ def train_model(
)

for test_options in options["test_set"]:
test_systems = read_systems(
filename=test_options["systems"]["read_from"],
reader=test_options["systems"]["reader"],
)
test_targets, _ = read_targets(conf=test_options["targets"])
test_dataset = Dataset({"system": test_systems, **test_targets})
test_datasets.append(test_dataset)
dataset, _ = get_dataset(test_options)
test_datasets.append(dataset)

###########################
# SETUP VALIDATION SET ####
Expand Down Expand Up @@ -296,14 +274,9 @@ def train_model(
desired_options=options["training_set"],
)

for val_options in options["validation_set"]:
val_systems = read_systems(
filename=val_options["systems"]["read_from"],
reader=val_options["systems"]["reader"],
)
val_targets, _ = read_targets(conf=val_options["targets"])
val_dataset = Dataset({"system": val_systems, **val_targets})
val_datasets.append(val_dataset)
for valid_options in options["validation_set"]:
dataset, _ = get_dataset(valid_options)
val_datasets.append(dataset)

###########################
# CREATE DATASET_INFO #####
Expand Down
1 change: 1 addition & 0 deletions src/metatrain/utils/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .combine_dataloaders import CombinedDataLoader # noqa: F401
from .system_to_ase import system_to_ase # noqa: F401
from .extract_targets import get_targets_dict # noqa: F401
from .get_dataset import get_dataset # noqa: F401
32 changes: 32 additions & 0 deletions src/metatrain/utils/data/get_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Tuple

from omegaconf import DictConfig

from .dataset import Dataset, TargetInfoDict
from .readers import read_systems, read_targets


def get_dataset(options: DictConfig) -> Tuple[Dataset, TargetInfoDict]:
"""
Gets a dataset given a configuration dictionary.
The system and targets in the dataset are read from one or more
files, as specified in ``options``.
:param options: the configuration options for the dataset.
This configuration dictionary must contain keys for both the
systems and targets in the dataset.
:returns: A tuple containing a ``Dataset`` object and a
``TargetInfoDict`` containing additional information (units,
physical quantities, ...) on the targets in the dataset
"""

systems = read_systems(
filename=options["systems"]["read_from"],
reader=options["systems"]["reader"],
)
targets, target_info_dictionary = read_targets(conf=options["targets"])
dataset = Dataset({"system": systems, **targets})

return dataset, target_info_dictionary
38 changes: 38 additions & 0 deletions tests/utils/data/test_get_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pathlib import Path

from omegaconf import OmegaConf

from metatrain.utils.data import get_dataset


RESOURCES_PATH = Path(__file__).parents[2] / "resources"


def test_get_dataset():

options = {
"systems": {
"read_from": RESOURCES_PATH / "qm9_reduced_100.xyz",
"reader": "ase",
},
"targets": {
"energy": {
"quantity": "energy",
"read_from": RESOURCES_PATH / "qm9_reduced_100.xyz",
"reader": "ase",
"key": "U0",
"unit": "eV",
"forces": False,
"stress": False,
"virial": False,
}
},
}

dataset, target_info = get_dataset(OmegaConf.create(options))

assert "system" in dataset[0]
assert "energy" in dataset[0]
assert "energy" in target_info
assert target_info["energy"].quantity == "energy"
assert target_info["energy"].unit == "eV"

0 comments on commit 8fccf03

Please sign in to comment.