From 93445cc4269fadde62994c727bd8def714fc7386 Mon Sep 17 00:00:00 2001 From: ElliottKasoar Date: Wed, 27 Nov 2024 18:11:10 +0000 Subject: [PATCH] Fix getting Atoms results --- abcd/model.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/abcd/model.py b/abcd/model.py index cb9f135..ea8933a 100644 --- a/abcd/model.py +++ b/abcd/model.py @@ -146,7 +146,9 @@ def __iter__(self): @classmethod def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True): - """ASE's original implementation""" + """Extract data from Atoms info, arrays and results.""" + if not isinstance(atoms, Atoms): + raise ValueError("atoms must be an ASE Atoms object.") reserved_keys = { "n_atoms", @@ -157,12 +159,13 @@ def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True): "derived", "formula", } + arrays_keys = set(atoms.arrays.keys()) info_keys = set(atoms.info.keys()) if store_calc and atoms.calc: results_keys = atoms.calc.results.keys() - (arrays_keys | info_keys) else: - results_keys = {} + results_keys = set() all_keys = (reserved_keys, arrays_keys, info_keys, results_keys) if len(set.union(*all_keys)) != sum(map(len, all_keys)): @@ -173,46 +176,47 @@ def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True): n_atoms = len(atoms) - dct = { + data = { "n_atoms": n_atoms, "cell": atoms.cell.tolist(), "pbc": atoms.pbc.tolist(), "formula": atoms.get_chemical_formula(), } - info_keys.update({"n_atoms", "cell", "pbc", "formula"}) + info_keys.update(data.keys()) for key, value in atoms.arrays.items(): if isinstance(value, np.ndarray): - dct[key] = value.tolist() + data[key] = value.tolist() else: - dct[key] = value + data[key] = value for key, value in atoms.info.items(): if isinstance(value, np.ndarray): - dct[key] = value.tolist() + data[key] = value.tolist() else: - dct[key] = value + data[key] = value if store_calc and atoms.calc: - dct["calculator_name"] = atoms.calc.__class__.__name__ - dct["calculator_parameters"] = atoms.calc.todict() + data["calculator_name"] = atoms.calc.__class__.__name__ + data["calculator_parameters"] = atoms.calc.todict() info_keys.update({"calculator_name", "calculator_parameters"}) for key, value in atoms.calc.results.items(): - if isinstance(value, np.ndarray): if value.shape[0] == n_atoms: - arrays_keys.update(key) + arrays_keys.add(key) else: - info_keys.update(key) - dct[key] = value.tolist() + info_keys.add(key) + data[key] = value.tolist() + else: + data[key] = value item.arrays_keys = list(arrays_keys) item.info_keys = list(info_keys) item.results_keys = list(results_keys) - item.update(dct) + item.update(data) if extra_info: item.info_keys.extend(extra_info.keys())