diff --git a/pycrostates/cluster/_base.py b/pycrostates/cluster/_base.py index 187b7b83..aaf3dfac 100644 --- a/pycrostates/cluster/_base.py +++ b/pycrostates/cluster/_base.py @@ -2,7 +2,7 @@ from copy import copy, deepcopy from itertools import groupby from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import numpy as np from matplotlib.axes import Axes @@ -277,11 +277,11 @@ def fit( def rename_clusters( self, - mapping: Optional[Dict[str, str]] = None, + mapping: Optional[dict[str, str]] = None, new_names: Optional[ Union[ - List[str], - Tuple[str, ...], + list[str], + tuple[str, ...], ] ] = None, ) -> None: @@ -341,11 +341,11 @@ def rename_clusters( def reorder_clusters( self, - mapping: Optional[Dict[int, int]] = None, + mapping: Optional[dict[int, int]] = None, order: Optional[ Union[ - List[int], - Tuple[int, ...], + list[int], + tuple[int, ...], NDArray[int], ] ] = None, @@ -458,8 +458,8 @@ def invert_polarity( self, invert: Union[ bool, - List[bool], - Tuple[bool, ...], + list[bool], + tuple[bool, ...], NDArray[bool], ], ) -> None: @@ -518,7 +518,7 @@ def plot( self, axes: Optional[Union[Axes, NDArray[Axes]]] = None, show_gradient: Optional[bool] = False, - gradient_kwargs: Dict[str, Any] = { + gradient_kwargs: dict[str, Any] = { "color": "black", "linestyle": "-", "marker": "P", @@ -1180,7 +1180,7 @@ def labels_(self) -> NDArray[int]: return self._labels_.copy() @property - def cluster_names(self) -> List[str]: + def cluster_names(self) -> list[str]: """Name of the clusters. :type: `list` diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index 87099b16..75df046e 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -1,7 +1,7 @@ """Atomize and Agglomerate Hierarchical Clustering (AAHC).""" from pathlib import Path -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Union import numpy as np from mne import BaseEpochs @@ -183,7 +183,7 @@ def _aahc( n_clusters: int, ignore_polarity: bool, normalize_input: bool, - ) -> Tuple[float, NDArray[float], NDArray[int]]: + ) -> tuple[float, NDArray[float], NDArray[int]]: """Run the AAHC algorithm.""" gfp_sum_sq = np.sum(data**2) maps, segmentation = AAHCluster._compute_maps( @@ -200,7 +200,7 @@ def _compute_maps( n_clusters: int, ignore_polarity: bool, normalize_input: bool, - ) -> Tuple[NDArray[float], NDArray[int]]: + ) -> tuple[NDArray[float], NDArray[int]]: """Compute microstates maps.""" n_chan, n_frame = data.shape diff --git a/pycrostates/cluster/kmeans.py b/pycrostates/cluster/kmeans.py index c7de600c..b69d4081 100644 --- a/pycrostates/cluster/kmeans.py +++ b/pycrostates/cluster/kmeans.py @@ -1,7 +1,7 @@ """Class and functions to use modified Kmeans.""" from pathlib import Path -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Union import numpy as np from mne import BaseEpochs @@ -263,7 +263,7 @@ def _kmeans( max_iter: int, random_state: Union[RandomState, Generator], tol: Union[int, float], - ) -> Tuple[float, NDArray[float], NDArray[int], bool]: + ) -> tuple[float, NDArray[float], NDArray[int], bool]: """Run the k-means algorithm.""" gfp_sum_sq = np.sum(data**2) maps, converged = ModKMeans._compute_maps( @@ -282,7 +282,7 @@ def _compute_maps( max_iter: int, random_state: Union[RandomState, Generator], tol: Union[int, float], - ) -> Tuple[NDArray[float], bool]: + ) -> tuple[NDArray[float], bool]: """Compute microstates maps. Based on mne_microstates by Marijn van Vliet diff --git a/pycrostates/io/fiff.py b/pycrostates/io/fiff.py index bee0bc9e..3b56dbfa 100644 --- a/pycrostates/io/fiff.py +++ b/pycrostates/io/fiff.py @@ -5,7 +5,7 @@ from functools import reduce from numbers import Integral from pathlib import Path -from typing import List, Union +from typing import Union import numpy as np from mne import Info, Transform @@ -98,7 +98,7 @@ def _write_cluster( cluster_centers_: NDArray[float], chinfo: Union[CHInfo, Info], algorithm: str, - cluster_names: List[str], + cluster_names: list[str], fitted_data: NDArray[float], labels_: NDArray[int], **kwargs, @@ -398,7 +398,7 @@ def _check_fit_parameters_and_variables( def _create_ModKMeans( cluster_centers_: NDArray[float], info: CHInfo, - cluster_names: List[str], + cluster_names: list[str], fitted_data: NDArray[float], labels_: NDArray[int], n_init: int, @@ -423,7 +423,7 @@ def _create_ModKMeans( def _create_AAHCluster( cluster_centers_: NDArray[float], info: CHInfo, - cluster_names: List[str], + cluster_names: list[str], fitted_data: NDArray[float], labels_: NDArray[int], ignore_polarity: bool, # pylint: disable=unused-argument diff --git a/pycrostates/io/meas_info.py b/pycrostates/io/meas_info.py index 41261a57..91a6ec54 100644 --- a/pycrostates/io/meas_info.py +++ b/pycrostates/io/meas_info.py @@ -2,7 +2,7 @@ from copy import deepcopy from numbers import Number -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import numpy as np from mne import Info, Projection, Transform @@ -155,11 +155,11 @@ def __init__( ch_names: Optional[ Union[ int, - List[str], - Tuple[str, ...], + list[str], + tuple[str, ...], ] ] = None, - ch_types: Optional[Union[str, List[str], Tuple[str, ...]]] = None, + ch_types: Optional[Union[str, list[str], tuple[str, ...]]] = None, ): if all(arg is None for arg in (info, ch_names, ch_types)): raise RuntimeError( @@ -195,8 +195,8 @@ def _init_from_info(self, info: Info): def _init_from_channels( self, - ch_names: Union[int, List[str], Tuple[str, ...]], - ch_types: Union[str, List[str], Tuple[str, ...]], + ch_names: Union[int, list[str], tuple[str, ...]], + ch_types: Union[str, list[str], tuple[str, ...]], ): """Init instance from channel names and types.""" self._unlocked = True diff --git a/pycrostates/preprocessing/resample.py b/pycrostates/preprocessing/resample.py index 7b2b6c54..eadcba7a 100644 --- a/pycrostates/preprocessing/resample.py +++ b/pycrostates/preprocessing/resample.py @@ -1,6 +1,6 @@ """Preprocessing functions to create resamples from raw or epochs instances.""" -from typing import List, Optional, Union +from typing import Optional, Union import numpy as np from mne import BaseEpochs, pick_info @@ -37,7 +37,7 @@ def resample( replace: bool = True, random_state: RANDomState = None, verbose=None, -) -> List[CHData]: +) -> list[CHData]: """Resample a recording into epochs of random samples. Resample :class:`~mne.io.Raw`. :class:`~mne.Epochs` or diff --git a/pycrostates/segmentation/_base.py b/pycrostates/segmentation/_base.py index 553a0ecd..41ed5a37 100644 --- a/pycrostates/segmentation/_base.py +++ b/pycrostates/segmentation/_base.py @@ -2,12 +2,13 @@ import itertools from abc import abstractmethod -from typing import List, Optional, Union +from typing import Optional, Union import numpy as np from matplotlib.axes import Axes from mne import BaseEpochs from mne.io import BaseRaw +from mne.utils import check_version from numpy.typing import NDArray from .._typing import Segmentation @@ -40,7 +41,7 @@ def __init__( labels: NDArray[int], inst: Union[BaseRaw, BaseEpochs], cluster_centers_: NDArray[float], - cluster_names: Optional[List[str]] = None, + cluster_names: Optional[list[str]] = None, predict_parameters: Optional[dict] = None, ): # check input @@ -141,7 +142,8 @@ def compute_parameters(self, norm_gfp: bool = True, return_dist: bool = False): assert data.ndim == 2 assert labels.size == data.shape[1] elif isinstance(self._inst, BaseEpochs): - data = self._inst.get_data() + kwargs_epochs = dict(copy=False) if check_version("mne", "1.6") else dict() + data = self._inst.get_data(**kwargs_epochs) # sanity-checks assert labels.ndim == 2 assert data.ndim == 3 @@ -291,7 +293,7 @@ def plot_cluster_centers( # -------------------------------------------------------------------- @staticmethod def _check_cluster_names( - cluster_names: List[str], + cluster_names: list[str], cluster_centers_: NDArray[float], ): """Check that the argument 'cluster_names' is valid.""" @@ -367,7 +369,7 @@ def cluster_centers_(self) -> NDArray[float]: return self._cluster_centers_.copy() @property - def cluster_names(self) -> List[str]: + def cluster_names(self) -> list[str]: """Name of the cluster centers. :type: `list` diff --git a/pycrostates/utils/_docs.py b/pycrostates/utils/_docs.py index dc326f3f..7e080d96 100644 --- a/pycrostates/utils/_docs.py +++ b/pycrostates/utils/_docs.py @@ -5,15 +5,15 @@ Inspired from mne.utils.docs.py by Eric Larson """ import sys -from typing import Callable, Dict, List, Tuple +from typing import Callable from mne.utils.docs import docdict as docdict_mne # ------------------------- Documentation dictionary ------------------------- -docdict: Dict[str, str] = {} +docdict: dict[str, str] = {} # ---- Documentation to inc. from MNE ---- -keys: Tuple[str, ...] = ( +keys: tuple[str, ...] = ( "n_jobs", "picks_all", "random_state", @@ -169,7 +169,7 @@ axes.""" # ------------------------- Documentation functions -------------------------- -docdict_indented: Dict[int, Dict[str, str]] = {} +docdict_indented: dict[int, dict[str, str]] = {} def fill_doc(f: Callable) -> Callable: @@ -215,7 +215,7 @@ def fill_doc(f: Callable) -> Callable: return f -def _indentcount_lines(lines: List[str]) -> int: +def _indentcount_lines(lines: list[str]) -> int: """Minimum indent for all lines in line list. >>> lines = [' one', ' two', ' three'] diff --git a/pycrostates/utils/sys_info.py b/pycrostates/utils/sys_info.py index 34ed5af9..710d4e0d 100644 --- a/pycrostates/utils/sys_info.py +++ b/pycrostates/utils/sys_info.py @@ -2,7 +2,7 @@ import sys from functools import partial from importlib.metadata import requires, version -from typing import IO, Callable, List, Optional +from typing import IO, Callable, Optional import psutil from packaging.requirements import Requirement @@ -71,14 +71,14 @@ def sys_info(fid: Optional[IO] = None, developer: bool = False): def _list_dependencies_info( - out: Callable, ljust: int, package: str, dependencies: List[Requirement] + out: Callable, ljust: int, package: str, dependencies: list[Requirement] ): """List dependencies names and versions.""" unicode = sys.stdout.encoding.lower().startswith("utf") if unicode: ljust += 1 - not_found: List[Requirement] = list() + not_found: list[Requirement] = list() for dep in dependencies: if dep.name == package: continue diff --git a/pycrostates/viz/cluster_centers.py b/pycrostates/viz/cluster_centers.py index 8a27c675..9d833313 100644 --- a/pycrostates/viz/cluster_centers.py +++ b/pycrostates/viz/cluster_centers.py @@ -1,6 +1,6 @@ """Visualization module for plotting cluster centers.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import numpy as np from matplotlib import pyplot as plt @@ -15,7 +15,7 @@ from ..utils._docs import fill_doc from ..utils._logs import logger, verbose -_GRADIENT_KWARGS_DEFAULTS: Dict[str, str] = { +_GRADIENT_KWARGS_DEFAULTS: dict[str, str] = { "color": "black", "linestyle": "-", "marker": "P", @@ -27,10 +27,10 @@ def plot_cluster_centers( cluster_centers: NDArray[float], info: Union[Info, CHInfo], - cluster_names: List[str] = None, + cluster_names: list[str] = None, axes: Optional[Union[Axes, NDArray[Axes]]] = None, show_gradient: Optional[bool] = False, - gradient_kwargs: Dict[str, Any] = _GRADIENT_KWARGS_DEFAULTS, + gradient_kwargs: dict[str, Any] = _GRADIENT_KWARGS_DEFAULTS, *, block: bool = False, verbose: Optional[str] = None, diff --git a/pycrostates/viz/segmentation.py b/pycrostates/viz/segmentation.py index 5e226278..1b8f7c48 100644 --- a/pycrostates/viz/segmentation.py +++ b/pycrostates/viz/segmentation.py @@ -1,6 +1,6 @@ """Visualisation module for plotting segmentations.""" -from typing import List, Optional, Union +from typing import Optional, Union import numpy as np from matplotlib import colormaps, colors @@ -21,7 +21,7 @@ def plot_raw_segmentation( labels: NDArray[int], raw: BaseRaw, n_clusters: int, - cluster_names: List[str] = None, + cluster_names: list[str] = None, tmin: Optional[Union[int, float]] = None, tmax: Optional[Union[int, float]] = None, cmap: Optional[str] = None, @@ -107,7 +107,7 @@ def plot_epoch_segmentation( labels: NDArray[int], epochs: BaseEpochs, n_clusters: int, - cluster_names: List[str] = None, + cluster_names: list[str] = None, cmap: Optional[str] = None, axes: Optional[Axes] = None, cbar_axes: Optional[Axes] = None, @@ -145,7 +145,8 @@ def plot_epoch_segmentation( _check_type(epochs, (BaseEpochs,), "epochs") _check_type(block, (bool,), "block") - data = epochs.get_data().swapaxes(0, 1) + kwargs_epochs = dict(copy=False) if check_version("mne", "1.6") else dict() + data = epochs.get_data(**kwargs_epochs).swapaxes(0, 1) data = data.reshape(data.shape[0], -1) gfp = np.std(data, axis=0) times = np.arange(0, data.shape[-1]) @@ -199,7 +200,7 @@ def _plot_segmentation( gfp: NDArray[float], times: NDArray[float], n_clusters: int, - cluster_names: List[str] = None, + cluster_names: list[str] = None, cmap: Optional[Union[str, colors.Colormap]] = None, axes: Optional[Axes] = None, cbar_axes: Optional[Axes] = None,