diff --git a/README.rst b/README.rst index 071866f28..51e503525 100644 --- a/README.rst +++ b/README.rst @@ -57,8 +57,10 @@ Optional dependencies GUI ~~~ -* ipywidgets (<=7.7.1) -* voila (<=0.3.6) +* ipywidgets (>=8.0.0) +* voila +* ipympl +* ipykernel Optimization ~~~~~~~~~~~~ diff --git a/hnn_core/gui/_viz_manager.py b/hnn_core/gui/_viz_manager.py index 883a281d7..4f9437454 100644 --- a/hnn_core/gui/_viz_manager.py +++ b/hnn_core/gui/_viz_manager.py @@ -4,7 +4,7 @@ import copy import io -from functools import partial +from functools import partial, wraps import matplotlib import matplotlib.pyplot as plt @@ -99,6 +99,38 @@ def plot_type_coupled_change(new_plot_type, target_data_selection): target_data_selection.disabled = False +def unlink_relink(attribute): + """ + Decorator function to unlink widgets and re-link widgets. + + Unlinks linked widgets, runs the wrapped function, and relinks the widgets + upon completion. To be used as a decorator on class methods. The class must + have an attribute containing an ipywidgets/traitlets link object. + + Parameters + ---------- + attribute: (str) The class attribute containing link object of ipywidgets + widgets + + """ + def _unlink_relink(f): + @wraps(f) + def wrapper(self, *args, **kwargs): + # Unlink the widgets using the provided link object + link_attribute: link = getattr(self, attribute) + link_attribute.unlink() + + # Call the original function + result = f(self, *args, **kwargs) + + # Re-link the widgets + link_attribute.link() + + return result + return wrapper + return _unlink_relink + + def _idx2figname(idx): return f"Figure {idx}" @@ -150,7 +182,7 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config): elif plot_type == 'PSD': if len(dpls_copied) > 0: - color = next(ax._get_lines.prop_cycler)['color'] + color = ax._get_lines.get_next_color() dpls_copied[0].plot_psd(fmin=0, fmax=50, ax=ax, color=color, label=sim_name, show=False) @@ -175,7 +207,7 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config): else: label = sim_name - color = next(ax._get_lines.prop_cycler)['color'] + color = ax._get_lines.get_next_color() if plot_type == 'current dipole': plot_dipole(dpls_copied, ax=ax, @@ -226,9 +258,7 @@ def _update_ax(fig, ax, single_simulation, sim_name, plot_type, plot_config): def _static_rerender(widgets, fig, fig_idx): logger.debug('_static_re_render is called') figs_tabs = widgets['figs_tabs'] - titles = [ - figs_tabs.get_title(idx) for idx in range(len(figs_tabs.children)) - ] + titles = figs_tabs.titles fig_tab_idx = titles.index(_idx2figname(fig_idx)) fig_output = widgets['figs_tabs'].children[fig_tab_idx] fig_output.clear_output() @@ -501,18 +531,25 @@ def _on_plot_type_change(new_plot_type): def _close_figure(b, widgets, data, fig_idx): fig_related_widgets = [widgets['figs_tabs'], widgets['axes_config_tabs']] for w_idx, tab in enumerate(fig_related_widgets): + # Get tab object's list of children and their titles tab_children = list(tab.children) - titles = [tab.get_title(idx) for idx in range(len(tab.children))] + titles = list(tab.titles) + # Get the index based on the title tab_idx = titles.index(_idx2figname(fig_idx)) + # Remove the child and title specified print(f"Del fig_idx={fig_idx}, fig_idx={fig_idx}") - del tab_children[tab_idx], titles[tab_idx] - - tab.children = tuple(tab_children) - [tab.set_title(idx, title) for idx, title in enumerate(titles)] + tab_children.pop(tab_idx) + titles.pop(tab_idx) + # Reset children and titles of the tab object + tab.children = tab_children + tab.titles = titles + # If the figure tab group... if w_idx == 0: + # Close figure and delete the data plt.close(data['figs'][fig_idx]) - del data['figs'][fig_idx] + data['figs'].pop(fig_idx) + # Redisplay the remaining children n_tabs = len(tab.children) for idx in range(n_tabs): _fig_idx = _figname2idx(tab.get_title(idx)) @@ -522,10 +559,11 @@ def _close_figure(b, widgets, data, fig_idx): with tab.children[idx]: display(data['figs'][_fig_idx].canvas) - if n_tabs == 0: - widgets['figs_output'].clear_output() - with widgets['figs_output']: - display(Label(_fig_placeholder)) + # If all children have been deleted display the placeholder + if n_tabs == 0: + widgets['figs_output'].clear_output() + with widgets['figs_output']: + display(Label(_fig_placeholder)) def _add_axes_controls(widgets, data, fig, axd): @@ -565,8 +603,9 @@ def _add_figure(b, widgets, data, scale=0.95, dpi=96): with widgets['figs_output']: display(widgets['figs_tabs']) - widgets['figs_tabs'].children = widgets['figs_tabs'].children + ( - fig_outputs, ) + widgets['figs_tabs'].children = ( + [s for s in widgets['figs_tabs'].children] + [fig_outputs] + ) widgets['figs_tabs'].set_title(n_tabs, _idx2figname(fig_idx)) with fig_outputs: @@ -627,7 +666,7 @@ def __init__(self, gui_data, viz_layout): self.figs_tabs = Tab() self.axes_config_tabs.selected_index = None self.figs_tabs.selected_index = None - link( + self.figs_config_tab_link = link( (self.axes_config_tabs, 'selected_index'), (self.figs_tabs, 'selected_index'), ) @@ -711,6 +750,7 @@ def compose(self): ]) return config_panel, fig_output_container + @unlink_relink(attribute='figs_config_tab_link') def add_figure(self, b=None): """Add a figure and corresponding config tabs to the dashboard. """ @@ -729,7 +769,7 @@ def _simulate_switch_fig_template(self, template_name): def _simulate_delete_figure(self, fig_name): tab = self.axes_config_tabs - titles = [tab.get_title(idx) for idx in range(len(tab.children))] + titles = tab.titles assert fig_name in titles tab_idx = titles.index(fig_name) @@ -764,16 +804,13 @@ def _simulate_edit_figure(self, fig_name, ax_name, simulation_name, assert operation in ("plot", "clear") tab = self.axes_config_tabs - titles = [tab.get_title(idx) for idx in range(len(tab.children))] + titles = tab.titles assert fig_name in titles, "No such figure" tab_idx = titles.index(fig_name) self.axes_config_tabs.selected_index = tab_idx ax_control_tabs = self.axes_config_tabs.children[tab_idx].children[1] - ax_titles = [ - ax_control_tabs.get_title(idx) - for idx in range(len(ax_control_tabs.children)) - ] + ax_titles = ax_control_tabs.titles assert ax_name in ax_titles, "No such axis" ax_idx = ax_titles.index(ax_name) ax_control_tabs.selected_index = ax_idx diff --git a/hnn_core/gui/gui.py b/hnn_core/gui/gui.py index 54ab6255e..70622d1d8 100644 --- a/hnn_core/gui/gui.py +++ b/hnn_core/gui/gui.py @@ -11,6 +11,7 @@ import urllib.request from collections import defaultdict from pathlib import Path +from datetime import datetime from IPython.display import IFrame, display from ipywidgets import (HTML, Accordion, AppLayout, BoundedFloatText, BoundedIntText, Button, Dropdown, FileUpload, VBox, @@ -438,8 +439,7 @@ def compose(self, return_layout=True): 'Layer 2/3 Pyramidal', 'Layer 5 Pyramidal', 'Layer 2 Basket', 'Layer 5 Basket') cell_connectivity = Accordion(children=connectivity_boxes) - for idx, connectivity_name in enumerate(connectivity_names): - cell_connectivity.set_title(idx, connectivity_name) + cell_connectivity.titles = [s for s in connectivity_names] drive_selections = VBox([ self.add_drive_button, self.widget_drive_type_selection, @@ -616,15 +616,14 @@ def _simulate_upload_drives(self, file_url): self.load_drives_button.set_trait('value', uploaded_value) def _simulate_left_tab_click(self, tab_title): - tab_index = None + # Get left tab group object left_tab = self.app_layout.left_sidebar.children[0].children[0] - for idx in left_tab._titles.keys(): - if tab_title == left_tab._titles[idx]: - tab_index = int(idx) - break - if tab_index is None: - raise ValueError("Incorrect tab title") - left_tab.selected_index = tab_index + # Check that the title is in the tab group + if tab_title in left_tab.titles: + # Simulate the user clicking on the tab + left_tab.selected_index = left_tab.titles.index(tab_title) + else: + raise ValueError("Tab title does not exist.") def _simulate_make_figure(self,): self._simulate_left_tab_click("Visualization") @@ -655,16 +654,13 @@ def _prepare_upload_file_from_url(file_url): for line in data: content += line - return { - params_name: { - 'metadata': { - 'name': params_name, - 'type': 'application/json', - 'size': len(content), - }, - 'content': content - } - } + return [{ + 'name': params_name, + 'type': 'application/json', + 'size': len(content), + 'content': content, + 'last_modified': datetime.now() + }] def create_expanded_button(description, button_style, layout, disabled=False, @@ -1133,14 +1129,14 @@ def on_upload_data_change(change, data, viz_manager, log_out): logger.info("Empty change") return - key = list(change['new'].keys())[0] + data_dict = change['new'][0] - data_fname = change['new'][key]['metadata']['name'].rstrip('.txt') + data_fname = data_dict['name'].rstrip('.txt') if data_fname in data['simulation_data'].keys(): logger.error(f"Found existing data: {data_fname}.") return - ext_content = change['new'][key]['content'] + ext_content = data_dict['content'] ext_content = codecs.decode(ext_content, encoding="utf-8") with log_out: data['simulation_data'][data_fname] = {'net': None, 'dpls': [ @@ -1163,10 +1159,9 @@ def on_upload_params_change(change, params, tstop, dt, log_out, drive_boxes, logger.info("Empty change") return logger.info("Loading connectivity...") - key = list(change['new'].keys())[0] - - params_fname = change['new'][key]['metadata']['name'] - param_data = change['new'][key]['content'] + param_dict = change['new'][0] + params_fname = param_dict['name'] + param_data = param_dict['content'] param_data = codecs.decode(param_data, encoding="utf-8") @@ -1191,9 +1186,8 @@ def on_upload_params_change(change, params, tstop, dt, log_out, drive_boxes, layout) else: raise ValueError - - change['owner'].set_trait('_counter', 0) - change['owner'].set_trait('value', {}) + # Resets file counter to 0 + change['owner'].set_trait('value', ([])) def _init_network_from_widgets(params, dt, tstop, single_simulation_data, diff --git a/hnn_core/tests/test_gui.py b/hnn_core/tests/test_gui.py index 06b1e0f6e..b8c976b51 100644 --- a/hnn_core/tests/test_gui.py +++ b/hnn_core/tests/test_gui.py @@ -3,14 +3,18 @@ import matplotlib.pyplot as plt import numpy as np import pytest +import traitlets + from hnn_core import Dipole, Network, Params from hnn_core.gui import HNNGUI -from hnn_core.gui._viz_manager import _idx2figname, _no_overlay_plot_types +from hnn_core.gui._viz_manager import (_idx2figname, _no_overlay_plot_types, + unlink_relink) from hnn_core.gui.gui import _init_network_from_widgets from hnn_core.network import pick_connection from hnn_core.network_models import jones_2009_model from hnn_core.parallel_backends import requires_mpi4py, requires_psutil from IPython.display import IFrame +from ipywidgets import Tab, Text, link matplotlib.use('agg') @@ -413,3 +417,50 @@ def test_gui_adaptive_spectrogram(): for attr in dir(gui.viz_manager.figs[figid])]) is False assert len(gui.viz_manager.figs[1].axes) == 2 plt.close('all') + + +def test_unlink_relink_widget(): + """Tests the unlinking and relinking of widgets decorator.""" + + # Create a basic version of the VizManager class + class MiniViz: + def __init__(self): + self.tab_group_1 = Tab() + self.tab_group_2 = Tab() + self.tab_link = link( + (self.tab_group_1, 'selected_index'), + (self.tab_group_2, 'selected_index'), + ) + + def add_child(self, to_add=1): + n_tabs = len(self.tab_group_2.children) + to_add + # Add tab and select latest tab + self.tab_group_1.children = \ + [Text(f'Test{s}') for s in np.arange(n_tabs)] + self.tab_group_1.selected_index = n_tabs - 1 + + self.tab_group_2.children = \ + [Text(f'Test{s}') for s in np.arange(n_tabs)] + self.tab_group_2.selected_index = n_tabs - 1 + + @unlink_relink(attribute='tab_link') + def add_child_decorated(self, to_add): + self.add_child(to_add) + + # Check that widgets are linked. + # Error from tab groups momentarily having a different number of children + gui = MiniViz() + with pytest.raises(traitlets.TraitError, match='.*index out of bounds.*'): + gui.add_child(2) + + # Check decorator unlinks and is able to make a change + gui = MiniViz() + gui.add_child_decorated(2) + assert len(gui.tab_group_1.children) == 2 + assert gui.tab_group_1.selected_index == 1 + assert len(gui.tab_group_2.children) == 2 + assert gui.tab_group_2.selected_index == 1 + + # Check if the widgets are relinked, the selected index should be synced + gui.tab_group_1.selected_index = 0 + assert gui.tab_group_2.selected_index == 0 diff --git a/setup.py b/setup.py index ada4bfd3b..c39f0dcdb 100644 --- a/setup.py +++ b/setup.py @@ -106,7 +106,7 @@ def run(self): 'h5io' ], extras_require={ - 'gui': ['ipywidgets <=7.7.1', 'ipympl<0.9', 'voila<=0.3.6'], + 'gui': ['ipywidgets>=8.0.0', 'ipykernel', 'ipympl', 'voila'], 'opt': ['scikit-learn'] }, python_requires='>=3.8',