Skip to content

Commit

Permalink
Avoid using full equality (==) to compare float, avoid `assert_arra…
Browse files Browse the repository at this point in the history
…y_equal` compare float array (#4159)

* replace some float equality check

* explicit encoding

* charge is also float

* enhance types

* access gcd via math namespace as math is already imported

* put dunder method to top

* fix typo

* tweak _proj implementation

* support array like

* add arg and return type

* tweak type

* avoid more == for float comparison

* replace some == in test, more left to do

* replace more in core test

* replace more in test

* replace even more

* replace last batch

* clean up assert approx

* replace pytest.approx with approx

* also fix membership check

* replace some equality check of list

* replace some sequences

* fix test

* replace float comparison as dict

* fix test

* replace more float compare, mostly for VASP

* fix test

* fix approx in condition block

* replace sci notation

* suppress buggy ruff sim300

* number_of_permutations to int

* revert change for formula_double_format, in favor of another PR

* c_indices seems to be int

* use sci notation for crazily large int

* simplify numpy.testing usage

* set tol as pos arg

* avoid array equal for list of str

* assert_array_equal should not be used on float array

* fix module level var name

* more assert_array_equal on complex number

* simplify approx on dict value

* avoid module level var when it's used only 3 times

* pytext.approx to approx

* fix approx on nested dict

* avoid unnecessary convert to np.array

* array_equal to all close for float array

* assert all close for float array

* capital class attrib is treated as constant
  • Loading branch information
DanielYang59 authored Jan 24, 2025
1 parent 5f744f2 commit 7e7756e
Show file tree
Hide file tree
Showing 90 changed files with 1,222 additions and 1,139 deletions.
3 changes: 2 additions & 1 deletion src/pymatgen/alchemy/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import abc
import math
from collections import defaultdict
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -285,7 +286,7 @@ def __init__(self):

def test(self, structure: Structure):
"""True if structure is neutral."""
return structure.charge == 0.0
return math.isclose(structure.charge, 0.0)


class SpeciesMaxDistFilter(AbstractStructureFilter):
Expand Down
5 changes: 4 additions & 1 deletion src/pymatgen/core/bonds.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def obtain_all_bond_lengths(
If None, a ValueError will be thrown.
Returns:
A dict mapping bond order to bond length in angstrom
dict[float, float]: mapping bond order to bond length in Angstrom.
Todo:
it's better to avoid using float as dict keys.
"""
if isinstance(sp1, Element):
sp1 = sp1.symbol
Expand Down
2 changes: 1 addition & 1 deletion src/pymatgen/electronic_structure/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ def get_elt_projected_plots_color(
proj[b][str(spin)][band_idx][j][str(el)][o]
for o in proj[b][str(spin)][band_idx][j][str(el)]
)
if sum_e == 0.0:
if math.isclose(sum_e, 0.0):
color = [0.0] * len(elt_ordered)
else:
color = [
Expand Down
2 changes: 1 addition & 1 deletion src/pymatgen/io/aims/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def get_content(
magmom = structure.site_properties.get("magmom", spins)
if (
parameters.get("spin", "") == "collinear"
and np.all(magmom == 0.0)
and np.allclose(magmom, 0.0)
and ("default_initial_moment" not in parameters)
):
warn(
Expand Down
10 changes: 8 additions & 2 deletions src/pymatgen/io/cp2k/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,10 +1320,16 @@ def parse_bandstructure(self, bandstructure_filename=None) -> None:
else:
eigenvals = {Spin.up: bands_data.reshape((nbands, nkpts))}

occ = bands_data[:, 1][bands_data[:, -1] != 0.0]
# Filter out occupied and unoccupied states
occupied_mask = ~np.isclose(bands_data[:, -1], 0.0)
unoccupied_mask = np.isclose(bands_data[:, -1], 0.0)

occ = bands_data[:, 1][occupied_mask]
homo = np.max(occ)
unocc = bands_data[:, 1][bands_data[:, -1] == 0.0]

unocc = bands_data[:, 1][unoccupied_mask]
lumo = np.min(unocc)

efermi = (lumo + homo) / 2
self.efermi = efermi

Expand Down
12 changes: 6 additions & 6 deletions src/pymatgen/io/vasp/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,13 +782,13 @@ def run_type(self) -> str:
4: "dDsC",
}

if self.parameters.get("AEXX", 1.00) == 1.00:
if math.isclose(self.parameters.get("AEXX", 1.00), 1.00):
run_type = "HF"
elif self.parameters.get("HFSCREEN", 0.30) == 0.30:
elif math.isclose(self.parameters.get("HFSCREEN", 0.30), 0.30):
run_type = "HSE03"
elif self.parameters.get("HFSCREEN", 0.20) == 0.20:
elif math.isclose(self.parameters.get("HFSCREEN", 0.20), 0.20):
run_type = "HSE06"
elif self.parameters.get("AEXX", 0.20) == 0.20:
elif math.isclose(self.parameters.get("AEXX", 0.20), 0.20):
run_type = "B3LYP"
elif self.parameters.get("LHFCALC", True):
run_type = "PBEO or other Hybrid Functional"
Expand Down Expand Up @@ -1031,7 +1031,7 @@ def get_band_structure(
if (hybrid_band or force_hybrid_mode) and not use_kpoints_opt:
start_bs_index = 0
for i in range(len(self.actual_kpoints)):
if self.actual_kpoints_weights[i] == 0.0:
if math.isclose(self.actual_kpoints_weights[i], 0.0):
start_bs_index = i
break
for i in range(start_bs_index, len(kpoint_file.kpts)):
Expand Down Expand Up @@ -5386,7 +5386,7 @@ def get_parchg(
Returns:
A Chgcar object.
"""
if phase and not np.all(self.kpoints[kpoint] == 0.0):
if phase and not np.allclose(self.kpoints[kpoint], 0.0):
warnings.warn(
"phase is True should only be used for the Gamma kpoint! I hope you know what you're doing!",
stacklevel=2,
Expand Down
89 changes: 52 additions & 37 deletions src/pymatgen/transformations/advanced_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
from collections.abc import Callable, Iterable, Sequence
from typing import Any, Literal

from numpy.typing import NDArray


__author__ = "Shyue Ping Ong, Stephen Dacek, Anubhav Jain, Matthew Horton, Alex Ganose"

Expand All @@ -67,6 +69,9 @@ def __init__(self, charge_balance_sp):
"""
self.charge_balance_sp = str(charge_balance_sp)

def __repr__(self):
return f"Charge Balance Transformation : Species to remove = {self.charge_balance_sp}"

def apply_transformation(self, structure: Structure):
"""Apply the transformation.
Expand All @@ -86,9 +91,6 @@ def apply_transformation(self, structure: Structure):
trans = SubstitutionTransformation({self.charge_balance_sp: {self.charge_balance_sp: 1 - removal_fraction}})
return trans.apply_transformation(structure)

def __repr__(self):
return f"Charge Balance Transformation : Species to remove = {self.charge_balance_sp}"


class SuperTransformation(AbstractTransformation):
"""This is a transformation that is inherently one-to-many. It is constructed
Expand All @@ -110,6 +112,9 @@ def __init__(self, transformations, nstructures_per_trans=1):
self._transformations = transformations
self.nstructures_per_trans = nstructures_per_trans

def __repr__(self):
return f"Super Transformation : Transformations = {' '.join(map(str, self._transformations))}"

def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
"""Apply the transformation.
Expand Down Expand Up @@ -139,11 +144,8 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
)
return structures

def __repr__(self):
return f"Super Transformation : Transformations = {' '.join(map(str, self._transformations))}"

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -191,6 +193,9 @@ def __init__(
self.charge_balance_species = charge_balance_species
self.order = order

def __repr__(self):
return f"Multiple Substitution Transformation : Substitution on {self.sp_to_replace}"

def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
"""Apply the transformation.
Expand Down Expand Up @@ -233,11 +238,8 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
outputs.append({"structure": new_structure})
return outputs

def __repr__(self):
return f"Multiple Substitution Transformation : Substitution on {self.sp_to_replace}"

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -322,6 +324,9 @@ def __init__(
if max_cell_size and max_disordered_sites:
raise ValueError("Cannot set both max_cell_size and max_disordered_sites!")

def __repr__(self):
return "EnumerateStructureTransformation"

def apply_transformation(
self, structure: Structure, return_ranked_list: bool | int = False
) -> Structure | list[dict]:
Expand Down Expand Up @@ -468,11 +473,8 @@ def sort_func(struct):
return self._all_structures[:num_to_return]
return self._all_structures[0]["structure"]

def __repr__(self):
return "EnumerateStructureTransformation"

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand All @@ -494,6 +496,9 @@ def __init__(self, threshold=1e-2, scale_volumes=True, **kwargs):
self.scale_volumes = scale_volumes
self._substitutor = SubstitutionPredictor(threshold=threshold, **kwargs)

def __repr__(self):
return "SubstitutionPredictorTransformation"

def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
"""Apply the transformation.
Expand Down Expand Up @@ -528,11 +533,8 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
outputs.append(output)
return outputs

def __repr__(self):
return "SubstitutionPredictorTransformation"

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -895,7 +897,7 @@ def key(struct: Structure) -> int:
return self._all_structures[:num_to_return] # type: ignore[return-value]

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -984,15 +986,19 @@ def __init__(
self.allowed_doping_species = allowed_doping_species
self.kwargs = kwargs

def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False):
def apply_transformation(
self,
structure: Structure,
return_ranked_list: bool | int = False,
) -> list[dict[Literal["structure", "energy"], Structure | float]] | Structure:
"""
Args:
structure (Structure): Input structure to dope
return_ranked_list (bool | int, optional): If return_ranked_list is int, that number of structures.
is returned. If False, only the single lowest energy structure is returned. Defaults to False.
structure (Structure): Input structure to dope.
return_ranked_list (bool | int, optional): If is int, that number of structures is returned.
If False, only the single lowest energy structure is returned. Defaults to False.
Returns:
list[dict] | Structure: each dict has shape {"structure": Structure, "energy": float}.
list[dict] | Structure: each dict as {"structure": Structure, "energy": float}.
"""
comp = structure.composition
logger.info(f"Composition: {comp}")
Expand Down Expand Up @@ -1125,7 +1131,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
return all_structures[0]["structure"]

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -1253,7 +1259,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
return disordered_structures

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -1714,7 +1720,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool |
return [{"structure": structure} for structure in structures[:return_ranked_list]]

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -1868,16 +1874,25 @@ def apply_transformation(
return [{"structure": structure} for structure in structures[:return_ranked_list]]

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True


def _proj(b, a):
"""Get vector projection (np.ndarray) of vector b (np.ndarray)
onto vector a (np.ndarray).
def _proj(b: NDArray, a: NDArray) -> NDArray:
"""Get vector projection of vector b onto vector a.
Args:
b (NDArray): Vector to be projected.
a (NDArray): Vector onto which `b` is projected.
Returns:
NDArray: Projection of `b` onto `a`.
"""
return (b.T @ (a / np.linalg.norm(a))) * (a / np.linalg.norm(a))
a = np.asarray(a)
b = np.asarray(b)

return (np.dot(b, a) / np.dot(a, a)) * a


class SQSTransformation(AbstractTransformation):
Expand Down Expand Up @@ -2146,7 +2161,7 @@ def _get_unique_best_sqs_structs(sqs, best_only, return_ranked_list, remove_dupl
return to_return

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[True]:
"""Transform one structure to many."""
return True

Expand Down Expand Up @@ -2195,6 +2210,9 @@ def __init__(self, rattle_std: float, min_distance: float, seed: int | None = No
self.random_state = np.random.RandomState(seed)
self.kwargs = kwargs

def __repr__(self):
return f"{__name__} : rattle_std = {self.rattle_std}"

def apply_transformation(self, structure: Structure) -> Structure:
"""Apply the transformation.
Expand All @@ -2216,6 +2234,3 @@ def apply_transformation(self, structure: Structure) -> Structure:
structure.cart_coords + displacements,
coords_are_cartesian=True,
)

def __repr__(self):
return f"{__name__} : rattle_std = {self.rattle_std}"
8 changes: 5 additions & 3 deletions src/pymatgen/transformations/transformation_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from monty.json import MSONable

if TYPE_CHECKING:
from typing import Any, Literal

from pymatgen.core import Structure

__author__ = "Shyue Ping Ong"
Expand Down Expand Up @@ -55,7 +57,7 @@ def inverse(self) -> AbstractTransformation | None:
"""

@property
def is_one_to_many(self) -> bool:
def is_one_to_many(self) -> Literal[False]:
"""Determine if a Transformation is a one-to-many transformation. In that case, the
apply_transformation method should have a keyword arg "return_ranked_list" which
allows for the transformed structures to be returned as a ranked list.
Expand All @@ -64,7 +66,7 @@ def is_one_to_many(self) -> bool:
return False

@property
def use_multiprocessing(self) -> bool:
def use_multiprocessing(self) -> Literal[False]:
"""Indicates whether the transformation can be applied by a
subprocessing pool. This should be overridden to return True for
transformations that the transmuter can parallelize.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_coordination_geometry(self):
assert cg_oct.IUCr_symbol_str == "[6o]"

cg_oct.permutations_safe_override = True
assert cg_oct.number_of_permutations == 720.0
assert cg_oct.number_of_permutations == 720
assert cg_oct.ref_permutation([0, 3, 2, 4, 5, 1]) == (0, 3, 1, 5, 2, 4)

sites = [FakeSite(coords=pp) for pp in cg_oct.points]
Expand Down
Loading

0 comments on commit 7e7756e

Please sign in to comment.