Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] GUI synaptic gains implementation #918

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 62 additions & 21 deletions hnn_core/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,11 @@ def __init__(self, theme_color="#802989",
# Connectivity list
self.connectivity_widgets = list()

# Cell parameter list
self.cell_pameters_widgets = dict()
# Cell parameter dict
self.cell_parameters_widgets = dict()

# Synaptic Gains dict
self.synaptic_gain_widgets = dict()

self._init_ui_components()
self.add_logging_window_logger()
Expand Down Expand Up @@ -473,6 +476,7 @@ def _init_ui_components(self):
self._drives_out = Output() # tab to add new drives
self._connectivity_out = Output() # tab to tune connectivity.
self._cell_params_out = Output()
self._syn_gain_out = Output()

self._log_out = Output()

Expand Down Expand Up @@ -569,7 +573,9 @@ def _run_button_clicked(b):
self.widget_mpi_cmd, self.widget_n_jobs, self.params,
self._simulation_status_bar, self._simulation_status_contents,
self.connectivity_widgets, self.viz_manager,
self.simulation_list_widget, self.cell_pameters_widgets)
self.simulation_list_widget, self.cell_parameters_widgets,
self.synaptic_gain_widgets
)

def _simulation_list_change(value):
# Simulation Data
Expand Down Expand Up @@ -612,13 +618,13 @@ def _driver_type_change(value):

def _cell_type_radio_change(value):
_update_cell_params_vbox(self._cell_params_out,
self.cell_pameters_widgets,
self.cell_parameters_widgets,
value.new,
self.cell_layer_radio_buttons.value)

def _cell_layer_radio_change(value):
_update_cell_params_vbox(self._cell_params_out,
self.cell_pameters_widgets,
self.cell_parameters_widgets,
self.cell_type_radio_buttons.value,
value.new)

Expand Down Expand Up @@ -673,7 +679,7 @@ def compose(self, return_layout=True):
self._backend_config_out]),
], layout=self.layout['config_box'])

connectivity_configuration = Tab()
network_configuration = Tab()

connectivity_box = VBox([
HBox([self.load_connectivity_button, ]),
Expand All @@ -686,10 +692,14 @@ def compose(self, return_layout=True):
self._cell_params_out
])

connectivity_configuration.children = [connectivity_box,
cell_parameters]
connectivity_configuration.titles = ['Connectivity',
'Cell parameters']
syn_gain = VBox([self._syn_gain_out])

network_configuration.children = [connectivity_box,
cell_parameters,
syn_gain]
network_configuration.titles = ['Connectivity',
'Cell parameters',
'Synaptic gains']

