Skip to content

Commit

Permalink
done plasticity_masks
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaozewang committed Jan 7, 2024
1 parent b4dd1d9 commit a2069e6
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 34 deletions.
5 changes: 4 additions & 1 deletion docs/change_logs/v1.1.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,7 @@
- [ ] Test `init_state`.
- [x] Add `auto_rescale()` to `CTRNN`.
- [ ] Check _balance_excitatory_inhibitory().
- [ ] Put `self_connections` into `sparsity_masks`.
- [ ] Put `self_connections` into `sparsity_masks`.
- [ ] Rename `structure` module to `masks`.
- [x] Add function `adjust_gradients()` to accommodate the new `plasticity_masks`.
- [ ] Add warning for unnecessary call of `adjust_gradients()` when all plasticity masks are 1.
15 changes: 12 additions & 3 deletions nn4n/layer/hidden_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ def __init__(self, layer_struct, **kwargs):
# some params are for verbose printing
self.input_dim = layer_struct['input_dim']
self.output_dim = layer_struct['output_dim']
self.ei_mask = layer_struct['ei_mask']
self.sparsity_mask = layer_struct['sparsity_mask']
self.plasticity_mask = layer_struct['plasticity_mask']
self.ei_mask = layer_struct['ei_mask'].T if layer_struct['ei_mask'] is not None else None
self.sparsity_mask = layer_struct['sparsity_mask'].T if layer_struct['sparsity_mask'] is not None else None
self.plasticity_mask = layer_struct['plasticity_mask'].T if layer_struct['plasticity_mask'] is not None else None
self.plasticity_scales = torch.unique(self.plasticity_mask) # all unique plasticity values in the plasticity mask

# generate weights and bias
self.weight = self._generate_weight(layer_struct['weights'])
Expand Down Expand Up @@ -111,6 +112,12 @@ def forward(self, x):
""" Forward """
return x.float() @ self.weight.T + self.bias

def adjust_gradients(self):
with torch.no_grad():
# assume the plasticity mask are all valid and being checked in ctrnn class
for scale in self.plasticity_scales:
self.weight.grad[self.plasticity_mask == scale] *= scale

def _enforce_spec_rad(self):
""" Enforce spectral radius """
# Calculate scale
Expand Down Expand Up @@ -147,6 +154,8 @@ def to(self, device):
self.sparsity_mask = self.sparsity_mask.to(device)
if self.ei_mask is not None:
self.ei_mask = self.ei_mask.to(device)
if self.bias.requires_grad:
self.bias = self.bias.to(device)

def plot_layers(self):
""" Plot weight """
Expand Down
15 changes: 12 additions & 3 deletions nn4n/layer/linear_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ def __init__(self, layer_struct):
super().__init__()
self.input_dim = layer_struct['input_dim']
self.output_dim = layer_struct['output_dim']
self.ei_mask = layer_struct['ei_mask']
self.sparsity_mask = layer_struct['sparsity_mask']
self.plasticity_mask = layer_struct['plasticity_mask']
self.ei_mask = layer_struct['ei_mask'].T if layer_struct['ei_mask'] is not None else None
self.sparsity_mask = layer_struct['sparsity_mask'].T if layer_struct['sparsity_mask'] is not None else None
self.plasticity_mask = layer_struct['plasticity_mask'].T if layer_struct['plasticity_mask'] is not None else None
self.plasticity_scales = torch.unique(self.plasticity_mask) # all unique plasticity values in the plasticity mask

# generate weights
self.weight = self._generate_weight(layer_struct['weights'])
Expand Down Expand Up @@ -148,6 +149,12 @@ def _enforce_ei(self):
def forward(self, x):
""" Forward Pass """
return x.float() @ self.weight.T + self.bias

def adjust_gradients(self):
with torch.no_grad():
# assume the plasticity mask are all valid and being checked in ctrnn class
for scale in self.plasticity_scales:
self.weight.grad[self.plasticity_mask == scale] *= scale
# ======================================================================================

