Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 14, 2024
1 parent c5897e1 commit 88d0619
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 52 deletions.
8 changes: 2 additions & 6 deletions mne_icalabel/megnet/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@ def _cart2sph(x, y, z):
return r, theta, phi


def _make_head_outlines(
sphere: NDArray,
pos: NDArray,
clip_origin: tuple
) -> dict:
"""a modified version of mne.viz.topomap._make_head_outlines.
def _make_head_outlines(sphere: NDArray, pos: NDArray, clip_origin: tuple) -> dict:
"""A modified version of mne.viz.topomap._make_head_outlines.
This function is used to generate head outlines for topomap plotting.
The difference between this function and the original one is that
Expand Down
17 changes: 6 additions & 11 deletions mne_icalabel/megnet/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from mne.io import BaseRaw
from mne.preprocessing import ICA
from mne.utils import _validate_type, warn
from mne_icalabel.iclabel._utils import _pol2cart
from numpy.typing import NDArray
from PIL import Image
from scipy import interpolate
from scipy.spatial import ConvexHull

from mne_icalabel.iclabel._utils import _pol2cart

from ._utils import _cart2sph, _make_head_outlines


Expand Down Expand Up @@ -42,8 +43,7 @@ def get_megnet_features(raw: BaseRaw, ica: ICA):
_validate_type(raw, BaseRaw, "raw")
_validate_type(ica, ICA, "ica")
if not any(
ch_type in ["mag", "grad"] for ch_type in raw.get_channel_types(
unique=True)
ch_type in ["mag", "grad"] for ch_type in raw.get_channel_types(unique=True)
):
raise RuntimeError(
"Could not find MEG channels in the provided Raw instance."
Expand Down Expand Up @@ -141,8 +141,7 @@ def _get_topomaps_data(ica: ICA):
Xnew, Ynew = _pol2cart(TH, adjusted_R)
pos_new = np.vstack((Xnew, Ynew)).T

outlines = _make_head_outlines(
np.array([0, 0, 0, 1]), pos_new, (0, 0))
outlines = _make_head_outlines(np.array([0, 0, 0, 1]), pos_new, (0, 0))
return pos_new, outlines


Expand All @@ -154,8 +153,7 @@ def _get_topomaps(ica: ICA, pos_new: NDArray, outlines: dict):

for comp in range(ica.n_components_):
data = components[data_picks, comp]
fig = plt.figure(
figsize=(1.3, 1.3), dpi=100, facecolor="black")
fig = plt.figure(figsize=(1.3, 1.3), dpi=100, facecolor="black")
ax = fig.add_subplot(111)
mnefig, _ = mne.viz.plot_topomap(
data,
Expand Down Expand Up @@ -201,10 +199,7 @@ def _check_line_noise(
# a sampling rate extremely low (100 Hz?) and (1)
# either they missed all of the previous warnings
# encountered or (2) they know what they are doing.
warn(
"The sampling rate raw.info['sfreq'] is too low"
"to estimate line niose."
)
warn("The sampling rate raw.info['sfreq'] is too low" "to estimate line niose.")
return False
# compute the power spectrum and retrieve the frequencies of interest
spectrum = raw.compute_psd(picks="meg", exclude="bads")
Expand Down
26 changes: 9 additions & 17 deletions mne_icalabel/megnet/label_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def _chunk_predicting(

chunk_votes = {start: 0 for start in start_times}
for t in range(time_len):
in_chunks = [
start <= t < start + chunk_len for start in start_times
]
in_chunks = [start <= t < start + chunk_len for start in start_times]
# how many chunks the time point is in
num_chunks = np.sum(in_chunks)
for start_time, is_in_chunk in zip(start_times, in_chunks):
Expand All @@ -79,24 +77,18 @@ def _chunk_predicting(
weighted_predictions = {}
for start_time in chunk_votes.keys():
onnx_inputs = {
session.get_inputs()[0]
.name: np.expand_dims(comp_map, 0)
.astype(np.float32),
session.get_inputs()[1]
.name: np.expand_dims(
np.expand_dims(
comp_series[start_time: start_time + chunk_len], 0),
session.get_inputs()[0].name: np.expand_dims(comp_map, 0).astype(
np.float32
),
session.get_inputs()[1].name: np.expand_dims(
np.expand_dims(comp_series[start_time : start_time + chunk_len], 0),
-1,
)
.astype(np.float32),
).astype(np.float32),
}
prediction = session.run(None, onnx_inputs)[0]
weighted_predictions[start_time] = (
prediction * chunk_votes[start_time]
)
weighted_predictions[start_time] = prediction * chunk_votes[start_time]

comp_prediction = np.stack(
list(weighted_predictions.values())).mean(axis=0)
comp_prediction = np.stack(list(weighted_predictions.values())).mean(axis=0)
comp_prediction /= comp_prediction.sum()
predction_vote.append(comp_prediction)

Expand Down
11 changes: 5 additions & 6 deletions mne_icalabel/megnet/tests/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from mne import create_info
from mne.io import RawArray
from mne.preprocessing import ICA

from mne_icalabel.megnet.features import _check_line_noise, get_megnet_features


Expand All @@ -13,8 +14,7 @@ def raw_with_line_noise():
data1 = np.sin(2 * np.pi * 10 * times) + np.sin(2 * np.pi * 30 * times)
data2 = np.sin(2 * np.pi * 30 * times) + np.sin(2 * np.pi * 80 * times)
data = np.vstack([data1, data2])
info = create_info(
ch_names=["10-30", "30-80"], sfreq=1000, ch_types="mag")
info = create_info(ch_names=["10-30", "30-80"], sfreq=1000, ch_types="mag")
return RawArray(data, info)


Expand Down Expand Up @@ -53,8 +53,7 @@ def create_raw_ica(
channel_locs[:, 1] += 0.1
channel_locs[:, 2] += 0.1

info = create_info(
ch_names=ch_names, sfreq=sfreq, ch_types=ch_type)
info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_type)
for i, loc in enumerate(channel_locs):
info["chs"][i]["loc"][:3] = loc

Expand All @@ -79,7 +78,7 @@ def raw_ica_valid():


def test_get_megnet_features(raw_ica_valid):
"""test whether the function returns the correct features."""
"""Test whether the function returns the correct features."""
time_series, topomaps = get_megnet_features(*raw_ica_valid)
n_components = raw_ica_valid[1].n_components
n_times = raw_ica_valid[0].times.shape[0]
Expand Down Expand Up @@ -138,7 +137,7 @@ def test_get_megnet_features_invalid(
raw_ica_invalid_ncomp,
raw_ica_invalid_method,
):
"""test whether the function raises the correct exceptions"""
"""Test whether the function raises the correct exceptions"""
test_cases = [
(raw_ica_invalid_channel, RuntimeError, "Could not find MEG channels"),
(
Expand Down
18 changes: 6 additions & 12 deletions mne_icalabel/megnet/tests/test_label_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import onnxruntime as ort
import pytest

from mne_icalabel.megnet.label_components import (
_chunk_predicting,
_get_chunk_start,
Expand All @@ -22,25 +23,22 @@ def raw_ica():
raw.notch_filter(60)
raw.filter(1, 100)

ica = mne.preprocessing.ICA(
n_components=20,
method="infomax",
random_state=88)
ica = mne.preprocessing.ICA(n_components=20, method="infomax", random_state=88)
ica.fit(raw)

return raw, ica


def test_megnet_label_components(raw_ica):
"""test whether the function returns the correct artifact index"""
"""Test whether the function returns the correct artifact index"""
real_atrifact_idx = [0, 3, 5] # heart beat, eye movement, heart beat
prob = megnet_label_components(*raw_ica)
this_atrifact_idx = list(np.nonzero(prob.argmax(axis=1))[0])
assert this_atrifact_idx == real_atrifact_idx


def test_get_chunk_start():
"""test whether the function returns the correct start times"""
"""Test whether the function returns the correct start times"""
input_len = 10000
chunk_len = 3000
overlap_len = 750
Expand All @@ -52,19 +50,15 @@ def test_get_chunk_start():


def test_chunk_predicting():
"""test whether MEGnet's chunk volte algorithm returns the correct shape"""
"""Test whether MEGnet's chunk volte algorithm returns the correct shape"""
time_series = np.random.rand(5, 10000)
spatial_maps = np.random.rand(5, 120, 120, 3)

mock_session = MagicMock(spec=ort.InferenceSession)
mock_session.run.return_value = [np.random.rand(4)]

predictions = _chunk_predicting(
mock_session,
time_series,
spatial_maps,
chunk_len=3000,
overlap_len=750
mock_session, time_series, spatial_maps, chunk_len=3000, overlap_len=750
)

assert predictions.shape == (5, 4)
Expand Down

0 comments on commit 88d0619

Please sign in to comment.