Skip to content

Commit

Permalink
Offline analysis bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
celikbasak committed Nov 6, 2023
1 parent 9848765 commit 6df9ee8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
1 change: 1 addition & 0 deletions bcipy/helpers/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def visualize_results_all_symbols(
if means is not None:
means[:, 0] = np.clip(means[:, 0], 0, 1)
means[:, 1] = np.clip(means[:, 1], 0, 1)
means[:, 1] = 1 - means[:, 1]

if heatmap:
# create a dataframe making a column for each x, y pair for both eyes and
Expand Down
36 changes: 18 additions & 18 deletions bcipy/signal/model/offline_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sklearn.model_selection import train_test_split

import bcipy.acquisition.devices as devices
from bcipy.config import (DEFAULT_DEVICE_SPEC_FILENAME,
from bcipy.config import (DEFAULT_DEVICE_SPEC_FILENAME, BCIPY_ROOT,
DEFAULT_PARAMETERS_PATH, STATIC_AUDIO_PATH,
STATIC_IMAGES_PATH, TRIGGER_FILENAME)
from bcipy.helpers.acquisition import analysis_channels, raw_data_filename
Expand Down Expand Up @@ -325,7 +325,7 @@ def analyze_gaze(gaze_data, device_spec, data_folder, save_figures=False, show_f
# Centralize the data using symbol positions:
# Load json file.
# TODO: move this to a helper function, or get the symbol positions from the build_grid method
with open(f"{DEFAULT_PARAMETERS_PATH}/symbol_positions.json", 'r') as params_file:
with open(f"{BCIPY_ROOT}/parameters/symbol_positions.json", 'r') as params_file:
symbol_positions = json.load(params_file)
# Subtract the symbol positions from the data:
centralized_data_left.append(model.reshaper.centralize_all_data(train_le, symbol_positions[sym]))
Expand All @@ -339,13 +339,13 @@ def analyze_gaze(gaze_data, device_spec, data_folder, save_figures=False, show_f
scores, means, covs = model.get_scores(test_re)

# Visualize the results:
figure_handles = visualize_gaze_inquiries(
le, re,
means, covs,
save_path=None,
show=show_figures,
raw_plot=True,
)
# figure_handles = visualize_gaze_inquiries(
# le, re,
# means, covs,
# save_path=None,
# show=show_figures,
# raw_plot=True,
# )
left_eye_all.append(le)
right_eye_all.append(re)
means_all.append(means)
Expand Down Expand Up @@ -375,13 +375,13 @@ def analyze_gaze(gaze_data, device_spec, data_folder, save_figures=False, show_f
le = preprocessed_data[sym][0]
re = preprocessed_data[sym][1]
# Visualize the results:
figure_handles = visualize_gaze_inquiries(
le, re,
means, covs,
save_path=None,
show=show_figures,
raw_plot=True,
)
# figure_handles = visualize_gaze_inquiries(
# le, re,
# means, covs,
# save_path=None,
# show=show_figures,
# raw_plot=True,
# )
left_eye_all.append(le)
right_eye_all.append(re)
means_all.append(means)
Expand Down Expand Up @@ -465,7 +465,7 @@ def offline_analysis(

if device_spec.content_type == "Eyetracker":
analyze_gaze(raw_data, device_spec, data_folder, save_figures, show_figures,
model_type="Individual")
model_type="Centralized")

if alert_finished:
play_sound(f"{STATIC_AUDIO_PATH}/{parameters['alert_sound_file']}")
Expand All @@ -486,7 +486,7 @@ def offline_analysis(
parser.set_defaults(alert=False)
parser.set_defaults(balanced=False)
parser.set_defaults(save_figures=False)
parser.set_defaults(show_figures=False)
parser.set_defaults(show_figures=True)
args = parser.parse_args()

log.info(f"Loading params from {args.parameters_file}")
Expand Down

0 comments on commit 6df9ee8

Please sign in to comment.