diff --git a/shakenbreak/input.py b/shakenbreak/input.py index 6903bf3..136ca38 100644 --- a/shakenbreak/input.py +++ b/shakenbreak/input.py @@ -12,7 +12,12 @@ import warnings from importlib.metadata import version from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Optional, Union +import sys +if sys.version_info >= (3, 9): + tuple_type = tuple +else: + from typing import Tuple as tuple_type import ase import numpy as np @@ -2281,7 +2286,7 @@ def write_distortion_metadata( def apply_distortions( self, verbose: Optional[bool] = None, - ) -> Tuple[dict, dict]: + ) -> tuple_type[dict, dict]: """ Applies rattle and bond distortion to all defects in `defect_entries`. Returns a dictionary with the distorted (and undistorted) structures @@ -2474,7 +2479,7 @@ def write_vasp_files( output_path: str = ".", verbose: Optional[bool] = None, **kwargs, - ) -> Tuple[dict, dict]: + ) -> tuple_type[dict, dict]: """ Generates the input files for `vasp_gam` relaxations of all output structures. @@ -2581,7 +2586,7 @@ def write_espresso_files( output_path: str = ".", verbose: Optional[bool] = None, profile=None, - ) -> Tuple[dict, dict]: + ) -> tuple_type[dict, dict]: """ Generates input files for Quantum Espresso relaxations of all output structures. @@ -2715,7 +2720,7 @@ def write_cp2k_files( write_structures_only: Optional[bool] = False, output_path: str = ".", verbose: Optional[bool] = None, - ) -> Tuple[dict, dict]: + ) -> tuple_type[dict, dict]: """ Generates input files for CP2K relaxations of all output structures. @@ -2777,7 +2782,7 @@ def write_castep_files( write_structures_only: Optional[bool] = False, output_path: str = ".", verbose: Optional[bool] = None, - ) -> Tuple[dict, dict]: + ) -> tuple_type[dict, dict]: """ Generates input `.cell` and `.param` files for CASTEP relaxations of all output structures. @@ -2854,7 +2859,7 @@ def write_fhi_aims_files( output_path: str = ".", verbose: Optional[bool] = None, profile=None, - ) -> Tuple[dict, dict]: + ) -> tuple_type[dict, dict]: """ Generates input geometry and control files for FHI-aims relaxations of all output structures. diff --git a/shakenbreak/plotting.py b/shakenbreak/plotting.py index c7e0142..2b86164 100644 --- a/shakenbreak/plotting.py +++ b/shakenbreak/plotting.py @@ -8,7 +8,12 @@ import os import shutil import warnings -from typing import Optional, Tuple +from typing import Optional +import sys +if sys.version_info >= (3, 9): + tuple_type = tuple +else: + from typing import Tuple as tuple_type import matplotlib as mpl import matplotlib.pyplot as plt @@ -182,7 +187,7 @@ def _change_energy_units_to_meV( energies_dict: dict, max_energy_above_unperturbed: float, y_label: str, -) -> Tuple[dict, float, str]: +) -> tuple_type[dict, float, str]: """ Converts energy values from eV to meV and format y label accordingly. @@ -211,7 +216,7 @@ def _change_energy_units_to_meV( def _purge_data_dicts( disp_dict: dict, energies_dict: dict, -) -> Tuple[dict, dict]: +) -> tuple_type[dict, dict]: """ Purges dictionaries of displacements and energies so that they are consistent (i.e. contain data for same distortions). @@ -249,7 +254,7 @@ def _remove_high_energy_points( energies_dict: dict, max_energy_above_unperturbed: float, disp_dict: Optional[dict] = None, -) -> Tuple[dict, dict]: +) -> tuple_type[dict, dict]: """ Remove points whose energy is higher than the reference (Unperturbed) by more than `max_energy_above_unperturbed`. @@ -285,7 +290,7 @@ def _get_displacement_dict( energies_dict: dict, add_colorbar: bool, code: Optional[str] = "vasp", -) -> Tuple[bool, dict, dict]: +) -> tuple_type[bool, dict, dict]: """ Parses structures of `defect_species` to calculate displacements between each of them and the reference configuration (Unperturbed). These displacements @@ -635,7 +640,7 @@ def _get_line_colors(number_of_colors: int) -> list: def _setup_colormap( disp_dict: dict, -) -> Tuple[mpl.colors.Colormap, float, float, float, mpl.colors.Normalize]: +) -> tuple_type[mpl.colors.Colormap, float, float, float, mpl.colors.Normalize]: """ Setup colormap to measure structural similarity between structures. @@ -1234,7 +1239,7 @@ def _setup_plot( num_nearest_neighbours: Optional[int], neighbour_atom: Optional[str], **fig_kwargs, -) -> Tuple[plt.Figure, plt.Axes]: +) -> tuple_type[plt.Figure, plt.Axes]: _install_custom_font() fig, ax = plt.subplots(1, 1, **fig_kwargs)