Skip to content

Commit

Permalink
feat: add celltype-rename to remaining net attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
asoplata committed Feb 7, 2025
1 parent 6e98c83 commit f29dcd4
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 22 deletions.
49 changes: 42 additions & 7 deletions hnn_core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,9 +1203,8 @@ def _add_cell_type(self, cell_name, pos, cell_template=None):
self.cell_types.update({cell_name: cell_template})
self._n_cells += len(pos)

def rename_cell(self, original_name, new_name):
"""Renames cells in the network and clears connectivity so user can
set new connections.
def rename_cell_type(self, original_name, new_name):
"""Renames cell types in the network.
Parameters
----------
Expand All @@ -1223,11 +1222,48 @@ def rename_cell(self, original_name, new_name):
# Raises error if the new name is already in cell_types
raise ValueError(f"'{new_name}' is already in cell_types!")
elif original_name in self.cell_types.keys():
# Update cell name in places where order doesn't matter
# Update cell name in dicts/etc. where order doesn't matter
# Update Network.cell_types
self.cell_types[new_name] = self.cell_types.pop(original_name)
# `Cell.name` does not currently provide a way to change its value,
# so we will leave that untouched for now.
# Update Network.pos_dict
self.pos_dict[new_name] = self.pos_dict.pop(original_name)

# Update cell name in gid_ranges: order matters for consistency!
# Update Network.external_biases
for bias_key, bias_value in self.external_biases.items():
if original_name in self.external_biases[bias_key].keys():
self.external_biases[bias_key][new_name] = \
self.external_biases[bias_key].pop(original_name)
# Update Network.external_drives
for drive_key, drive_config in self.external_drives.items():
if original_name in drive_config['target_types']:
drive_config['target_types'].remove(original_name)
drive_config['target_types'].append(new_name)
drive_config['target_types'].sort()
for config_key, config_value in drive_config.items():
if (config_key == 'dynamics'):
if 'rate_constant' in config_value.keys():
config_value['rate_constant'][new_name] = \
config_value['rate_constant'].pop(
original_name
)
elif ((config_key == 'synaptic_delays')
and (isinstance(config_value, dict))):
drive_config[config_key][new_name] = \
drive_config[config_key].pop(original_name)
elif ('weights_' in config_key):
if config_value is not None:
drive_config[config_key][new_name] = \
drive_config[config_key].pop(original_name)

# Update Network.connectivity
for connection in self.connectivity:
if connection['src_type'] == original_name:
connection['src_type'] = new_name
if connection['target_type'] == original_name:
connection['target_type'] = new_name

# Update Network.gid_ranges: order matters for consistency!
for _ in range(len(self.gid_ranges)):
name, gid_range = self.gid_ranges.popitem(last=False)
if name == original_name:
Expand All @@ -1236,7 +1272,6 @@ def rename_cell(self, original_name, new_name):
else:
# Insert the value as it is
self.gid_ranges[name] = gid_range
self.clear_connectivity()

def gid_to_type(self, gid):
"""Reverse lookup of gid to type."""
Expand Down
72 changes: 57 additions & 15 deletions hnn_core/tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from hnn_core.network_models import add_erp_drives_to_jones_model
from hnn_core.network_builder import NetworkBuilder
from hnn_core.network import pick_connection
from hnn_core.viz import plot_dipole

hnn_core_root = op.dirname(hnn_core.__file__)
params_fname = op.join(hnn_core_root, 'param', 'default.json')
Expand Down Expand Up @@ -1165,37 +1166,78 @@ def test_only_drives_specified(self, base_network, src_gids,
assert len(indices) == expected


def test_rename_cell():
def test_rename_cell(base_network):
"""Tests renaming cell function"""
params = read_params(params_fname)
net = jones_2009_model(params)
assert net.connectivity
net1, params = base_network
#
# Make a new network, rename all the cell type names, then test it
#
net2 = net1.copy()
assert net2.connectivity
# adding a list of new_names
new_names = ['L2_basket_test', 'L2_pyramidal_test',
'L5_basket_test', 'L5_pyrmidal_test']
'L5_basket_test', 'L5_pyramidal_test']
# avoid iteration through net.cell_type.keys() by creating tuples of old and new names
rename_pairs = list(zip(net.cell_types.keys(), new_names))
rename_pairs = list(zip(net2.cell_types.keys(), new_names))
for original_name, new_name in rename_pairs:
net.rename_cell(original_name, new_name)
net2.rename_cell_type(original_name, new_name)
for new_name in new_names:
assert new_name in net.cell_types.keys()
assert new_name in net.pos_dict.keys()
assert original_name not in net.cell_types.keys()
assert original_name not in net.pos_dict.keys()
assert not net.connectivity
assert new_name in net2.cell_types.keys()
assert new_name in net2.pos_dict.keys()
assert original_name not in net2.cell_types.keys()
assert original_name not in net2.pos_dict.keys()
# Tests for non-existent original_name
original_name = 'original_name'
with pytest.raises(ValueError,
match=f"'{original_name}' is not in cell_types!"):
net.rename_cell('original_name', 'L2_basket_2')
net2.rename_cell_type('original_name', 'L2_basket_2')

# Test for already existing new_name
new_name = 'L2_basket_test'
with pytest.raises(ValueError,
match=f"'{new_name}' is already in cell_types!"):
net.rename_cell('L2_basket_test', new_name)
net2.rename_cell_type('L2_basket_test', new_name)

# Tests for non-string new_name
new_name = 5
with pytest.raises(TypeError, match="new_name must be an instance of str"):
net.rename_cell('L2_basket_test', 5)
net2.rename_cell_type('L2_basket_test', 5)

#
# Make another new network, but rename all the celltypes back to their old
# names, then test that everything works the same
#
net3 = net2.copy()
old_names = ['L2_basket', 'L2_pyramidal',
'L5_basket', 'L5_pyramidal']
rename_pairs = list(zip(net3.cell_types.keys(), old_names))
for new_name, old_name in rename_pairs:
net3.rename_cell_type(new_name, old_name)
for old_name in old_names:
assert old_name in net3.cell_types.keys()
assert old_name in net3.pos_dict.keys()
assert new_name not in net3.cell_types.keys()
assert new_name not in net3.pos_dict.keys()

assert net3 == net1

#
# Test that the networks actually run
#
dpls1 = simulate_dipole(net1, tstop=100., n_trials=1)
plot_dipole(dpls1, show=False)

dpls2 = simulate_dipole(net2, tstop=100., n_trials=1)
plot_dipole(dpls2, show=False)

dpls3 = simulate_dipole(net2, tstop=100., n_trials=1)
plot_dipole(dpls3, show=False)

# Test the other main network we use for testing
net4 = hnn_core.hnn_io.read_network_configuration(
op.join(hnn_core_root, 'tests', 'assets', 'jones2009_3x3_drives.json'))
rename_pairs = list(zip(net4.cell_types.keys(), new_names))
for original_name, new_name in rename_pairs:
net4.rename_cell_type(original_name, new_name)
dpls4 = simulate_dipole(net4, tstop=100., n_trials=1)
plot_dipole(dpls4, show=False)

0 comments on commit f29dcd4

Please sign in to comment.