-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
26d6a0e
commit 8fccf03
Showing
6 changed files
with
88 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Reading a dataset | ||
################# | ||
|
||
.. automodule:: metatrain.utils.data.get_dataset | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ API for handling data in ``metatrain``. | |
|
||
combine_dataloaders | ||
dataset | ||
get_dataset | ||
readers | ||
writers | ||
systems_to_ase |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from typing import Tuple | ||
|
||
from omegaconf import DictConfig | ||
|
||
from .dataset import Dataset, TargetInfoDict | ||
from .readers import read_systems, read_targets | ||
|
||
|
||
def get_dataset(options: DictConfig) -> Tuple[Dataset, TargetInfoDict]: | ||
""" | ||
Gets a dataset given a configuration dictionary. | ||
The system and targets in the dataset are read from one or more | ||
files, as specified in ``options``. | ||
:param options: the configuration options for the dataset. | ||
This configuration dictionary must contain keys for both the | ||
systems and targets in the dataset. | ||
:returns: A tuple containing a ``Dataset`` object and a | ||
``TargetInfoDict`` containing additional information (units, | ||
physical quantities, ...) on the targets in the dataset | ||
""" | ||
|
||
systems = read_systems( | ||
filename=options["systems"]["read_from"], | ||
reader=options["systems"]["reader"], | ||
) | ||
targets, target_info_dictionary = read_targets(conf=options["targets"]) | ||
dataset = Dataset({"system": systems, **targets}) | ||
|
||
return dataset, target_info_dictionary |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from pathlib import Path | ||
|
||
from omegaconf import OmegaConf | ||
|
||
from metatrain.utils.data import get_dataset | ||
|
||
|
||
RESOURCES_PATH = Path(__file__).parents[2] / "resources" | ||
|
||
|
||
def test_get_dataset(): | ||
|
||
options = { | ||
"systems": { | ||
"read_from": RESOURCES_PATH / "qm9_reduced_100.xyz", | ||
"reader": "ase", | ||
}, | ||
"targets": { | ||
"energy": { | ||
"quantity": "energy", | ||
"read_from": RESOURCES_PATH / "qm9_reduced_100.xyz", | ||
"reader": "ase", | ||
"key": "U0", | ||
"unit": "eV", | ||
"forces": False, | ||
"stress": False, | ||
"virial": False, | ||
} | ||
}, | ||
} | ||
|
||
dataset, target_info = get_dataset(OmegaConf.create(options)) | ||
|
||
assert "system" in dataset[0] | ||
assert "energy" in dataset[0] | ||
assert "energy" in target_info | ||
assert target_info["energy"].quantity == "energy" | ||
assert target_info["energy"].unit == "eV" |