From d883f98ff413c0a48d022027f66c0436c429bdc5 Mon Sep 17 00:00:00 2001 From: Jimmy Shen <14003693+jmmshn@users.noreply.github.com> Date: Tue, 12 Nov 2024 10:04:27 -0800 Subject: [PATCH] new flag for AutoOxi (#4150) new flag for AutoOxi --- .../transformations/standard_transformations.py | 14 +++++++++++++- .../test_standard_transformations.py | 9 +++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/pymatgen/transformations/standard_transformations.py b/src/pymatgen/transformations/standard_transformations.py index 83ba693b4ba..634b74840c6 100644 --- a/src/pymatgen/transformations/standard_transformations.py +++ b/src/pymatgen/transformations/standard_transformations.py @@ -107,6 +107,7 @@ def __init__( max_radius=4, max_permutations=100000, distance_scale_factor=1.015, + zeros_on_fail=False, ): """ Args: @@ -121,12 +122,16 @@ def __init__( calculation-relaxed structures, which may tend to under (GGA) or over bind (LDA). The default of 1.015 works for GGA. For experimental structure, set this to 1. + zeros_on_fail (bool): If True and the BVAnalyzer fails to come up + with a guess for the oxidation states, we will set the all the + oxidation states to zero. """ self.symm_tol = symm_tol self.max_radius = max_radius self.max_permutations = max_permutations self.distance_scale_factor = distance_scale_factor self.analyzer = BVAnalyzer(symm_tol, max_radius, max_permutations, distance_scale_factor) + self.zeros_on_fail = zeros_on_fail def apply_transformation(self, structure): """Apply the transformation. @@ -137,7 +142,14 @@ def apply_transformation(self, structure): Returns: Oxidation state decorated Structure. """ - return self.analyzer.get_oxi_state_decorated_structure(structure) + try: + return self.analyzer.get_oxi_state_decorated_structure(structure) + except ValueError as er: + if self.zeros_on_fail: + struct_ = structure.copy() + struct_.add_oxidation_state_by_site([0] * len(struct_)) + return struct_ + raise ValueError(f"BVAnalyzer failed with error: {er}") class OxidationStateRemovalTransformation(AbstractTransformation): diff --git a/tests/transformations/test_standard_transformations.py b/tests/transformations/test_standard_transformations.py index 21204c67bab..42ee06704d0 100644 --- a/tests/transformations/test_standard_transformations.py +++ b/tests/transformations/test_standard_transformations.py @@ -200,6 +200,15 @@ def test_as_from_dict(self): trafo = AutoOxiStateDecorationTransformation.from_dict(dct) assert trafo.analyzer.dist_scale_factor == 1.015 + def test_failure(self): + trafo_fail = AutoOxiStateDecorationTransformation() + trafo_no_fail = AutoOxiStateDecorationTransformation(zeros_on_fail=True) + struct_metal = Structure.from_spacegroup("Fm-3m", Lattice.cubic(3.677), ["Cu"], [[0, 0, 0]]) + with pytest.raises(ValueError, match="BVAnalyzer failed with error"): + trafo_fail.apply_transformation(struct_metal) + zero_oxi_struct = trafo_no_fail.apply_transformation(struct_metal) + assert all(site.specie.oxi_state == 0 for site in zero_oxi_struct) + class TestOxidationStateRemovalTransformation: def test_apply_transformation(self):