diff --git a/wfl/fit/error.py b/wfl/fit/error.py index e3a29a33..e2986d64 100755 --- a/wfl/fit/error.py +++ b/wfl/fit/error.py @@ -7,6 +7,8 @@ from matplotlib.figure import Figure from matplotlib.pyplot import get_cmap +from ase.stress import full_3x3_to_voigt_6_stress + def calc(inputs, calc_property_prefix, ref_property_prefix, config_properties=None, atom_properties=None, category_keys="config_type", @@ -86,6 +88,21 @@ def _reshape_normalize(quant, prop, atoms, per_atom): # convert scalars or lists into arrays quant = np.asarray(quant) + # fix shape of stress/virial + if prop.startswith("stress") or prop.startswith("virial"): + if prop.split("/")[0] in ["stress", "virial"]: + if quant.shape != (6,): + if quant.shape not in [(9,), (3,3)]: + raise ValueError(f"Prop '{prop}' has unknown shape of quant {quant.shape}") + quant = full_3x3_to_voigt_6_stress(quant.reshape((3, 3))) + elif prop.split("/")[0] in ["stresses", "virials"]: + eff_quant_shape = quant.shape[1:] + if eff_quant_shape != (6,): + if eff_quant_shape not in [(9,), (3,3)]: + raise ValueError(f"Prop '{prop}' has unknown shape of quant {quant.shape}") + quant = [full_3x3_to_voigt_6_stress(q.reshape((3, 3))) for q in quant] + quant = np.asarray(quant) + # Reshape to 2-d, with leading dimension 1 for per-config, and len(atoms) for per-atom. # This is the right shape to work with later flattening for per-property and norm calculation # for vector property differences. diff --git a/wfl/fit/utils.py b/wfl/fit/utils.py index 848000cf..a18cc342 100755 --- a/wfl/fit/utils.py +++ b/wfl/fit/utils.py @@ -5,7 +5,7 @@ import shlex import warnings -from ase.constraints import voigt_6_to_full_3x3_stress +from ase.stress import voigt_6_to_full_3x3_stress from wfl.utils.julia import julia_exec_path