From f82ce1f476313478b4658037e1cf737ab59fefdb Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 28 Jan 2025 15:10:06 -0500 Subject: [PATCH] Add `seed: int = 0` parameter to `Structure.perturb()` method (#4270) * Add seed parameter to Structure.perturb() method - Update tests to verify reproducible random perturbations with same seed - Add test cases for different seed values and min_distance parameter * fix test_perturb --- src/pymatgen/core/structure.py | 10 +++---- tests/core/test_structure.py | 52 +++++++++++++++++++++++----------- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/src/pymatgen/core/structure.py b/src/pymatgen/core/structure.py index 9cde191afa7..9863bd9427e 100644 --- a/src/pymatgen/core/structure.py +++ b/src/pymatgen/core/structure.py @@ -4635,16 +4635,16 @@ def rotate_sites( return self - def perturb(self, distance: float, min_distance: float | None = None) -> Self: + def perturb(self, distance: float, min_distance: float | None = None, seed: int = 0) -> Self: """Perform a random perturbation of the sites in a structure to break symmetries. Modifies the structure in place. Args: distance (float): Distance in angstroms by which to perturb each site. min_distance (None, int, or float): if None, all displacements will - be equal amplitude. If int or float, perturb each site a - distance drawn from the uniform distribution between - 'min_distance' and 'distance'. + be equal amplitude. If int or float, perturb each site a distance drawn + from the uniform distribution between 'min_distance' and 'distance'. + seed (int): Seed for the random number generator. Defaults to 0. Returns: Structure: self with perturbed sites. @@ -4652,7 +4652,7 @@ def perturb(self, distance: float, min_distance: float | None = None) -> Self: def get_rand_vec(): # Deal with zero vectors - rng = np.random.default_rng() + rng = np.random.default_rng(seed=seed) vector = rng.standard_normal(3) vnorm = np.linalg.norm(vector) dist = distance diff --git a/tests/core/test_structure.py b/tests/core/test_structure.py index b26db3edcb9..c400709415c 100644 --- a/tests/core/test_structure.py +++ b/tests/core/test_structure.py @@ -1213,22 +1213,42 @@ def test_propertied_structure(self): assert dct == struct.as_dict() def test_perturb(self): - dist = 0.1 - pre_perturbation_sites = self.struct.copy() - returned = self.struct.perturb(distance=dist) - assert returned is self.struct - post_perturbation_sites = self.struct.sites - - for idx, site in enumerate(pre_perturbation_sites): - assert site.distance(post_perturbation_sites[idx]) == approx(dist), "Bad perturbation distance" - - structure2 = pre_perturbation_sites.copy() - structure2.perturb(distance=dist, min_distance=0) - post_perturbation_sites2 = structure2.sites - - for idx, site in enumerate(pre_perturbation_sites): - assert site.distance(post_perturbation_sites2[idx]) <= dist - assert site.distance(post_perturbation_sites2[idx]) >= 0 + struct = self.get_structure("Li2O") + struct_orig = struct.copy() + struct.perturb(0.1) + # Ensure all sites were perturbed by a distance of at most 0.1 Angstroms + for site, site_orig in zip(struct, struct_orig, strict=True): + cart_dist = site.distance(site_orig) + # allow 1e-6 to account for numerical precision + assert cart_dist <= 0.1 + 1e-6, f"Distance {cart_dist} > 0.1" + + # Test that same seed gives same perturbation + s1 = self.get_structure("Li2O") + s2 = self.get_structure("Li2O") + s1.perturb(0.1, seed=42) + s2.perturb(0.1, seed=42) + for site1, site2 in zip(s1, s2, strict=True): + assert site1.distance(site2) < 1e-7 # should be exactly equal up to numerical precision + + # Test that different seeds give different perturbations + s3 = self.get_structure("Li2O") + s3.perturb(0.1, seed=100) + any_different = False + for site1, site3 in zip(s1, s3, strict=True): + if site1.distance(site3) > 1e-7: + any_different = True + break + assert any_different, "Different seeds should give different perturbations" + + # Test min_distance + s4 = self.get_structure("Li2O") + s4.perturb(0.1, min_distance=0.05, seed=42) + any_different = False + for site1, site4 in zip(s1, s4, strict=True): + if site1.distance(site4) > 1e-7: + any_different = True + break + assert any_different, "Using min_distance should give different perturbations" def test_add_oxidation_state_by_element(self): oxidation_states = {"Si": -4}