drive_selections = VBox([
self.add_drive_button, self.widget_drive_type_selection,
Expand All @@ -709,7 +719,7 @@ def compose(self, return_layout=True):
# Tabs for left pane
left_tab = Tab()
left_tab.children = [
simulation_box, connectivity_configuration, drives_options,
simulation_box, network_configuration, drives_options,
config_panel,
]
titles = ('Simulation', 'Network', 'External drives',
Expand Down Expand Up @@ -902,9 +912,11 @@ def load_drive_and_connectivity(self):
self._connectivity_out,
self.connectivity_widgets,
self._cell_params_out,
self.cell_pameters_widgets,
self.cell_parameters_widgets,
self.cell_layer_radio_buttons,
self.cell_type_radio_buttons,
self._syn_gain_out,
self.synaptic_gain_widgets,
self.layout)

# Add drives
Expand Down Expand Up @@ -1034,9 +1046,11 @@ def on_upload_params_change(self, change, layout, load_type):
if load_type == 'connectivity':
add_connectivity_tab(
params, self._connectivity_out, self.connectivity_widgets,
self._cell_params_out, self.cell_pameters_widgets,
self._cell_params_out, self.cell_parameters_widgets,
self.cell_layer_radio_buttons,
self.cell_type_radio_buttons, layout)
self.cell_type_radio_buttons,
self._syn_gain_out, self.synaptic_gain_widgets,
layout)
elif load_type == 'drives':
self.add_drive_tab(params)
else:
Expand Down Expand Up @@ -1598,9 +1612,9 @@ def _build_drive_objects(drive_type, name, tstop_widget, layout, style,


def add_connectivity_tab(params, connectivity_out, connectivity_textfields,
cell_params_out, cell_pameters_vboxes,
cell_params_out, cell_parameters_vboxes,
cell_layer_radio_button, cell_type_radio_button,
layout):
syn_gain_out, syn_gain_textfields, layout):
"""Add all possible connectivity boxes to connectivity tab."""
net = dict_to_network(params)

Expand All @@ -1609,9 +1623,13 @@ def add_connectivity_tab(params, connectivity_out, connectivity_textfields,
connectivity_textfields)

# build cell parameters tab
add_cell_parameters_tab(cell_params_out, cell_pameters_vboxes,
add_cell_parameters_tab(cell_params_out, cell_parameters_vboxes,
cell_layer_radio_button, cell_type_radio_button,
layout)

# build synaptic gains tab
add_synaptic_gain_tab(net, syn_gain_out, syn_gain_textfields, layout)

return net


Expand Down Expand Up @@ -1719,6 +1737,24 @@ def add_cell_parameters_tab(cell_params_out, cell_pameters_vboxes,
cell_layer_radio_button.value)


def add_synaptic_gain_tab(net, syn_gain_out, syn_gain_textfields, layout):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docstring!


gain_values = net.get_synaptic_gains()
gain_types = ('e_e', 'e_i', 'i_e', 'i_i')
for gain_type in gain_types:
gain_widget = BoundedFloatText(
value=gain_values[gain_type],
description=f'{gain_type}',
min=0, max=1e6, step=.1,
disabled=False, layout=layout)
syn_gain_textfields[gain_type] = gain_widget

gain_vbox = VBox([widget for widget in syn_gain_textfields.values()])

with syn_gain_out:
display(gain_vbox)


def get_cell_param_default_value(cell_type_key, param_dict):
return param_dict[cell_type_key]

Expand Down Expand Up @@ -1794,7 +1830,7 @@ def _drive_widget_to_dict(drive, name):

def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
drive_widgets, connectivity_textfields,
cell_params_vboxes,
cell_params_vboxes, syn_gain_textfields,
add_drive=True):
"""Construct network and add drives."""
print("init network")
Expand All @@ -1819,7 +1855,6 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
'nc_dict']['A_weight'] = vbox_key.children[1].value

# Update cell params

update_functions = {
'L2 Geometry': _update_L2_geometry_cell_params,
'L5 Geometry': _update_L5_geometry_cell_params,
Expand All @@ -1842,6 +1877,11 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
single_simulation_data['net'].cell_types[
cell_type]._compute_section_mechs()

# Update with synaptic gains
syn_gain_values = {key: widget.value
for key, widget in syn_gain_textfields.items()}
single_simulation_data['net'].set_synaptic_gains(**syn_gain_values)

if add_drive is False:
return
# add drives to network
Expand Down Expand Up @@ -1914,7 +1954,7 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets,
mpi_cmd, n_jobs, params, simulation_status_bar,
simulation_status_contents, connectivity_textfields,
viz_manager, simulations_list_widget,
cell_pameters_widgets):
cell_parameters_widgets, syn_gain_textfields):
"""Run the simulation and plot outputs."""
simulation_data = all_data["simulation_data"]
with log_out:
Expand All @@ -1933,7 +1973,8 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets,
_init_network_from_widgets(params, dt, tstop,
simulation_data[_sim_name], drive_widgets,
connectivity_textfields,
cell_pameters_widgets)
cell_parameters_widgets,
syn_gain_textfields)