# HELPER FUNCTIONS
Expand All @@ -161,6 +168,8 @@ def to(self, device):
self.sparsity_mask = self.sparsity_mask.to(device)
if self.ei_mask is not None:
self.ei_mask = self.ei_mask.to(device)
if self.bias.requires_grad:
self.bias = self.bias.to(device)

def plot_layers(self):
# plot weight matrix
Expand Down
5 changes: 5 additions & 0 deletions nn4n/layer/recurrent_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def forward(self, input):
self.hidden_state = fr_t.detach().clone() # TODO: haven't tested this yet

return torch.stack(stacked_states, dim=0)


def adjust_gradients(self):
self.input_layer.adjust_gradients()
self.hidden_layer.adjust_gradients()
# ==================================================================================================

# HELPER FUNCTIONS
Expand Down
77 changes: 54 additions & 23 deletions nn4n/model/ctrnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import numpy as np

from nn4n.model import BaseNN
from nn4n.layer import RecurrentLayer
Expand Down Expand Up @@ -60,8 +61,8 @@ def _initialize(self, **kwargs):
# parameters that used in all layers
# base parameters
self.dims = kwargs.pop("dims", [1, 100, 1])
self.biases = kwargs.pop("biases", [None, None, None])
self.weights = kwargs.pop("weights", ['uniform', 'uniform', 'uniform'])
self.biases = kwargs.pop("biases", None)
self.weights = kwargs.pop("weights", 'uniform')

# network dynamics parameters
self.sparsity_masks = kwargs.pop("sparsity_masks", None)
Expand Down Expand Up @@ -135,37 +136,60 @@ def _check_masks(self, param, param_type, dims):

# Handle None cases
if param is None:
if param_type in ["ei_masks", "sparsity_masks", "plasticity_masks"]:
if param_type in ["ei_masks", "sparsity_masks", "plasticity_masks", "biases"]:
param = [None] * 3
else: raise ValueError(f"{param_type} cannot be None when param_type is {param_type}")
elif param is not None and type(param) != list and param_type in ["weights", "biases"]:
elif param is not None and type(param) != list and param_type in ["weights"]:
param = [param] * 3

if type(param) != list:
raise ValueError(f"{param_type} is/can not be broadcasted to a list")
if len(param) != 3:
raise ValueError(f"{param_type} is/can not be broadcasted to a list of length 3")

for i in range(3):
if param[i] is not None:
if param_type in ["ei_masks", "sparsity_masks", "plasticity_masks"]:
param[i] = self._check_array(param[i], target_dim[i], param_type, i)
if param_type == "ei_masks":
param[i] = np.where(param[i] > 0, 1, -1)
elif param_type == "sparsity_masks":
param[i] = np.where(param[i] == 0, 0, 1)
elif param_type == "plasticity_masks":
# Normalize plasticity_masks
min_val, max_val = param[i].min(), param[i].max()
param[i] = (param[i] - min_val) / (max_val - min_val)
elif param_type in ["weights", "biases"]:
self._check_distribution_or_array(param[i], target_dim_biases[i] if param_type == "biases" else target_dim[i], param_type, i)
# param_type are all legal because it is passed by non-user code
if param_type == "plasticity_masks": param = self._reformat_plas_masks(param, target_dim)
else:
# if its not plasticity_masks, then it must be a list of 3 values
for i in range(3):
if param[i] is not None:
if param_type in ["ei_masks", "sparsity_masks"]:
param[i] = self._check_array(param[i], param_type, target_dim[i], i)
if param_type == "ei_masks":
param[i] = torch.where(param[i] > 0, torch.tensor(1), torch.tensor(-1))
elif param_type == "sparsity_masks":
param[i] = torch.where(param[i] == 0, torch.tensor(0), torch.tensor(1))
elif param_type in ["weights", "biases"]:
self._check_distribution_or_array(param[i], param_type, target_dim_biases[i] if param_type == "biases" else target_dim[i], i)
return param

