Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into dev-aifpaif
Browse files Browse the repository at this point in the history
  • Loading branch information
mscheltienne committed Nov 10, 2023
2 parents 3c72e6f + deec3c8 commit c043df3
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 51 deletions.
22 changes: 11 additions & 11 deletions pycrostates/cluster/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from copy import copy, deepcopy
from itertools import groupby
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union

import numpy as np
from matplotlib.axes import Axes
Expand Down Expand Up @@ -277,11 +277,11 @@ def fit(

def rename_clusters(
self,
mapping: Optional[Dict[str, str]] = None,
mapping: Optional[dict[str, str]] = None,
new_names: Optional[
Union[
List[str],
Tuple[str, ...],
list[str],
tuple[str, ...],
]
] = None,
) -> None:
Expand Down Expand Up @@ -341,11 +341,11 @@ def rename_clusters(

def reorder_clusters(
self,
mapping: Optional[Dict[int, int]] = None,
mapping: Optional[dict[int, int]] = None,
order: Optional[
Union[
List[int],
Tuple[int, ...],
list[int],
tuple[int, ...],
NDArray[int],
]
] = None,
Expand Down Expand Up @@ -458,8 +458,8 @@ def invert_polarity(
self,
invert: Union[
bool,
List[bool],
Tuple[bool, ...],
list[bool],
tuple[bool, ...],
NDArray[bool],
],
) -> None:
Expand Down Expand Up @@ -518,7 +518,7 @@ def plot(
self,
axes: Optional[Union[Axes, NDArray[Axes]]] = None,
show_gradient: Optional[bool] = False,
gradient_kwargs: Dict[str, Any] = {
gradient_kwargs: dict[str, Any] = {
"color": "black",
"linestyle": "-",
"marker": "P",
Expand Down Expand Up @@ -1180,7 +1180,7 @@ def labels_(self) -> NDArray[int]:
return self._labels_.copy()

@property
def cluster_names(self) -> List[str]:
def cluster_names(self) -> list[str]:
"""Name of the clusters.
:type: `list`
Expand Down
6 changes: 3 additions & 3 deletions pycrostates/cluster/aahc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Atomize and Agglomerate Hierarchical Clustering (AAHC)."""

from pathlib import Path
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional, Union

import numpy as np
from mne import BaseEpochs
Expand Down Expand Up @@ -183,7 +183,7 @@ def _aahc(
n_clusters: int,
ignore_polarity: bool,
normalize_input: bool,
) -> Tuple[float, NDArray[float], NDArray[int]]:
) -> tuple[float, NDArray[float], NDArray[int]]:
"""Run the AAHC algorithm."""
gfp_sum_sq = np.sum(data**2)
maps, segmentation = AAHCluster._compute_maps(
Expand All @@ -200,7 +200,7 @@ def _compute_maps(
n_clusters: int,
ignore_polarity: bool,
normalize_input: bool,
) -> Tuple[NDArray[float], NDArray[int]]:
) -> tuple[NDArray[float], NDArray[int]]:
"""Compute microstates maps."""
n_chan, n_frame = data.shape

Expand Down
6 changes: 3 additions & 3 deletions pycrostates/cluster/kmeans.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Class and functions to use modified Kmeans."""

from pathlib import Path
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional, Union

import numpy as np
from mne import BaseEpochs
Expand Down Expand Up @@ -263,7 +263,7 @@ def _kmeans(
max_iter: int,
random_state: Union[RandomState, Generator],
tol: Union[int, float],
) -> Tuple[float, NDArray[float], NDArray[int], bool]:
) -> tuple[float, NDArray[float], NDArray[int], bool]:
"""Run the k-means algorithm."""
gfp_sum_sq = np.sum(data**2)
maps, converged = ModKMeans._compute_maps(
Expand All @@ -282,7 +282,7 @@ def _compute_maps(
max_iter: int,
random_state: Union[RandomState, Generator],
tol: Union[int, float],
) -> Tuple[NDArray[float], bool]:
) -> tuple[NDArray[float], bool]:
"""Compute microstates maps.
Based on mne_microstates by Marijn van Vliet <[email protected]>
Expand Down
8 changes: 4 additions & 4 deletions pycrostates/io/fiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import reduce
from numbers import Integral
from pathlib import Path
from typing import List, Union
from typing import Union

import numpy as np
from mne import Info, Transform
Expand Down Expand Up @@ -98,7 +98,7 @@ def _write_cluster(
cluster_centers_: NDArray[float],
chinfo: Union[CHInfo, Info],
algorithm: str,
cluster_names: List[str],
cluster_names: list[str],
fitted_data: NDArray[float],
labels_: NDArray[int],
**kwargs,
Expand Down Expand Up @@ -398,7 +398,7 @@ def _check_fit_parameters_and_variables(
def _create_ModKMeans(
cluster_centers_: NDArray[float],
info: CHInfo,
cluster_names: List[str],
cluster_names: list[str],
fitted_data: NDArray[float],
labels_: NDArray[int],
n_init: int,
Expand All @@ -423,7 +423,7 @@ def _create_ModKMeans(
def _create_AAHCluster(
cluster_centers_: NDArray[float],
info: CHInfo,
cluster_names: List[str],
cluster_names: list[str],
fitted_data: NDArray[float],
labels_: NDArray[int],
ignore_polarity: bool, # pylint: disable=unused-argument
Expand Down
12 changes: 6 additions & 6 deletions pycrostates/io/meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from copy import deepcopy
from numbers import Number
from typing import List, Optional, Tuple, Union
from typing import Optional, Union

import numpy as np
from mne import Info, Projection, Transform
Expand Down Expand Up @@ -155,11 +155,11 @@ def __init__(
ch_names: Optional[
Union[
int,
List[str],
Tuple[str, ...],
list[str],
tuple[str, ...],
]
] = None,
ch_types: Optional[Union[str, List[str], Tuple[str, ...]]] = None,
ch_types: Optional[Union[str, list[str], tuple[str, ...]]] = None,
):
if all(arg is None for arg in (info, ch_names, ch_types)):
raise RuntimeError(
Expand Down Expand Up @@ -195,8 +195,8 @@ def _init_from_info(self, info: Info):

def _init_from_channels(
self,
ch_names: Union[int, List[str], Tuple[str, ...]],
ch_types: Union[str, List[str], Tuple[str, ...]],
ch_names: Union[int, list[str], tuple[str, ...]],
ch_types: Union[str, list[str], tuple[str, ...]],
):
"""Init instance from channel names and types."""
self._unlocked = True
Expand Down
4 changes: 2 additions & 2 deletions pycrostates/preprocessing/resample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Preprocessing functions to create resamples from raw or epochs instances."""

from typing import List, Optional, Union
from typing import Optional, Union

import numpy as np
from mne import BaseEpochs, pick_info
Expand Down Expand Up @@ -37,7 +37,7 @@ def resample(
replace: bool = True,
random_state: RANDomState = None,
verbose=None,
) -> List[CHData]:
) -> list[CHData]:
"""Resample a recording into epochs of random samples.
Resample :class:`~mne.io.Raw`. :class:`~mne.Epochs` or
Expand Down
13 changes: 8 additions & 5 deletions pycrostates/segmentation/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import itertools
from abc import abstractmethod
from typing import List, Optional, Union
from typing import Optional, Union

import numpy as np
from matplotlib.axes import Axes
from mne import BaseEpochs
from mne.io import BaseRaw
from mne.utils import check_version
from numpy.typing import NDArray

from .._typing import Segmentation
Expand Down Expand Up @@ -41,7 +42,7 @@ def __init__(
labels: NDArray[int],
inst: Union[BaseRaw, BaseEpochs],
cluster_centers_: NDArray[float],
cluster_names: Optional[List[str]] = None,
cluster_names: Optional[list[str]] = None,
predict_parameters: Optional[dict] = None,
):
# check input
Expand Down Expand Up @@ -142,7 +143,8 @@ def compute_parameters(self, norm_gfp: bool = True, return_dist: bool = False):
assert data.ndim == 2
assert labels.size == data.shape[1]
elif isinstance(self._inst, BaseEpochs):
data = self._inst.get_data()
kwargs_epochs = dict(copy=False) if check_version("mne", "1.6") else dict()
data = self._inst.get_data(**kwargs_epochs)
# sanity-checks
assert labels.ndim == 2
assert data.ndim == 3
Expand Down Expand Up @@ -324,7 +326,8 @@ def plot_cluster_centers(
# --------------------------------------------------------------------
@staticmethod
def _check_cluster_names(
cluster_names: List[str], cluster_centers_: NDArray[float]
cluster_names: list[str],
cluster_centers_: NDArray[float],
):
"""Check that the argument 'cluster_names' is valid."""
_check_type(cluster_names, (list, None), "cluster_names")
Expand Down Expand Up @@ -399,7 +402,7 @@ def cluster_centers_(self) -> NDArray[float]:
return self._cluster_centers_.copy()

@property
def cluster_names(self) -> List[str]:
def cluster_names(self) -> list[str]:
"""Name of the cluster centers.
:type: `list`
Expand Down
10 changes: 5 additions & 5 deletions pycrostates/utils/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
Inspired from mne.utils.docs.py by Eric Larson <[email protected]>
"""
import sys
from typing import Callable, Dict, List, Tuple
from typing import Callable

from mne.utils.docs import docdict as docdict_mne

# ------------------------- Documentation dictionary -------------------------
docdict: Dict[str, str] = {}
docdict: dict[str, str] = {}

# ---- Documentation to inc. from MNE ----
keys: Tuple[str, ...] = (
keys: tuple[str, ...] = (
"n_jobs",
"picks_all",
"random_state",
Expand Down Expand Up @@ -196,7 +196,7 @@
axes."""

# ------------------------- Documentation functions --------------------------
docdict_indented: Dict[int, Dict[str, str]] = {}
docdict_indented: dict[int, dict[str, str]] = {}


def fill_doc(f: Callable) -> Callable:
Expand Down Expand Up @@ -242,7 +242,7 @@ def fill_doc(f: Callable) -> Callable:
return f


def _indentcount_lines(lines: List[str]) -> int:
def _indentcount_lines(lines: list[str]) -> int:
"""Minimum indent for all lines in line list.
>>> lines = [' one', ' two', ' three']
Expand Down
6 changes: 3 additions & 3 deletions pycrostates/utils/sys_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from functools import partial
from importlib.metadata import requires, version
from typing import IO, Callable, List, Optional
from typing import IO, Callable, Optional

import psutil
from packaging.requirements import Requirement
Expand Down Expand Up @@ -71,14 +71,14 @@ def sys_info(fid: Optional[IO] = None, developer: bool = False):


def _list_dependencies_info(
out: Callable, ljust: int, package: str, dependencies: List[Requirement]
out: Callable, ljust: int, package: str, dependencies: list[Requirement]
):
"""List dependencies names and versions."""
unicode = sys.stdout.encoding.lower().startswith("utf")
if unicode:
ljust += 1

not_found: List[Requirement] = list()
not_found: list[Requirement] = list()
for dep in dependencies:
if dep.name == package:
continue
Expand Down
8 changes: 4 additions & 4 deletions pycrostates/viz/cluster_centers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Visualization module for plotting cluster centers."""

from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union

import numpy as np
from matplotlib import pyplot as plt
Expand All @@ -15,7 +15,7 @@
from ..utils._docs import fill_doc
from ..utils._logs import logger, verbose

_GRADIENT_KWARGS_DEFAULTS: Dict[str, str] = {
_GRADIENT_KWARGS_DEFAULTS: dict[str, str] = {
"color": "black",
"linestyle": "-",
"marker": "P",
Expand All @@ -27,10 +27,10 @@
def plot_cluster_centers(
cluster_centers: NDArray[float],
info: Union[Info, CHInfo],
cluster_names: List[str] = None,
cluster_names: list[str] = None,
axes: Optional[Union[Axes, NDArray[Axes]]] = None,
show_gradient: Optional[bool] = False,
gradient_kwargs: Dict[str, Any] = _GRADIENT_KWARGS_DEFAULTS,
gradient_kwargs: dict[str, Any] = _GRADIENT_KWARGS_DEFAULTS,
*,
block: bool = False,
verbose: Optional[str] = None,
Expand Down
Loading

0 comments on commit c043df3

Please sign in to comment.