print("start simulation")
if backend_selection.value == "MPI":
Expand Down
89 changes: 80 additions & 9 deletions hnn_core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,40 @@ def pick_connection(net, src_gids=None, target_gids=None,
return sorted(conn_set)


def _get_cell_index_by_synapse_type(net):
"""Returns the indexes of excitatory and inhibitory cells in the network.
gtdang marked this conversation as resolved.
Show resolved Hide resolved

This function extracts the source GIDs (Global Identifiers) of excitatory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Global Identifiers" isn't too descriptive either ... it's a bit of a misnomer since it's mainly a programming construct. I like to say "cell ID" to be clear.

and inhibitory cells based on their connection types. Excitatory cells are
identified by their synaptic connections using AMPA and NMDA receptors,
while inhibitory cells are identified by their connections using GABAA and
GABAB receptors.

Parameters
----------
net : Instance of Network object
The Network object

Returns
-------
tuple: A tuple containing two lists:
- e_cells (list): The source GIDs of excitatory cells.
- i_cells (list): The source GIDs of inhibitory cells.
"""

def list_src_gids(indices):
return np.concatenate([list(net.connectivity[conn_idx]['src_gids'])
for conn_idx in indices]).tolist()

e_conns = pick_connection(net, receptor=['ampa', 'nmda'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
e_conns = pick_connection(net, receptor=['ampa', 'nmda'])
picks_e = pick_connection(net, receptor=['ampa', 'nmda'])

I like to use e_conn after the indexing, i.e.:

e_conns = [net.connectivity[p] for p in picks_e]

so it's clear from the variable name that one is an element of net.connectivity and the other is simply an index

e_cells = list_src_gids(e_conns)

i_conns = pick_connection(net, receptor=['gabaa', 'gabab'])
i_cells = list_src_gids(i_conns)

return e_cells, i_cells


class Network:
"""The Network class.

Expand Down Expand Up @@ -1427,8 +1461,8 @@ def add_electrode_array(self, name, electrode_pos, *, conductivity=0.3,
method=method,
min_distance=min_distance)})

def update_weights(self, e_e=None, e_i=None,
i_e=None, i_i=None, copy=False):
def set_synaptic_gains(self, e_e=None, e_i=None,
i_e=None, i_i=None, copy=False):
"""Update synaptic weights of the network.

Parameters
Expand Down Expand Up @@ -1466,13 +1500,7 @@ def update_weights(self, e_e=None, e_i=None,

net = self.copy() if copy else self

e_conns = pick_connection(self, receptor=['ampa', 'nmda'])
e_cells = np.concatenate([list(net.connectivity[
conn_idx]['src_gids']) for conn_idx in e_conns]).tolist()

i_conns = pick_connection(self, receptor=['gabaa', 'gabab'])
i_cells = np.concatenate([list(net.connectivity[
conn_idx]['src_gids']) for conn_idx in i_conns]).tolist()
e_cells, i_cells = _get_cell_index_by_synapse_type(net)
conn_types = {
'e_e': (e_e, e_cells, e_cells),
'e_i': (e_i, e_cells, i_cells),
Expand All @@ -1497,6 +1525,49 @@ def update_weights(self, e_e=None, e_i=None,
if copy:
return net

def get_synaptic_gains(self):
"""Retrieve gain values for different connection types in the network.

This function identifies excitatory and inhibitory cells in the network
and retrieves the gain value for each type of synaptic connection:
- excitatory to excitatory (e_e)
- excitatory to inhibitory (e_i)
- inhibitory to excitatory (i_e)
- inhibitory to inhibitory (i_i)

The gain is assumed to be uniform within each connection type, and only
the first connection's gain value is used for each type.

Returns
-------
dict: A dictionary with the connection types ('e_e', 'e_i', 'i_e',
gtdang marked this conversation as resolved.
Show resolved Hide resolved
'i_i') as keys and their corresponding gain values.
"""
values = {}
e_cells, i_cells = _get_cell_index_by_synapse_type(self)

# Define the connection types and source/target cell indexes
conn_types = {
'e_e': (e_cells, e_cells),
'e_i': (e_cells, i_cells),
'i_e': (i_cells, e_cells),
'i_i': (i_cells, i_cells)
}

# Retrieve the gain value for each connection type
for conn_type, (src_indexes, target_indexes) in conn_types.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for conn_type, (src_indexes, target_indexes) in conn_types.items():
for conn_type, (src_idxs, target_idxs) in conn_types.items():

just for sake of consistency

conn_indices = pick_connection(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
conn_indices = pick_connection(self,
picks = pick_connection(self,

see above comment too ... would try to be consistent with naming

src_gids=src_indexes,
target_gids=target_indexes)

if conn_indices:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if conn_indices:
if len(picks) > 0:

to be explicit that picks is a list, not a bool

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using the truthyness of collections is more concise and pythonic. If we really wanted to be explicit we could also name the variable "picks_list".

But I'll open it up to the rest of the group to define a style guide. @ntolley @asoplata @dylansdaniels

# Extract the gain from the first connection
values[conn_type] = (
self.connectivity[conn_indices[0]]['nc_dict']['gain']
)

return values

def plot_cells(self, ax=None, show=True):
"""Plot the cells using Network.pos_dict.

Expand Down
Loading
Loading