Skip to content

Commit

Permalink
use binned spike counts
Browse files Browse the repository at this point in the history
  • Loading branch information
ntolley committed Jul 17, 2024
1 parent f0e5265 commit f8c18e6
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 26 deletions.
4 changes: 3 additions & 1 deletion hnn_core/cells_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def _cell_L2Pyr(override_params, pos=(0., 0., 0), gid=0.):

if sec_name == 'soma':
section.syns = ['gabaa', 'gabab']
section.syns = ['ampa', 'nmda', 'gabaa', 'gabab']
else:
section.syns = ['ampa', 'nmda', 'gabaa', 'gabab']

Expand Down Expand Up @@ -170,7 +171,8 @@ def _cell_L5Pyr(override_params, pos=(0., 0., 0), gid=0.):
section._end_pts = end_pts[sec_name]

if sec_name == 'soma':
section.syns = ['gabaa', 'gabab']
# section.syns = ['gabaa', 'gabab']
section.syns = ['ampa', 'nmda', 'gabaa', 'gabab']
else:
section.syns = ['ampa', 'nmda', 'gabaa', 'gabab']

Expand Down
67 changes: 42 additions & 25 deletions hnn_core/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,27 @@ def update_hnn_params(neuron_net, param_dict):
syn_names = param_dict['syn_names']
for syn_name in syn_names:
mask = param_dict[f'mask_{syn_name}'] < param_dict[f'p_{syn_name}']
weights = (10 ** param_dict[f'w_{syn_name}']) * (param_dict[f'g_{syn_name}']) * mask
for nc, w in zip(neuron_net.ncs[syn_name], weights):
weights = (10 ** param_dict[f'w_{syn_name}']) * (10 ** param_dict[f'g_{syn_name}']) * mask
delays = param_dict[f'delay_{syn_name}']
for nc, w, delay in zip(neuron_net.ncs[syn_name], weights, delays):
nc.weight[0] = w
nc.delay = delay

rate_rec_list = list()
for cell in neuron_net._cells:
rate_mech = h.Rates(cell._nrn_sections['soma'](0.5))
rate_mech.tau1 = param_dict['tau1']
rate_mech.tau2 = param_dict['tau1']
cell._nrn_synapses['rates'] = rate_mech
# rate_rec_list = list()
# for cell in neuron_net._cells:
# rate_mech = h.Rates(cell._nrn_sections['soma'](0.5))
# rate_mech.tau1 = param_dict['tau1']
# rate_mech.tau2 = param_dict['tau2']
# cell._nrn_synapses['rates'] = rate_mech

nc = _PC.gid_connect(cell.gid, cell._nrn_synapses['rates'])
nc.weight[0] = 1.0
neuron_net.ncs['rates'].append(nc)
# nc = _PC.gid_connect(cell.gid, cell._nrn_synapses['rates'])
# nc.weight[0] = 1.0
# neuron_net.ncs['rates'].append(nc)

rate_rec = h.Vector()
rate_rec.record(cell._nrn_synapses['rates']._ref_g)
rate_rec_list.append(rate_rec)

net.param_dict['rate_rec'] = rate_rec_list
# rate_rec = h.Vector()
# rate_rec.record(cell._nrn_synapses['rates']._ref_g)
# rate_rec_list.append(rate_rec)
# net.param_dict['rate_rec'] = rate_rec_list

update_hnn_params(neuron_net, param_dict)
# _________________________________________________________________________
Expand Down Expand Up @@ -117,8 +118,17 @@ def sine_wave(t, freq=1):

output_gids = param_dict['output_gids']

spike_counts_all = np.zeros(net._n_gids)
def update_drives():
spike_counts = np.array([neuron_net._cells[gid]._nrn_synapses['rates'].g for gid in output_gids])
# spike_counts = np.array([neuron_net._cells[gid]._nrn_synapses['rates'].g for gid in output_gids])
# spike_counts = np.array([neuron_net._drive_cells[gid - net._n_cells].nrn_nsloc.interval for gid in net.gid_ranges['p_drive1']])

spike_counts_all.fill(0.0)
indices = neuron_net._spike_times.as_numpy() > h.t - param_dict['move_dt']
values, counts = np.unique(neuron_net._spike_gids.as_numpy()[indices],
return_counts=True)
spike_counts_all[values.astype(int)] = counts
spike_counts = spike_counts_all[output_gids]

if param_dict['save_frames']:
param_dict['frame_list'].append(param_dict['env'].render())
Expand All @@ -131,7 +141,7 @@ def update_drives():
param_dict['left_output'].append(left_output)
param_dict['right_output'].append(right_output)

# Perform action
# Perform action after burnin period
observation, reward, terminated, truncated, info = param_dict['env'].step(move)

param_dict['total_reward'] += reward
Expand All @@ -141,23 +151,30 @@ def update_drives():

# observation = np.array([sine_wave(h.t / 1000, freq=1), 0, 0, 0]) # testing tuning
input_intensity = param_dict['gaussian_tuning'](observation, param_dict['tuning'], param_dict['tuning_sigma'])
input_intensity /= param_dict['tuning_denom']
input_intensity = 1000.0 / np.clip(np.sum(input_intensity, axis=1) * param_dict['max_freq'], 1e-8, 1e8)
input_intensity = input_intensity / param_dict['tuning_denom']
input_intensity = 1000.0 / np.clip(np.sum(input_intensity, axis=1) * param_dict['max_freq'], 1e-6, 1000)

# For testing purposes
# input_intensity = 1000.0 / np.clip(sine_wave(h.t / 1000, freq=1) * param_dict['max_freq'], 1e-6, 1000)

# assert len(neuron_net._drive_cells) == len(input_intensity)
for drive_cell, cell_intensity in zip(neuron_net._drive_cells, input_intensity):
drive_cell.nrn_nsloc.interval = cell_intensity
if hasattr(drive_cell, 'nrn_nsloc'):
drive_cell.nrn_nsloc.interval = cell_intensity


if terminated:
param_dict['env'].reset(seed=param_dict['solution_idx']*10 + int(param_dict['num_attempts']) * 1)
param_dict['num_attempts'] += 1
#_________________________________________________________________
# for cell in neuron_net._cells:
# cell._nrn_synapses['rates'].g = 0.0

#_________________________________________________________________
if rank == 0:
# for tt in range(0, int(h.tstop), 100):
# _CVODE.event(tt, simulation_time)

for tt in range(0, int(h.tstop), int(param_dict['move_dt'])):
for tt in np.arange(300, h.tstop, param_dict['move_dt']):
_CVODE.event(tt, update_drives)

h.fcurrent()
Expand All @@ -167,9 +184,9 @@ def update_drives():

# actual simulation - run the solver
# _PC.psolve(h.tstop)
while h.t < h.tstop and param_dict['num_attempts'] < 3:
while h.t < h.tstop and param_dict['num_attempts'] < 2:
# while h.t < h.tstop:
h.fadvance(param_dict['move_dt'])
# update_drives()

_PC.barrier()

Expand Down

0 comments on commit f8c18e6

Please sign in to comment.