From ea5be74346b693297cc281d252c1bd1ea8b8f9e2 Mon Sep 17 00:00:00 2001 From: Basak Celik Date: Thu, 16 Nov 2023 14:20:37 -0500 Subject: [PATCH] Model and visualization updates --- bcipy/helpers/visualization.py | 23 ++++++++- bcipy/signal/model/fusion_model.py | 64 +++++++++++++++--------- bcipy/signal/model/offline_analysis.py | 67 +++++++++++++++++++------- 3 files changed, 114 insertions(+), 40 deletions(-) diff --git a/bcipy/helpers/visualization.py b/bcipy/helpers/visualization.py index e1ff25ea2..02767b536 100644 --- a/bcipy/helpers/visualization.py +++ b/bcipy/helpers/visualization.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Dict, Optional, Tuple import matplotlib.pyplot as plt import mne @@ -725,3 +725,24 @@ def visualize_session_data(session_path: str, parameters: dict, show=True) -> Fi save_path=session_path, show=show, ) + + +def visualize_gaze_accuracies(accuracy_dict: Dict[str, np.ndarray], + accuracy: float, + save_path: Optional[str] = None, + show: Optional[bool] = False) -> Figure: + """ + Visualize Gaze Accuracies. + + Plot the accuracies of each symbol using a bar plot. + + Returns a list of the figure handles created. + """ + + fig, ax = plt.subplots() + ax.bar(accuracy_dict.keys(), accuracy_dict.values()) + ax.set_xlabel('Symbol') + ax.set_ylabel('Accuracy') + ax.set_title('Overall Accuracy: ' + str(round(accuracy, 2))) + + return fig \ No newline at end of file diff --git a/bcipy/signal/model/fusion_model.py b/bcipy/signal/model/fusion_model.py index cd4c03ce6..e47178697 100644 --- a/bcipy/signal/model/fusion_model.py +++ b/bcipy/signal/model/fusion_model.py @@ -6,6 +6,10 @@ from bcipy.helpers.stimuli import GazeReshaper from sklearn.model_selection import cross_val_score # noqa from sklearn.utils.estimator_checks import check_estimator # noqa +import cvxpy as cp +import scipy.stats as stats + +from typing import Optional class GazeModel(SignalModel): @@ -15,40 +19,55 @@ def __init__(self, num_components=2): self.num_components = num_components # number of gaussians to fit def fit(self, train_data: np.array): - model = GaussianMixture(n_components=2, random_state=0, init_params='kmeans') + model = GaussianMixture(n_components=1, random_state=0, init_params='kmeans') model.fit(train_data) self.model = model return self - def get_scores(self, test_data: np.array): + def evaluate(self, test_data: np.array): ''' - Compute the log-likelihood of each sample. - Compute the mean and covariance of each mixture component. + Compute mean and covariance of each mixture component. ''' - scores = self.model.score_samples(test_data) means = self.model.means_ covs = self.model.covariances_ - return scores, means, covs + return means, covs - def predict(self, scores: np.array): + def predict(self, test_data: np.array, means: np.array, covs: np.array): ''' + Compute log-likelihood of each sample. Predict the labels for the test data. ''' - # Compute over log-likelihood scores - # Get the argmax of the scores - # return predictions + num_components = len(means) + + N, D = test_data.shape + K = num_components + + likelihoods = np.zeros((N, K), dtype=object) + predictions = np.zeros(N, dtype=object) - def evaluate(self, test_data: np.ndarray, test_labels: np.ndarray): + # Find the likelihoods by insterting the test data into the pdf of each component + for i in range(N): + for k in range(K): + mu = means[k] + sigma = covs[k] + + likelihoods[i,k] = stats.multivariate_normal.pdf(test_data[i], mu, sigma) + + # Find the argmax of the likelihoods to get the predictions + predictions[i] = np.argmax(likelihoods[i]) + + return likelihoods, predictions + + def calculate_acc(self, test_data: np.ndarray, test_labels: Optional[np.ndarray]): ''' Compute model performance characteristics on the provided test data and labels. ''' - # Compute the AUC - # return ModelEvaluationReport(auc) + # return accuracy def save(self, path: Path): """Save model state to the provided checkpoint""" @@ -73,37 +92,38 @@ def fit(self, train_data: np.array): return self - def get_scores(self, test_data: np.array, sym_pos: np.array): + def evaluate(self, test_data: np.array, sym_pos: np.array): ''' - Compute the log-likelihood of each sample. - Return the mean and covariance of each mixture component. + Return mean and covariance of each mixture component. test_data: gaze data for each symbol sym_pos: mid positions for each symbol in Tobii coordinates ''' - scores = self.model.score_samples(test_data) means = self.model.means_ + sym_pos covs = self.model.covariances_ - return scores, means, covs + return means, covs - def predict(self, scores: np.array): + def predict(self, test_data: np.array): ''' + Compute log-likelihood of each sample. Predict the labels for the test data. ''' # Compute over log-likelihood scores # Get the argmax of the scores + scores = self.model.score_samples(test_data) + # return predictions - def evaluate(self, test_data: np.ndarray, test_labels: np.ndarray): + def calculate_acc(self, test_data: np.ndarray, sym_pos: np.array): ''' Compute model performance characteristics on the provided test data and labels. ''' - # Compute the AUC - # return ModelEvaluationReport(auc) + + # return accuracy def save(self, path: Path): """Save model state to the provided checkpoint""" diff --git a/bcipy/signal/model/offline_analysis.py b/bcipy/signal/model/offline_analysis.py index 100c34c64..4e76554d8 100644 --- a/bcipy/signal/model/offline_analysis.py +++ b/bcipy/signal/model/offline_analysis.py @@ -23,7 +23,8 @@ from bcipy.helpers.triggers import TriggerType, trigger_decoder from bcipy.helpers.visualization import (visualize_erp, visualize_gaze, visualize_centralized_data, - visualize_results_all_symbols) + visualize_results_all_symbols, + visualize_gaze_accuracies) from bcipy.preferences import preferences from bcipy.signal.model.base_model import SignalModel, SignalModelMetadata from bcipy.signal.model.fusion_model import GazeModel, GazeModel_AllSymbols @@ -312,6 +313,10 @@ def analyze_gaze(gaze_data, device_spec, data_folder, save_figures=False, show_f means_all = [] covs_all = [] for sym in symbol_set: + # Skip if there's no evidence for this symbol: + if len(inquiries[sym]) == 0: + test_dict[sym] = [] + continue le = preprocessed_data[sym][0] re = preprocessed_data[sym][1] @@ -319,24 +324,14 @@ def analyze_gaze(gaze_data, device_spec, data_folder, save_figures=False, show_f labels = np.array([sym] * len(le)) # Labels are the same for both eyes train_le, test_le, train_labels_le, test_labels_le = subset_data(le, labels, test_size=0.2, swap_axes=False) train_re, test_re, train_labels_re, test_labels_re = subset_data(re, labels, test_size=0.2, swap_axes=False) - test_dict[sym] = [test_le, test_re] - - if model_type == "Centralized": - # Centralize the data using symbol positions: - # Load json file. - # TODO: move this to a helper function, or get the symbol positions from the build_grid method - with open(f"{BCIPY_ROOT}/parameters/symbol_positions.json", 'r') as params_file: - symbol_positions = json.load(params_file) - # Subtract the symbol positions from the data: - centralized_data_left.append(model.reshaper.centralize_all_data(train_le, symbol_positions[sym])) - centralized_data_right.append(model.reshaper.centralize_all_data(train_re, symbol_positions[sym])) + test_dict[sym] = np.concatenate((test_le, test_re), axis=0) # Fit the model based on model type. # Model 1: Fit Gaussian mixture (comp=2) on each symbol and each eye separately if model_type == "Individual": model.fit(train_re) - scores, means, covs = model.get_scores(test_re) + means, covs = model.evaluate(test_re) # Visualize the results: # figure_handles = visualize_gaze_inquiries( @@ -351,8 +346,45 @@ def analyze_gaze(gaze_data, device_spec, data_folder, save_figures=False, show_f means_all.append(means) covs_all.append(covs) - if model_type == "Centralized": + # scores, predictions = model.predict(test_re) + # Model 2: Fit Gaussian mixture (comp=1) on a centralized data + if model_type == "Centralized": + # Centralize the data using symbol positions: + # Load json file. + # TODO: move this to a helper function, or get the symbol positions from the build_grid method + with open(f"{BCIPY_ROOT}/parameters/symbol_positions.json", 'r') as params_file: + symbol_positions = json.load(params_file) + # Subtract the symbol positions from the data: + centralized_data_left.append(model.reshaper.centralize_all_data(train_le, symbol_positions[sym])) + centralized_data_right.append(model.reshaper.centralize_all_data(train_re, symbol_positions[sym])) + + if model_type == "Individual": + # Takes in all means and covs from individual symbols, and + # calculates the likelihoods and predictions for each test point. + # Convert test_dict to a list of arrays: + accuracy = 0 + acc_all_symbols = {} + counter = 0 + for sym in symbol_set: + # Skip if test_dict is empty for certain symbols: + if len(test_dict[sym]) == 0: + acc_all_symbols[sym] = 0 + continue + likelihoods, predictions = model.predict(test_dict[sym], np.squeeze(np.array(means_all)), np.squeeze(np.array(covs_all))) + acc_all_symbols[sym] = np.sum(predictions == counter) / len(predictions) * 100 + accuracy += np.sum(predictions == counter) / len(predictions) * 100 + print(f"""Correct predictions for {sym}: {np.sum(predictions == counter)} / {len(predictions)}, + Accuracy {sym}: {np.sum(predictions == counter) / len(predictions) * 100:.2f}""") + counter += 1 + accuracy /= counter + print(f"Overall accuracy: {accuracy:.2f}") + + # Plot all accuracies as bar plot: + figure_handles = visualize_gaze_accuracies(acc_all_symbols, accuracy, save_path=None, show=True) + + + if model_type == "Centralized": cent_left = np.concatenate(np.array(centralized_data_left, dtype=object)) cent_right = np.concatenate(np.array(centralized_data_right, dtype=object)) @@ -370,7 +402,7 @@ def analyze_gaze(gaze_data, device_spec, data_folder, save_figures=False, show_f # Add the means back to the symbol positions. # Calculate scores for the test set. for sym in symbol_set: - scores, means, covs = model.get_scores(test_dict[sym][0], symbol_positions[sym]) + means, covs = model.evaluate(test_dict[sym][0], symbol_positions[sym]) le = preprocessed_data[sym][0] re = preprocessed_data[sym][1] @@ -387,6 +419,7 @@ def analyze_gaze(gaze_data, device_spec, data_folder, save_figures=False, show_f means_all.append(means) covs_all.append(covs) + fig_handles = visualize_results_all_symbols( left_eye_all, right_eye_all, means_all, covs_all, @@ -465,7 +498,7 @@ def offline_analysis( if device_spec.content_type == "Eyetracker": analyze_gaze(raw_data, device_spec, data_folder, save_figures, show_figures, - model_type="Centralized") + model_type="Individual") if alert_finished: play_sound(f"{STATIC_AUDIO_PATH}/{parameters['alert_sound_file']}") @@ -484,7 +517,7 @@ def offline_analysis( parser.add_argument("--alert", dest="alert", action="store_true") parser.add_argument("--balanced-acc", dest="balanced", action="store_true") parser.set_defaults(alert=False) - parser.set_defaults(balanced=False) + parser.set_defaults(balanced=True) parser.set_defaults(save_figures=False) parser.set_defaults(show_figures=True) args = parser.parse_args()