Skip to content

Commit

Permalink
Fusion model patch for 2.0.1rc4 (#366)
Browse files Browse the repository at this point in the history
  • Loading branch information
celikbasak authored Dec 13, 2024
1 parent d91a878 commit a36ee91
Show file tree
Hide file tree
Showing 43 changed files with 2,022 additions and 238 deletions.
2 changes: 2 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ omit =
*/gui/experiments/*
*/gui/viewer/*
*/gui/BCInterface.py
*/signal/model/offline_analysis.py
*/signal/evaluate/fusion.py

[report]
exclude_lines =
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ jobs:
sudo apt-get install xvfb
python -m pip install --upgrade pip
pip install attrdict3
conda install -c conda-forge liblsl
- name: Install dependencies
run: |
make dev-install
Expand Down Expand Up @@ -96,6 +97,9 @@ jobs:
- name: lint
run: |
make lint
- name: integration-test
run: |
make integration-test
build-macos:

Expand Down Expand Up @@ -129,5 +133,8 @@ jobs:
- name: lint
run: |
make lint
- name: integration-test
run: |
make integration-test
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# 2.0.1-rc.4

Patch on final release candidate

## Contributions

- Fusion model analysis and performance metrics support. Bugfixes in gaze model #366

# 2.0.0-rc.4

Our final release candidate before the official 2.0 release!
Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ test-all:
make coverage-report
make type
make lint
make integration-test

unit-test:
pytest --mpl -k "not slow"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Invoke an experiment protocol or task directly using command line utility `bcipy

##### Train a signal model with registered BciPy models

To train a signal model (currently `PCARDAKDE`), run the following command after installing BciPy:
To train a signal model (currently `PCARDAKDE` and `GazeModels`), run the following command after installing BciPy:

`bcipy-train`

Expand Down
1 change: 1 addition & 0 deletions bcipy/acquisition/tests/datastream/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
class TestProducer(unittest.TestCase):
"""Tests for Producer"""

@pytest.mark.skip(reason="Skipping due to CI failures. Run locally to test.")
def test_frequency(self):
"""Data should be generated at the provided frequency"""
sample_hz = 300
Expand Down
3 changes: 1 addition & 2 deletions bcipy/acquisition/tests/protocols/lsl/test_lsl_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DEVICE = preconfigured_device(DEVICE_NAME)


@pytest.mark.slow
class TestDataAcquisitionClient(unittest.TestCase):
"""Main Test class for DataAcquisitionClient code."""

Expand Down Expand Up @@ -100,7 +101,6 @@ def test_with_unspecified_device(self):
client.stop_acquisition()
self.assertAlmostEqual(DEVICE.sample_rate, len(samples), delta=5.0)

@pytest.mark.slow
def test_get_data(self):
"""Test functionality with a provided device_spec"""
client = LslAcquisitionClient(max_buffer_len=1, device_spec=DEVICE)
Expand All @@ -125,7 +125,6 @@ def test_get_data(self):
start,
delta=0.002)

@pytest.mark.slow
def test_event_offset(self):
"""Test the offset in seconds of a given event relative to the first
sample time."""
Expand Down
6 changes: 4 additions & 2 deletions bcipy/helpers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def load_json_parameters(path: str, value_cast: bool = False) -> Parameters:

def load_experimental_data() -> str:
filename = ask_directory() # show dialog box and return the path
if not filename:
raise BciPyCoreException('No file selected in GUI. Exiting...')
log.info("Loaded Experimental Data From: %s" % filename)
return filename

Expand Down Expand Up @@ -255,7 +257,7 @@ def choose_csv_file(filename: Optional[str] = None) -> Optional[str]:
file_name = filename.split('/')[-1]

if 'csv' not in file_name:
raise Exception(
raise TypeError(
'File type unrecognized. Please use a supported csv type')

return filename
Expand All @@ -280,7 +282,7 @@ def load_txt_data() -> str:
file_name = filename.split('/')[-1]

if 'txt' not in file_name:
raise Exception(
raise TypeError(
'File type unrecognized. Please use a supported text type')

return filename
Expand Down
2 changes: 2 additions & 0 deletions bcipy/helpers/tests/test_offset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import zipfile
from pathlib import Path
import tempfile
import pytest

from matplotlib import pyplot as plt

Expand All @@ -24,6 +25,7 @@
input_folder = pwd / "resources/mock_offset/time_test_data/"


@pytest.mark.slow
class TestOffset(unittest.TestCase):

def setUp(self) -> None:
Expand Down
Binary file not shown.
2 changes: 1 addition & 1 deletion bcipy/parameters/lm_params.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"kenlm": {
"model_file": {
"description": "Name of the pretrained model file",
"value": "lm_dec19_char_large_12gram.kenlm",
"value": "lm_dec19_char_tiny_12gram.kenlm",
"type": "filepath"
}
},
Expand Down
36 changes: 29 additions & 7 deletions bcipy/signal/README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,41 @@
# Signal

The BciPy Signal module contains all code needed to process, model, and generate signals for Brain Computer Interface control using EEG. Further documentation provided in submodule READMEs.
The BciPy Signal module contains all code needed to process, evaluate, model, and generate signals for Brain Computer Interface control using EEG and/or Eye Tracking. Further documentation provided in submodule READMEs.

# Evaluate
## Evaluate

Evaluates signal based on configured rules.
The evaluation module contains functions for evaluating signals based on configured rules. The module contains functionailty for detecting artifacts in EEG signals, and for evaluating the quality of the signal. In addition, analysis functions are provided to evaluate the performance of the BCI system. Currently, the fusion of the signals is evaluated using the `calculate_eeg_gaze_fusion_acc` function.

# Process
## Process

The process module contains functions for decomposing signals into frequency bands (psd, cwt), filtering signals (bandpass, notch), and other signal processing functions.

# Model
## Model

Modeling needed to classify signals. See signal/model/README.md for more detailed information.
The module contains functions for training and testing classifiers, and for evaluating the performance of the classifiers. Several classifiers are provided, including a PCA/RDA/KDE classifier and several Gaussian Mixture Model classifiers. See the submodule README for more information.

# Generator
### Model Training (offline analysis)

To train a signal model (such as, `PCARDAKDE`), run the following command after installing BciPy:

`bcipy-train`

Use the help flag to see other available input options: `bcipy-train --help` You can pass it attributes with flags, if desired.

Execute without a window prompting for data session folder: `bcipy-train -d path/to/data`

Execute with data visualizations (ERPs, etc.): `bcipy-train -v`

Execute with data visualizations that do not show, but save to file: `bcipy-train -s`

Execute with balanced accuracy: `bcipy-train --balanced-acc`

Execute with alerts after each Task execution: `bcipy-train --alert`

Execute with custom parameters: `bcipy-train -p "path/to/valid/parameters.json"`

Execute with custom number of iterations for fusion analysis (by default 10): `bcipy-train -i 10`

## Generator

Generates fake signal data.
28 changes: 25 additions & 3 deletions bcipy/signal/evaluate/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ steps:
4. If the signal exceeds the thresholds, it is labeled as an artifact. All artifacts are annotated in the
data file with the prefix `BAD_`.

### Usage
### Artifact Detection Usage

The `ArtifactDetection` class is used to detect artifacts in the data. The class takes in a `RawData` object, a `DeviceSpec` object, and a `Parameters` object. The `RawData` object contains the data to be analyzed, the `DeviceSpec` object contains the specifications of the device used to collect the data, and the `Parameters` object contains the parameters used to detect the artifacts. The `ArtifactDetection` class has a method called `detect_artifacts` that returns a list of the detected artifacts.

Expand All @@ -41,7 +41,7 @@ artifact_detector = ArtifactDetection(raw_data, parameters, device_spec, session
detected_artifacts = artifact_detector.detect_artifacts()
```

This can be used in conjunction with the `ArtifactDetection` semiautomatic mode to determine artifacts that overlap with triggers of interest and correct any labels before removal. To use the semiautomatic mode, the user must provide a list of triggers of interest. The `ArtifactDetection` class can be inititalized with `semi_automatic`.
This can be used in conjunction with the `ArtifactDetection` semiautomatic mode to determine artifacts that overlap with triggers of interest and correct any labels before removal. To use the semiautomatic mode, the user must provide a list of triggers of interest. The `ArtifactDetection` class can be inititalized with `semi_automatic`. The `semi_automatic` parameter is a boolean that determines if the user wants to manually correct or add to the detected artifacts.

```python

Expand Down Expand Up @@ -79,7 +79,7 @@ write_mne_annotations(
'artifact_annotations.txt')
```

## Artifact Correction
### Artifact Correction

Artifact correction is the process of removing unwanted signals from the data. After detection is complete, the user may use the MNE epoching tool to remove the unwanted epochs and channels.

Expand All @@ -99,3 +99,25 @@ epochs = mne_epochs(mne_data, trial_length, preload=True, reject_by_annotation=T

# This will return the epochs object with the bad epochs removed. A drop log can be accessed to see which and how many epochs were removed.
```

## Fusion Accuracy

The `calculate_eeg_gaze_fusion_acc` function is used to evaluate the performance of the BCI system. The function takes in a list of EEG and gaze data, and returns the accuracy of the fusion of the two signals. The function uses the following steps to calculate the accuracy:

1. The data is loaded into the system and preprocessed.
2. The data is passed through the EEG and gaze models to generate predictions.
3. The predictions are fused together to generate a final prediction.
4. The final prediction is compared to the actual data to calculate the accuracy.
5. The accuracy is returned to the user.

### Fusion Usage

The `calculate_eeg_gaze_fusion_acc` function is used to evaluate the performance of the BCI system. The function takes in a list of EEG and gaze data, and returns the accuracy of the fusion of the two signals.

```python
from bcipy.signal.evaluate.fusion import calculate_eeg_gaze_fusion_acc

# Assuming BciPy raw data objects, device specs and parameters object are already defined.

result = calculate_eeg_gaze_fusion_acc(eeg_data, gaze_data, eeg_spec, gaze_spec, symbol_set, parameters, data_folder)
```
Loading

0 comments on commit a36ee91

Please sign in to comment.