Skip to content

Commit

Permalink
Model and visualization updates
Browse files Browse the repository at this point in the history
  • Loading branch information
celikbasak committed Nov 16, 2023
1 parent c2f23ae commit ea5be74
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 40 deletions.
23 changes: 22 additions & 1 deletion bcipy/helpers/visualization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List, Dict, Optional, Tuple

import matplotlib.pyplot as plt
import mne
Expand Down Expand Up @@ -725,3 +725,24 @@ def visualize_session_data(session_path: str, parameters: dict, show=True) -> Fi
save_path=session_path,
show=show,
)


def visualize_gaze_accuracies(accuracy_dict: Dict[str, np.ndarray],
accuracy: float,
save_path: Optional[str] = None,
show: Optional[bool] = False) -> Figure:
"""
Visualize Gaze Accuracies.
Plot the accuracies of each symbol using a bar plot.
Returns a list of the figure handles created.
"""

fig, ax = plt.subplots()
ax.bar(accuracy_dict.keys(), accuracy_dict.values())
ax.set_xlabel('Symbol')
ax.set_ylabel('Accuracy')
ax.set_title('Overall Accuracy: ' + str(round(accuracy, 2)))

return fig
64 changes: 42 additions & 22 deletions bcipy/signal/model/fusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from bcipy.helpers.stimuli import GazeReshaper
from sklearn.model_selection import cross_val_score # noqa
from sklearn.utils.estimator_checks import check_estimator # noqa
import cvxpy as cp
import scipy.stats as stats

from typing import Optional


class GazeModel(SignalModel):
Expand All @@ -15,40 +19,55 @@ def __init__(self, num_components=2):
self.num_components = num_components # number of gaussians to fit

def fit(self, train_data: np.array):
model = GaussianMixture(n_components=2, random_state=0, init_params='kmeans')
model = GaussianMixture(n_components=1, random_state=0, init_params='kmeans')
model.fit(train_data)
self.model = model

return self

def get_scores(self, test_data: np.array):
def evaluate(self, test_data: np.array):
'''
Compute the log-likelihood of each sample.
Compute the mean and covariance of each mixture component.
Compute mean and covariance of each mixture component.
'''

scores = self.model.score_samples(test_data)
means = self.model.means_
covs = self.model.covariances_

return scores, means, covs
return means, covs

def predict(self, scores: np.array):
def predict(self, test_data: np.array, means: np.array, covs: np.array):
'''
Compute log-likelihood of each sample.
Predict the labels for the test data.
'''
# Compute over log-likelihood scores
# Get the argmax of the scores

# return predictions
num_components = len(means)

N, D = test_data.shape
K = num_components

likelihoods = np.zeros((N, K), dtype=object)
predictions = np.zeros(N, dtype=object)

def evaluate(self, test_data: np.ndarray, test_labels: np.ndarray):
# Find the likelihoods by insterting the test data into the pdf of each component
for i in range(N):
for k in range(K):
mu = means[k]
sigma = covs[k]

likelihoods[i,k] = stats.multivariate_normal.pdf(test_data[i], mu, sigma)

# Find the argmax of the likelihoods to get the predictions
predictions[i] = np.argmax(likelihoods[i])

return likelihoods, predictions

def calculate_acc(self, test_data: np.ndarray, test_labels: Optional[np.ndarray]):
'''
Compute model performance characteristics on the provided test data and labels.
'''
# Compute the AUC

# return ModelEvaluationReport(auc)
# return accuracy

def save(self, path: Path):
"""Save model state to the provided checkpoint"""
Expand All @@ -73,37 +92,38 @@ def fit(self, train_data: np.array):

return self

def get_scores(self, test_data: np.array, sym_pos: np.array):
def evaluate(self, test_data: np.array, sym_pos: np.array):
'''
Compute the log-likelihood of each sample.
Return the mean and covariance of each mixture component.
Return mean and covariance of each mixture component.
test_data: gaze data for each symbol
sym_pos: mid positions for each symbol in Tobii coordinates
'''

scores = self.model.score_samples(test_data)
means = self.model.means_ + sym_pos
covs = self.model.covariances_

return scores, means, covs
return means, covs

def predict(self, scores: np.array):
def predict(self, test_data: np.array):
'''
Compute log-likelihood of each sample.
Predict the labels for the test data.
'''
# Compute over log-likelihood scores
# Get the argmax of the scores

scores = self.model.score_samples(test_data)

# return predictions

def evaluate(self, test_data: np.ndarray, test_labels: np.ndarray):
def calculate_acc(self, test_data: np.ndarray, sym_pos: np.array):
'''
Compute model performance characteristics on the provided test data and labels.
'''
# Compute the AUC

# return ModelEvaluationReport(auc)

# return accuracy