def _reformat_plas_masks(self, masks, target_dim):
if any(mask is not None for mask in masks):
min_plas, max_plas = [], []
for mask in masks:
if mask is not None:
min_plas.append(mask.min())
max_plas.append(mask.max())
min_plas, max_plas = min(min_plas), max(max_plas)
if min_plas != max_plas:
params = []
for i in range(3):
if masks[i] is None: params.append(torch.ones(target_dim[i]))
else:
_temp_mask = (masks[i] - min_plas) / (max_plas - min_plas)
params.append(self._check_array(_temp_mask, "plasticity_masks", target_dim[i], i))
# check the total number of unique plasticity values
plasticity_scales = torch.unique(torch.cat([param.flatten() for param in params]))
if len(plasticity_scales) > 5:
raise ValueError("The number of unique plasticity values cannot be larger than 5")
return params
return [torch.ones(target_dim[i]) for i in range(3)]

def _check_array(self, param, param_type, dim, index):
if type(param) != np.ndarray:
if type(param) == torch.Tensor: return param.numpy()
else: raise ValueError(f"{param_type}[{index}] must be a numpy array")
if type(param) != torch.Tensor:
if type(param) == np.ndarray:
param = torch.from_numpy(param)
else: raise ValueError(f"{param_type}[{index}] must be a numpy array or torch tensor")
if param.shape != dim:
raise ValueError(f"{param_type}[{index}] must be a numpy array of shape {dim}")
return param
Expand All @@ -174,8 +198,8 @@ def _check_distribution_or_array(self, param, param_type, dim, index):
if type(param) == str:
if param not in ['uniform', 'normal']:
raise ValueError(f"{param_type}[{index}] must be a string of 'uniform' or 'normal'")
elif type(param) == np.ndarray:
# its already being converted to numpy array if it is a torch tensor, so no need to check
elif type(param) == torch.Tensor:
# its already being converted to torch.Tensor, so no need to check np.ndarray case
if param.shape != dim:
raise ValueError(f"{param_type}[{index}] must be a numpy array of shape {dim}")
else:
Expand Down Expand Up @@ -269,6 +293,13 @@ def forward(self, x):
output = self.readout_layer(hidden_states.float())
return output, hidden_states

def adjust_gradients(self):
""" Update weights in the custom speed """
# no need to consider the case where plasticity_mask is None as
# it will be automatically converted to a tensor of ones in parameter initialization
self.recurrent_layer.adjust_gradients()
self.readout_layer.adjust_gradients()

def _enforce_constraints(self):
self.recurrent_layer.enforce_constraints()
self.readout_layer.enforce_constraints()
Expand Down
7 changes: 4 additions & 3 deletions nn4n/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def plot_connectivity_matrix_dist(w, title, colorbar=True, ignore_zeros=False):

fig, ax = plt.subplots(figsize=(img_width, plt_height))
ax.imshow(-w, cmap='bwr', vmin=-r, vmax=r)
ax.set_title(f'{title}')
ax.set_title(f'{title}' if not ignore_zeros else f'{title} (nonzero)')
if colorbar:
fig.colorbar(ax.imshow(-w, cmap='bwr', vmin=-r, vmax=r), ax=ax)
if w.shape[1] < 5:
Expand All @@ -46,9 +46,10 @@ def plot_connectivity_matrix_dist(w, title, colorbar=True, ignore_zeros=False):
fig, ax = plt.subplots(figsize=(img_width, hist_height))
ax.set_title(f'{title} distribution')
if ignore_zeros:
ax.hist(w[np.abs(w) < np.mean(np.abs(w))*0.1].flatten(), bins=50)
mean_nonzero = np.mean(np.abs(w)[np.abs(w) != 0])
ax.hist(w[np.abs(w) > mean_nonzero*0.001].flatten(), bins=100)
else:
ax.hist(w.flatten(), bins=50)
ax.hist(w.flatten(), bins=100)
plt.tight_layout()
plt.show()

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name='nn4n',
version='1.0.3',
version='1.1.0',
description='Neural Networks for Neuroscience Research',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit a2069e6

Please sign in to comment.