Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Added plot sets functionality #746

Merged
merged 6 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/gui/tutorial_erp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@
"metadata": {},
"outputs": [],
"source": [
"gui._simulate_viz_action('switch_fig_template', 'single figure')\n",
"gui._simulate_viz_action('switch_fig_template', '[Blank] single figure')\n",
"gui._simulate_viz_action('add_fig')\n",
"gui._simulate_viz_action(\n",
" \"edit_figure\",\n",
Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ Changelog

- Added kwargs options to `plot_spikes_hist` for adjusting the histogram plots
of spiking activity, by `Abdul Samad Siddiqui`_ in :gh:`732`.

- Added pre defined plot sets for simulated data,
kmilo9999 marked this conversation as resolved.
Show resolved Hide resolved
by `Camilo Diaz`_ in :gh:`746`
kmilo9999 marked this conversation as resolved.
Show resolved Hide resolved

Bug
~~~
Expand Down
170 changes: 152 additions & 18 deletions hnn_core/gui/_viz_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from hnn_core.gui._logging import logger
from hnn_core.viz import plot_dipole


_fig_placeholder = 'Run simulation to add figures here.'

_plot_types = [
Expand Down Expand Up @@ -49,26 +48,91 @@
]

fig_templates = {
"2row x 1col (1:3)": {
"kwargs": "gridspec_kw={\"height_ratios\":[1,3]}",
"[Blank] 2row x 1col (1:3)": {
"kwargs": {
"gridspec_kw": {"height_ratios": [1, 3]}
},
"mosaic": "00\n11",
},
"2row x 1col (1:1)": {
"kwargs": "gridspec_kw={\"height_ratios\":[1,1]}",
"[Blank] 2row x 1col (1:1)": {
"kwargs": {
"gridspec_kw": {"height_ratios": [1, 1]}
},
"mosaic": "00\n11",
},
"1row x 2col (1:1)": {
"kwargs": "gridspec_kw={\"height_ratios\":[1,1]}",
"[Blank] 1row x 2col (1:1)": {
"kwargs": {
"gridspec_kw": {"height_ratios": [1, 1]}
},
"mosaic": "01\n01",
},
"single figure": {
"kwargs": "",
"[Blank] single figure": {
"kwargs": {
"gridspec_kw": ""
},
"mosaic": "00\n00",
},
"2row x 2col (1:1)": {
"kwargs": "gridspec_kw={\"height_ratios\":[1,1]}",
"[Blank] 2row x 2col (1:1)": {
"kwargs": {
"gridspec_kw": {"height_ratios": [1, 1]}
},
"mosaic": "01\n23",
}
}

data_templates = {
"Drive-Dipole (2x1)": {
"kwargs": {
"gridspec_kw": {"height_ratios": [1, 3]}
},
"mosaic": "00\n11",
"ax_plots": [("ax0", "input histogram"), ("ax1", "current dipole")]
},
"Dipole Layers (3x1)": {
"kwargs": {
"gridspec_kw": {"height_ratios": [1, 1, 1]}
},
"mosaic": "0\n1\n2",
"ax_plots": [("ax0", "layer2 dipole"), ("ax1", "layer5 dipole"),
("ax2", "current dipole")]
},
"Drive-Spikes (2x1)": {
"kwargs": {
"gridspec_kw": {"height_ratios": [1, 3]}
},
"mosaic": "00\n11",
"ax_plots": [("ax0", "input histogram"), ("ax1", "spikes")]
},
"Dipole-Spectrogram (2x1)": {
"kwargs": {
"gridspec_kw": {"height_ratios": [1, 3]}
},
"mosaic": "00\n11",
"ax_plots": [("ax0", "current dipole"), ("ax1", "spectrogram")]
},
"Dipole-Spikes (2x1)": {
"kwargs": {
"gridspec_kw": {"height_ratios": [1, 1]}
},
"mosaic": "00\n11",
"ax_plots": [("ax0", "current dipole"), ("ax1", "spikes")]
},
"Drive-Dipole-Spectrogram (3x1)": {
"kwargs": {
"gridspec_kw": {"height_ratios": [1, 1, 2]}
},
"mosaic": "0\n1\n2",
"ax_plots": [("ax0", "input histogram"), ("ax1", "current dipole"),
("ax2", "spectrogram")]
},
"PSD Layers (3x1)": {
"kwargs": {
"gridspec_kw": {"height_ratios": [1, 1, 1]}
},
"mosaic": "0\n1\n2",
"ax_plots": [("ax0", "layer2 dipole"), ("ax1", "layer5 dipole"),
("ax2", "PSD")]
}
}


Expand All @@ -87,6 +151,11 @@ def check_sim_plot_types(
target_selection.value = 'None'


def _check_template_type_is_data_dependant(template_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have a unit test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kmilo9999 do you think this is sufficiently covered by the current unit tests? It's a pretty trivial function so I'll leave it to your judgement

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kmilo9999 do you think this is sufficiently covered by the current unit tests? It's a pretty trivial function so I'll leave it to your judgement

Yes, the unit test already covers this function inside the gui.viz_manager.make_fig_button call.

sim_data_options = list(data_templates.keys())
return template_name in sim_data_options


def target_comparison_change(new_target_name, simulation_selection, data):
"""Triggered when the target data is turned on or changed.
"""
Expand Down Expand Up @@ -636,8 +705,7 @@ def _add_axes_controls(widgets, data, fig, axd):
widgets['axes_config_tabs'].set_title(n_tabs, _idx2figname(fig_idx))


def _add_figure(b, widgets, data, scale=0.95, dpi=96):
template_name = widgets['templates_dropdown'].value
def _add_figure(b, widgets, data, template_type, scale=0.95, dpi=96):
fig_idx = data['fig_idx']['idx']
viz_output_layout = data['visualization_output']
fig_outputs = Output()
Expand All @@ -656,8 +724,8 @@ def _add_figure(b, widgets, data, scale=0.95, dpi=96):
with fig_outputs:
figsize = (scale * ((int(viz_output_layout.width[:-2]) - 10) / dpi),
scale * ((int(viz_output_layout.height[:-2]) - 10) / dpi))
mosaic = fig_templates[template_name]['mosaic']
kwargs = eval(f"dict({fig_templates[template_name]['kwargs']})")
mosaic = template_type['mosaic']
kwargs = template_type['kwargs']
kmilo9999 marked this conversation as resolved.
Show resolved Hide resolved
plt.ioff()
fig, axd = plt.subplot_mosaic(mosaic,
figsize=figsize,
Expand Down Expand Up @@ -716,20 +784,30 @@ def __init__(self, gui_data, viz_layout):
(self.figs_tabs, 'selected_index'),
)

template_names = list(fig_templates.keys())
template_names = list(data_templates.keys())
template_names.extend(list(fig_templates.keys()))
self.templates_dropdown = Dropdown(
description='Layout template:',
options=template_names,
value=template_names[0],
style={'description_width': 'initial'},
layout=Layout(width="98%"))
self.templates_dropdown.observe(self._layout_template_change, 'value')

self.make_fig_button = Button(
description='Make figure',
button_style="primary",
style={'button_color': self.viz_layout['theme_color']},
layout=self.viz_layout['btn'])
self.make_fig_button.on_click(self.add_figure)

self.datasets_dropdown = Dropdown(
description='Dataset:',
options=[],
value=None,
style={'description_width': 'initial'},
layout=Layout(width="98%"))

# data
self.fig_idx = {"idx": 1}
self.figs = {}
Expand All @@ -741,7 +819,8 @@ def widgets(self):
"figs_output": self.figs_output,
"axes_config_tabs": self.axes_config_tabs,
"figs_tabs": self.figs_tabs,
"templates_dropdown": self.templates_dropdown
"templates_dropdown": self.templates_dropdown,
"dataset_dropdown": self.datasets_dropdown
}

@property
Expand Down Expand Up @@ -782,6 +861,7 @@ def compose(self):
Box(
[
self.templates_dropdown,
self.datasets_dropdown,
self.make_fig_button,
],
layout=Layout(
Expand All @@ -795,21 +875,75 @@ def compose(self):
])
return config_panel, fig_output_container

def _layout_template_change(self, template_type):
# check if plot set type requires loaded sim-data
if _check_template_type_is_data_dependant(template_type.new):
# Add only simualated data
sim_names = [simulations for simulations, sim_name
in self.data["simulations"].items()
if sim_name['net'] is not None]

if len(sim_names) == 0:
sim_names = [" "]

self.datasets_dropdown.options = sim_names
self.datasets_dropdown.value = sim_names[0]
# show list of simulated to gui dropdown
self.datasets_dropdown.layout.visibility = "visible"
else:
# hide sim-data dropdown
self.datasets_dropdown.layout.visibility = "hidden"

@unlink_relink(attribute='figs_config_tab_link')
def add_figure(self, b=None):
"""Add a figure and corresponding config tabs to the dashboard.
"""
if len(self.data["simulations"]) == 0:
logger.error("No data has been loaded")
return

template_name = self.widgets['templates_dropdown'].value
is_data_template = (_check_template_type_is_data_dependant
(template_name))
if is_data_template:
sim_name = self.widgets["dataset_dropdown"].value
if sim_name not in self.data["simulations"]:
logger.error("No simulation data has been loaded")
return

# Use data_templates dictionary if it's a data dependent layout
template_type = (data_templates[template_name]
if is_data_template
else fig_templates[template_name])

# Add empty figure according to template arguments
_add_figure(None,
self.widgets,
self.data,
template_type,
scale=0.97,
dpi=self.viz_layout['dpi'])

# Plot data if it is a data-dependent template
if is_data_template:
fig_name = _idx2figname(self.data['fig_idx']['idx'] - 1)
# get figs per axis
ax_plots = data_templates[template_name]["ax_plots"]
for ax_name, plot_type in ax_plots:
# paint fig in axis
self._simulate_edit_figure(fig_name, ax_name, sim_name,
plot_type, {}, "plot")
logger.info(f"Figure {template_name} for "
f"simulation {sim_name} "
"has been created"
)

def _simulate_add_fig(self):
self.make_fig_button.click()

def _simulate_switch_fig_template(self, template_name):
assert template_name in fig_templates.keys(), "No such template"
assert (template_name in fig_templates.keys() or
data_templates.keys()), "No such template"
self.templates_dropdown.value = template_name

def _simulate_delete_figure(self, fig_name):
Expand Down
5 changes: 4 additions & 1 deletion hnn_core/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,10 +1143,13 @@ def on_upload_data_change(change, data, viz_manager, log_out):
hnn_core.read_dipole(io.StringIO(ext_content))
]}
logger.info(f'External data {data_fname} loaded.')
viz_manager.reset_fig_config_tabs(template_name='single figure')
_template_name = "[Blank] single figure"
gtdang marked this conversation as resolved.
Show resolved Hide resolved
viz_manager.reset_fig_config_tabs(template_name=_template_name)
viz_manager.add_figure()
fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1)
ax_plots = [("ax0", "current dipole")]

# these lines plot the data per axis
for ax_name, plot_type in ax_plots:
viz_manager._simulate_edit_figure(
fig_name, ax_name, data_fname, plot_type, {}, "plot")
Expand Down
40 changes: 40 additions & 0 deletions hnn_core/tests/test_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,46 @@ def test_gui_add_figure():
plt.close('all')


def test_gui_add_data_dependent_figure():
kmilo9999 marked this conversation as resolved.
Show resolved Hide resolved
"""Test if the GUI adds/deletes figs data dependent properly."""
gui = HNNGUI()
_ = gui.compose()
gui.params['N_pyr_x'] = 3
gui.params['N_pyr_y'] = 3

fig_tabs = gui.viz_manager.figs_tabs
axes_config_tabs = gui.viz_manager.axes_config_tabs
assert len(fig_tabs.children) == 0
assert len(axes_config_tabs.children) == 0

# after each run we should have a default fig
gui.run_button.click()
assert len(fig_tabs.children) == 1
assert len(axes_config_tabs.children) == 1
assert gui.viz_manager.fig_idx['idx'] == 2

template_names = [('Drive-Dipole (2x1)', 2),
('Dipole Layers (3x1)', 3),
('Drive-Spikes (2x1)', 2),
('Dipole-Spectrogram (2x1)', 2),
("Dipole-Spikes (2x1)", 2),
('Drive-Dipole-Spectrogram (3x1)', 3),
('PSD Layers (3x1)', 3)]

n_fig = 1
for template_name, num_axes in template_names:
gui.viz_manager.templates_dropdown.value = template_name
assert len(gui.viz_manager.datasets_dropdown.options) == 1
gui.viz_manager.make_fig_button.click()
# Check figs have data on their axis
for ax in range(num_axes):
assert gui.viz_manager.figs[n_fig + 1].axes[ax].has_data()
n_fig = n_fig + 1

# test number of created figures
assert len(fig_tabs.children) == n_fig


def test_gui_edit_figure():
"""Test if the GUI adds/deletes figs properly."""
gui = HNNGUI()
Expand Down
Loading