Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add Min/Max Frequency and Add Visual Updates Ahead of 0.4 #952

Merged
merged 8 commits into from
Nov 27, 2024
14 changes: 8 additions & 6 deletions hnn_core/gui/_viz_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,8 @@ def _get_ax_control(widgets, data, fig_default_params, fig_idx, fig, ax):
simulation_names = tuple(data['simulations'].keys())
sim_index = 0
default_smoothing = fig_default_params['default_smoothing']
default_min_frequency = fig_default_params['default_min_frequency']
default_max_frequency = fig_default_params['default_max_frequency']
if not simulation_names:
simulation_names = ("None",)
else:
Expand Down Expand Up @@ -639,7 +641,7 @@ def _get_ax_control(widgets, data, fig_default_params, fig_idx, fig, ax):
style=analysis_style)

min_spectral_frequency = BoundedFloatText(
value=10,
value=default_min_frequency,
min=0.1,
max=1000,
description='Min Spectral Frequency (Hz):',
Expand All @@ -648,7 +650,7 @@ def _get_ax_control(widgets, data, fig_default_params, fig_idx, fig, ax):
style=analysis_style)

max_spectral_frequency = BoundedFloatText(
value=100,
value=default_max_frequency,
min=0.1,
max=1000,
description='Max Spectral Frequency (Hz):',
Expand Down Expand Up @@ -762,12 +764,12 @@ def _close_figure(b, widgets, data, fig_idx):
display(Label(_fig_placeholder))


def _add_axes_controls(widgets, data, fig_default_smoothing, fig, axd):
def _add_axes_controls(widgets, data, fig_default_params, fig, axd):
fig_idx = data['fig_idx']['idx']

controls = Tab()
children = [
_get_ax_control(widgets, data, fig_default_smoothing, fig_idx=fig_idx,
_get_ax_control(widgets, data, fig_default_params, fig_idx=fig_idx,
fig=fig, ax=ax)
for ax_key, ax in axd.items()
]
Expand All @@ -788,7 +790,7 @@ def _add_axes_controls(widgets, data, fig_default_smoothing, fig, axd):
widgets['axes_config_tabs'].set_title(n_tabs, _idx2figname(fig_idx))


def _add_figure(b, widgets, data, fig_default_smoothing,
def _add_figure(b, widgets, data, fig_default_params,
template_type, scale=0.95, dpi=96):
fig_idx = data['fig_idx']['idx']
viz_output_layout = data['visualization_output']
Expand Down Expand Up @@ -821,7 +823,7 @@ def _add_figure(b, widgets, data, fig_default_smoothing,
else:
display(fig.canvas)

_add_axes_controls(widgets, data, fig_default_smoothing, fig=fig, axd=axd)
_add_axes_controls(widgets, data, fig_default_params, fig=fig, axd=axd)

data['figs'][fig_idx] = fig
widgets['figs_tabs'].selected_index = n_tabs
Expand Down
149 changes: 124 additions & 25 deletions hnn_core/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ipywidgets import (HTML, Accordion, AppLayout, BoundedFloatText,
BoundedIntText, Button, Dropdown, FileUpload, VBox,
HBox, IntText, Layout, Output, RadioButtons, Tab, Text,
Checkbox)
Checkbox, Box)
from ipywidgets.embed import embed_minimal_html
import hnn_core
from hnn_core import JoblibBackend, MPIBackend, simulate_dipole
Expand Down Expand Up @@ -290,7 +290,7 @@ def __init__(self, theme_color="#802989",
height=f"{operation_box_height}px",
flex_wrap="wrap",
),
"config_box": Layout(width=f"{left_sidebar_width}px",
"config_box": Layout(width=f"{left_sidebar_width - 40}px",
height=f"{config_box_height - 100}px"),
"drive_widget": Layout(width="auto"),
"drive_textbox": Layout(width='270px', height='auto'),
Expand Down Expand Up @@ -324,12 +324,37 @@ def __init__(self, theme_color="#802989",
self.simulation_data = defaultdict(lambda: dict(net=None, dpls=list()))

# Default visualization params for figures
analysis_style = {'description_width': '200px'}
layout = Layout(width="300px")

self.widget_default_smoothing = BoundedFloatText(
value=30.0, description='Smoothing:',
min=0.0, max=100.0, step=1.0, disabled=False)
min=0.0, max=100.0, step=1.0, disabled=False,
layout=layout, style=analysis_style,
)

self.widget_min_frequency = BoundedFloatText(
value=10,
min=0.1,
max=1000,
description='Min Spectral Frequency (Hz):',
disabled=False,
layout=layout,
style=analysis_style)

self.widget_max_frequency = BoundedFloatText(
value=100,
min=0.1,
max=1000,
description='Max Spectral Frequency (Hz):',
disabled=False,
layout=layout,
style=analysis_style)

self.fig_default_params = {
'default_smoothing': self.widget_default_smoothing.value
'default_smoothing': self.widget_default_smoothing.value,
'default_min_frequency': self.widget_min_frequency.value,
'default_max_frequency': self.widget_max_frequency.value,
}

# Simulation parameters
Expand Down Expand Up @@ -374,16 +399,20 @@ def __init__(self, theme_color="#802989",
description='',
layout={'width': '15%'})
# Drive selection
self.widget_drive_type_selection = RadioButtons(
self.widget_drive_type_selection = Dropdown(
options=['Evoked', 'Poisson', 'Rhythmic', 'Tonic'],
value='Evoked',
description='Drive:',
description='Drive type:',
disabled=False,
layout=self.layout['drive_widget'])
self.widget_location_selection = RadioButtons(
options=['proximal', 'distal'], value='proximal',
description='Location', disabled=False,
layout=self.layout['drive_widget'])
layout=self.layout['drive_widget'],
style={'description_width': '100px'}
)
self.widget_location_selection = Dropdown(
options=['Proximal', 'Distal'], value='Proximal',
description='Drive location:', disabled=False,
layout=self.layout['drive_widget'],
style={'description_width': '100px'},
)
self.add_drive_button = create_expanded_button(
'Add drive', 'primary', layout=self.layout['btn'],
button_color=self.layout['theme_color'])
Expand All @@ -405,7 +434,7 @@ def __init__(self, theme_color="#802989",
button_style='success')

self.delete_drive_button = create_expanded_button(
'Delete drives', 'success', layout=self.layout['btn'],
'Delete all drives', 'success', layout=self.layout['btn'],
button_color=self.layout['theme_color'])

self.cell_type_radio_buttons = RadioButtons(
Expand Down Expand Up @@ -543,9 +572,10 @@ def _handle_backend_change(backend_type):
self.widget_n_jobs)

def _add_drive_button_clicked(b):
location = self.widget_location_selection.value.lower()
return self.add_drive_widget(
self.widget_drive_type_selection.value,
self.widget_location_selection.value,
location,
)

def _delete_drives_clicked(b):
Expand Down Expand Up @@ -576,6 +606,7 @@ def _run_button_clicked(b):
self.widget_simulation_name, self._log_out, self.drive_widgets,
self.data, self.widget_dt, self.widget_tstop,
self.fig_default_params, self.widget_default_smoothing,
self.widget_min_frequency, self.widget_max_frequency,
self.widget_ntrials, self.widget_backend_selection,
self.widget_mpi_cmd, self.widget_n_jobs, self.params,
self._simulation_status_bar, self._simulation_status_contents,
Expand Down Expand Up @@ -677,11 +708,31 @@ def compose(self, return_layout=True):
If the method returns the layout object which can be rendered by
IPython.display.display() method.
"""
box_style = """
style="
background: gray;
color: white;
# font-weight: bold;
width: 290px;
padding: 0px 5px;
margin-bottom: 2px;
"
"""
simulation_box = VBox([
HTML(f"<div {box_style}>Simulation Parameters</div>"),
VBox([
self.widget_simulation_name, self.widget_tstop, self.widget_dt,
self.widget_ntrials, self.widget_default_smoothing,
self.widget_ntrials,
self.widget_backend_selection, self._backend_config_out]),
Box(layout=Layout(height="20px")),
HTML(
f"<div {box_style}'>Default Visualization Parameters</div>",
),
VBox([
self.widget_default_smoothing,
self.widget_min_frequency,
self.widget_max_frequency,
])
], layout=self.layout['config_box'])

connectivity_configuration = Tab()
Expand Down Expand Up @@ -1157,21 +1208,35 @@ def create_expanded_button(description, button_style, layout, disabled=False,
def _get_connectivity_widgets(conn_data):
"""Create connectivity box widgets from specified weight and probability"""

style = {'description_width': '150px'}
style = {}
style = {'description_width': '100px'}
sliders = list()
for receptor_name in conn_data.keys():
w_text_input = BoundedFloatText(
value=conn_data[receptor_name]['weight'], disabled=False,
continuous_update=False, min=0, max=1e6, step=0.01,
description="weight", style=style)
description="Weight:", style=style)

display_name = conn_data[receptor_name]['receptor'].upper()

map_display_names = {
'GABAA': 'GABA<sub>A</sub>',
'GABAB': 'GABA<sub>B</sub>',
}

if display_name in map_display_names:
display_name = map_display_names[display_name]

html_tab = '&emsp;'

conn_widget = VBox([
HTML(value=f"""<p>
Receptor: {conn_data[receptor_name]['receptor']}</p>"""),
w_text_input, HTML(value="<hr style='margin-bottom:5px'/>")
HTML(value=f"""<p style='margin:5px;'><b>{html_tab}{html_tab}
Receptor: {display_name}</b></p>"""),
w_text_input
])

# Add class to child Vboxes for targeted CSS
conn_widget.add_class('connectivity-subsection')

conn_widget._belongsto = {
"receptor": conn_data[receptor_name]['receptor'],
"location": conn_data[receptor_name]['location'],
Expand Down Expand Up @@ -1672,13 +1737,44 @@ def add_network_connectivity_tab(net, connectivity_out,
connectivity_textfields.append(
_get_connectivity_widgets(receptor_related_conn))

# Style the contents of the Connectivity Tab
# -------------------------------------------------------------------------

# define custom Vbox layout
# no_padding_layout = Layout(padding="0", margin="0") # unused

# Initialize sections within the Accordion

connectivity_boxes = [VBox(slider) for slider in connectivity_textfields]

# Add class to child Vboxes for targeted CSS
for box in connectivity_boxes:
box.add_class("connectivity-contents")

# Initialize the Accordion section

cell_connectivity = Accordion(children=connectivity_boxes)

# Add class to Accordion section for targeted CSS
cell_connectivity.add_class("connectivity-section")

for idx, connectivity_name in enumerate(connectivity_names):
cell_connectivity.set_title(idx, connectivity_name)

# Style the <div> automatically created around connectivity boxes
connectivity_out_style = HTML("""
<style>
/* CSS to style elements inside the Accordion */
.connectivity-section .jupyter-widget-Collapse-contents {
padding: 0px 0px 10px 0px !important;
margin: 0 !important;
}
</style>
""")

# Display the Accordion with styling
with connectivity_out:
display(cell_connectivity)
display(connectivity_out_style, cell_connectivity)

return net

Expand Down Expand Up @@ -1923,6 +2019,7 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
def run_button_clicked(widget_simulation_name, log_out, drive_widgets,
all_data, dt, tstop,
fig_default_params, widget_default_smoothing,
widget_min_frequency, widget_max_frequency,
ntrials, backend_selection,
mpi_cmd, n_jobs, params, simulation_status_bar,
simulation_status_contents, connectivity_textfields,
Expand Down Expand Up @@ -1974,12 +2071,14 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets,

viz_manager.reset_fig_config_tabs()

# update default_smoothing in gui based on widget
# update default visualization params in gui based on widget
fig_default_params['default_smoothing'] = widget_default_smoothing.value
fig_default_params['default_min_frequency'] = widget_min_frequency.value
fig_default_params['default_max_frequency'] = widget_max_frequency.value

# change default smoothing in viz_manager to mirror gui
new_default_smoothing = fig_default_params['default_smoothing']
viz_manager.fig_default_params['default_smoothing'] = new_default_smoothing
# change default visualization params in viz_manager to mirror gui
for widget, value in fig_default_params.items():
viz_manager.fig_default_params[widget] = value

viz_manager.add_figure()
fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1)
Expand Down
45 changes: 41 additions & 4 deletions hnn_core/tests/test_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def test_gui_add_drives():
_ = gui.compose()

for val_drive_type in ("Poisson", "Evoked", "Rhythmic"):
for val_location in ("distal", "proximal"):
for val_location in ("Distal", "Proximal"):
gui.delete_drive_button.click()
assert len(gui.drive_widgets) == 0

Expand All @@ -334,7 +334,9 @@ def test_gui_add_drives():

assert len(gui.drive_widgets) == 1
assert gui.drive_widgets[0]['type'] == val_drive_type
assert gui.drive_widgets[0]['location'] == val_location
# note that val_location is transformed to .lower() after the
# add_drive_button.click() action
assert gui.drive_widgets[0]['location'] == val_location.lower()
assert val_drive_type in gui.drive_widgets[0]['name']
plt.close('all')

Expand Down Expand Up @@ -1128,8 +1130,7 @@ def test_default_smoothing(setup_gui):
gui_smooth_value = gui.fig_default_params['default_smoothing']
viz_smooth_value = gui.viz_manager.fig_default_params['default_smoothing']

assert gui_smooth_value == 30
assert viz_smooth_value == 30
assert gui_smooth_value == viz_smooth_value

# update simulation name
gui.widget_simulation_name.value = 'no_smoothing'
Expand Down Expand Up @@ -1174,3 +1175,39 @@ def test_default_smoothing(setup_gui):
assert gui.viz_manager.figs[figid].axes[0].has_data()

plt.close('all')


def test_default_frequencies(setup_gui):
"""Tests that default min/max frequency are inherited correctly"""
gui = setup_gui

# check that the defaults are the same everywhere after running
# the default simulation
gui.run_button.click()

gui_min = gui.fig_default_params['default_min_frequency']
viz_min = gui.viz_manager.fig_default_params['default_min_frequency']
gui_max = gui.fig_default_params['default_max_frequency']
viz_max = gui.viz_manager.fig_default_params['default_max_frequency']

assert gui_min == viz_min
assert gui_max == viz_max

# change value of default min/max frequencies in the widget
new_min = 5
new_max = 50
gui.widget_min_frequency.value = new_min
gui.widget_max_frequency.value = new_max

# update simulation name
gui.widget_simulation_name.value = 'new_defaults'
gui.run_button.click()

# check that the new default smoothing value is set everywhere
gui_min = gui.fig_default_params['default_min_frequency']
viz_min = gui.viz_manager.fig_default_params['default_min_frequency']
gui_max = gui.fig_default_params['default_max_frequency']
viz_max = gui.viz_manager.fig_default_params['default_max_frequency']

assert gui_min == viz_min == new_min
assert gui_max == viz_max == new_max
Loading