Skip to content

Commit

Permalink
additional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stenczelt committed Dec 8, 2024
1 parent 51cc5c0 commit 38f232a
Showing 1 changed file with 104 additions and 1 deletion.
105 changes: 104 additions & 1 deletion tests/test_abstract_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import io

import ase
import pytest
from pytest import approx

from io import StringIO
from ase.io import read
from ase.io import read, write
import numpy as np

from abcd.model import AbstractModel
from ase.calculators.lj import LennardJones


@pytest.fixture
Expand Down Expand Up @@ -151,3 +156,101 @@ def test_to_ase_no_results(extxyz_file):
assert new_atoms.info["formula"] == atoms.get_chemical_formula()

assert new_atoms.calc is None


def test_from_atoms_len_atoms_3():
atoms = ase.Atoms(
"H3",
positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]],
pbc=True,
cell=[2, 2, 2],
)
atoms.calc = LennardJones()
atoms.calc.calculate(atoms)

# convert
abcd_data = AbstractModel.from_atoms(atoms, store_calc=True)

assert set(abcd_data.info_keys) == {
"pbc",
"n_atoms",
"cell",
"formula",
"calculator_name",
"calculator_parameters",
}
assert set(abcd_data.arrays_keys) == {"numbers", "positions"}
assert set(abcd_data.results_keys) == {
"stress",
"energy",
"forces",
"energies",
"stresses",
"free_energy",
}

# check a some keys as well
assert abcd_data["energy"] == atoms.get_potential_energy()
assert abcd_data["forces"] == approx(atoms.get_forces())


@pytest.mark.parametrize("store_calc", [True, False])
def test_write_and_read(store_calc):
# create atoms & add a calculator
atoms = ase.Atoms(
"H3",
positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]],
pbc=True,
cell=[2, 2, 2],
)
atoms.calc = LennardJones()
atoms.calc.calculate(atoms)

# dump to XYZ
buffer = io.StringIO()
write(buffer, atoms, format="extxyz")

# read back
buffer.seek(0)
atoms_read = read(buffer, format="extxyz")

# read in both of them
abcd_data = AbstractModel.from_atoms(atoms, store_calc=store_calc)
abcd_data_after_read = AbstractModel.from_atoms(atoms_read, store_calc=store_calc)

# check that all results are the same
for key in ["info_keys", "arrays_keys", "derived_keys", "results_keys"]:
assert set(getattr(abcd_data, key)) == set(
getattr(abcd_data_after_read, key)
), f"{key} mismatched"

# info & arrays same, except calc recognised as LJ when not from XYZ
for key in set(abcd_data.info_keys + abcd_data.arrays_keys) - {
"calculator_name",
"calculator_parameters",
}:
assert (
abcd_data[key] == abcd_data_after_read[key]
), f"{key}'s value does not match"

# date & hashed will differ
for key in set(abcd_data.derived_keys) - {
"hash",
"modified",
"uploaded",
"hash_structure", # see issue #118
}:
assert (
abcd_data[key] == abcd_data_after_read[key]
), f"{key}'s value does not match"

# expected differences - n.b. order of calls above
assert abcd_data_after_read["modified"] > abcd_data["modified"]
assert abcd_data_after_read["uploaded"] > abcd_data["uploaded"]
assert abcd_data_after_read["hash"] != abcd_data["hash"]

# expect results to match within fp precision
for key in set(abcd_data.results_keys):
assert abcd_data[key] == approx(
np.array(abcd_data_after_read[key])
), f"{key}'s value does not match"

0 comments on commit 38f232a

Please sign in to comment.