diff --git a/tests/test_abstract_model.py b/tests/test_abstract_model.py index 5d4d7c6..3c185ea 100644 --- a/tests/test_abstract_model.py +++ b/tests/test_abstract_model.py @@ -1,10 +1,15 @@ +import io + +import ase import pytest +from pytest import approx from io import StringIO -from ase.io import read +from ase.io import read, write import numpy as np from abcd.model import AbstractModel +from ase.calculators.lj import LennardJones @pytest.fixture @@ -151,3 +156,101 @@ def test_to_ase_no_results(extxyz_file): assert new_atoms.info["formula"] == atoms.get_chemical_formula() assert new_atoms.calc is None + + +def test_from_atoms_len_atoms_3(): + atoms = ase.Atoms( + "H3", + positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], + pbc=True, + cell=[2, 2, 2], + ) + atoms.calc = LennardJones() + atoms.calc.calculate(atoms) + + # convert + abcd_data = AbstractModel.from_atoms(atoms, store_calc=True) + + assert set(abcd_data.info_keys) == { + "pbc", + "n_atoms", + "cell", + "formula", + "calculator_name", + "calculator_parameters", + } + assert set(abcd_data.arrays_keys) == {"numbers", "positions"} + assert set(abcd_data.results_keys) == { + "stress", + "energy", + "forces", + "energies", + "stresses", + "free_energy", + } + + # check a some keys as well + assert abcd_data["energy"] == atoms.get_potential_energy() + assert abcd_data["forces"] == approx(atoms.get_forces()) + + +@pytest.mark.parametrize("store_calc", [True, False]) +def test_write_and_read(store_calc): + # create atoms & add a calculator + atoms = ase.Atoms( + "H3", + positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], + pbc=True, + cell=[2, 2, 2], + ) + atoms.calc = LennardJones() + atoms.calc.calculate(atoms) + + # dump to XYZ + buffer = io.StringIO() + write(buffer, atoms, format="extxyz") + + # read back + buffer.seek(0) + atoms_read = read(buffer, format="extxyz") + + # read in both of them + abcd_data = AbstractModel.from_atoms(atoms, store_calc=store_calc) + abcd_data_after_read = AbstractModel.from_atoms(atoms_read, store_calc=store_calc) + + # check that all results are the same + for key in ["info_keys", "arrays_keys", "derived_keys", "results_keys"]: + assert set(getattr(abcd_data, key)) == set( + getattr(abcd_data_after_read, key) + ), f"{key} mismatched" + + # info & arrays same, except calc recognised as LJ when not from XYZ + for key in set(abcd_data.info_keys + abcd_data.arrays_keys) - { + "calculator_name", + "calculator_parameters", + }: + assert ( + abcd_data[key] == abcd_data_after_read[key] + ), f"{key}'s value does not match" + + # date & hashed will differ + for key in set(abcd_data.derived_keys) - { + "hash", + "modified", + "uploaded", + "hash_structure", # see issue #118 + }: + assert ( + abcd_data[key] == abcd_data_after_read[key] + ), f"{key}'s value does not match" + + # expected differences - n.b. order of calls above + assert abcd_data_after_read["modified"] > abcd_data["modified"] + assert abcd_data_after_read["uploaded"] > abcd_data["uploaded"] + assert abcd_data_after_read["hash"] != abcd_data["hash"] + + # expect results to match within fp precision + for key in set(abcd_data.results_keys): + assert abcd_data[key] == approx( + np.array(abcd_data_after_read[key]) + ), f"{key}'s value does not match"