Skip to content

Commit

Permalink
MEGNet utils testing (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-ws-m committed Oct 25, 2021
1 parent e1cea74 commit 99579d4
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 59 deletions.
51 changes: 51 additions & 0 deletions tests/test_meg_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Test for the MEGNet utilities module"""
import os


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ""

from pathlib import Path

import numpy as np
import pytest
from megnet.models.megnet import MEGNetModel
from megnet.utils.preprocessing import Scaler
from unlocknn.megnet_utils import create_megnet_input

from .utils import datadir, load_df_head


class ExampleScaler(Scaler):
"""An example scaler for testing.
Performs transforming by dividing by number of atoms,
and inverse transforming by multiplying.
"""
def transform(self, target: np.ndarray, n: int = 1) -> np.ndarray:
return target / n

def inverse_transform(self, transformed_target: np.ndarray, n: int = 1) -> np.ndarray:
return transformed_target * n

def test_input_with_scaler(datadir: Path):
"""Test input generation."""
binary_dir = datadir / "mp_binary_on_hull.pkl"
binary_df = load_df_head(binary_dir)

meg_model = MEGNetModel.from_file(str(datadir / "formation_energy.hdf5"))
meg_model.target_scaler = ExampleScaler()

input_gen, _ = create_megnet_input(
meg_model,
binary_df["structure"],
binary_df["formation_energy_per_atom"],
batch_size=100, # We have just one batch
shuffle=False
)

# Get first batch (the whole input)
_, scaled_targets = input_gen.__getitem__(0)
# Check targets are scaled
scaled_targets == pytest.approx(binary_df["formation_energy_per_atom"] / binary_df["num_atoms"], rel=1e-6)
Binary file added tests/test_meg_utils/formation_energy.hdf5
Binary file not shown.
8 changes: 8 additions & 0 deletions tests/test_meg_utils/formation_energy.hdf5.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{"graph_converter": {"@module": "megnet.data.crystal", "@class": "CrystalGraph", "atom_converter": {"@module": "megnet.data.graph", "@class": "DummyConverter"}, "bond_converter": {"@module": "megnet.data.graph", "@class": "GaussianDistance", "centers": {"@module": "numpy", "@class": "array", "dtype": "float64", "data": [0.0, 0.06060606060606061, 0.12121212121212122, 0.18181818181818182, 0.24242424242424243, 0.30303030303030304, 0.36363636363636365, 0.42424242424242425, 0.48484848484848486, 0.5454545454545454, 0.6060606060606061, 0.6666666666666667, 0.7272727272727273, 0.7878787878787878, 0.8484848484848485, 0.9090909090909092, 0.9696969696969697, 1.0303030303030303, 1.0909090909090908, 1.1515151515151516, 1.2121212121212122, 1.2727272727272727, 1.3333333333333335, 1.393939393939394, 1.4545454545454546, 1.5151515151515151, 1.5757575757575757, 1.6363636363636365, 1.696969696969697, 1.7575757575757576, 1.8181818181818183, 1.878787878787879, 1.9393939393939394, 2.0, 2.0606060606060606, 2.121212121212121, 2.1818181818181817, 2.2424242424242427, 2.303030303030303, 2.3636363636363638, 2.4242424242424243, 2.484848484848485, 2.5454545454545454, 2.606060606060606, 2.666666666666667, 2.7272727272727275, 2.787878787878788, 2.8484848484848486, 2.909090909090909, 2.9696969696969697, 3.0303030303030303, 3.090909090909091, 3.1515151515151514, 3.2121212121212124, 3.272727272727273, 3.3333333333333335, 3.393939393939394, 3.4545454545454546, 3.515151515151515, 3.5757575757575757, 3.6363636363636367, 3.6969696969696972, 3.757575757575758, 3.8181818181818183, 3.878787878787879, 3.9393939393939394, 4.0, 4.0606060606060606, 4.121212121212121, 4.181818181818182, 4.242424242424242, 4.303030303030303, 4.363636363636363, 4.424242424242425, 4.484848484848485, 4.545454545454546, 4.606060606060606, 4.666666666666667, 4.7272727272727275, 4.787878787878788, 4.848484848484849, 4.909090909090909, 4.96969696969697, 5.03030303030303, 5.090909090909091, 5.151515151515151, 5.212121212121212, 5.2727272727272725, 5.333333333333334, 5.3939393939393945, 5.454545454545455, 5.515151515151516, 5.575757575757576, 5.636363636363637, 5.696969696969697, 5.757575757575758, 5.818181818181818, 5.878787878787879, 5.9393939393939394, 6.0]}, "width": 0.5}, "cutoff": 5.0},
"metadata": {"name": "Formation energy",
"unit": "eV/atom",
"metric": "mae",
"metric value": 0.026,
"training dataset": "mp-2019.4.1",
"description": "This model was trained using structures and formation energy of Materials Project data base downloaded on April 1, 2019. The total data size is 133420. The model was trained by a 0.8-0.1-0.1, train, validation and test data ratio"}
}
Binary file added tests/test_meg_utils/mp_binary_on_hull.pkl
Binary file not shown.
62 changes: 3 additions & 59 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,79 +5,23 @@
os.environ["CUDA_VISIBLE_DEVICES"] = ""

import random as python_random
from distutils import dir_util
from math import floor
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import pytest
import tensorflow as tf
from megnet.models import MEGNetModel
from unlocknn import MEGNetProbModel
from unlocknn.initializers import SampleInitializer

from .utils import (SplitData, datadir, load_df_head, train_test_split,
weights_equal)

np.random.seed(123)
python_random.seed(123)
tf.random.set_seed(123)

SplitData = Tuple[Tuple[list, list], Tuple[list, list]]

@pytest.fixture
def datadir(tmpdir, request) -> Path:
"""Access data directory.
Fixture responsible for searching a folder with the same name of test
module and, if available, moving all contents to a temporary directory so
tests can use them freely.
Source: https://stackoverflow.com/a/29631801/
"""
filename = request.module.__file__
test_dir, _ = os.path.splitext(filename)

if os.path.isdir(test_dir):
dir_util.copy_tree(test_dir, str(tmpdir))

return tmpdir


def weights_equal(weights_a: List[np.ndarray], weights_b: List[np.ndarray]) -> bool:
"""Check equality between weights."""
return all(
weight1 == pytest.approx(weight2, rel=1e-6) for weight1, weight2 in zip(weights_a, weights_b)
)

def load_df_head(fname: Path, num_entries: int=100) -> pd.DataFrame:
"""Load first entries of a pandas DataFrame in a backwards-compatible way.
Args:
fname: The pickle file to open.
num_entries: How many values to read.
"""
try:
return pd.read_pickle(fname)[:num_entries]
except ValueError:
# Older python version
import pickle5 as pkl

with fname.open("rb") as f:
return pkl.load(f)[:num_entries]



def train_test_split(
structures: list, targets: list, train_frac: float = 0.8
) -> SplitData:
"""Split structures and targets into training and testing subsets."""
num_train = floor(len(structures) * train_frac)
return (
(structures[:num_train], targets[:num_train]),
(structures[num_train:], targets[num_train:]),
)

@pytest.fixture
def split_data(datadir: Path) -> SplitData:
Expand Down
69 changes: 69 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Testing suite shared functionality."""
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ""

from distutils import dir_util
from math import floor
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import pytest

SplitData = Tuple[Tuple[list, list], Tuple[list, list]]

@pytest.fixture
def datadir(tmpdir, request) -> Path:
"""Access data directory.
Fixture responsible for searching a folder with the same name of test
module and, if available, moving all contents to a temporary directory so
tests can use them freely.
Source: https://stackoverflow.com/a/29631801/
"""
filename = request.module.__file__
test_dir, _ = os.path.splitext(filename)

if os.path.isdir(test_dir):
dir_util.copy_tree(test_dir, str(tmpdir))

return tmpdir


def weights_equal(weights_a: List[np.ndarray], weights_b: List[np.ndarray]) -> bool:
"""Check equality between weights."""
return all(
weight1 == pytest.approx(weight2, rel=1e-6) for weight1, weight2 in zip(weights_a, weights_b)
)

def load_df_head(fname: Path, num_entries: int=100) -> pd.DataFrame:
"""Load first entries of a pandas DataFrame in a backwards-compatible way.
Args:
fname: The pickle file to open.
num_entries: How many values to read.
"""
try:
return pd.read_pickle(fname)[:num_entries]
except ValueError:
# Older python version
import pickle5 as pkl

with fname.open("rb") as f:
return pkl.load(f)[:num_entries]

def train_test_split(
structures: list, targets: list, train_frac: float = 0.8
) -> SplitData:
"""Split structures and targets into training and testing subsets."""
num_train = floor(len(structures) * train_frac)
return (
(structures[:num_train], targets[:num_train]),
(structures[num_train:], targets[num_train:]),
)

0 comments on commit 99579d4

Please sign in to comment.