diff --git a/docs/change_logs/v1.1.0.md b/docs/change_logs/v1.1.0.md index 51ded4e..7b304d3 100644 --- a/docs/change_logs/v1.1.0.md +++ b/docs/change_logs/v1.1.0.md @@ -30,4 +30,7 @@ - [ ] Test `init_state`. - [x] Add `auto_rescale()` to `CTRNN`. - [ ] Check _balance_excitatory_inhibitory(). -- [ ] Put `self_connections` into `sparsity_masks`. \ No newline at end of file +- [ ] Put `self_connections` into `sparsity_masks`. +- [ ] Rename `structure` module to `masks`. +- [x] Add function `adjust_gradients()` to accommodate the new `plasticity_masks`. +- [ ] Add warning for unnecessary call of `adjust_gradients()` when all plasticity masks are 1. \ No newline at end of file diff --git a/nn4n/layer/hidden_layer.py b/nn4n/layer/hidden_layer.py index 5392dac..48f4226 100644 --- a/nn4n/layer/hidden_layer.py +++ b/nn4n/layer/hidden_layer.py @@ -22,9 +22,10 @@ def __init__(self, layer_struct, **kwargs): # some params are for verbose printing self.input_dim = layer_struct['input_dim'] self.output_dim = layer_struct['output_dim'] - self.ei_mask = layer_struct['ei_mask'] - self.sparsity_mask = layer_struct['sparsity_mask'] - self.plasticity_mask = layer_struct['plasticity_mask'] + self.ei_mask = layer_struct['ei_mask'].T if layer_struct['ei_mask'] is not None else None + self.sparsity_mask = layer_struct['sparsity_mask'].T if layer_struct['sparsity_mask'] is not None else None + self.plasticity_mask = layer_struct['plasticity_mask'].T if layer_struct['plasticity_mask'] is not None else None + self.plasticity_scales = torch.unique(self.plasticity_mask) # all unique plasticity values in the plasticity mask # generate weights and bias self.weight = self._generate_weight(layer_struct['weights']) @@ -111,6 +112,12 @@ def forward(self, x): """ Forward """ return x.float() @ self.weight.T + self.bias + def adjust_gradients(self): + with torch.no_grad(): + # assume the plasticity mask are all valid and being checked in ctrnn class + for scale in self.plasticity_scales: + self.weight.grad[self.plasticity_mask == scale] *= scale + def _enforce_spec_rad(self): """ Enforce spectral radius """ # Calculate scale @@ -147,6 +154,8 @@ def to(self, device): self.sparsity_mask = self.sparsity_mask.to(device) if self.ei_mask is not None: self.ei_mask = self.ei_mask.to(device) + if self.bias.requires_grad: + self.bias = self.bias.to(device) def plot_layers(self): """ Plot weight """ diff --git a/nn4n/layer/linear_layer.py b/nn4n/layer/linear_layer.py index 69d05ec..cf1c291 100644 --- a/nn4n/layer/linear_layer.py +++ b/nn4n/layer/linear_layer.py @@ -21,9 +21,10 @@ def __init__(self, layer_struct): super().__init__() self.input_dim = layer_struct['input_dim'] self.output_dim = layer_struct['output_dim'] - self.ei_mask = layer_struct['ei_mask'] - self.sparsity_mask = layer_struct['sparsity_mask'] - self.plasticity_mask = layer_struct['plasticity_mask'] + self.ei_mask = layer_struct['ei_mask'].T if layer_struct['ei_mask'] is not None else None + self.sparsity_mask = layer_struct['sparsity_mask'].T if layer_struct['sparsity_mask'] is not None else None + self.plasticity_mask = layer_struct['plasticity_mask'].T if layer_struct['plasticity_mask'] is not None else None + self.plasticity_scales = torch.unique(self.plasticity_mask) # all unique plasticity values in the plasticity mask # generate weights self.weight = self._generate_weight(layer_struct['weights']) @@ -148,6 +149,12 @@ def _enforce_ei(self): def forward(self, x): """ Forward Pass """ return x.float() @ self.weight.T + self.bias + + def adjust_gradients(self): + with torch.no_grad(): + # assume the plasticity mask are all valid and being checked in ctrnn class + for scale in self.plasticity_scales: + self.weight.grad[self.plasticity_mask == scale] *= scale # ====================================================================================== # HELPER FUNCTIONS @@ -161,6 +168,8 @@ def to(self, device): self.sparsity_mask = self.sparsity_mask.to(device) if self.ei_mask is not None: self.ei_mask = self.ei_mask.to(device) + if self.bias.requires_grad: + self.bias = self.bias.to(device) def plot_layers(self): # plot weight matrix diff --git a/nn4n/layer/recurrent_layer.py b/nn4n/layer/recurrent_layer.py index 7f5917f..b4c5ad6 100644 --- a/nn4n/layer/recurrent_layer.py +++ b/nn4n/layer/recurrent_layer.py @@ -104,6 +104,11 @@ def forward(self, input): self.hidden_state = fr_t.detach().clone() # TODO: haven't tested this yet return torch.stack(stacked_states, dim=0) + + + def adjust_gradients(self): + self.input_layer.adjust_gradients() + self.hidden_layer.adjust_gradients() # ================================================================================================== # HELPER FUNCTIONS diff --git a/nn4n/model/ctrnn.py b/nn4n/model/ctrnn.py index 361572e..3dac1ba 100644 --- a/nn4n/model/ctrnn.py +++ b/nn4n/model/ctrnn.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import numpy as np from nn4n.model import BaseNN from nn4n.layer import RecurrentLayer @@ -60,8 +61,8 @@ def _initialize(self, **kwargs): # parameters that used in all layers # base parameters self.dims = kwargs.pop("dims", [1, 100, 1]) - self.biases = kwargs.pop("biases", [None, None, None]) - self.weights = kwargs.pop("weights", ['uniform', 'uniform', 'uniform']) + self.biases = kwargs.pop("biases", None) + self.weights = kwargs.pop("weights", 'uniform') # network dynamics parameters self.sparsity_masks = kwargs.pop("sparsity_masks", None) @@ -135,10 +136,10 @@ def _check_masks(self, param, param_type, dims): # Handle None cases if param is None: - if param_type in ["ei_masks", "sparsity_masks", "plasticity_masks"]: + if param_type in ["ei_masks", "sparsity_masks", "plasticity_masks", "biases"]: param = [None] * 3 else: raise ValueError(f"{param_type} cannot be None when param_type is {param_type}") - elif param is not None and type(param) != list and param_type in ["weights", "biases"]: + elif param is not None and type(param) != list and param_type in ["weights"]: param = [param] * 3 if type(param) != list: @@ -146,26 +147,49 @@ def _check_masks(self, param, param_type, dims): if len(param) != 3: raise ValueError(f"{param_type} is/can not be broadcasted to a list of length 3") - for i in range(3): - if param[i] is not None: - if param_type in ["ei_masks", "sparsity_masks", "plasticity_masks"]: - param[i] = self._check_array(param[i], target_dim[i], param_type, i) - if param_type == "ei_masks": - param[i] = np.where(param[i] > 0, 1, -1) - elif param_type == "sparsity_masks": - param[i] = np.where(param[i] == 0, 0, 1) - elif param_type == "plasticity_masks": - # Normalize plasticity_masks - min_val, max_val = param[i].min(), param[i].max() - param[i] = (param[i] - min_val) / (max_val - min_val) - elif param_type in ["weights", "biases"]: - self._check_distribution_or_array(param[i], target_dim_biases[i] if param_type == "biases" else target_dim[i], param_type, i) + # param_type are all legal because it is passed by non-user code + if param_type == "plasticity_masks": param = self._reformat_plas_masks(param, target_dim) + else: + # if its not plasticity_masks, then it must be a list of 3 values + for i in range(3): + if param[i] is not None: + if param_type in ["ei_masks", "sparsity_masks"]: + param[i] = self._check_array(param[i], param_type, target_dim[i], i) + if param_type == "ei_masks": + param[i] = torch.where(param[i] > 0, torch.tensor(1), torch.tensor(-1)) + elif param_type == "sparsity_masks": + param[i] = torch.where(param[i] == 0, torch.tensor(0), torch.tensor(1)) + elif param_type in ["weights", "biases"]: + self._check_distribution_or_array(param[i], param_type, target_dim_biases[i] if param_type == "biases" else target_dim[i], i) return param + def _reformat_plas_masks(self, masks, target_dim): + if any(mask is not None for mask in masks): + min_plas, max_plas = [], [] + for mask in masks: + if mask is not None: + min_plas.append(mask.min()) + max_plas.append(mask.max()) + min_plas, max_plas = min(min_plas), max(max_plas) + if min_plas != max_plas: + params = [] + for i in range(3): + if masks[i] is None: params.append(torch.ones(target_dim[i])) + else: + _temp_mask = (masks[i] - min_plas) / (max_plas - min_plas) + params.append(self._check_array(_temp_mask, "plasticity_masks", target_dim[i], i)) + # check the total number of unique plasticity values + plasticity_scales = torch.unique(torch.cat([param.flatten() for param in params])) + if len(plasticity_scales) > 5: + raise ValueError("The number of unique plasticity values cannot be larger than 5") + return params + return [torch.ones(target_dim[i]) for i in range(3)] + def _check_array(self, param, param_type, dim, index): - if type(param) != np.ndarray: - if type(param) == torch.Tensor: return param.numpy() - else: raise ValueError(f"{param_type}[{index}] must be a numpy array") + if type(param) != torch.Tensor: + if type(param) == np.ndarray: + param = torch.from_numpy(param) + else: raise ValueError(f"{param_type}[{index}] must be a numpy array or torch tensor") if param.shape != dim: raise ValueError(f"{param_type}[{index}] must be a numpy array of shape {dim}") return param @@ -174,8 +198,8 @@ def _check_distribution_or_array(self, param, param_type, dim, index): if type(param) == str: if param not in ['uniform', 'normal']: raise ValueError(f"{param_type}[{index}] must be a string of 'uniform' or 'normal'") - elif type(param) == np.ndarray: - # its already being converted to numpy array if it is a torch tensor, so no need to check + elif type(param) == torch.Tensor: + # its already being converted to torch.Tensor, so no need to check np.ndarray case if param.shape != dim: raise ValueError(f"{param_type}[{index}] must be a numpy array of shape {dim}") else: @@ -269,6 +293,13 @@ def forward(self, x): output = self.readout_layer(hidden_states.float()) return output, hidden_states + def adjust_gradients(self): + """ Update weights in the custom speed """ + # no need to consider the case where plasticity_mask is None as + # it will be automatically converted to a tensor of ones in parameter initialization + self.recurrent_layer.adjust_gradients() + self.readout_layer.adjust_gradients() + def _enforce_constraints(self): self.recurrent_layer.enforce_constraints() self.readout_layer.enforce_constraints() diff --git a/nn4n/utils.py b/nn4n/utils.py index 87bfdfb..873f985 100644 --- a/nn4n/utils.py +++ b/nn4n/utils.py @@ -33,7 +33,7 @@ def plot_connectivity_matrix_dist(w, title, colorbar=True, ignore_zeros=False): fig, ax = plt.subplots(figsize=(img_width, plt_height)) ax.imshow(-w, cmap='bwr', vmin=-r, vmax=r) - ax.set_title(f'{title}') + ax.set_title(f'{title}' if not ignore_zeros else f'{title} (nonzero)') if colorbar: fig.colorbar(ax.imshow(-w, cmap='bwr', vmin=-r, vmax=r), ax=ax) if w.shape[1] < 5: @@ -46,9 +46,10 @@ def plot_connectivity_matrix_dist(w, title, colorbar=True, ignore_zeros=False): fig, ax = plt.subplots(figsize=(img_width, hist_height)) ax.set_title(f'{title} distribution') if ignore_zeros: - ax.hist(w[np.abs(w) < np.mean(np.abs(w))*0.1].flatten(), bins=50) + mean_nonzero = np.mean(np.abs(w)[np.abs(w) != 0]) + ax.hist(w[np.abs(w) > mean_nonzero*0.001].flatten(), bins=100) else: - ax.hist(w.flatten(), bins=50) + ax.hist(w.flatten(), bins=100) plt.tight_layout() plt.show() diff --git a/setup.py b/setup.py index 70e8582..0b57603 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name='nn4n', - version='1.0.3', + version='1.1.0', description='Neural Networks for Neuroscience Research', long_description=long_description, long_description_content_type='text/markdown',