def save(self, path: Path):
"""Save model state to the provided checkpoint"""
Expand Down
67 changes: 50 additions & 17 deletions bcipy/signal/model/offline_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from bcipy.helpers.triggers import TriggerType, trigger_decoder
from bcipy.helpers.visualization import (visualize_erp, visualize_gaze,
visualize_centralized_data,
visualize_results_all_symbols)
visualize_results_all_symbols,
visualize_gaze_accuracies)
from bcipy.preferences import preferences
from bcipy.signal.model.base_model import SignalModel, SignalModelMetadata
from bcipy.signal.model.fusion_model import GazeModel, GazeModel_AllSymbols
Expand Down Expand Up @@ -312,31 +313,25 @@ def analyze_gaze(gaze_data, device_spec, data_folder, save_figures=False, show_f
means_all = []
covs_all = []
for sym in symbol_set:
# Skip if there's no evidence for this symbol:
if len(inquiries[sym]) == 0:
test_dict[sym] = []
continue
le = preprocessed_data[sym][0]
re = preprocessed_data[sym][1]

# Train test split:
labels = np.array([sym] * len(le)) # Labels are the same for both eyes
train_le, test_le, train_labels_le, test_labels_le = subset_data(le, labels, test_size=0.2, swap_axes=False)
train_re, test_re, train_labels_re, test_labels_re = subset_data(re, labels, test_size=0.2, swap_axes=False)
test_dict[sym] = [test_le, test_re]

if model_type == "Centralized":
# Centralize the data using symbol positions:
# Load json file.
# TODO: move this to a helper function, or get the symbol positions from the build_grid method
with open(f"{BCIPY_ROOT}/parameters/symbol_positions.json", 'r') as params_file:
symbol_positions = json.load(params_file)
# Subtract the symbol positions from the data:
centralized_data_left.append(model.reshaper.centralize_all_data(train_le, symbol_positions[sym]))
centralized_data_right.append(model.reshaper.centralize_all_data(train_re, symbol_positions[sym]))
test_dict[sym] = np.concatenate((test_le, test_re), axis=0)

# Fit the model based on model type.
# Model 1: Fit Gaussian mixture (comp=2) on each symbol and each eye separately
if model_type == "Individual":
model.fit(train_re)

scores, means, covs = model.get_scores(test_re)
means, covs = model.evaluate(test_re)

# Visualize the results:
# figure_handles = visualize_gaze_inquiries(
Expand All @@ -351,8 +346,45 @@ def analyze_gaze(gaze_data, device_spec, data_folder, save_figures=False, show_f
means_all.append(means)
covs_all.append(covs)

if model_type == "Centralized":
# scores, predictions = model.predict(test_re)

# Model 2: Fit Gaussian mixture (comp=1) on a centralized data
if model_type == "Centralized":
# Centralize the data using symbol positions:
# Load json file.
# TODO: move this to a helper function, or get the symbol positions from the build_grid method
with open(f"{BCIPY_ROOT}/parameters/symbol_positions.json", 'r') as params_file:
symbol_positions = json.load(params_file)
# Subtract the symbol positions from the data:
centralized_data_left.append(model.reshaper.centralize_all_data(train_le, symbol_positions[sym]))
centralized_data_right.append(model.reshaper.centralize_all_data(train_re, symbol_positions[sym]))

if model_type == "Individual":
# Takes in all means and covs from individual symbols, and
# calculates the likelihoods and predictions for each test point.
# Convert test_dict to a list of arrays:
accuracy = 0
acc_all_symbols = {}
counter = 0
for sym in symbol_set:
# Skip if test_dict is empty for certain symbols:
if len(test_dict[sym]) == 0:
acc_all_symbols[sym] = 0
continue
likelihoods, predictions = model.predict(test_dict[sym], np.squeeze(np.array(means_all)), np.squeeze(np.array(covs_all)))
acc_all_symbols[sym] = np.sum(predictions == counter) / len(predictions) * 100
accuracy += np.sum(predictions == counter) / len(predictions) * 100
print(f"""Correct predictions for {sym}: {np.sum(predictions == counter)} / {len(predictions)},
Accuracy {sym}: {np.sum(predictions == counter) / len(predictions) * 100:.2f}""")
counter += 1
accuracy /= counter
print(f"Overall accuracy: {accuracy:.2f}")

# Plot all accuracies as bar plot:
figure_handles = visualize_gaze_accuracies(acc_all_symbols, accuracy, save_path=None, show=True)


if model_type == "Centralized":
cent_left = np.concatenate(np.array(centralized_data_left, dtype=object))
cent_right = np.concatenate(np.array(centralized_data_right, dtype=object))

Expand All @@ -370,7 +402,7 @@ def analyze_gaze(gaze_data, device_spec, data_folder, save_figures=False, show_f
# Add the means back to the symbol positions.
# Calculate scores for the test set.
for sym in symbol_set:
scores, means, covs = model.get_scores(test_dict[sym][0], symbol_positions[sym])
means, covs = model.evaluate(test_dict[sym][0], symbol_positions[sym])

le = preprocessed_data[sym][0]
re = preprocessed_data[sym][1]
Expand All @@ -387,6 +419,7 @@ def analyze_gaze(gaze_data, device_spec, data_folder, save_figures=False, show_f
means_all.append(means)
covs_all.append(covs)


fig_handles = visualize_results_all_symbols(
left_eye_all, right_eye_all,
means_all, covs_all,
Expand Down Expand Up @@ -465,7 +498,7 @@ def offline_analysis(

if device_spec.content_type == "Eyetracker":
analyze_gaze(raw_data, device_spec, data_folder, save_figures, show_figures,
model_type="Centralized")
model_type="Individual")

if alert_finished:
play_sound(f"{STATIC_AUDIO_PATH}/{parameters['alert_sound_file']}")
Expand All @@ -484,7 +517,7 @@ def offline_analysis(
parser.add_argument("--alert", dest="alert", action="store_true")
parser.add_argument("--balanced-acc", dest="balanced", action="store_true")
parser.set_defaults(alert=False)
parser.set_defaults(balanced=False)
parser.set_defaults(balanced=True)
parser.set_defaults(save_figures=False)
parser.set_defaults(show_figures=True)
args = parser.parse_args()
Expand Down

0 comments on commit ea5be74

Please sign in to comment.