diff --git a/.coveragerc b/.coveragerc index c7ad07bca..443874bee 100644 --- a/.coveragerc +++ b/.coveragerc @@ -17,4 +17,4 @@ exclude_lines = raise NotImplementedError @abstract if __name__ == .__main__.: - log = logging.getLogger(__name__) + logging.getLogger(__name__) diff --git a/bcipy/config.py b/bcipy/config.py index e3ca262d2..0bbc4b562 100644 --- a/bcipy/config.py +++ b/bcipy/config.py @@ -28,6 +28,7 @@ TASK_SEPERATOR = '->' DEFAULT_PARAMETER_FILENAME = 'parameters.json' +DEFAULT_DEVICES_PATH = f"{BCIPY_ROOT}/parameters" DEFAULT_PARAMETERS_PATH = f'{BCIPY_ROOT}/parameters/{DEFAULT_PARAMETER_FILENAME}' DEFAULT_DEVICE_SPEC_FILENAME = 'devices.json' DEVICE_SPEC_PATH = f'{BCIPY_ROOT}/parameters/{DEFAULT_DEVICE_SPEC_FILENAME}' diff --git a/bcipy/gui/BCInterface.py b/bcipy/gui/BCInterface.py index 0e60ef938..3a4c79134 100644 --- a/bcipy/gui/BCInterface.py +++ b/bcipy/gui/BCInterface.py @@ -1,9 +1,10 @@ import subprocess import sys +import logging from typing import List from bcipy.config import (BCIPY_ROOT, DEFAULT_PARAMETERS_PATH, - STATIC_IMAGES_PATH) + STATIC_IMAGES_PATH, PROTOCOL_LOG_FILENAME) from bcipy.gui.main import (AlertMessageResponse, AlertMessageType, AlertResponse, BCIGui, app, contains_special_characters, contains_whitespaces, @@ -12,6 +13,8 @@ load_json_parameters, load_users) from bcipy.task import TaskRegistry +logger = logging.getLogger(PROTOCOL_LOG_FILENAME) + class BCInterface(BCIGui): """BCI Interface. @@ -29,7 +32,7 @@ class BCInterface(BCIGui): max_length = 25 min_length = 1 timeout = 3 - font = 'Consolas' + font = 'Courier New' def __init__(self, *args, **kwargs): super(BCInterface, self).__init__(*args, **kwargs) @@ -420,7 +423,13 @@ def start_experiment(self) -> None: ) if self.alert: cmd += ' -a' - subprocess.Popen(cmd, shell=True) + output = subprocess.run(cmd, shell=True) + if output.returncode != 0: + self.throw_alert_message( + title='BciPy Alert', + message=f'Error: {output.stderr.decode()}', + message_type=AlertMessageType.CRIT, + message_response=AlertMessageResponse.OTE) if self.autoclose: self.close() @@ -431,7 +440,7 @@ def offline_analysis(self) -> None: Run offline analysis as a script in a new process. """ if not self.action_disabled(): - cmd = f'python {BCIPY_ROOT}/signal/model/offline_analysis.py --alert --p "{self.parameter_location}"' + cmd = f'bcipy-train --alert --p "{self.parameter_location}" -v -s' subprocess.Popen(cmd, shell=True) def action_disabled(self) -> bool: diff --git a/bcipy/gui/alert.py b/bcipy/gui/alert.py index d974349d9..91fec4c9e 100644 --- a/bcipy/gui/alert.py +++ b/bcipy/gui/alert.py @@ -15,12 +15,13 @@ def confirm(message: str) -> bool: ------- users selection : True for selecting Ok, False for Cancel. """ - app = QApplication(sys.argv) + app = QApplication(sys.argv).instance() + if not app: + app = QApplication(sys.argv) dialog = alert_message(message, message_type=AlertMessageType.INFO, message_response=AlertMessageResponse.OCE) button = dialog.exec() - result = bool(button == AlertResponse.OK.value) app.quit() return result diff --git a/bcipy/gui/bcipy_stylesheet.css b/bcipy/gui/bcipy_stylesheet.css index 3cb7e1e79..58b53a88b 100644 --- a/bcipy/gui/bcipy_stylesheet.css +++ b/bcipy/gui/bcipy_stylesheet.css @@ -1,44 +1,87 @@ /* This stylesheet uses the QSS syntax, but is named as a CSS file to take advantage of IDE CSS tooling */ -* { - background-color: white; + +QWidget[class="experiment-registry"] { + background-color: black; +} + +QWidget[class="inter-task"] { + background-color: black; } -QWidget { + +QLabel { background-color: black; color: white; - font-size: 14px; } -QWidget > * { - background-color: none; +QLabel[class="task-label"] { + background-color: transparent; + color: black; } QPushButton { - background-color: green; + background-color: rgb(16, 173, 39); + color: white; padding: 10px; border-radius: 10px; } +QPushButton[class="remove-button"] { + background-color: rgb(243, 58, 58); +} + +QPushButton[class="remove-button"]:hover { + background-color: rgb(255, 0, 0); +} QPushButton[class="small-button"] { + background-color: darkslategray; + color: white; padding: 5px; border-radius: 5px; } -QPushButton:hover { +QPushButton[class="small-button"]:hover { background-color: darkgreen; } + QPushButton:pressed { background-color: darkslategrey; } -QComboBox, QLineEdit { +QComboBox { + background-color: white; + color: black; + padding: 4px; + border-radius: 1px; +} + +QComboBox:hover { + background-color: #e6f5ea; + color: black; +} + +QComboBox:on { + background-color: #e6f5ea; + color: black; +} + +QListView { + background-color: white; + color: black; + padding: 5px; +} + +QLineEdit { background-color: white; color: black; padding: 5px; border: none; + border-radius: 1px; } QScrollArea { background-color: white; -} + color: black; + border-radius: 1px; +} \ No newline at end of file diff --git a/bcipy/gui/bciui.py b/bcipy/gui/bciui.py index 585438cf7..b3c301832 100644 --- a/bcipy/gui/bciui.py +++ b/bcipy/gui/bciui.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, Type from PyQt6.QtCore import pyqtSignal from PyQt6.QtWidgets import ( QWidget, @@ -9,9 +9,11 @@ QLayout, QSizePolicy, QMessageBox, + QApplication, ) from typing import Optional, List from bcipy.config import BCIPY_ROOT +import sys class BCIUI(QWidget): @@ -68,7 +70,6 @@ def centered(widget: QWidget) -> QHBoxLayout: @staticmethod def make_list_scroll_area(widget: QWidget) -> QScrollArea: - widget.setStyleSheet("background-color: transparent;") scroll_area = QScrollArea() scroll_area.setWidget(widget) scroll_area.setWidgetResizable(True) @@ -107,6 +108,10 @@ def toggle_on(): on_button.clicked.connect(toggle_off) off_button.clicked.connect(toggle_on) + def hide(self) -> None: + """Close the UI window""" + self.hide() + class SmallButton(QPushButton): """A small button with a fixed size""" @@ -221,3 +226,13 @@ def list_property(self, prop: str): A list of values for the given property. """ return [widget.data[prop] for widget in self.widgets] + + +def run_bciui(ui: Type[BCIUI], *args, **kwargs): + # add app to kwargs + app = QApplication(sys.argv).instance() + if not app: + app = QApplication(sys.argv) + ui_instance = ui(*args, **kwargs) + ui_instance.display() + return app.exec() diff --git a/bcipy/gui/experiments/ExperimentRegistry.py b/bcipy/gui/experiments/ExperimentRegistry.py index 6865cf75a..b1bd74de5 100644 --- a/bcipy/gui/experiments/ExperimentRegistry.py +++ b/bcipy/gui/experiments/ExperimentRegistry.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Type +from typing import List, Optional from PyQt6.QtWidgets import ( QComboBox, QVBoxLayout, @@ -7,9 +7,8 @@ QLineEdit, QPushButton, QScrollArea, - QApplication, ) -from bcipy.gui.bciui import BCIUI, DynamicItem, DynamicList, SmallButton +from bcipy.gui.bciui import BCIUI, DynamicItem, DynamicList, SmallButton, run_bciui from bcipy.helpers.load import load_fields, load_experiments from bcipy.helpers.save import save_experiment_data from bcipy.config import ( @@ -33,9 +32,14 @@ class ExperimentRegistry(BCIUI): def __init__(self): super().__init__("Experiment Registry", 600, 700) self.task_registry = TaskRegistry() + self.setProperty("class", "experiment-registry") def format_experiment_combobox( - self, label_text: str, combobox: QComboBox, buttons: Optional[List[QPushButton]] + self, + label_text: str, + combobox: QComboBox, + buttons: Optional[List[QPushButton]], + class_name: str = 'default', ) -> QVBoxLayout: """ Create a formatted widget for a the experiment comboboxes with optional buttons. @@ -51,11 +55,11 @@ def format_experiment_combobox( A QVBoxLayout with the label, combobox, and buttons. """ label = QLabel(label_text) - label.setStyleSheet("font-size: 18px") area = QVBoxLayout() input_area = QHBoxLayout() input_area.setContentsMargins(15, 0, 0, 15) area.addWidget(label) + combobox.setProperty("class", class_name) input_area.addWidget(combobox, 1) if buttons: for button in buttons: @@ -77,7 +81,7 @@ def make_task_entry(self, name: str) -> DynamicItem: """ layout = QHBoxLayout() label = QLabel(name) - label.setStyleSheet("color: black;") + label.setProperty("class", "task-label") layout.addWidget(label) widget = DynamicItem() @@ -102,7 +106,7 @@ def make_task_entry(self, name: str) -> DynamicItem: layout.addWidget(move_down_button) remove_button = SmallButton("Remove") - remove_button.setStyleSheet("background-color: red;") + remove_button.setProperty("class", "remove-button") remove_button.clicked.connect( lambda: layout.deleteLater() ) # This may not be needed @@ -127,7 +131,7 @@ def make_field_entry(self, name: str) -> DynamicItem: """ layout = QHBoxLayout() label = QLabel(name) - label.setStyleSheet("color: black;") + label.setProperty("class", "task-label") layout.addWidget(label) widget = DynamicItem() @@ -277,7 +281,10 @@ def add_task(): add_task_button = QPushButton("Add") add_task_button.clicked.connect(add_task) experiment_protocol_box = self.format_experiment_combobox( - "Protocol", self.experiment_protocol_input, [add_task_button] + "Protocol", + self.experiment_protocol_input, + [add_task_button], + "protocol", ) form_area.addLayout(experiment_protocol_box) @@ -288,7 +295,10 @@ def add_task(): new_field_button.clicked.connect(self.create_experiment_field) form_area.addLayout( self.format_experiment_combobox( - "Fields", self.field_input, [add_field_button, new_field_button] + "Fields", + self.field_input, + [add_field_button, new_field_button], + "fields", ) ) @@ -317,12 +327,5 @@ def add_task(): self.contents.addWidget(create_experiment_button) -def run_bciui(ui: Type[BCIUI], *args, **kwargs): - app = QApplication([]) - ui_instance = ui(*args, **kwargs) - ui_instance.display() - app.exec() - - if __name__ == "__main__": run_bciui(ExperimentRegistry) diff --git a/bcipy/gui/intertask_gui.py b/bcipy/gui/intertask_gui.py index c6e3b3979..87e1a8e47 100644 --- a/bcipy/gui/intertask_gui.py +++ b/bcipy/gui/intertask_gui.py @@ -1,36 +1,47 @@ -from typing import List -from bcipy.gui.bciui import BCIUI, run_bciui -from bcipy.task import Task +from typing import Callable, List + from PyQt6.QtWidgets import ( - QApplication, QLabel, QHBoxLayout, QPushButton, QProgressBar, + QApplication ) +from bcipy.gui.bciui import BCIUI, run_bciui +from bcipy.config import SESSION_LOG_FILENAME +import logging -from bcipy.task.main import TaskData +logger = logging.getLogger(SESSION_LOG_FILENAME) class IntertaskGUI(BCIUI): + action_name = "IntertaskAction" + def __init__( - self, next_task_name: str, current_task_index: int, total_tasks: int + self, + next_task_index: int, + tasks: List[str], + exit_callback: Callable, ): - self.total_tasks = total_tasks - self.current_task_index = current_task_index - self.next_task_name = next_task_name - super().__init__("Progress", 400, 100) + self.tasks = tasks + self.current_task_index = next_task_index + self.next_task_name = tasks[self.current_task_index] + self.total_tasks = len(tasks) + self.task_progress = next_task_index + self.callback = exit_callback + super().__init__("Progress", 800, 150) + self.setProperty("class", "inter-task") def app(self): self.contents.addLayout(BCIUI.centered(QLabel("Experiment Progress"))) progress_container = QHBoxLayout() progress_container.addWidget( - QLabel(f"({self.current_task_index}/{self.total_tasks})") + QLabel(f"({self.task_progress}/{self.total_tasks})") ) self.progress = QProgressBar() - self.progress.setValue(int(self.current_task_index / self.total_tasks * 100)) + self.progress.setValue(int(self.task_progress / self.total_tasks * 100)) self.progress.setTextVisible(False) progress_container.addWidget(self.progress) self.contents.addLayout(progress_container) @@ -51,22 +62,25 @@ def app(self): buttons_layout.addWidget(self.next_button) self.contents.addLayout(buttons_layout) - self.next_button.clicked.connect(self.close) - # This should be replaced with a method that stops orchestrator execution - self.stop_button.clicked.connect(QApplication.instance().quit) + self.next_button.clicked.connect(self.next) + self.stop_button.clicked.connect(self.stop_tasks) + def stop_tasks(self): + # This should exit Task executions + logger.info(f"Stopping Tasks... user requested. Using callback: {self.callback}") + self.callback() + self.quit() + logger.info("Tasks Stopped") -class IntertaskAction(Task): - name = "Intertask Action" - protocol: List[Task] - current_task_index: int + def next(self): + logger.info(f"Next Task=[{self.next_task_name}] requested") + self.quit() - def execute(self) -> TaskData: - task = self.protocol[self.current_task_index] - run_bciui(IntertaskGUI, task.name, self.current_task_index, len(self.protocol)) - return TaskData() + def quit(self): + QApplication.instance().quit() if __name__ == "__main__": - # test values - run_bciui(IntertaskGUI, "Placeholder Task Name", 1, 3) + tasks = ["RSVP Calibration", "IntertaskAction", "Matrix Calibration", "IntertaskAction"] + + run_bciui(IntertaskGUI, tasks=tasks, next_task_index=1, exit_callback=lambda: print("Stopping orchestrator")) diff --git a/bcipy/gui/main.py b/bcipy/gui/main.py index b7d76be03..9d032dfa6 100644 --- a/bcipy/gui/main.py +++ b/bcipy/gui/main.py @@ -1118,7 +1118,11 @@ def app(args) -> QApplication: Passes args from main and initializes the app """ - return QApplication(args) + + bci_app = QApplication(args).instance() + if not bci_app: + return QApplication(args) + return bci_app def start_app() -> None: @@ -1128,24 +1132,7 @@ def start_app() -> None: height=650, width=650, background_color='white') - - # ex.get_filename_dialog() - # ex.add_button(message='Test Button', position=[200, 300], size=[100, 100], id=1) - # ex.add_image(path='../static/images/gui/bci_cas_logo.png', position=[50, 50], size=200) - # ex.add_static_textbox( - # text='Test static text', - # background_color='black', - # text_color='white', - # position=[100, 20], - # wrap_text=True) - # ex.add_combobox(position=[100, 100], size=[100, 100], items=['first', 'second', 'third'], editable=True) - # ex.add_text_input(position=[100, 100], size=[100, 100]) ex.show_gui() - ex.throw_alert_message(title='title', - message='test', - message_response=AlertMessageResponse.OCE, - message_timeout=5) - sys.exit(bcipy_gui.exec()) diff --git a/bcipy/helpers/load.py b/bcipy/helpers/load.py index cd3f33be3..229c4b67e 100644 --- a/bcipy/helpers/load.py +++ b/bcipy/helpers/load.py @@ -6,7 +6,7 @@ from pathlib import Path from shutil import copyfile from time import localtime, strftime -from typing import List, Optional +from typing import List, Optional, Union from bcipy.config import (DEFAULT_ENCODING, DEFAULT_EXPERIMENT_PATH, DEFAULT_FIELD_PATH, DEFAULT_PARAMETERS_PATH, @@ -213,7 +213,7 @@ def choose_csv_file(filename: Optional[str] = None) -> Optional[str]: return filename -def load_raw_data(filename: str) -> RawData: +def load_raw_data(filename: Union[Path, str]) -> RawData: """Reads the data (.csv) file written by data acquisition. Parameters diff --git a/bcipy/helpers/report.py b/bcipy/helpers/report.py index 4836a7a0c..774e7ea8e 100644 --- a/bcipy/helpers/report.py +++ b/bcipy/helpers/report.py @@ -111,7 +111,7 @@ def _create_heatmap(self, onsets: List[float], range: Tuple[float, float], type: """ # create a heatmap with the onset values fig, ax = plt.subplots() - fig.set_size_inches(6, 3) + fig.set_size_inches(4, 2) ax.hist(onsets, bins=100, range=range, color='red', alpha=0.7) ax.set_title(f'{type} Artifact Onsets') ax.set_xlabel('Time (s)') @@ -156,8 +156,12 @@ class SessionReportSection(ReportSection): A class to handle the creation of a Session Report section in a BciPy Report using a summary dictionary. """ - def __init__(self, summary: Optional[dict] = None) -> None: + def __init__(self, summary: dict) -> None: self.summary = summary + if 'task' in self.summary: + self.session_name = self.summary['task'] + else: + self.session_name = 'Session Summary' self.style = getSampleStyleSheet() self.summary_table = None @@ -203,7 +207,7 @@ def _create_header(self) -> Paragraph: Creates a header for the Session Report section. """ - header = Paragraph('Session Summary', self.style['Heading3']) + header = Paragraph(f'{self.session_name}', self.style['Heading3']) return header diff --git a/bcipy/helpers/tests/resources/mock_session/parameters.json b/bcipy/helpers/tests/resources/mock_session/parameters.json index 1c75bf273..9331fe714 100644 --- a/bcipy/helpers/tests/resources/mock_session/parameters.json +++ b/bcipy/helpers/tests/resources/mock_session/parameters.json @@ -392,14 +392,13 @@ "type": "float" }, "font": { - "value": "Overpass Mono Medium", + "value": "Courier New", "section": "bci_config", "name": "Font", - "helpTip": "Specifies the font used for all text stimuli. Default: Consolas", + "helpTip": "Specifies the font used for all text stimuli. Default: Courier New", "recommended": [ "Courier New", - "Lucida Sans", - "Consolas" + "Lucida Sans" ], "editable": true, "type": "str" diff --git a/bcipy/helpers/tests/test_report.py b/bcipy/helpers/tests/test_report.py index 7df690816..3d28833d4 100644 --- a/bcipy/helpers/tests/test_report.py +++ b/bcipy/helpers/tests/test_report.py @@ -30,7 +30,8 @@ def test_init_no_name_default(self): self.assertEqual(report.name, Report.DEFAULT_NAME) def test_init_sections(self): - report_section = SessionReportSection() + summary = {} + report_section = SessionReportSection(summary) section = [report_section] report = Report(self.temp_dir, sections=section) self.assertEqual(report.sections, section) @@ -64,7 +65,8 @@ def test_add_section(self): def test_save(self): report = Report(self.temp_dir) - report_section = SessionReportSection() + summary = {} + report_section = SessionReportSection(summary) report.add(report_section) report.save() self.assertTrue(os.path.exists(os.path.join(self.temp_dir, report.name))) @@ -113,7 +115,8 @@ def setUp(self) -> None: } def test_init(self): - report_section = SessionReportSection() + summary = {} + report_section = SessionReportSection(summary) self.assertIsInstance(report_section, ReportSection) self.assertIsNotNone(report_section.style) @@ -124,7 +127,8 @@ def test_create_summary_text(self): self.assertIsInstance(table, Flowable) def test_create_header(self): - report_section = SessionReportSection() + summary = {} + report_section = SessionReportSection(summary) header = report_section._create_header() self.assertIsInstance(header, Paragraph) diff --git a/bcipy/helpers/visualization.py b/bcipy/helpers/visualization.py index 0dbb8623a..2c4402736 100644 --- a/bcipy/helpers/visualization.py +++ b/bcipy/helpers/visualization.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd import seaborn as sns + from matplotlib.figure import Figure from matplotlib.patches import Ellipse from mne import Epochs @@ -18,10 +19,11 @@ import bcipy.acquisition.devices as devices from bcipy.config import (DEFAULT_DEVICE_SPEC_FILENAME, DEFAULT_GAZE_IMAGE_PATH, RAW_DATA_FILENAME, - TRIGGER_FILENAME, SESSION_LOG_FILENAME) + TRIGGER_FILENAME, SESSION_LOG_FILENAME, + DEFAULT_PARAMETERS_PATH) from bcipy.helpers.acquisition import analysis_channels from bcipy.helpers.convert import convert_to_mne -from bcipy.helpers.load import choose_csv_file, load_raw_data +from bcipy.helpers.load import choose_csv_file, load_raw_data, load_json_parameters from bcipy.helpers.parameters import Parameters from bcipy.helpers.raw_data import RawData from bcipy.helpers.stimuli import mne_epochs @@ -669,7 +671,8 @@ def visualize_evokeds(epochs: Tuple[Epochs, Epochs], def visualize_session_data( session_path: str, parameters: Union[dict, Parameters], - show=True) -> Figure: + show=True, + save=True) -> Figure: """Visualize Session Data. This method is used to load and visualize EEG data after a session. @@ -732,7 +735,7 @@ def visualize_session_data( transform=default_transform, plot_average=True, plot_topomaps=True, - save_path=session_path, + save_path=session_path if save else None, show=show, ) @@ -756,3 +759,34 @@ def visualize_gaze_accuracies(accuracy_dict: Dict[str, np.ndarray], ax.set_title('Overall Accuracy: ' + str(round(accuracy, 2))) return fig + + +def erp(): + import argparse + + parser = argparse.ArgumentParser(description='Visualize ERP data') + + parser.add_argument( + '-s', '--session_path', + type=str, + help='Path to the session directory', + required=True) + parser.add_argument( + '-p', '--parameters', + type=str, + help='Path to the parameters file', + default=DEFAULT_PARAMETERS_PATH) + parser.add_argument( + '--show', + action='store_true', + help='Whether to show the figure', + default=False) + parser.add_argument( + '--save', + action='store_true', + help='Whether to save the figure', default=True) + + args = parser.parse_args() + + parameters = load_json_parameters(args.parameters, value_cast=True) + visualize_session_data(args.session_path, parameters, args.show, args.save) diff --git a/bcipy/parameters/experiment/experiments.json b/bcipy/parameters/experiment/experiments.json index 9e164703c..add8e1184 100644 --- a/bcipy/parameters/experiment/experiments.json +++ b/bcipy/parameters/experiment/experiments.json @@ -3,5 +3,10 @@ "fields": [], "summary": "Default experiment to test various BciPy features without registering a full experiment.", "protocol": "RSVP Calibration -> Matrix Calibration" + }, + "BCIOne": { + "fields": [], + "summary": "BCIOne experiment", + "protocol": "RSVP Calibration -> IntertaskAction -> OfflineAnalysisAction -> IntertaskAction -> Matrix Calibration -> IntertaskAction -> OfflineAnalysisAction -> IntertaskAction -> BciPy Report Action" } } \ No newline at end of file diff --git a/bcipy/parameters/parameters.json b/bcipy/parameters/parameters.json index d7e24b73e..296d91858 100755 --- a/bcipy/parameters/parameters.json +++ b/bcipy/parameters/parameters.json @@ -329,10 +329,10 @@ "type": "int" }, "time_flash": { - "value": "0.25", + "value": "0.2", "section": "task_config", "name": "Stimulus Presentation Duration", - "helpTip": "Specifies the duration of time (in seconds) that each stimulus is displayed in an inquiry.", + "helpTip": "Specifies the duration of time (in seconds) that each stimulus is displayed in an inquiry. Default: 0.2", "recommended": "", "editable": false, "type": "float" @@ -376,7 +376,7 @@ "rsvp_stim_pos_x": { "value": "0", "section": "rsvp_task_config", - "name": "Stimulus Position Horizontal", + "name": "RSVP Stimulus Position Horizontal", "helpTip": "Specifies the center point of the stimulus position along the X axis. Possible values range from -1 to 1, with 0 representing the center. Default: 0", "recommended": "", "editable": false, @@ -385,7 +385,7 @@ "rsvp_stim_pos_y": { "value": "0", "section": "rsvp_task_config", - "name": "Stimulus Position Vertical", + "name": "RSVP Stimulus Position Vertical", "helpTip": "Specifies the center point of the stimulus position along the Y axis. Possible values range from -1 to 1, with 0 representing the center. Default: 0", "recommended": "", "editable": false, @@ -394,7 +394,7 @@ "matrix_stim_pos_x": { "value": "-0.6", "section": "matrix_task_config", - "name": "Stimulus Starting Position Horizontal", + "name": "Matrix Stimulus Starting Position Horizontal", "helpTip": "Specifies the center point of the stimulus position along the X axis. Possible values range from -1 to 1, with 0 representing the center. Default: -0.6", "recommended": "", "editable": false, @@ -403,7 +403,7 @@ "matrix_stim_pos_y": { "value": "0.4", "section": "matrix_task_config", - "name": "Stimulus Starting Position Vertical", + "name": "Matrix Stimulus Starting Position Vertical", "helpTip": "Specifies the center point of the stimulus position along the Y axis. Possible values range from -1 to 1, with 0 representing the center. Default: 0.4", "recommended": "", "editable": false, @@ -435,7 +435,7 @@ "vep_stim_height": { "value": "0.1", "section": "vep_task_config", - "name": "vep Stimulus Size", + "name": "VEP Stimulus Size", "helpTip": "Specifies the height of text stimuli. See https://www.psychopy.org/general/units.html. Default: 0.1", "recommended": "", "editable": false, @@ -453,7 +453,7 @@ "matrix_keyboard_layout": { "value": "ALP", "section": "matrix_task_config", - "name": "Keyboard Layout", + "name": "Matrix Keyboard Layout", "helpTip": "Specifies the keyboard layout to use for the Matrix task. Default: ALP (Alphabetical)", "recommended": [ "ALP", @@ -568,13 +568,13 @@ "type": "str" }, "copy_phrases_location": { - "value": "", + "value": "bcipy/parameters/experiment/phrases.json", "section": "online_config", "name": "Copy Phrases Location", "helpTip": "Specifies a list of copy phrases to execute during Task Orchestration. If provided, any copy phrases in the protocol will be executed in order pulling task text and starting locations from the file.", "recommended": "", "editable": true, - "type": "str" + "type": "filepath" }, "rsvp_task_height": { "value": "0.1", @@ -621,7 +621,7 @@ "rsvp_task_padding": { "value": "0.05", "section": "bci_config", - "name": "Task Bar Padding", + "name": "RSVP Task Bar Padding", "helpTip": "Specifies the padding around the task bar text for RSVP tasks. Default: 0.05", "recommended": [ "0.05" @@ -632,7 +632,7 @@ "matrix_task_padding": { "value": "0.05", "section": "bci_config", - "name": "Task Bar Padding", + "name": "Matrix Task Bar Padding", "helpTip": "Specifies the padding around the task bar text for Matrix Tasks. Default: 0.05", "recommended": [ "0.05" @@ -641,7 +641,7 @@ "type": "float" }, "stim_number": { - "value": "100", + "value": "5", "section": "bci_config", "name": "Number of Calibration inquiries", "helpTip": "Specifies the number of inquiries to present in a calibration session. Default: 100", diff --git a/bcipy/signal/evaluate/artifact.py b/bcipy/signal/evaluate/artifact.py index 3ab32eae6..19ec6180f 100644 --- a/bcipy/signal/evaluate/artifact.py +++ b/bcipy/signal/evaluate/artifact.py @@ -40,7 +40,7 @@ class DefaultArtifactParameters(Enum): """ # Voltage - PEAK_THRESHOLD = 100e-7 + PEAK_THRESHOLD = 75e-7 PEAK_MIN_DURATION = 0.005 FLAT_THRESHOLD = 0.5e-6 FLAT_MIN_DURATION = 0.1 @@ -133,6 +133,12 @@ class ArtifactDetection: detect_voltage : bool Whether to detect voltage artifacts. Defaults to True. + + semi_automatic : bool + Whether to use a semi-automatic approach to artifact detection. Defaults to False. + + session_triggers : tuple + A tuple of lists containing the trigger type, trigger timing, and trigger label for the session. """ supported_units: List[str] = ['volts', 'microvolts'] @@ -237,13 +243,13 @@ def label_artifacts( voltage = self.label_voltage_events() if voltage: voltage_annotations, bad_channels = voltage - log.info(f'Voltage violation events found: {len(voltage_annotations)}') if bad_channels: # add bad channel labels to the raw data self.mne_data.info['bads'] = bad_channels log.info(f'Bad channels detected: {bad_channels}') if voltage_annotations: + log.info(f'Voltage violation events found: {len(voltage_annotations)}') annotations += voltage_annotations self.voltage_annotations = voltage_annotations @@ -251,8 +257,9 @@ def label_artifacts( eog = self.label_eog_events() if eog: eog_annotations, eog_events = eog - log.info(f'EOG events found: {len(eog_events)}') + if eog_annotations: + log.info(f'EOG events found: {len(eog_events)}') annotations += eog_annotations self.eog_annotations = eog_annotations @@ -431,8 +438,10 @@ def label_voltage_events( # combine the bad channels bad_channels = bad_channels1 + bad_channels2 - if len(onsets) > 0 or len(bad_channels) > 0: + if len(onsets) > 0 and len(bad_channels) > 0: return mne.Annotations(onsets, durations, descriptions), bad_channels + elif len(bad_channels) > 0: + return None, bad_channels return None diff --git a/bcipy/signal/model/base_model.py b/bcipy/signal/model/base_model.py index 8a854f72b..828e3bb1e 100644 --- a/bcipy/signal/model/base_model.py +++ b/bcipy/signal/model/base_model.py @@ -16,10 +16,14 @@ class SignalModelMetadata(NamedTuple): device_spec: DeviceSpec # device used to train the model transform: Composition # data preprocessing steps evidence_type: str = None # optional; type of evidence produced + auc: float = None # optional; area under the curve + balanced_accuracy: float = None # optional; balanced accuracy class SignalModel(ABC): + name = "undefined" + @property def metadata(self) -> SignalModelMetadata: """Information regarding the data and parameters used to train the diff --git a/bcipy/signal/model/cross_validation.py b/bcipy/signal/model/cross_validation.py index 8709eaa40..238229b0a 100644 --- a/bcipy/signal/model/cross_validation.py +++ b/bcipy/signal/model/cross_validation.py @@ -15,7 +15,7 @@ def cost_cross_validation_auc(model, opt_el, x, y, param, k_folds=10, Cost function: given a particular architecture (model). Fits the parameters to the folds with leave one fold out procedure. Calculates scores for the validation fold. Concatenates all calculated scores - together and returns a -AUC vale. + together and returns a -AUC value. Args: model(pipeline): model to be iterated on opt_el(int): number of the element in pipeline to be optimized diff --git a/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py b/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py index 0f02f39e9..0f71f6c94 100644 --- a/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py +++ b/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py @@ -12,6 +12,7 @@ class GazeModelIndividual(SignalModel): reshaper = GazeReshaper() + name = "gaze_model_individual" def __init__(self, num_components=2): self.num_components = num_components # number of gaussians to fit @@ -79,6 +80,7 @@ def load(self, path: Path): class GazeModelCombined(SignalModel): '''Gaze model that uses all symbols to fit a single Gaussian ''' reshaper = GazeReshaper() + name = "gaze_model_combined" def __init__(self, num_components=1): self.num_components = num_components # number of gaussians to fit diff --git a/bcipy/signal/model/offline_analysis.py b/bcipy/signal/model/offline_analysis.py index 2a621ef02..bbbda64e4 100644 --- a/bcipy/signal/model/offline_analysis.py +++ b/bcipy/signal/model/offline_analysis.py @@ -1,30 +1,31 @@ # mypy: disable-error-code="attr-defined" -# needed for the ERPTransformParams import json import logging +import subprocess from pathlib import Path -from typing import Tuple +from typing import List import numpy as np -from matplotlib.figure import Figure + from sklearn.metrics import balanced_accuracy_score from sklearn.model_selection import train_test_split import bcipy.acquisition.devices as devices from bcipy.config import (BCIPY_ROOT, DEFAULT_DEVICE_SPEC_FILENAME, - DEFAULT_PARAMETERS_PATH, MATRIX_IMAGE_FILENAME, - STATIC_AUDIO_PATH, TRIGGER_FILENAME, SESSION_LOG_FILENAME) + DEFAULT_PARAMETERS_PATH, MATRIX_IMAGE_FILENAME, DEFAULT_DEVICES_PATH, + TRIGGER_FILENAME, SESSION_LOG_FILENAME) from bcipy.helpers.acquisition import analysis_channels, raw_data_filename from bcipy.helpers.load import (load_experimental_data, load_json_parameters, load_raw_data) +from bcipy.gui.alert import confirm from bcipy.helpers.parameters import Parameters from bcipy.helpers.save import save_model -from bcipy.helpers.stimuli import play_sound, update_inquiry_timing +from bcipy.helpers.stimuli import update_inquiry_timing from bcipy.helpers.symbols import alphabet from bcipy.helpers.system_utils import report_execution_time from bcipy.helpers.triggers import TriggerType, trigger_decoder from bcipy.helpers.visualization import (visualize_centralized_data, - visualize_erp, visualize_gaze, + visualize_gaze, visualize_gaze_accuracies, visualize_gaze_inquiries, visualize_results_all_symbols) @@ -74,7 +75,7 @@ def subset_data(data: np.ndarray, labels: np.ndarray, test_size: float, random_s return train_data, test_data, train_labels, test_labels -def analyze_erp(erp_data, parameters, device_spec, data_folder, estimate_balanced_acc, +def analyze_erp(erp_data, parameters, device_spec, data_folder, estimate_balanced_acc: bool, save_figures=False, show_figures=False): """Analyze ERP data and return/save the ERP model. Extract relevant information from raw data object. @@ -182,39 +183,46 @@ def analyze_erp(erp_data, parameters, device_spec, data_folder, estimate_balance # train and save the model as a pkl file log.info("Training model. This will take some time...") model = PcaRdaKdeModel(k_folds=k_folds) - model.fit(data, labels) - model.metadata = SignalModelMetadata(device_spec=device_spec, - transform=default_transform) - log.info(f"Training complete [AUC={model.auc:0.4f}]. Saving data...") + try: + model.fit(data, labels) + model.metadata = SignalModelMetadata(device_spec=device_spec, + transform=default_transform, + evidence_type="ERP", + auc=model.auc) + log.info(f"Training complete [AUC={model.auc:0.4f}]. Saving data...") + except Exception as e: + log.error(f"Error training model: {e}") + + try: + # Using an 80/20 split, report on balanced accuracy + if estimate_balanced_acc: + train_data, test_data, train_labels, test_labels = subset_data(data, labels, test_size=0.2) + dummy_model = PcaRdaKdeModel(k_folds=k_folds) + dummy_model.fit(train_data, train_labels) + probs = dummy_model.predict_proba(test_data) + preds = probs.argmax(-1) + score = balanced_accuracy_score(test_labels, preds) + log.info(f"Balanced acc with 80/20 split: {score}") + model.metadata.balanced_accuracy = score + del dummy_model, train_data, test_data, train_labels, test_labels, probs, preds + + except Exception as e: + log.error(f"Error calculating balanced accuracy: {e}") save_model(model, Path(data_folder, f"model_{model.auc:0.4f}.pkl")) preferences.signal_model_directory = data_folder - # Using an 80/20 split, report on balanced accuracy - if estimate_balanced_acc: - train_data, test_data, train_labels, test_labels = subset_data(data, labels, test_size=0.2) - dummy_model = PcaRdaKdeModel(k_folds=k_folds) - dummy_model.fit(train_data, train_labels) - probs = dummy_model.predict_proba(test_data) - preds = probs.argmax(-1) - score = balanced_accuracy_score(test_labels, preds) - log.info(f"Balanced acc with 80/20 split: {score}") - del dummy_model, train_data, test_data, train_labels, test_labels, probs, preds - - # this should have uncorrected trigger timing for display purposes - figure_handles = visualize_erp( - erp_data, - channel_map, - trigger_timing, - labels, - trial_window, - transform=default_transform, - plot_average=True, - plot_topomaps=True, - save_path=data_folder if save_figures else None, - show=show_figures - ) - return model, figure_handles + if save_figures or show_figures: + cmd = f'bcipy-erp-viz --session_path "{data_folder}" --parameters "{parameters["parameter_location"]}"' + if save_figures: + cmd += ' --save' + if show_figures: + cmd += ' --show' + subprocess.run( + cmd, + shell=True + ) + return model def analyze_gaze( @@ -248,15 +256,13 @@ def analyze_gaze( "Individual": Fits a separate Gaussian for each symbol. Default model "Centralized": Uses data from all symbols to fit a single centralized Gaussian """ - figures = [] - figure_handles = visualize_gaze( + visualize_gaze( gaze_data, save_path=save_figures, img_path=f'{data_folder}/{MATRIX_IMAGE_FILENAME}', show=show_figures, raw_plot=plot_points, ) - figures.extend(figure_handles) channels = gaze_data.channels type_amp = gaze_data.daq_type @@ -347,7 +353,7 @@ def analyze_gaze( means, covs = model.evaluate(test_re) # Visualize the results: - figure_handles = visualize_gaze_inquiries( + visualize_gaze_inquiries( le, re, means, covs, save_path=save_figures, @@ -355,7 +361,6 @@ def analyze_gaze( show=show_figures, raw_plot=plot_points, ) - figures.extend(figure_handles) left_eye_all.append(le) right_eye_all.append(re) means_all.append(means) @@ -399,22 +404,20 @@ def analyze_gaze( print(f"Overall accuracy: {accuracy:.2f}") # Plot all accuracies as bar plot: - figure_handles = visualize_gaze_accuracies(acc_all_symbols, accuracy, save_path=None, show=True) - figures.extend(figure_handles) + visualize_gaze_accuracies(acc_all_symbols, accuracy, save_path=None, show=True) if model_type == "Centralized": cent_left = np.concatenate(np.array(centralized_data_left, dtype=object)) cent_right = np.concatenate(np.array(centralized_data_right, dtype=object)) # Visualize the results: - figure_handles = visualize_centralized_data( + visualize_centralized_data( cent_left, cent_right, save_path=save_figures, img_path=f'{data_folder}/{MATRIX_IMAGE_FILENAME}', show=show_figures, raw_plot=plot_points, ) - figures.extend(figure_handles) # Fit the model: model.fit(cent_left) @@ -427,7 +430,7 @@ def analyze_gaze( le = preprocessed_data[sym][0] re = preprocessed_data[sym][1] # Visualize the results: - figure_handles = visualize_gaze_inquiries( + visualize_gaze_inquiries( le, re, means, covs, save_path=save_figures, @@ -435,13 +438,13 @@ def analyze_gaze( show=show_figures, raw_plot=plot_points, ) - figures.extend(figure_handles) left_eye_all.append(le) right_eye_all.append(re) means_all.append(means) covs_all.append(covs) - fig_handles = visualize_results_all_symbols( + # TODO: add visualizations to subprocess + visualize_results_all_symbols( left_eye_all, right_eye_all, means_all, covs_all, img_path=f'{data_folder}/{MATRIX_IMAGE_FILENAME}', @@ -449,7 +452,6 @@ def analyze_gaze( show=show_figures, raw_plot=plot_points, ) - figures.extend(fig_handles) model.metadata = SignalModelMetadata(device_spec=device_spec, transform=None) @@ -457,7 +459,7 @@ def analyze_gaze( save_model( model, Path(data_folder, f"model_{device_spec.content_type}_{model_type}.pkl")) - return model, figures + return model @report_execution_time @@ -468,7 +470,7 @@ def offline_analysis( estimate_balanced_acc: bool = False, show_figures: bool = False, save_figures: bool = False, -) -> Tuple[SignalModel, Figure]: +) -> List[SignalModel]: """Gets calibration data and trains the model in an offline fashion. pickle dumps the model into a .pkl folder @@ -499,45 +501,52 @@ def offline_analysis( Returns: -------- model (SignalModel): trained model - figure_handles (Figure): handles to the ERP figures """ assert parameters, "Parameters are required for offline analysis." if not data_folder: data_folder = load_experimental_data() + # Load default devices which are used for training the model with different channels, etc. devices_by_name = devices.load( - Path(data_folder, DEFAULT_DEVICE_SPEC_FILENAME), replace=True) + Path(DEFAULT_DEVICES_PATH, DEFAULT_DEVICE_SPEC_FILENAME), replace=True) - active_devices = (spec for spec in devices_by_name.values() + # Load the active devices used during a session; this will be used to exclude inactive devices + active_devices_by_name = devices.load( + Path(data_folder, DEFAULT_DEVICE_SPEC_FILENAME), replace=True) + active_devices = (spec for spec in active_devices_by_name.values() if spec.is_active) active_raw_data_paths = (Path(data_folder, raw_data_filename(device_spec)) for device_spec in active_devices) data_file_paths = [path for path in active_raw_data_paths if path.exists()] + assert len(data_file_paths) < 3, "BciPy only supports up to 2 devices for offline analysis." + assert len(data_file_paths) > 0, "No data files found for offline analysis." + models = [] - figure_handles = [] + log.info(f"Starting offline analysis for {data_file_paths}") for raw_data_path in data_file_paths: raw_data = load_raw_data(raw_data_path) device_spec = devices_by_name.get(raw_data.daq_type) # extract relevant information from raw data object eeg if device_spec.content_type == "EEG": - erp_model, erp_figure_handles = analyze_erp( + erp_model = analyze_erp( raw_data, parameters, device_spec, data_folder, estimate_balanced_acc, save_figures, show_figures) models.append(erp_model) - figure_handles.extend(erp_figure_handles) if device_spec.content_type == "Eyetracker": - et_model, et_figure_handles = analyze_gaze( + et_model = analyze_gaze( raw_data, parameters, device_spec, data_folder, save_figures, show_figures, model_type="Individual") models.append(et_model) - figure_handles.extend(et_figure_handles) if alert_finished: - play_sound(f"{STATIC_AUDIO_PATH}/{parameters['alert_sound_file']}") - return models, figure_handles + log.info("Alerting Offline Analysis Complete") + results = [f"{model.name}: {model.auc}" for model in models] + confirm(f"Offline analysis complete! \n Results={results}") + log.info("Offline analysis complete") + return models -if __name__ == "__main__": +def main(): import argparse parser = argparse.ArgumentParser() @@ -548,13 +557,17 @@ def offline_analysis( parser.add_argument("--alert", dest="alert", action="store_true") parser.add_argument("--balanced-acc", dest="balanced", action="store_true") parser.set_defaults(alert=False) - parser.set_defaults(balanced=True) + parser.set_defaults(balanced=False) parser.set_defaults(save_figures=False) - parser.set_defaults(show_figures=True) + parser.set_defaults(show_figures=False) args = parser.parse_args() log.info(f"Loading params from {args.parameters_file}") parameters = load_json_parameters(args.parameters_file, value_cast=True) + log.info( + f"Starting offline analysis client with the following: Data={args.data_folder} || " + f"Save Figures={args.save_figures} || Show Figures={args.show_figures} || " + f"Alert={args.alert} || Calculate Balanced Accuracy={args.balanced}") offline_analysis( args.data_folder, @@ -563,4 +576,7 @@ def offline_analysis( estimate_balanced_acc=args.balanced, save_figures=args.save_figures, show_figures=args.show_figures) - log.info("Offline Analysis complete.") + + +if __name__ == "__main__": + main() diff --git a/bcipy/signal/model/pca_rda_kde/pca_rda_kde.py b/bcipy/signal/model/pca_rda_kde/pca_rda_kde.py index 48b2701a6..28280c1d3 100644 --- a/bcipy/signal/model/pca_rda_kde/pca_rda_kde.py +++ b/bcipy/signal/model/pca_rda_kde/pca_rda_kde.py @@ -18,6 +18,7 @@ class PcaRdaKdeModel(SignalModel): reshaper: InquiryReshaper = InquiryReshaper() + name = "pca_rda_kde" def __init__(self, k_folds: int = 10, prior_type="uniform", pca_n_components=0.9): self.k_folds = k_folds diff --git a/bcipy/signal/model/rda_kde/rda_kde.py b/bcipy/signal/model/rda_kde/rda_kde.py index 759f090f4..6c716d569 100644 --- a/bcipy/signal/model/rda_kde/rda_kde.py +++ b/bcipy/signal/model/rda_kde/rda_kde.py @@ -15,6 +15,7 @@ class RdaKdeModel(SignalModel): reshaper = InquiryReshaper() + name = "rda_kde" def __init__(self, k_folds: int, prior_type: str = "uniform"): self.k_folds = k_folds diff --git a/bcipy/signal/tests/model/test_offline_analysis.py b/bcipy/signal/tests/model/test_offline_analysis.py index 3f8535e2e..54f84d8e3 100644 --- a/bcipy/signal/tests/model/test_offline_analysis.py +++ b/bcipy/signal/tests/model/test_offline_analysis.py @@ -19,7 +19,7 @@ @pytest.mark.slow -class TestOfflineAnalysis(unittest.TestCase): +class TestOfflineAnalysisEEG(unittest.TestCase): """Integration test of offline_analysis.py (slow) This test is slow because it runs the full offline analysis pipeline and compares its' output @@ -50,12 +50,14 @@ def setUpClass(cls): params_path = pwd.parent.parent.parent / "parameters" / DEFAULT_PARAMETER_FILENAME cls.parameters = load_json_parameters(params_path, value_cast=True) - models, fig_handles = offline_analysis( - str(cls.tmp_dir), cls.parameters, save_figures=True, show_figures=False, alert_finished=False) + models = offline_analysis( + str(cls.tmp_dir), + cls.parameters, + save_figures=False, + show_figures=False, + alert_finished=False) + # only one model is generated using the default parameters cls.model = models[0] - cls.mean_erp_fig_handle = fig_handles[0] - cls.mean_nontarget_topomap_handle = fig_handles[1] - cls.mean_target_topomap_handle = fig_handles[2] @classmethod def tearDownClass(cls): @@ -73,20 +75,6 @@ def test_model_AUC(self): found_auc = self.get_auc(list(self.tmp_dir.glob("model_*.pkl"))[0].name) self.assertAlmostEqual(expected_auc, found_auc, delta=0.005) - @pytest.mark.mpl_image_compare(baseline_dir=expected_output_folder, filename="test_mean_erp.png", remove_text=True) - def test_mean_erp(self): - return self.mean_erp_fig_handle - - @pytest.mark.mpl_image_compare(baseline_dir=expected_output_folder, - filename="test_target_topomap.png", remove_text=False) - def test_target_topomap(self): - return self.mean_target_topomap_handle - - @pytest.mark.mpl_image_compare(baseline_dir=expected_output_folder, - filename="test_nontarget_topomap.png", remove_text=False) - def test_nontarget_topomap(self): - return self.mean_nontarget_topomap_handle - if __name__ == "__main__": unittest.main() diff --git a/bcipy/simulator/data/sampler.py b/bcipy/simulator/data/sampler.py index fa1e3f8c4..5424c95e2 100644 --- a/bcipy/simulator/data/sampler.py +++ b/bcipy/simulator/data/sampler.py @@ -180,8 +180,8 @@ def sample(self, state: SimState) -> List[Trial]: inquiry_n = random.choice(source_inquiries[data_source]) # select all trials for the data_source and inquiry - inquiry_df = self.data.loc[(self.data['source'] == data_source) - & (self.data['inquiry_n'] == inquiry_n)] + inquiry_df = self.data.loc[(self.data['source'] == data_source) & + (self.data['inquiry_n'] == inquiry_n)] assert len(inquiry_df) == len( inquiry_letter_subset), f"Invalid data source {data_source}" diff --git a/bcipy/simulator/task/copy_phrase.py b/bcipy/simulator/task/copy_phrase.py index db4ea3c96..1dee73b09 100644 --- a/bcipy/simulator/task/copy_phrase.py +++ b/bcipy/simulator/task/copy_phrase.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="union-attr" """Simulates the Copy Phrase task""" from typing import Dict, List, Optional, Tuple import logging diff --git a/bcipy/task/actions.py b/bcipy/task/actions.py index 3b5b27e24..523105262 100644 --- a/bcipy/task/actions.py +++ b/bcipy/task/actions.py @@ -1,12 +1,30 @@ +# mypy: disable-error-code="assignment,arg-type" import subprocess -from typing import Any, Optional +from typing import Any, Optional, List, Callable, Tuple import logging +from pathlib import Path +import glob +from bcipy.gui.bciui import run_bciui +from matplotlib.figure import Figure + +from bcipy.gui.intertask_gui import IntertaskGUI from bcipy.gui.experiments.ExperimentField import start_experiment_field_collection_gui -from bcipy.task import Task, TaskData, TaskMode +from bcipy.task import Task, TaskMode, TaskData +from bcipy.helpers.triggers import trigger_decoder, TriggerType + +from bcipy.acquisition import devices +from bcipy.helpers.acquisition import analysis_channels from bcipy.helpers.parameters import Parameters -from bcipy.config import DEFAULT_PARAMETERS_PATH, SESSION_LOG_FILENAME -from bcipy.signal.model.offline_analysis import offline_analysis +from bcipy.acquisition.devices import DeviceSpec +from bcipy.helpers.load import load_raw_data +from bcipy.helpers.raw_data import RawData +from bcipy.signal.process import get_default_transform +from bcipy.helpers.report import SignalReportSection, SessionReportSection, Report, ReportSection +from bcipy.config import DEFAULT_PARAMETERS_PATH, SESSION_LOG_FILENAME, RAW_DATA_FILENAME, TRIGGER_FILENAME +from bcipy.helpers.visualization import visualize_erp +from bcipy.signal.evaluate.artifact import ArtifactDetection + logger = logging.getLogger(SESSION_LOG_FILENAME) @@ -54,13 +72,14 @@ def __init__( data_directory: str, parameters_path: str = f'{DEFAULT_PARAMETERS_PATH}', last_task_dir: Optional[str] = None, - alert: bool = False, + alert_finished: bool = False, **kwargs: Any) -> None: super().__init__() self.parameters = parameters self.parameters_path = parameters_path - self.alert_finished = alert + self.alert_finished = alert_finished + # TODO: add a feature to orchestrator to permit the user to select the last task directory or have it loaded. if last_task_dir: self.data_directory = last_task_dir else: @@ -75,9 +94,16 @@ def execute(self) -> TaskData: inconsistent. """ - logger.info(f"Running offline analysis on data in {self.data_directory}") + logger.info("Running offline analysis action") try: - response = offline_analysis(self.data_directory, self.parameters, alert_finished=self.alert_finished) + cmd = f"bcipy-train --parameters '{self.parameters_path}'" + if self.alert_finished: + cmd += " --alert" + response = subprocess.run( + cmd, + shell=True, + check=True, + ) except Exception as e: logger.exception(f"Error running offline analysis: {e}") raise e @@ -88,6 +114,52 @@ def execute(self) -> TaskData: ) +class IntertaskAction(Task): + name = "IntertaskAction" + mode = TaskMode.ACTION + tasks: List[Task] + current_task_index: int + + def __init__( + self, + parameters: Parameters, + save_path: str, + progress: Optional[int] = None, + tasks: Optional[List[Task]] = None, + exit_callback: Optional[Callable] = None, + **kwargs: Any) -> None: + super().__init__() + self.save_folder = save_path + self.parameters = parameters + assert progress is not None and tasks is not None, "Either progress or tasks must be provided" + self.next_task_index = progress # progress is 1-indexed, tasks is 0-indexed so we can use the same index + assert self.next_task_index >= 0, "Progress must be greater than 1 " + self.tasks = tasks + self.task_name = self.tasks[self.next_task_index].name + self.task_names = [task.name for task in self.tasks] + self.exit_callback = exit_callback + + def execute(self) -> TaskData: + + run_bciui( + IntertaskGUI, + tasks=self.task_names, + next_task_index=self.next_task_index, + exit_callback=self.exit_callback), + + return TaskData( + save_path=self.save_folder, + task_dict={ + "next_task_index": self.next_task_index, + "tasks": self.task_names, + "task_name": self.task_name, + }, + ) + + def alert(self): + pass + + class ExperimentFieldCollectionAction(Task): """ Action for collecting experiment field data. @@ -118,3 +190,190 @@ def execute(self) -> TaskData: "experiment_id": self.experiment_id, }, ) + + +class BciPyCalibrationReportAction(Task): + """ + Action for generating a report after calibration Tasks. + """ + + name = "BciPy Report Action" + mode = TaskMode.ACTION + + def __init__( + self, + parameters: Parameters, + save_path: str, + protocol_path: Optional[str] = None, + last_task_dir: Optional[str] = None, + trial_window: Optional[Tuple[float, float]] = None, + **kwargs: Any) -> None: + super().__init__() + self.save_folder = save_path + # Currently we assume all Tasks have the same parameters, this may change in the future. + self.parameters = parameters + self.protocol_path = protocol_path or '' + self.last_task_dir = last_task_dir + self.default_transform = None + self.trial_window = (-0.2, 1.0) + self.static_offset = self.parameters.get("static_offset") + self.report = Report(self.protocol_path) + self.report_sections: List[ReportSection] = [] + self.all_raw_data: List[RawData] = [] + self.type_amp = None + + def execute(self) -> TaskData: + """Excute the report generation action. + + This assumes all data were collected using the same protocol, device, and parameters. + """ + logger.info(f"Generating report in save folder {self.save_folder}") + # loop through all the files in the last_task_dir + + data_directories = [] + # If a protocol is given, loop over and look for any calibration directories + try: + if self.protocol_path: + # Use glob to find all directories with Calibration in the name + calibration_directories = glob.glob( + f"{self.protocol_path}/**/*Calibration*", + recursive=True) + for data_dir in calibration_directories: + path_data_dir = Path(data_dir) + # pull out the last directory name + task_name = path_data_dir.parts[-1].split('_')[0] + data_directories.append(path_data_dir) + # For each calibration directory, attempt to load the raw data + signal_report_section = self.create_signal_report(path_data_dir) + session_report = self.create_session_report(path_data_dir, task_name) + self.report_sections.append(session_report) + self.report.add(session_report) + self.report_sections.append(signal_report_section) + self.report.add(signal_report_section) + if data_directories: + logger.info(f"Saving report generated from: {self.protocol_path}") + else: + logger.info(f"No data found in {self.protocol_path}") + + except Exception as e: + logger.exception(f"Error generating report: {e}") + + self.report.compile() + self.report.save() + return TaskData( + save_path=self.save_folder, + task_dict={}, + ) + + def create_signal_report(self, data_dir: Path) -> SignalReportSection: + raw_data = load_raw_data(Path(data_dir, f'{RAW_DATA_FILENAME}.csv')) + if not self.type_amp: + self.type_amp = raw_data.daq_type + channels = raw_data.channels + sample_rate = raw_data.sample_rate + device_spec = devices.preconfigured_device(raw_data.daq_type) + channel_map = analysis_channels(channels, device_spec) + self.all_raw_data.append(raw_data) + + # Set the default transform if not already set + if not self.default_transform: + self.set_default_transform(sample_rate) + + triggers = self.get_triggers(data_dir) + # get figure handles + figure_handles = self.get_figure_handles(raw_data, channel_map, triggers) + artifact_detector = self.get_artifact_detector(raw_data, device_spec, triggers) + return SignalReportSection(figure_handles, artifact_detector) + + def create_session_report(self, data_dir, task_name) -> SessionReportSection: + # get task name + summary_dict = { + "task": task_name, + "data_location": data_dir, + "amplifier": self.type_amp + } + signal_model_metrics = self.get_signal_model_metrics(data_dir) + summary_dict.update(signal_model_metrics) + + return SessionReportSection(summary_dict) + + def get_signal_model_metrics(self, data_directory: Path) -> dict: + """Get the signal model metrics from the session folder. + + In the future, the model will save a ModelMetrics with the pkl file. + For now, we just look for the pkl file and extract the AUC from the filename. + """ + pkl_file = None + for file in data_directory.iterdir(): + if file.suffix == '.pkl': + pkl_file = file + break + + if pkl_file: + auc = pkl_file.stem.split('_')[-1] + else: + auc = 'No Signal Model found in session folder' + + return {'AUC': auc} + + def set_default_transform(self, sample_rate: int) -> None: + downsample_rate = self.parameters.get("down_sampling_rate") + notch_filter = self.parameters.get("notch_filter_frequency") + filter_high = self.parameters.get("filter_high") + filter_low = self.parameters.get("filter_low") + filter_order = self.parameters.get("filter_order") + self.default_transform = get_default_transform( + sample_rate_hz=sample_rate, + notch_freq_hz=notch_filter, + bandpass_low=filter_low, + bandpass_high=filter_high, + bandpass_order=filter_order, + downsample_factor=downsample_rate, + ) + + def find_eye_channels(self, device_spec: DeviceSpec) -> Optional[list]: + eye_channels = [] + for channel in device_spec.channels: + if 'F' in channel: + eye_channels.append(channel) + if len(eye_channels) == 0: + eye_channels = None + return eye_channels + + def get_triggers(self, session) -> tuple: + trigger_type, trigger_timing, trigger_label = trigger_decoder( + offset=self.static_offset, + trigger_path=f"{session}/{TRIGGER_FILENAME}", + exclusion=[ + TriggerType.PREVIEW, + TriggerType.EVENT, + TriggerType.FIXATION], + device_type='EEG' + ) + return trigger_type, trigger_timing, trigger_label + + def get_figure_handles(self, raw_data, channel_map, triggers) -> List[Figure]: + _, trigger_timing, trigger_label = triggers + figure_handles = visualize_erp( + raw_data, + channel_map, + trigger_timing, + trigger_label, + self.trial_window, + transform=self.default_transform, + plot_average=True, + plot_topomaps=True, + show=False, + ) + return figure_handles + + def get_artifact_detector(self, raw_data, device_spec, triggers) -> ArtifactDetection: + eye_channels = self.find_eye_channels(device_spec) + artifact_detector = ArtifactDetection( + raw_data, + self.parameters, + device_spec, + eye_channels=eye_channels, + session_triggers=triggers) + artifact_detector.detect_artifacts() + return artifact_detector diff --git a/bcipy/task/demo/demo_orchestrator.py b/bcipy/task/demo/demo_orchestrator.py index b596d376c..d57bd0013 100644 --- a/bcipy/task/demo/demo_orchestrator.py +++ b/bcipy/task/demo/demo_orchestrator.py @@ -1,11 +1,10 @@ from bcipy.config import DEFAULT_PARAMETERS_PATH from bcipy.task.orchestrator import SessionOrchestrator -# from bcipy.task.actions import (OfflineAnalysisAction) -from bcipy.task.paradigm.rsvp import RSVPCalibrationTask, RSVPCopyPhraseTask -# from bcipy.task.paradigm.rsvp import RSVPTimingVerificationCalibration +from bcipy.task.actions import (OfflineAnalysisAction, IntertaskAction, BciPyCalibrationReportAction) +from bcipy.task.paradigm.rsvp import RSVPCalibrationTask +# from bcipy.task.paradigm.rsvp import RSVPCopyPhraseTask, RSVPTimingVerificationCalibration from bcipy.task.paradigm.matrix import MatrixCalibrationTask # from bcipy.task.paradigm.matrix.timing_verification import MatrixTimingVerificationCalibration -from bcipy.task.paradigm.vep import VEPCalibrationTask def demo_orchestrator(parameters_path: str) -> None: @@ -18,17 +17,22 @@ def demo_orchestrator(parameters_path: str) -> None: fake_data = True alert_finished = True tasks = [ - VEPCalibrationTask, - MatrixCalibrationTask, RSVPCalibrationTask, + IntertaskAction, # OfflineAnalysisAction, - RSVPCopyPhraseTask, - RSVPCopyPhraseTask, - RSVPCopyPhraseTask, - RSVPCopyPhraseTask + # IntertaskAction, + MatrixCalibrationTask, + IntertaskAction, + OfflineAnalysisAction, + IntertaskAction, + BciPyCalibrationReportAction ] orchestrator = SessionOrchestrator( - user='demo_orchestrator', parameters_path=parameters_path, alert=alert_finished, fake=fake_data) + user='offline_testing', + parameters_path=parameters_path, + alert=alert_finished, + visualize=True, + fake=fake_data) orchestrator.add_tasks(tasks) orchestrator.execute() diff --git a/bcipy/task/orchestrator/orchestrator.py b/bcipy/task/orchestrator/orchestrator.py index 24ba2c77f..8ec3ce2f7 100644 --- a/bcipy/task/orchestrator/orchestrator.py +++ b/bcipy/task/orchestrator/orchestrator.py @@ -1,7 +1,10 @@ +# mypy: disable-error-code="arg-type, assignment" import errno import os import json +import subprocess from datetime import datetime +import random import logging from logging import Logger from typing import List, Type, Optional @@ -16,10 +19,9 @@ MULTIPHRASE_FILENAME, PROTOCOL_FILENAME, PROTOCOL_LOG_FILENAME, - SESSION_LOG_FILENAME + SESSION_LOG_FILENAME, ) from bcipy.helpers.load import load_json_parameters -from bcipy.helpers.visualization import visualize_session_data class SessionOrchestrator: @@ -74,10 +76,12 @@ def __init__( self.logger = self._init_orchestrator_logger(self.save_folder) self.alert = alert + self.logger.info("Alerts are on") if self.alert else self.logger.info("Alerts are off") self.visualize = visualize self.progress = 0 self.ready_to_execute = False + self.user_exit = False self.logger.info("Session Orchestrator initialized successfully") def add_task(self, task: Type[Task]) -> None: @@ -124,6 +128,8 @@ def initialize_copy_phrases(self) -> None: with open(self.parameters['copy_phrases_location'], 'r') as f: copy_phrases = json.load(f) self.copyphrases = copy_phrases['Phrases'] + # randomize the order of the phrases + random.shuffle(self.copyphrases) else: self.copyphrases = None self.next_phrase = self.parameters['task_text'] @@ -153,25 +159,39 @@ def execute(self) -> None: self.parameters, data_save_location, fake=self.fake, + alert_finished=self.alert, experiment_id=self.experiment_id, parameters_path=self.parameters_path, - last_task_dir=self.last_task_dir) + protocol_path=self.save_folder, + last_task_dir=self.last_task_dir, + progress=self.progress, + tasks=self.tasks, + exit_callback=self.close_experiment_callback) task_data = initialized_task.execute() self.session_data.append(task_data) self.logger.info(f"Task {task.name} completed successfully") # some tasks may need access to the previous task's data self.last_task_dir = data_save_location - if self.alert: - initialized_task.alert() + if self.user_exit: + break - if self.visualize: - # Visualize session data and fail silently if it errors - try: - visualize_session_data(data_save_location, self.parameters) - pass - except Exception as e: - self.logger.info(f'Error visualizing session data: {e}') + if initialized_task.mode != TaskMode.ACTION: + if self.alert: + initialized_task.alert() + + if self.visualize: + # Visualize session data and fail silently if it errors + try: + self.logger.info(f"Visualizing session data. Saving to {data_save_location}") + subprocess.run( + f'bcipy-erp-viz -s "{data_save_location}" ' + f'--parameters "{self.parameters_path}" --show --save', + shell=True) + except Exception as e: + self.logger.info(f'Error visualizing session data: {e}') + + initialized_task = None except Exception as e: self.logger.error(f"Task {task.name} failed to execute") @@ -210,7 +230,19 @@ def _init_task_save_folder(self, task: Type[Task]) -> str: # make a directory to save task data to os.makedirs(save_directory) os.makedirs(os.path.join(save_directory, 'logs'), exist_ok=True) - # save parameters to save directory + # save parameters to save directory with task name + self.parameters.add_entry( + "task", + { + "value": task.name, + "section": "task_congig", + "name": "BciPy Task", + "helpTip": "A string representing the task that was executed", + "recommended": "", + "editable": "false", + "type": "str", + } + ) self.parameters.save(save_directory) except OSError as error: @@ -252,3 +284,8 @@ def _save_copy_phrases(self) -> None: def get_system_info(self) -> dict: return get_system_info() + + def close_experiment_callback(self): + """Callback to close the experiment.""" + self.logger.info("User has exited the experiment.") + self.user_exit = True diff --git a/bcipy/task/paradigm/rsvp/copy_phrase.py b/bcipy/task/paradigm/rsvp/copy_phrase.py index f66f47cd7..da66a69f4 100644 --- a/bcipy/task/paradigm/rsvp/copy_phrase.py +++ b/bcipy/task/paradigm/rsvp/copy_phrase.py @@ -125,12 +125,13 @@ def __init__( ) -> None: super(RSVPCopyPhraseTask, self).__init__() self.fake = fake + self.parameters = parameters + self.language_model = self.get_language_model() + self.signal_models = self.get_signal_models() daq, servers, win = self.setup(parameters, file_save, fake) self.servers = servers self.window = win self.daq = daq - self.parameters = parameters - self.fake = fake self.validate_parameters() @@ -142,8 +143,6 @@ def __init__( self.button_press_error_prob = parameters['preview_inquiry_error_prob'] - self.language_model = self.get_language_model() - self.signal_models = self.get_signal_models() self.signal_model = self.signal_models[0] if self.signal_models else None self.evidence_evaluators = self.init_evidence_evaluators(self.signal_models) self.evidence_types = self.init_evidence_types(self.signal_models, self.evidence_evaluators) @@ -198,7 +197,7 @@ def get_language_model(self) -> LanguageModel: def get_signal_models(self) -> Optional[List[SignalModel]]: if not self.fake: try: - model_dir = self.parameters['signal_model_path'] + model_dir = self.parameters.get('signal_model_path', None) signal_models = load_signal_models(directory=model_dir) assert signal_models, f"No signal models found in {model_dir}" except Exception as error: diff --git a/bcipy/task/tests/core/test_actions.py b/bcipy/task/tests/core/test_actions.py index b70e8d55e..31b67c453 100644 --- a/bcipy/task/tests/core/test_actions.py +++ b/bcipy/task/tests/core/test_actions.py @@ -1,7 +1,7 @@ import unittest import subprocess -from mockito import mock, when, verify, unstub, any +from mockito import mock, when, verify, unstub from bcipy.task import actions, TaskData from bcipy.task.actions import CodeHookAction, OfflineAnalysisAction, ExperimentFieldCollectionAction @@ -43,15 +43,18 @@ def test_code_hook_action_no_subprocess(self) -> None: verify(subprocess, times=1).run(code_hook, shell=True) def test_offline_analysis_action(self) -> None: - when(actions).offline_analysis(self.data_directory, self.parameters, alert_finished=any()).thenReturn(None) + cmd_expected = f"bcipy-train --parameters '{self.parameters_path}'" + + when(subprocess).run(cmd_expected, shell=True, check=True).thenReturn(None) action = OfflineAnalysisAction( parameters=self.parameters, data_directory=self.data_directory, - parameters_path=self.parameters_path + parameters_path=self.parameters_path, ) response = action.execute() + cmd_expected = f"bcipy-train --parameters '{self.parameters_path}'" self.assertIsInstance(response, TaskData) - verify(actions, times=1).offline_analysis(self.data_directory, self.parameters, alert_finished=any()) + verify(subprocess, times=1).run(cmd_expected, shell=True, check=True) def test_experiment_field_collection_action(self) -> None: experiment_id = 'experiment_id' diff --git a/bcipy/task/tests/orchestrator/test_orchestrator.py b/bcipy/task/tests/orchestrator/test_orchestrator.py index 656c78cc3..f2805a849 100644 --- a/bcipy/task/tests/orchestrator/test_orchestrator.py +++ b/bcipy/task/tests/orchestrator/test_orchestrator.py @@ -71,15 +71,30 @@ def test_orchestrator_execute(self) -> None: any(), fake=False, experiment_id=any(), + alert_finished=any(), parameters_path=any(), - last_task_dir=None).thenReturn(task) + last_task_dir=None, + protocol_path=any(), + progress=any(), + tasks=any(), + exit_callback=any(), + ).thenReturn(task) orchestrator = SessionOrchestrator() orchestrator.add_task(task) orchestrator.execute() verify(task, times=1).__call__( - any(), any(), - fake=False, experiment_id=any(), parameters_path=any(), last_task_dir=None) + any(), + any(), + fake=False, + experiment_id=any(), + alert_finished=any(), + parameters_path=any(), + last_task_dir=None, + protocol_path=any(), + progress=any(), + tasks=any(), + exit_callback=any()) verify(SessionOrchestrator, times=1)._init_orchestrator_save_folder(any()) verify(SessionOrchestrator, times=1)._init_orchestrator_logger(any()) verify(SessionOrchestrator, times=1)._init_task_save_folder(any()) diff --git a/requirements.txt b/requirements.txt index f3f9df88e..f4a764fdd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ construct==2.8.14 mne==1.5.0 pyo==1.0.5 pyglet<=1.5.27,>=1.4 -PsychoPy==2023.2.1 +PsychoPy==2024.2.1 openpyxl==3.1.2 numpy==1.24.4 sounddevice==0.4.4 @@ -13,15 +13,15 @@ SoundFile==0.12.1 scipy==1.10.1 scikit-learn==1.2.2 seaborn==0.9.0 -matplotlib==3.7.2 +matplotlib==3.7.5 pylsl==1.16.2 -pandas==1.5.3 +pandas==2.0.3 psutil==5.7.2 Pillow==9.4.0 py-cpuinfo==9.0.0 pyedflib==0.1.34 -PyQt6==6.6.0 -PyQt6-Qt6==6.6.0 +pyopengl==3.1.7 +PyQt6==6.7.1 pywavelets==1.4.1 tqdm==4.62.2 reportlab==4.2.0 diff --git a/scripts/shell/run_gui.sh b/scripts/shell/run_gui.sh deleted file mode 100644 index 26065ada7..000000000 --- a/scripts/shell/run_gui.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -###### RUN BCI. ####### -# cd to path of bcipy code -# cd bcipy - -# Execute the gui code -python bcipy/gui/BCInterface.py diff --git a/setup.py b/setup.py index 7d0c34f83..76a95b252 100644 --- a/setup.py +++ b/setup.py @@ -107,7 +107,11 @@ def run(self): )), entry_points={ 'console_scripts': - ['bcipy = bcipy.main:bcipy_main', 'bcipy-sim = bcipy.simulator'], + [ + 'bcipy = bcipy.main:bcipy_main', + 'bcipy-erp-viz = bcipy.helpers.visualization:erp', + 'bcipy-sim = bcipy.simulator', + "bcipy-train = bcipy.signal.model.offline_analysis:main"], }, install_requires=REQUIRED, include_package_data=True,