Skip to content

Commit

Permalink
[MRG] Fix dipole plot scale and smooth factors (jonescompneurolab#730)
Browse files Browse the repository at this point in the history
* reorganizing visualization tab widgets. adding two additional float fields for data comparison

* fixing gui test and funciton args name

* fix linting errors

* fix lint errors in gui.py

* fixing gui tests

* adding scaling and smooth input parameters for data comparison visualization

* applied code review suggestions. Added functio to avoid repeated code

* adding docustring to new funtion. removing unused is_loaded_data

* fixing flake8 erros

* removed logic to determine if data is a sim or loaded data

* MAINT: commenting out try-except block

* MAINT: uncommenting atry-except dev code in viz_manager

* MAINT: remove try-except block
  • Loading branch information
kmilo9999 authored and gtdang committed Mar 29, 2024
1 parent 8f78e57 commit fe71cd2
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 39 deletions.
101 changes: 64 additions & 37 deletions hnn_core/gui/_viz_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from hnn_core.gui._logging import logger
from hnn_core.viz import plot_dipole


_fig_placeholder = 'Run simulation to add figures here.'

_plot_types = [
Expand Down Expand Up @@ -273,11 +274,11 @@ def _dynamic_rerender(fig):
fig.tight_layout()


def _plot_on_axes(b, widgets_simulation, widgets_plot_type,
target_simulations,
spectrogram_colormap_selection, dipole_smooth,
max_spectral_frequency, dipole_scaling, widgets, data,
fig_idx, fig, ax, existing_plots):
def _plot_on_axes(b, simulations_widget, widgets_plot_type,
data_widget,
spectrogram_colormap_selection, max_spectral_frequency,
dipole_smooth, dipole_scaling, data_smooth, data_scaling,
widgets, data, fig_idx, fig, ax, existing_plots):
"""Plotting different types of data on the given axes.
Now this function is also responsible for comparing multiple simulations,
Expand Down Expand Up @@ -314,7 +315,7 @@ def _plot_on_axes(b, widgets_simulation, widgets_plot_type,
existing_plots : ipywidgets.VBox
A VBox widget that contains all the existing plots.
"""
sim_name = widgets_simulation.value
sim_name = simulations_widget.value
plot_type = widgets_plot_type.value
# disable add plots for types that do not support overlay
if plot_type in _no_overlay_plot_types:
Expand All @@ -324,33 +325,34 @@ def _plot_on_axes(b, widgets_simulation, widgets_plot_type,
widgets_plot_type.disabled = True

single_simulation = data['simulations'][sim_name]

plot_config = {
"max_spectral_frequency": max_spectral_frequency.value,
simulation_plot_config = {
"dipole_scaling": dipole_scaling.value,
"dipole_smooth": dipole_smooth.value,
"max_spectral_frequency": max_spectral_frequency.value,
"spectrogram_cm": spectrogram_colormap_selection.value
}

dpls_processed = _update_ax(fig, ax, single_simulation, sim_name,
plot_type, plot_config)
plot_type, simulation_plot_config)

# If target_simulations is not None and we are plotting a dipole,
# we need to plot the target dipole as well.
if target_simulations.value in data['simulations'].keys(
if data_widget.value in data['simulations'].keys(
) and plot_type == 'current dipole':

target_sim_name = target_simulations.value
target_sim_name = data_widget.value
target_sim = data['simulations'][target_sim_name]

# plot the target dipole.
# disable scaling for the target dipole.
plot_config['dipole_scaling'] = 1.
data_plot_config = {
"dipole_scaling": data_scaling.value,
"dipole_smooth": data_smooth.value,
"max_spectral_frequency": max_spectral_frequency.value,
"spectrogram_cm": spectrogram_colormap_selection.value
}

# plot the target dipole.
target_dpl_processed = _update_ax(
fig, ax, target_sim, target_sim_name, plot_type,
plot_config)[0] # we assume there is only one dipole.
data_plot_config)[0] # we assume there is only one dipole.

# calculate the RMSE between the two dipoles.
t0 = 0.0
Expand Down Expand Up @@ -401,14 +403,15 @@ def _get_ax_control(widgets, data, fig_idx, fig, ax):
layout = Layout(width="98%")
simulation_names = tuple(data['simulations'].keys())
sim_name_default = simulation_names[-1]

if len(simulation_names) == 0:
simulation_names = [
"None",
]

simulation_selection = Dropdown(
options=simulation_names,
value=sim_name_default,
value=simulation_names[0],
description='Simulation Data:',
disabled=False,
layout=layout,
Expand All @@ -431,8 +434,12 @@ def _get_ax_control(widgets, data, fig_idx, fig, ax):
style=analysis_style,
)

tagert_names = simulation_names[:-1]
if len(simulation_names) > 1:
tagert_names = simulation_names[1:]

target_data_selection = Dropdown(
options=simulation_names[:-1] + ('None',),
options=tagert_names + ('None',),
value='None',
description='Data to Compare:',
disabled=False,
Expand All @@ -447,16 +454,33 @@ def _get_ax_control(widgets, data, fig_idx, fig, ax):
layout=layout,
style=analysis_style,
)
dipole_smooth = FloatText(value=30,
description='Dipole Smooth Window (ms):',
disabled=False,
layout=layout,
style=analysis_style)
dipole_scaling = FloatText(value=3000,
description='Dipole Scaling:',
disabled=False,
layout=layout,
style=analysis_style)
simulation_dipole_smooth = FloatText(
value=30,
description='Dipole Smooth Window (ms):',
disabled=False,
layout=layout,
style=analysis_style)

simulation_dipole_scaling = FloatText(
value=3000,
description='Simulation Dipole Scaling:',
disabled=False,
layout=layout,
style=analysis_style)

data_dipole_smooth = FloatText(
value=0,
description='Data Smooth Window (ms):',
disabled=False,
layout=layout,
style=analysis_style)

data_dipole_scaling = FloatText(
value=1,
description='Data Dipole Scaling:',
disabled=False,
layout=layout,
style=analysis_style)

max_spectral_frequency = FloatText(
value=100,
Expand Down Expand Up @@ -501,13 +525,15 @@ def _on_plot_type_change(new_plot_type):
plot_button.on_click(
partial(
_plot_on_axes,
widgets_simulation=simulation_selection,
simulations_widget=simulation_selection,
widgets_plot_type=plot_type_selection,
target_simulations=target_data_selection,
data_widget=target_data_selection,
spectrogram_colormap_selection=spectrogram_colormap_selection,
dipole_smooth=dipole_smooth,
max_spectral_frequency=max_spectral_frequency,
dipole_scaling=dipole_scaling,
dipole_smooth=simulation_dipole_smooth,
dipole_scaling=simulation_dipole_scaling,
data_smooth=data_dipole_smooth,
data_scaling=data_dipole_scaling,
widgets=widgets,
data=data,
fig_idx=fig_idx,
Expand All @@ -517,8 +543,9 @@ def _on_plot_type_change(new_plot_type):
))

vbox = VBox([
simulation_selection, plot_type_selection, target_data_selection,
dipole_smooth, dipole_scaling, max_spectral_frequency,
plot_type_selection, simulation_selection, simulation_dipole_smooth,
simulation_dipole_scaling, target_data_selection, data_dipole_smooth,
data_dipole_scaling, max_spectral_frequency,
spectrogram_colormap_selection,
HBox(
[plot_button, clear_button],
Expand Down Expand Up @@ -816,11 +843,11 @@ def _simulate_edit_figure(self, fig_name, ax_name, simulation_name,
ax_control_tabs.selected_index = ax_idx

# ax config
simulation_ctrl = ax_control_tabs.children[ax_idx].children[0]
simulation_ctrl = ax_control_tabs.children[ax_idx].children[1]
# return simulation_ctrl
simulation_ctrl.value = simulation_name

plot_type_ctrl = ax_control_tabs.children[ax_idx].children[1]
plot_type_ctrl = ax_control_tabs.children[ax_idx].children[0]
plot_type_ctrl.value = plot_type

config_name_idx = {
Expand Down
4 changes: 2 additions & 2 deletions hnn_core/tests/test_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def test_gui_edit_figure():
assert len(axes_config_tabs.children) == n_figs

axes_config = axes_config_tabs.children[-1].children[1]
simulation_selection = axes_config.children[0].children[0]
simulation_selection = axes_config.children[0].children[1]
assert simulation_selection.options == tuple(sim_names[:n_figs])
plt.close('all')

Expand All @@ -378,7 +378,7 @@ def test_gui_figure_overlay():
for controls in tab.children[1].children:
add_plot_button = controls.children[-2].children[0]
clear_ax_button = controls.children[-2].children[1]
plot_type_selection = controls.children[1]
plot_type_selection = controls.children[0]

assert plot_type_selection.disabled is True
clear_ax_button.click()
Expand Down

0 comments on commit fe71cd2

Please sign in to comment.