-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
131 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""Test for the MEGNet utilities module""" | ||
import os | ||
|
||
|
||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "" | ||
|
||
from pathlib import Path | ||
|
||
import numpy as np | ||
import pytest | ||
from megnet.models.megnet import MEGNetModel | ||
from megnet.utils.preprocessing import Scaler | ||
from unlocknn.megnet_utils import create_megnet_input | ||
|
||
from .utils import datadir, load_df_head | ||
|
||
|
||
class ExampleScaler(Scaler): | ||
"""An example scaler for testing. | ||
Performs transforming by dividing by number of atoms, | ||
and inverse transforming by multiplying. | ||
""" | ||
def transform(self, target: np.ndarray, n: int = 1) -> np.ndarray: | ||
return target / n | ||
|
||
def inverse_transform(self, transformed_target: np.ndarray, n: int = 1) -> np.ndarray: | ||
return transformed_target * n | ||
|
||
def test_input_with_scaler(datadir: Path): | ||
"""Test input generation.""" | ||
binary_dir = datadir / "mp_binary_on_hull.pkl" | ||
binary_df = load_df_head(binary_dir) | ||
|
||
meg_model = MEGNetModel.from_file(str(datadir / "formation_energy.hdf5")) | ||
meg_model.target_scaler = ExampleScaler() | ||
|
||
input_gen, _ = create_megnet_input( | ||
meg_model, | ||
binary_df["structure"], | ||
binary_df["formation_energy_per_atom"], | ||
batch_size=100, # We have just one batch | ||
shuffle=False | ||
) | ||
|
||
# Get first batch (the whole input) | ||
_, scaled_targets = input_gen.__getitem__(0) | ||
# Check targets are scaled | ||
scaled_targets == pytest.approx(binary_df["formation_energy_per_atom"] / binary_df["num_atoms"], rel=1e-6) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
{"graph_converter": {"@module": "megnet.data.crystal", "@class": "CrystalGraph", "atom_converter": {"@module": "megnet.data.graph", "@class": "DummyConverter"}, "bond_converter": {"@module": "megnet.data.graph", "@class": "GaussianDistance", "centers": {"@module": "numpy", "@class": "array", "dtype": "float64", "data": [0.0, 0.06060606060606061, 0.12121212121212122, 0.18181818181818182, 0.24242424242424243, 0.30303030303030304, 0.36363636363636365, 0.42424242424242425, 0.48484848484848486, 0.5454545454545454, 0.6060606060606061, 0.6666666666666667, 0.7272727272727273, 0.7878787878787878, 0.8484848484848485, 0.9090909090909092, 0.9696969696969697, 1.0303030303030303, 1.0909090909090908, 1.1515151515151516, 1.2121212121212122, 1.2727272727272727, 1.3333333333333335, 1.393939393939394, 1.4545454545454546, 1.5151515151515151, 1.5757575757575757, 1.6363636363636365, 1.696969696969697, 1.7575757575757576, 1.8181818181818183, 1.878787878787879, 1.9393939393939394, 2.0, 2.0606060606060606, 2.121212121212121, 2.1818181818181817, 2.2424242424242427, 2.303030303030303, 2.3636363636363638, 2.4242424242424243, 2.484848484848485, 2.5454545454545454, 2.606060606060606, 2.666666666666667, 2.7272727272727275, 2.787878787878788, 2.8484848484848486, 2.909090909090909, 2.9696969696969697, 3.0303030303030303, 3.090909090909091, 3.1515151515151514, 3.2121212121212124, 3.272727272727273, 3.3333333333333335, 3.393939393939394, 3.4545454545454546, 3.515151515151515, 3.5757575757575757, 3.6363636363636367, 3.6969696969696972, 3.757575757575758, 3.8181818181818183, 3.878787878787879, 3.9393939393939394, 4.0, 4.0606060606060606, 4.121212121212121, 4.181818181818182, 4.242424242424242, 4.303030303030303, 4.363636363636363, 4.424242424242425, 4.484848484848485, 4.545454545454546, 4.606060606060606, 4.666666666666667, 4.7272727272727275, 4.787878787878788, 4.848484848484849, 4.909090909090909, 4.96969696969697, 5.03030303030303, 5.090909090909091, 5.151515151515151, 5.212121212121212, 5.2727272727272725, 5.333333333333334, 5.3939393939393945, 5.454545454545455, 5.515151515151516, 5.575757575757576, 5.636363636363637, 5.696969696969697, 5.757575757575758, 5.818181818181818, 5.878787878787879, 5.9393939393939394, 6.0]}, "width": 0.5}, "cutoff": 5.0}, | ||
"metadata": {"name": "Formation energy", | ||
"unit": "eV/atom", | ||
"metric": "mae", | ||
"metric value": 0.026, | ||
"training dataset": "mp-2019.4.1", | ||
"description": "This model was trained using structures and formation energy of Materials Project data base downloaded on April 1, 2019. The total data size is 133420. The model was trained by a 0.8-0.1-0.1, train, validation and test data ratio"} | ||
} |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
"""Testing suite shared functionality.""" | ||
import os | ||
|
||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "" | ||
|
||
from distutils import dir_util | ||
from math import floor | ||
from pathlib import Path | ||
from typing import List, Tuple | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
SplitData = Tuple[Tuple[list, list], Tuple[list, list]] | ||
|
||
@pytest.fixture | ||
def datadir(tmpdir, request) -> Path: | ||
"""Access data directory. | ||
Fixture responsible for searching a folder with the same name of test | ||
module and, if available, moving all contents to a temporary directory so | ||
tests can use them freely. | ||
Source: https://stackoverflow.com/a/29631801/ | ||
""" | ||
filename = request.module.__file__ | ||
test_dir, _ = os.path.splitext(filename) | ||
|
||
if os.path.isdir(test_dir): | ||
dir_util.copy_tree(test_dir, str(tmpdir)) | ||
|
||
return tmpdir | ||
|
||
|
||
def weights_equal(weights_a: List[np.ndarray], weights_b: List[np.ndarray]) -> bool: | ||
"""Check equality between weights.""" | ||
return all( | ||
weight1 == pytest.approx(weight2, rel=1e-6) for weight1, weight2 in zip(weights_a, weights_b) | ||
) | ||
|
||
def load_df_head(fname: Path, num_entries: int=100) -> pd.DataFrame: | ||
"""Load first entries of a pandas DataFrame in a backwards-compatible way. | ||
Args: | ||
fname: The pickle file to open. | ||
num_entries: How many values to read. | ||
""" | ||
try: | ||
return pd.read_pickle(fname)[:num_entries] | ||
except ValueError: | ||
# Older python version | ||
import pickle5 as pkl | ||
|
||
with fname.open("rb") as f: | ||
return pkl.load(f)[:num_entries] | ||
|
||
def train_test_split( | ||
structures: list, targets: list, train_frac: float = 0.8 | ||
) -> SplitData: | ||
"""Split structures and targets into training and testing subsets.""" | ||
num_train = floor(len(structures) * train_frac) | ||
return ( | ||
(structures[:num_train], targets[:num_train]), | ||
(structures[num_train:], targets[num_train:]), | ||
) |