Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
Former-commit-id: 3014ae2
  • Loading branch information
kavanase committed Jan 31, 2024
1 parent f41b1ee commit 347c681
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 57 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,4 @@ snb_test.ipynb
conda_packaging.md
.pytest_cache/
shakenbreak/__pycache__/
tests/__pycache__/
tests/__pycache__/
Empty file modified .pytest_cache/.gitignore
100755 → 100644
Empty file.
Empty file modified .pytest_cache/CACHEDIR.TAG
100755 → 100644
Empty file.
Empty file modified .pytest_cache/README.md
100755 → 100644
Empty file.
5 changes: 3 additions & 2 deletions .pytest_cache/v/cache/lastfailed
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@
"tests/test_input.py::InputTestCase::test_write_vasp_files": true,
"tests/test_input.py::InputTestCase::test_write_vasp_files_from_doped_defect_gen": true,
"tests/test_input.py::InputTestCase::test_write_vasp_files_from_doped_dict": true,
"tests/test_input.py::InputTestCase::test_write_vasp_files_from_list": true
}
"tests/test_input.py::InputTestCase::test_write_vasp_files_from_list": true,
"tests/test_cli.py": true
}
13 changes: 12 additions & 1 deletion .pytest_cache/v/cache/nodeids
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
[
"tests/test_cli.py::CLITestCase::test_analyse",
"tests/test_cli.py::CLITestCase::test_groundstate",
"tests/test_cli.py::CLITestCase::test_mag",
"tests/test_cli.py::CLITestCase::test_parse",
"tests/test_cli.py::CLITestCase::test_parse_codes",
"tests/test_cli.py::CLITestCase::test_plot",
"tests/test_cli.py::CLITestCase::test_regenerate",
"tests/test_cli.py::CLITestCase::test_run",
"tests/test_cli.py::CLITestCase::test_snb_generate",
"tests/test_cli.py::CLITestCase::test_snb_generate_all",
"tests/test_cli.py::CLITestCase::test_snb_generate_config",
"tests/test_input.py::InputTestCase::test_apply_rattle_bond_distortions_V_Cd_dimer",
"tests/test_input.py::InputTestCase::test_apply_snb_distortions_V_Cd_Dimer",
"tests/test_input.py::InputTestCase::test_apply_snb_distortions_V_Cd_dimer",
Expand Down Expand Up @@ -37,4 +48,4 @@
"tests/test_plotting.py::PlottingDefectsTestCase::test_remove_high_energy_points",
"tests/test_plotting.py::PlottingDefectsTestCase::test_save_plot",
"tests/test_plotting.py::PlottingDefectsTestCase::test_verify_data_directories_exist"
]
]
2 changes: 1 addition & 1 deletion .pytest_cache/v/cache/stepwise
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1 +1 @@
[]
[]
40 changes: 23 additions & 17 deletions shakenbreak/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,15 @@ def _get_distortion_filename(distortion) -> str:
distortion_label = f"Bond_Distortion_{distortion:.1f}%"
elif isinstance(distortion, str):
if "_from_" in distortion and (
"Rattled" not in distortion
and "Dimer" not in distortion
"Rattled" not in distortion and "Dimer" not in distortion
):
distortion_label = f"Bond_Distortion_{distortion}"
# runs from other charge states
elif (
"Rattled_from_" in distortion
or "Dimer_from" in distortion
or distortion in [
or distortion
in [
"Unperturbed",
"Rattled",
"Dimer",
Expand Down Expand Up @@ -194,9 +194,8 @@ def _format_distortion_names(
):
return float(distortion_label.split("Bond_Distortion_")[-1].split("%")[0]) / 100
elif (
("Dimer" in distortion_label or "Rattled" in distortion_label)
and "High_Energy" in distortion_label
):
"Dimer" in distortion_label or "Rattled" in distortion_label
) and "High_Energy" in distortion_label:
return distortion_label
else:
return "Label_not_recognized"
Expand Down Expand Up @@ -505,7 +504,10 @@ def get_structures(
distortion_subdirectories = [
i
for i in next(os.walk(f"{output_path}/{defect_species}"))[1]
if ("Bond_Distortion" in i) or ("Unperturbed" in i) or ("Rattled" in i) or ("Dimer" in i)
if ("Bond_Distortion" in i)
or ("Unperturbed" in i)
or ("Rattled" in i)
or ("Dimer" in i)
] # distortion subdirectories
if not distortion_subdirectories:
raise FileNotFoundError(
Expand Down Expand Up @@ -1000,8 +1002,10 @@ def get_homoionic_bonds(
homoionic neighbours and distances (A) (e.g.
{'O(1)': {'O(2)': '2.0 A', 'O(3)': '2.0 A'}})
"""
if isinstance(elements, str): # For backward compatibility
elements = [elements,]
if isinstance(elements, str): # For backward compatibility
elements = [
elements,
]
structure = structure.copy()
structure.remove_oxidation_states()
for element in elements:
Expand All @@ -1012,11 +1016,13 @@ def get_homoionic_bonds(
# Search for homoionic bonds in the whole structure
sites = []
for element in elements:
sites.extend([
(site_index, site)
for site_index, site in enumerate(structure)
if site.species_string == element
])
sites.extend(
[
(site_index, site)
for site_index, site in enumerate(structure)
if site.species_string == element
]
)
homoionic_bonds = {}
for site_index, site in sites:
neighbours = structure.get_neighbors(site, r=radius)
Expand All @@ -1040,9 +1046,9 @@ def get_homoionic_bonds(
if neighbour[0] == element
}
if f"{site.species_string}({site_index})" in homoionic_bonds:
homoionic_bonds[
f"{site.species_string}({site_index})"
].update(homoionic_neighbours)
homoionic_bonds[f"{site.species_string}({site_index})"].update(
homoionic_neighbours
)
else:
homoionic_bonds[
f"{site.species_string}({site_index})"
Expand Down
2 changes: 1 addition & 1 deletion shakenbreak/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import click
from doped.generation import get_defect_name_from_entry
from doped.utils.plotting import _format_defect_name
from doped.utils.parsing import get_outcar
from doped.utils.plotting import _format_defect_name

# Monty and pymatgen
from monty.serialization import dumpfn, loadfn
Expand Down
24 changes: 15 additions & 9 deletions shakenbreak/distortions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
_probability_mc_rattle,
generate_mc_rattled_structures,
)
from pymatgen.analysis.local_env import MinimumDistanceNN
from pymatgen.analysis.local_env import CrystalNN, MinimumDistanceNN
from pymatgen.core.structure import Structure
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.analysis.local_env import CrystalNN


def _warning_on_one_line(message, category, filename, lineno, file=None, line=None):
"""Format warnings output"""
Expand Down Expand Up @@ -89,7 +89,9 @@ def distort(
if distorted_atoms and len(distorted_atoms) >= num_nearest_neighbours:
nearest = [
(
round(input_structure_ase.get_distance(atom_number, index, mic=True), 4),
round(
input_structure_ase.get_distance(atom_number, index, mic=True), 4
),
index + 1,
input_structure_ase.get_chemical_symbols()[index],
)
Expand All @@ -106,7 +108,9 @@ def distort(
)
distances = [ # Get all distances between the selected atom and all other atoms
(
round(input_structure_ase.get_distance(atom_number, index, mic=True), 4),
round(
input_structure_ase.get_distance(atom_number, index, mic=True), 4
),
index + 1, # Indices start from 1
symbol,
)
Expand All @@ -124,7 +128,9 @@ def distort(
): # filter the neighbours that match the element criteria and are
# closer than 4.5 Angstroms
nearest = [] # list of nearest neighbours
for dist, index, element in distances[1:]: # starting from 1 to exclude defect atom
for dist, index, element in distances[
1:
]: # starting from 1 to exclude defect atom
if (
element == distorted_element
and dist < 4.5
Expand Down Expand Up @@ -228,7 +234,7 @@ def apply_dimer_distortion(

if site_index is not None: # site_index can be 0
atom_number = site_index - 1 # Align atom number with python 0-indexing
elif type(frac_coords) in [list, tuple, np.ndarray]: # Only for vacancies!
elif type(frac_coords) in [list, tuple, np.ndarray]: # Only for vacancies!
input_structure_ase.append("V") # fake "V" at vacancy
input_structure_ase.positions[-1] = np.dot(
frac_coords, input_structure_ase.cell
Expand All @@ -243,20 +249,20 @@ def apply_dimer_distortion(
# Get defect nn
struct = aaa.get_structure(input_structure_ase)
cnn = CrystalNN()
sites = [d['site'] for d in cnn.get_nn_info(struct, atom_number)]
sites = [d["site"] for d in cnn.get_nn_info(struct, atom_number)]

# Get distances between NN
distances = {}
for i, site in enumerate(sites):
for other_site in sites[i+1:]:
for other_site in sites[i + 1 :]:
distances[(site.index, other_site.index)] = site.distance(other_site)
# Get defect NN with smallest distance
site_indexes = min(distances, key=distances.get)
# Set their distance to 2 A
input_structure_ase.set_distance(
a0=site_indexes[0], a1=site_indexes[1], distance=2.0, fix=0.5, mic=True
)
if type(frac_coords) in [list, tuple, np.ndarray]:
if type(frac_coords) in [list, tuple, np.ndarray]:
input_structure_ase.pop(-1) # remove fake V from vacancy structure

distorted_structure = aaa.get_structure(input_structure_ase)
Expand Down
45 changes: 28 additions & 17 deletions shakenbreak/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,7 +1360,10 @@ def _apply_rattle_bond_distortions(
# unit cell is conserved in the supercell
frac_coords = None # only for vacancies
if defect_site_index is not None:
if isinstance(distortion_factor, str) and distortion_factor.lower() == "dimer":
if (
isinstance(distortion_factor, str)
and distortion_factor.lower() == "dimer"
):
bond_distorted_defect = distortions.apply_dimer_distortion(
structure=defect_structure,
site_index=defect_site_index,
Expand Down Expand Up @@ -1568,14 +1571,18 @@ def apply_snb_distortions(
seed=seed,
**mc_rattle_kwargs,
)
distorted_defect_dict["distortions"]["Dimer"] = (
bond_distorted_defect["distorted_structure"]
distorted_defect_dict["distortions"]["Dimer"] = bond_distorted_defect[
"distorted_structure"
]
distorted_defect_dict["distortion_parameters"].update(
{
"unique_site": bulk_supercell_site.frac_coords,
"num_distorted_neighbours_in_dimer": 2, # Dimer distortion only affects 2 atoms
"distorted_atoms_in_dimer": bond_distorted_defect[
"distorted_atoms"
],
}
)
distorted_defect_dict["distortion_parameters"].update({
"unique_site": bulk_supercell_site.frac_coords,
"num_distorted_neighbours_in_dimer": 2, # Dimer distortion only affects 2 atoms
"distorted_atoms_in_dimer": bond_distorted_defect["distorted_atoms"],
})
if defect_site_index: # only add site index if not vacancy
distorted_defect_dict["distortion_parameters"][
"defect_site_index"
Expand Down Expand Up @@ -1635,11 +1642,15 @@ def apply_snb_distortions(
distorted_defect_dict["distortions"]["Dimer"] = bond_distorted_defect[
"distorted_structure"
]
distorted_defect_dict["distortion_parameters"].update({
"unique_site": bulk_supercell_site.frac_coords,
"num_distorted_neighbours_in_dimer": 2, # Dimer distortion only affects 2 atoms
"distorted_atoms_in_dimer": bond_distorted_defect["distorted_atoms"],
})
distorted_defect_dict["distortion_parameters"].update(
{
"unique_site": bulk_supercell_site.frac_coords,
"num_distorted_neighbours_in_dimer": 2, # Dimer distortion only affects 2 atoms
"distorted_atoms_in_dimer": bond_distorted_defect[
"distorted_atoms"
],
}
)
return distorted_defect_dict


Expand Down Expand Up @@ -1961,9 +1972,9 @@ def guess_oxidation_states(bulk_comp):
self.bond_distortions.append("Dimer")
bond_distortions.remove("Dimer")

self.bond_distortions.extend(list(
np.around(bond_distortions, 3)
)) # round to 3 decimal places
self.bond_distortions.extend(
list(np.around(bond_distortions, 3))
) # round to 3 decimal places
else:
# If the user does not specify bond_distortions, use
# distortion_increment:
Expand Down Expand Up @@ -2123,7 +2134,7 @@ def _print_distortion_info(
) -> None:
"""Print applied bond distortions and rattle standard deviation."""
rounded_distortions = [
f'{round(i,3)+0}' if isinstance(i, float) else i for i in bond_distortions
f"{round(i,3)+0}" if isinstance(i, float) else i for i in bond_distortions
]
print(
"Applying ShakeNBreak...",
Expand Down
2 changes: 1 addition & 1 deletion shakenbreak/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def parse_fhi_aims_energy(defect_dir, dist, energy, aims_out):
"Bond_Distortion",
"Rattled",
"Unperturbed",
"Dimer",
"Dimer",
]
)
and "High_Energy" in dir
Expand Down
13 changes: 6 additions & 7 deletions shakenbreak/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def _format_datapoints_from_other_chargestates(
elif entry == "Rattled": # add 0.0 for Rattled
# (to avoid problems when sorting distortions)
keys.append(0.0)
elif entry == "Dimer": # add 0.0 for Dimer
elif entry == "Dimer": # add 0.0 for Dimer
# (to avoid problems when sorting distortions)
keys.append(0.0)
else:
Expand Down Expand Up @@ -1340,7 +1340,7 @@ def plot_colorbar(
energies_dict["distortions"]["Dimer"],
c=disp_dict["Dimer"],
s=50,
marker="s", #default_style_settings["marker"],
marker="s", # default_style_settings["marker"],
label="Dimer",
cmap=colormap,
norm=norm,
Expand Down Expand Up @@ -1728,15 +1728,14 @@ def plot_datasets(
dataset["distortions"]["Dimer"],
c=colors[dataset_number],
s=50,
marker="s", #default_style_settings["marker"],
marker="s", # default_style_settings["marker"],
label="Dimer",
)

if len(sorted_distortions) > 0 and [

key for key in dataset["distortions"]
if (key != "Rattled" and key != "Dimer")

key
for key in dataset["distortions"]
if (key != "Rattled" and key != "Dimer")
]: # more than just Rattled
if imported_indices: # Exclude datapoints from other charge states
non_imported_sorted_indices = [
Expand Down

0 comments on commit 347c681

Please sign in to comment.