Skip to content

Commit

Permalink
reshape stress/virial to voigt-6 in error calc
Browse files Browse the repository at this point in the history
  • Loading branch information
bernstei committed Dec 3, 2024
1 parent 5b08766 commit ae6718c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
17 changes: 17 additions & 0 deletions wfl/fit/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from matplotlib.figure import Figure
from matplotlib.pyplot import get_cmap

from ase.stress import full_3x3_to_voigt_6_stress


def calc(inputs, calc_property_prefix, ref_property_prefix,
config_properties=None, atom_properties=None, category_keys="config_type",
Expand Down Expand Up @@ -86,6 +88,21 @@ def _reshape_normalize(quant, prop, atoms, per_atom):
# convert scalars or lists into arrays
quant = np.asarray(quant)

# fix shape of stress/virial
if prop.startswith("stress") or prop.startswith("virial"):
if prop.split("/")[0] in ["stress", "virial"]:
if quant.shape != (6,):
if quant.shape not in [(9,), (3,3)]:
raise ValueError(f"Prop '{prop}' has unknown shape of quant {quant.shape}")
quant = full_3x3_to_voigt_6_stress(quant.reshape((3, 3)))
elif prop.split("/")[0] in ["stresses", "virials"]:
eff_quant_shape = quant.shape[1:]
if eff_quant_shape != (6,):
if eff_quant_shape not in [(9,), (3,3)]:
raise ValueError(f"Prop '{prop}' has unknown shape of quant {quant.shape}")
quant = [full_3x3_to_voigt_6_stress(q.reshape((3, 3))) for q in quant]
quant = np.asarray(quant)

# Reshape to 2-d, with leading dimension 1 for per-config, and len(atoms) for per-atom.
# This is the right shape to work with later flattening for per-property and norm calculation
# for vector property differences.
Expand Down
2 changes: 1 addition & 1 deletion wfl/fit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import shlex
import warnings

from ase.constraints import voigt_6_to_full_3x3_stress
from ase.stress import voigt_6_to_full_3x3_stress

from wfl.utils.julia import julia_exec_path

Expand Down

0 comments on commit ae6718c

Please sign in to comment.