Skip to content

Commit

Permalink
Fix getting Atoms results
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Nov 28, 2024
1 parent 4a56b0b commit 93445cc
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions abcd/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def __iter__(self):

@classmethod
def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True):
"""ASE's original implementation"""
"""Extract data from Atoms info, arrays and results."""
if not isinstance(atoms, Atoms):
raise ValueError("atoms must be an ASE Atoms object.")

reserved_keys = {
"n_atoms",
Expand All @@ -157,12 +159,13 @@ def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True):
"derived",
"formula",
}

arrays_keys = set(atoms.arrays.keys())
info_keys = set(atoms.info.keys())
if store_calc and atoms.calc:
results_keys = atoms.calc.results.keys() - (arrays_keys | info_keys)
else:
results_keys = {}
results_keys = set()

all_keys = (reserved_keys, arrays_keys, info_keys, results_keys)
if len(set.union(*all_keys)) != sum(map(len, all_keys)):
Expand All @@ -173,46 +176,47 @@ def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True):

n_atoms = len(atoms)

dct = {
data = {
"n_atoms": n_atoms,
"cell": atoms.cell.tolist(),
"pbc": atoms.pbc.tolist(),
"formula": atoms.get_chemical_formula(),
}

info_keys.update({"n_atoms", "cell", "pbc", "formula"})
info_keys.update(data.keys())

for key, value in atoms.arrays.items():
if isinstance(value, np.ndarray):
dct[key] = value.tolist()
data[key] = value.tolist()
else:
dct[key] = value
data[key] = value

for key, value in atoms.info.items():
if isinstance(value, np.ndarray):
dct[key] = value.tolist()
data[key] = value.tolist()
else:
dct[key] = value
data[key] = value

if store_calc and atoms.calc:
dct["calculator_name"] = atoms.calc.__class__.__name__
dct["calculator_parameters"] = atoms.calc.todict()
data["calculator_name"] = atoms.calc.__class__.__name__
data["calculator_parameters"] = atoms.calc.todict()
info_keys.update({"calculator_name", "calculator_parameters"})

for key, value in atoms.calc.results.items():

if isinstance(value, np.ndarray):
if value.shape[0] == n_atoms:
arrays_keys.update(key)
arrays_keys.add(key)
else:
info_keys.update(key)
dct[key] = value.tolist()
info_keys.add(key)
data[key] = value.tolist()
else:
data[key] = value

item.arrays_keys = list(arrays_keys)
item.info_keys = list(info_keys)
item.results_keys = list(results_keys)

item.update(dct)
item.update(data)

if extra_info:
item.info_keys.extend(extra_info.keys())
Expand Down

0 comments on commit 93445cc

Please sign in to comment.