From 71024ac8311fc6e8b662fc17711e9bac6cd25328 Mon Sep 17 00:00:00 2001 From: zhaozewang Date: Mon, 25 Nov 2024 14:49:04 -0500 Subject: [PATCH] complete modularized rnn --- nn4n/criterion/firing_rate_loss.py | 19 +- nn4n/criterion/rnn_loss.py | 7 - nn4n/layer/__init__.py | 2 +- nn4n/layer/base_layer.py | 286 +---------------- nn4n/layer/hidden_layer.py | 43 ++- nn4n/layer/linear_layer.py | 297 +++++++++++++++++- nn4n/layer/recurrent_layer.py | 21 +- nn4n/layer/rnn.py | 38 ++- nn4n/mask/base_mask.py | 8 +- nn4n/mask/multi_area.py | 2 +- nn4n/mask/multi_area_ei.py | 21 +- nn4n/mask/multi_io.py | 2 +- nn4n/mask/random_input.py | 2 +- nn4n/model/ctrnn.py | 11 - nn4n/utils/__init__.py | 2 +- nn4n/utils/area_manager.py | 5 +- ...{help_functions.py => helper_functions.py} | 49 +-- todo.md | 2 +- 18 files changed, 423 insertions(+), 394 deletions(-) rename nn4n/utils/{help_functions.py => helper_functions.py} (80%) diff --git a/nn4n/criterion/firing_rate_loss.py b/nn4n/criterion/firing_rate_loss.py index 3857551..ec62d3c 100644 --- a/nn4n/criterion/firing_rate_loss.py +++ b/nn4n/criterion/firing_rate_loss.py @@ -2,17 +2,7 @@ import torch.nn as nn import torch.nn.functional as F - -class CustomLoss(nn.Module): - def __init__(self, batch_first=True): - super().__init__() - self.batch_first = batch_first - - def forward(self, **kwargs): - pass - - -class FiringRateLoss(CustomLoss): +class FiringRateLoss(nn.Module): def __init__(self, metric="l2", **kwargs): super().__init__(**kwargs) assert metric in ["l1", "l2"], "metric must be either l1 or l2" @@ -29,7 +19,7 @@ def forward(self, states, **kwargs): return F.mse_loss(mean_fr, torch.zeros_like(mean_fr), reduction="mean") -class FiringRateDistLoss(CustomLoss): +class FiringRateDistLoss(nn.Module): def __init__(self, metric="sd", **kwargs): super().__init__(**kwargs) valid_metrics = ["sd", "cv", "mean_ad", "max_ad"] @@ -63,15 +53,12 @@ def forward(self, states, **kwargs): return torch.max(torch.abs(mean_fr - avg_mean_fr)) -class StatePredictionLoss(CustomLoss): +class StatePredictionLoss(nn.Module): def __init__(self, tau=1, **kwargs): super().__init__(**kwargs) self.tau = tau def forward(self, states, **kwargs): - if not self.batch_first: - states = states.transpose(0, 1) - # Ensure the sequence is long enough for the prediction window assert ( states.shape[1] > self.tau diff --git a/nn4n/criterion/rnn_loss.py b/nn4n/criterion/rnn_loss.py index 0d32cad..b2cf35f 100644 --- a/nn4n/criterion/rnn_loss.py +++ b/nn4n/criterion/rnn_loss.py @@ -35,7 +35,6 @@ class RNNLoss(nn.Module): def __init__(self, model, **kwargs): super().__init__() self.model = model - self.batch_first = model.batch_first if type(self.model) != CTRNN: raise TypeError("model must be CTRNN") self._init_losses(**kwargs) @@ -103,8 +102,6 @@ def _loss_fr(self, states, **kwargs): This compute the L2 norm (for now) of the hidden states across all timesteps and batch_size Then take the square of the mean of the norm """ - if not self.batch_first: - states = states.transpose(0, 1) mean_fr = torch.mean(states, dim=(0, 1)) # return torch.pow(torch.mean(states, dim=(0, 1)), 2).mean() # this might not be correct # return torch.norm(states, p='fro')**2/states.numel() # this might not be correct @@ -119,8 +116,6 @@ def _loss_fr_sd(self, states, **kwargs): Parameters: - states: size=(batch_size, n_timesteps, hidden_size), hidden states of the network """ - if not self.batch_first: - states = states.transpose(0, 1) avg_fr = torch.mean(states, dim=(0, 1)) return avg_fr.std() @@ -133,8 +128,6 @@ def _loss_fr_cv(self, states, **kwargs): Parameters: - states: size=(batch_size, n_timesteps, hidden_size), hidden states of the network """ - if not self.batch_first: - states = states.transpose(0, 1) avg_fr = torch.mean(torch.sqrt(torch.square(states)), dim=(0, 1)) return avg_fr.std() / avg_fr.mean() diff --git a/nn4n/layer/__init__.py b/nn4n/layer/__init__.py index 10601a5..b4d43a6 100644 --- a/nn4n/layer/__init__.py +++ b/nn4n/layer/__init__.py @@ -1,4 +1,4 @@ from .linear_layer import LinearLayer from .hidden_layer import HiddenLayer from .recurrent_layer import RecurrentLayer -from .rnn import RNN \ No newline at end of file +from .rnn import RNN diff --git a/nn4n/layer/base_layer.py b/nn4n/layer/base_layer.py index 55da8cc..56f0a89 100644 --- a/nn4n/layer/base_layer.py +++ b/nn4n/layer/base_layer.py @@ -6,295 +6,17 @@ class BaseLayer(nn.Module): """ - Linear Layer with optional sparsity, excitatory/inhibitory, and plasticity constraints. - The layer is initialized by passing specs in layer_struct. - - Required keywords in layer_struct: - - input_dim: dimension of input - - output_dim: dimension of output - - weight: weight matrix init method/init weight matrix, default: 'uniform' - - bias: bias vector init method/init bias vector, default: 'uniform' - - sparsity_mask: mask for sparse connectivity - - ei_mask: mask for Dale's law - - plasticity_mask: mask for plasticity + nn4n Layer class """ - def __init__( - self, - input_dim: int, - output_dim: int, - weight: str = "uniform", - bias: str = "uniform", - ei_mask: torch.Tensor = None, - sparsity_mask: torch.Tensor = None, - plasticity_mask: torch.Tensor = None, - ): + def __init__(self): super().__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.weight_dist = weight - self.bias_dist = bias - self.weight = self._generate_weight(self.weight_dist) - self.bias = self._generate_bias(self.bias_dist) - self.ei_mask = ei_mask.T if ei_mask is not None else None - self.sparsity_mask = sparsity_mask.T if sparsity_mask is not None else None - self.plasticity_mask = ( - plasticity_mask.T if plasticity_mask is not None else None - ) - # All unique plasticity values in the plasticity mask - self.plasticity_scales = ( - torch.unique(self.plasticity_mask) - if self.plasticity_mask is not None - else None - ) - - self._init_trainable() - self._check_layer() - - # INITIALIZATION - # ====================================================================================== - @staticmethod - def _check_keys(layer_struct): - required_keys = ["input_dim", "output_dim"] - for key in required_keys: - if key not in layer_struct: - raise ValueError(f"Key '{key}' is missing in layer_struct") - - valid_keys = ["input_dim", "output_dim", "weight", "bias", "ei_mask", "sparsity_mask", "plasticity_mask"] - for key in layer_struct.keys(): - if key not in valid_keys: - raise ValueError(f"Key '{key}' is not a valid key in layer_struct") - - @classmethod - def from_dict(cls, layer_struct): - """ - Alternative constructor to initialize LinearLayer from a dictionary. - """ - # Create an instance using the dictionary values - cls._check_keys(layer_struct) - return cls( - input_dim=layer_struct["input_dim"], - output_dim=layer_struct["output_dim"], - weight=layer_struct.get("weight", "uniform"), - bias=layer_struct.get("bias", "uniform"), - ei_mask=layer_struct.get("ei_mask"), - sparsity_mask=layer_struct.get("sparsity_mask"), - plasticity_mask=layer_struct.get("plasticity_mask"), - ) - - def _check_layer(self): - """ - Check if the layer is initialized properly - """ - # TODO: Implement this - pass - - # INIT TRAINABLE - # ====================================================================================== - def _init_trainable(self): - # Enfore constraints - self._init_constraints() - # Convert weight and bias to learnable parameters - self.weight = nn.Parameter( - self.weight, requires_grad=self.weight_dist is not None - ) - self.bias = nn.Parameter(self.bias, requires_grad=self.bias_dist is not None) - - def _init_constraints(self): - """ - Initialize constraints - It will also balance excitatory and inhibitory neurons - """ - if self.sparsity_mask is not None: - - self.weight *= self.sparsity_mask - if self.ei_mask is not None: - # Apply Dale's law - self.weight[self.ei_mask == 1] = torch.clamp( - self.weight[self.ei_mask == 1], min=0 - ) # For excitatory neurons, set negative weights to 0 - self.weight[self.ei_mask == -1] = torch.clamp( - self.weight[self.ei_mask == -1], max=0 - ) # For inhibitory neurons, set positive weights to 0 - - # Balance excitatory and inhibitory neurons weight magnitudes - self._balance_excitatory_inhibitory() - - def _generate_bias(self, bias_init): - """Generate random bias""" - if bias_init == "uniform": - # If uniform, let b be uniform in [-sqrt(k), sqrt(k)] - sqrt_k = torch.sqrt(torch.tensor(1 / self.input_dim)) - b = torch.rand(self.output_dim) * sqrt_k - b = b * 2 - sqrt_k - elif bias_init == "normal": - b = torch.randn(self.output_dim) / torch.sqrt(torch.tensor(self.input_dim)) - elif bias_init == "zero" or bias_init == None: - b = torch.zeros(self.output_dim) - elif type(bias_init) == np.ndarray: - b = torch.from_numpy(bias_init) - else: - raise NotImplementedError - return b.float() - - def _generate_weight(self, weight_init): - """Generate random weight""" - if weight_init == "uniform": - # If uniform, let w be uniform in [-sqrt(k), sqrt(k)] - sqrt_k = torch.sqrt(torch.tensor(1 / self.input_dim)) - w = torch.rand(self.output_dim, self.input_dim) * sqrt_k - w = w * 2 - sqrt_k - elif weight_init == "normal": - w = torch.randn(self.output_dim, self.input_dim) / torch.sqrt( - torch.tensor(self.input_dim) - ) - elif weight_init == "zero": - w = torch.zeros((self.output_dim, self.input_dim)) - elif type(weight_init) == np.ndarray: - w = torch.from_numpy(weight_init) - else: - raise NotImplementedError - return w.float() - - def _balance_excitatory_inhibitory(self): - """Balance excitatory and inhibitory weights""" - scale_mat = torch.ones_like(self.weight) - ext_sum = self.weight[self.sparsity_mask == 1].sum() - inh_sum = self.weight[self.sparsity_mask == -1].sum() - if ext_sum == 0 or inh_sum == 0: - # Automatically stop balancing if one of the sums is 0 - # devide by 10 to avoid recurrent explosion/decay - self.weight /= 10 - else: - if ext_sum > abs(inh_sum): - _scale = abs(inh_sum).item() / ext_sum.item() - scale_mat[self.sparsity_mask == 1] = _scale - elif ext_sum < abs(inh_sum): - _scale = ext_sum.item() / abs(inh_sum).item() - scale_mat[self.sparsity_mask == -1] = _scale - # Apply scaling - self.weight *= scale_mat - - # TRAINING - # ====================================================================================== - def to(self, device): - """Move the network to the device (cpu/gpu)""" - super().to(device) - if self.sparsity_mask is not None: - self.sparsity_mask = self.sparsity_mask.to(device) - if self.ei_mask is not None: - self.ei_mask = self.ei_mask.to(device) - if self.bias.requires_grad: - self.bias = self.bias.to(device) - return self - - def forward(self, x): - """ - Forwardly update network - - Inputs: - - x: input, shape: (batch_size, input_dim) - - Returns: - - state: shape: (batch_size, hidden_size) - """ - return x.float() @ self.weight.T + self.bias - - def apply_plasticity(self): - """ - Apply plasticity mask to the weight gradient - """ - with torch.no_grad(): - # assume the plasticity mask are all valid and being checked in ctrnn class - for scale in self.plasticity_scales: - if self.weight.grad is not None: - self.weight.grad[self.plasticity_mask == scale] *= scale - else: - raise RuntimeError( - "Weight gradient is None, possibly because the forward loop is non-differentiable" - ) - - def freeze(self): - """Freeze the layer""" - self.weight.requires_grad = False - self.bias.requires_grad = False - - def unfreeze(self): - """Unfreeze the layer""" - self.weight.requires_grad = True - self.bias.requires_grad = True - - # CONSTRAINTS - # ====================================================================================== - def enforce_constraints(self): - """ - Enforce constraints - - The constraints are: - - sparsity_mask: mask for sparse connectivity - - ei_mask: mask for Dale's law - """ - if self.sparsity_mask is not None: - self._enforce_sparsity() - if self.ei_mask is not None: - self._enforce_ei() - - def _enforce_sparsity(self): - """Enforce sparsity""" - w = self.weight.detach().clone() * self.sparsity_mask - self.weight.data.copy_(torch.nn.Parameter(w)) - - def _enforce_ei(self): - """Enforce Dale's law""" - w = self.weight.detach().clone() - w[self.ei_mask == 1] = torch.clamp(w[self.ei_mask == 1], min=0) - w[self.ei_mask == -1] = torch.clamp(w[self.ei_mask == -1], max=0) - self.weight.data.copy_(torch.nn.Parameter(w)) - - # HELPER FUNCTIONS - # ====================================================================================== - def set_weight(self, weight): - """Set the value of weight""" - assert ( - weight.shape == self.weight.shape - ), f"Weight shape mismatch, expected {self.weight.shape}, got {weight.shape}" - with torch.no_grad(): - self.weight.copy_(weight) - - def plot_layer(self): - """Plot the weights matrix and distribution of each layer""" - weight = ( - self.weight.cpu() - if self.weight.device != torch.device("cpu") - else self.weight - ) - utils.plot_connectivity_matrix_dist( - weight.detach().numpy(), - f"Weight", - False, - self.sparsity_mask is not None, - ) def get_specs(self): - """Print the specs of each layer""" - return { - "input_dim": self.input_dim, - "output_dim": self.output_dim, - "weight_learnable": self.weight.requires_grad, - "weight_min": self.weight.min().item(), - "weight_max": self.weight.max().item(), - "bias_learnable": self.bias.requires_grad, - "bias_min": self.bias.min().item(), - "bias_max": self.bias.max().item(), - "sparsity": ( - self.sparsity_mask.sum() / self.sparsity_mask.numel() - if self.sparsity_mask is not None - else 1 - ) - } + pass def print_layer(self): """ Print the specs of the layer """ - utils.print_dict("Layer Specs", self.get_specs()) + utils.print_dict(f"{self.__class__.__name__} layer", self.get_specs()) diff --git a/nn4n/layer/hidden_layer.py b/nn4n/layer/hidden_layer.py index 0640e47..7bd657d 100644 --- a/nn4n/layer/hidden_layer.py +++ b/nn4n/layer/hidden_layer.py @@ -39,11 +39,8 @@ def __init__( self.postact_noise = postact_noise self.alpha = ( torch.nn.Parameter( - torch.full((self.linear_layer.input_dim,), - alpha - ), requires_grad=True) - if learn_alpha - else alpha + torch.full((self.hidden_size,), alpha + ), requires_grad=True if learn_alpha else False) ) @property @@ -56,7 +53,7 @@ def output_dim(self) -> int: @property def hidden_size(self) -> int: - return self.output_dim + return self.linear_layer.input_dim @staticmethod def _generate_noise(shape: torch.Size, noise: float) -> torch.Tensor: @@ -69,6 +66,8 @@ def to(self, device): self.alpha = self.alpha.to(device) return self + # FORWARD + # ================================================================================= def forward( self, fr_hid_t: torch.Tensor, @@ -98,3 +97,35 @@ def forward( _postact_noise = self._generate_noise(fr_t_next.size(), self.postact_noise) fr_t_next = fr_t_next + _postact_noise return fr_t_next, v_t_next + + def enforce_constraints(self): + """ + Enforce constraints on the layer + """ + self.linear_layer.enforce_constraints() + self.input_layer.enforce_constraints() + + def apply_plasticity(self): + """ + Apply plasticity masks to the weight gradients + """ + self.linear_layer.apply_plasticity() + self.input_layer.apply_plasticity() + + def train(self): + # TODO: change the noise to regular level + pass + + def eval(self): + # TODO: change the noise to zero + pass + + # HELPER FUNCTIONS + # ====================================================================================== + def plot_layer(self, **kwargs): + """ + Plot the layer + """ + self.linear_layer.plot_layer(**kwargs) + if self.input_layer is not None: + self.input_layer.plot_layer(**kwargs) diff --git a/nn4n/layer/linear_layer.py b/nn4n/layer/linear_layer.py index ad14638..69882ee 100644 --- a/nn4n/layer/linear_layer.py +++ b/nn4n/layer/linear_layer.py @@ -1,15 +1,23 @@ import torch import torch.nn as nn - import numpy as np -import nn4n.utils as utils from .base_layer import BaseLayer +import nn4n.utils as utils class LinearLayer(BaseLayer): """ Linear Layer with optional sparsity, excitatory/inhibitory, and plasticity constraints. The layer is initialized by passing specs in layer_struct. + + Required keywords in layer_struct: + - input_dim: dimension of input + - output_dim: dimension of output + - weight: weight matrix init method/init weight matrix, default: 'uniform' + - bias: bias vector init method/init bias vector, default: 'uniform' + - sparsity_mask: mask for sparse connectivity + - ei_mask: mask for Dale's law + - plasticity_mask: mask for plasticity """ def __init__( @@ -22,21 +30,76 @@ def __init__( sparsity_mask: torch.Tensor = None, plasticity_mask: torch.Tensor = None, ): - super().__init__( - input_dim=input_dim, - output_dim=output_dim, - weight=weight, - bias=bias, - ei_mask=ei_mask, - sparsity_mask=sparsity_mask, - plasticity_mask=plasticity_mask, + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.weight_dist = weight + self.bias_dist = bias + self.weight = self._generate_weight(self.weight_dist) + self.bias = self._generate_bias(self.bias_dist) + self.ei_mask = ei_mask.T if ei_mask is not None else None + self.sparsity_mask = sparsity_mask.T if sparsity_mask is not None else None + self.plasticity_mask = ( + plasticity_mask.T if plasticity_mask is not None else None ) + # All unique plasticity values in the plasticity mask + self.plasticity_scales = ( + torch.unique(self.plasticity_mask) + if self.plasticity_mask is not None + else None + ) + self._init_trainable() # INITIALIZATION # ====================================================================================== + @classmethod + def from_dict(cls, layer_struct): + """ + Alternative constructor to initialize LinearLayer from a dictionary. + """ + # Create an instance using the dictionary values + cls._check_keys(layer_struct) + return cls( + input_dim=layer_struct["input_dim"], + output_dim=layer_struct["output_dim"], + weight=layer_struct.get("weight", "uniform"), + bias=layer_struct.get("bias", "uniform"), + ei_mask=layer_struct.get("ei_mask"), + sparsity_mask=layer_struct.get("sparsity_mask"), + plasticity_mask=layer_struct.get("plasticity_mask"), + ) + @staticmethod def _check_keys(layer_struct): - BaseLayer._check_keys(layer_struct) + required_keys = ["input_dim", "output_dim"] + for key in required_keys: + if key not in layer_struct: + raise ValueError(f"Key '{key}' is missing in layer_struct") + + valid_keys = ["input_dim", "output_dim", "weight", "bias", "ei_mask", "sparsity_mask", "plasticity_mask"] + for key in layer_struct.keys(): + if key not in valid_keys: + raise ValueError(f"Key '{key}' is not a valid key in layer_struct") + + def _check_constaint_dims(self): + """ + Check if the mask dimensions are valid + """ + if self.sparsity_mask is not None: + assert ( + self.sparsity_mask.shape == (self.output_dim, self.input_dim) + ), f"Sparsity mask shape mismatch, expected {(self.output_dim, self.input_dim)}, got {self.sparsity_mask.shape}" + self.sparsity_mask = torch.tensor(self.sparsity_mask, dtype=torch.int) + if self.ei_mask is not None: + assert ( + self.ei_mask.shape == (self.output_dim, self.input_dim) + ), f"Excitatory/Inhibitory mask shape mismatch, expected {(self.output_dim, self.input_dim)}, got {self.ei_mask.shape}" + self.ei_mask = torch.tensor(self.ei_mask, dtype=torch.float) + if self.plasticity_mask is not None: + assert ( + self.plasticity_mask.shape == (self.output_dim, self.input_dim) + ), f"Plasticity mask shape mismatch, expected {(self.output_dim, self.input_dim)}, got {self.plasticity_mask.shape}" + self.plasticity_mask = torch.tensor(self.plasticity_mask, dtype=torch.float) def auto_rescale(self, param_type): """ @@ -62,4 +125,216 @@ def auto_rescale(self, param_type): self.weight.data.copy_(mat) elif param_type == "bias": self.bias.data.copy_(mat) + + # INIT TRAINABLE + # ====================================================================================== + def _init_trainable(self): + # Enfore constraints + self._init_constraints() + # Convert weight and bias to learnable parameters + self.weight = nn.Parameter( + self.weight, requires_grad=self.weight_dist is not None + ) + self.bias = nn.Parameter(self.bias, requires_grad=self.bias_dist is not None) + + def _init_constraints(self): + """ + Initialize constraints + It will also balance excitatory and inhibitory neurons + """ + self._check_constaint_dims() + if self.sparsity_mask is not None: + self._enforce_sparsity() + if self.ei_mask is not None: + # Apply Dale's law + self.weight[self.ei_mask == 1] = torch.clamp( + self.weight[self.ei_mask == 1], min=0 + ) # For excitatory neurons, set negative weights to 0 + self.weight[self.ei_mask == -1] = torch.clamp( + self.weight[self.ei_mask == -1], max=0 + ) # For inhibitory neurons, set positive weights to 0 + + # Balance excitatory and inhibitory neurons weight magnitudes + self._balance_excitatory_inhibitory() + + def _generate_bias(self, bias_init): + """Generate random bias""" + if bias_init == "uniform": + # If uniform, let b be uniform in [-sqrt(k), sqrt(k)] + sqrt_k = torch.sqrt(torch.tensor(1 / self.input_dim)) + b = torch.rand(self.output_dim) * sqrt_k + b = b * 2 - sqrt_k + elif bias_init == "normal": + b = torch.randn(self.output_dim) / torch.sqrt(torch.tensor(self.input_dim)) + elif bias_init == "zero" or bias_init == None: + b = torch.zeros(self.output_dim) + elif type(bias_init) == np.ndarray: + b = torch.from_numpy(bias_init) + else: + raise NotImplementedError + return b.float() + + def _generate_weight(self, weight_init): + """Generate random weight""" + if weight_init == "uniform": + # If uniform, let w be uniform in [-sqrt(k), sqrt(k)] + sqrt_k = torch.sqrt(torch.tensor(1 / self.input_dim)) + w = torch.rand(self.output_dim, self.input_dim) * sqrt_k + w = w * 2 - sqrt_k + elif weight_init == "normal": + w = torch.randn(self.output_dim, self.input_dim) / torch.sqrt( + torch.tensor(self.input_dim) + ) + elif weight_init == "zero": + w = torch.zeros((self.output_dim, self.input_dim)) + elif type(weight_init) == np.ndarray: + w = torch.from_numpy(weight_init) + else: + raise NotImplementedError + return w.float() + + def _balance_excitatory_inhibitory(self): + """Balance excitatory and inhibitory weights""" + scale_mat = torch.ones_like(self.weight) + ext_sum = self.weight[self.ei_mask == 1].sum() + inh_sum = self.weight[self.ei_mask == -1].sum() + if ext_sum == 0 or inh_sum == 0: + # Automatically stop balancing if one of the sums is 0 + # devide by 10 to avoid recurrent explosion/decay + self.weight /= 10 + else: + if ext_sum > abs(inh_sum): + _scale = abs(inh_sum).item() / ext_sum.item() + scale_mat[self.ei_mask == 1] = _scale + elif ext_sum < abs(inh_sum): + _scale = ext_sum.item() / abs(inh_sum).item() + scale_mat[self.ei_mask == -1] = _scale + # Apply scaling + self.weight *= scale_mat + + # TRAINING # ====================================================================================== + def to(self, device): + """Move the network to the device (cpu/gpu)""" + super().to(device) + if self.sparsity_mask is not None: + self.sparsity_mask = self.sparsity_mask.to(device) + if self.ei_mask is not None: + self.ei_mask = self.ei_mask.to(device) + if self.bias.requires_grad: + self.bias = self.bias.to(device) + return self + + def forward(self, x): + """ + Forwardly update network + + Inputs: + - x: input, shape: (batch_size, input_dim) + + Returns: + - state: shape: (batch_size, hidden_size) + """ + return x.float() @ self.weight.T + self.bias + + def apply_plasticity(self): + """ + Apply plasticity mask to the weight gradient + """ + with torch.no_grad(): + # assume the plasticity mask are all valid and being checked in ctrnn class + for scale in self.plasticity_scales: + if self.weight.grad is not None: + self.weight.grad[self.plasticity_mask == scale] *= scale + else: + raise RuntimeError( + "Weight gradient is None, possibly because the forward loop is non-differentiable" + ) + + def freeze(self): + """Freeze the layer""" + self.weight.requires_grad = False + self.bias.requires_grad = False + + def unfreeze(self): + """Unfreeze the layer""" + self.weight.requires_grad = True + self.bias.requires_grad = True + + # CONSTRAINTS + # ====================================================================================== + def enforce_constraints(self): + """ + Enforce constraints + + The constraints are: + - sparsity_mask: mask for sparse connectivity + - ei_mask: mask for Dale's law + """ + if self.sparsity_mask is not None: + self._enforce_sparsity() + if self.ei_mask is not None: + self._enforce_ei() + + def _enforce_sparsity(self): + """Enforce sparsity""" + if self.sparsity_mask is not None: + # Apply mask directly without scaling + w = self.weight.detach().clone() + w = w * (self.sparsity_mask > 0).float() # Ensure binary masking + self.weight.data.copy_(w) + + def _enforce_ei(self): + """Enforce Dale's law""" + w = self.weight.detach().clone() + w[self.ei_mask == 1] = torch.clamp(w[self.ei_mask == 1], min=0) + w[self.ei_mask == -1] = torch.clamp(w[self.ei_mask == -1], max=0) + self.weight.data.copy_(torch.nn.Parameter(w)) + + # HELPER FUNCTIONS + # ====================================================================================== + def set_weight(self, weight): + """Set the value of weight""" + assert ( + weight.shape == self.weight.shape + ), f"Weight shape mismatch, expected {self.weight.shape}, got {weight.shape}" + with torch.no_grad(): + self.weight.copy_(weight) + + def plot_layer(self, plot_type="weight"): + """Plot the weights matrix and distribution of each layer""" + weight = ( + self.weight.cpu() + if self.weight.device != torch.device("cpu") + else self.weight + ) + if plot_type == "weight": + utils.plot_connectivity_matrix( + w=weight.detach().numpy(), + title=f"Weight", + colorbar=True, + ) + elif plot_type == "dist": + utils.plot_connectivity_distribution( + w=weight.detach().numpy(), + title=f"Weight", + ignore_zeros=self.sparsity_mask is not None, + ) + + def get_specs(self): + """Print the specs of each layer""" + return { + "input_dim": self.input_dim, + "output_dim": self.output_dim, + "weight_learnable": self.weight.requires_grad, + "weight_min": self.weight.min().item(), + "weight_max": self.weight.max().item(), + "bias_learnable": self.bias.requires_grad, + "bias_min": self.bias.min().item(), + "bias_max": self.bias.max().item(), + "sparsity": ( + self.sparsity_mask.sum() / self.sparsity_mask.numel() + if self.sparsity_mask is not None + else 1 + ) + } diff --git a/nn4n/layer/recurrent_layer.py b/nn4n/layer/recurrent_layer.py index 0860e29..eb6df49 100644 --- a/nn4n/layer/recurrent_layer.py +++ b/nn4n/layer/recurrent_layer.py @@ -3,9 +3,10 @@ from nn4n.utils import print_dict, get_activation from nn4n.layer import LinearLayer +from .base_layer import BaseLayer -class RecurrentLayer(nn.Module): +class RecurrentLayer(BaseLayer): """ Recurrent layer of the RNN. The layer is initialized by passing specs in layer_struct. @@ -117,8 +118,6 @@ def _recurrence(self, fr_t, v_t, u_t): return fr_t, v_t - # ================================================================================================== - # HELPER FUNCTIONS # ================================================================================================== def plot_layer(self, **kwargs): @@ -126,16 +125,12 @@ def plot_layer(self, **kwargs): self.input_layer.plot_layer() self.hidden_layer.plot_layer() - def print_layer(self): - """Print the weight matrix and distribution of each layer""" - param_dict = { + def get_specs(self): + """Return the specs of the layer""" + return { + "activation": self.act, "preact_noise": self.preact_noise, "postact_noise": self.postact_noise, - "activation": self.act, - "alpha": self.alpha, + "learn_alpha": self.alpha.requires_grad, + "alpha_mean": self.alpha.mean().item() if len(self.alpha) > 0 else self.alpha, } - self.input_layer.print_layer() - print_dict("Recurrence", param_dict) - self.hidden_layer.print_layer() - - # ================================================================================================== diff --git a/nn4n/layer/rnn.py b/nn4n/layer/rnn.py index 7481bd8..bca1dcd 100644 --- a/nn4n/layer/rnn.py +++ b/nn4n/layer/rnn.py @@ -70,6 +70,10 @@ def forward( Returns: - hidden_state_list: hidden states of the network, list of tensors, each element """ + # Skip constraints if the model is not in training mode + if self.training: + self.enforce_constraints() + # Initialize hidden states as a list of tensors _bs, _T, _ = x.size() hidden_states = [ @@ -94,7 +98,6 @@ def forward( for i in range(len(self.hidden_layers)) ] - # Forward pass through time for t in range(_T): for i, layer in enumerate(self.hidden_layers): @@ -113,9 +116,30 @@ def forward( return output, hidden_states + def train(self): + """ + Set pre-activation and post-activation noise to the specified value + and resume enforcing constraints + """ + for layer in self.hidden_layers: + layer.train() + self.training = True + + def eval(self): + """ + Set pre-activation and post-activation noise to zero + and pause enforcing constraints + """ + for layer in self.hidden_layers: + layer.eval() + self.training = False + def apply_plasticity(self): """Apply plasticity masks to the weight gradients""" - pass + for layer in self.hidden_layers: + layer.apply_plasticity() + if self.readout_layer is not None: + self.readout_layer.apply_plasticity() def enforce_constraints(self): """ @@ -123,15 +147,17 @@ def enforce_constraints(self): This is by default automatically called after each forward pass, but can be called manually if needed """ - pass - - # ================================================================================================== + for layer in self.hidden_layers: + layer.enforce_constraints() + if self.readout_layer is not None: + self.readout_layer.enforce_constraints() # HELPER FUNCTIONS # ================================================================================================== def plot_layer(self, **kwargs): """Plot the weight matrix and distribution of each layer""" - pass + for i, layer in enumerate(self.hidden_layers): + layer.plot_layer(**kwargs) def print_layer(self): """Print the weight matrix and distribution of each layer""" diff --git a/nn4n/mask/base_mask.py b/nn4n/mask/base_mask.py index b5f60f1..84a9daf 100644 --- a/nn4n/mask/base_mask.py +++ b/nn4n/mask/base_mask.py @@ -1,6 +1,6 @@ import numpy as np import nn4n.utils as utils -from nn4n.utils.help_functions import print_dict +from nn4n.utils.helper_functions import print_dict class BaseMask: @@ -106,10 +106,10 @@ def plot_masks(self): if self.input_mask.shape[1] > self.input_mask.shape[0] else self.input_mask.T ) - utils.plot_connectivity_matrix(input_mask_, "Input Layer Mask", False) + utils.plot_connectivity_matrix(input_mask_, "Input Layer Mask", True) if self.hidden_mask is not None: - utils.plot_connectivity_matrix(self.hidden_mask, "Hidden Layer Mask", False) + utils.plot_connectivity_matrix(self.hidden_mask, "Hidden Layer Mask", True) if self.readout_mask is not None: readout_mask_ = ( @@ -117,7 +117,7 @@ def plot_masks(self): if self.readout_mask.shape[1] > self.readout_mask.shape[0] else self.readout_mask.T ) - utils.plot_connectivity_matrix(readout_mask_, "Readout Layer Mask", False) + utils.plot_connectivity_matrix(readout_mask_, "Readout Layer Mask", True) def get_masks(self): """Return the masks""" diff --git a/nn4n/mask/multi_area.py b/nn4n/mask/multi_area.py index 7cf4029..c7869e6 100644 --- a/nn4n/mask/multi_area.py +++ b/nn4n/mask/multi_area.py @@ -1,6 +1,6 @@ import numpy as np from nn4n.mask.base_mask import BaseMask -from nn4n.utils.help_functions import print_dict +from nn4n.utils.helper_functions import print_dict class MultiArea(BaseMask): diff --git a/nn4n/mask/multi_area_ei.py b/nn4n/mask/multi_area_ei.py index 3f92938..ae06ef4 100644 --- a/nn4n/mask/multi_area_ei.py +++ b/nn4n/mask/multi_area_ei.py @@ -1,6 +1,6 @@ import numpy as np from .multi_area import MultiArea -from nn4n.utils.help_functions import print_dict +from nn4n.utils.helper_functions import print_dict class MultiAreaEI(MultiArea): @@ -16,31 +16,32 @@ def __init__(self, **kwargs): @kwarg inh_readout: whether to readout inhibitory neurons, default: True """ super().__init__(**kwargs) - # initialize parameters + # Initialize parameters self.exc_pct = kwargs.get("exc_pct", 0.8) self.inter_area_connections = kwargs.get( "inter_area_connections", [True, True, True, True] ) self.inh_readout = kwargs.get("inh_readout", True) - # check parameters and generate mask + # Check parameters and generate mask self._check_parameters() self._generate_masks() def _check_parameters(self): super()._check_parameters() - # check exc_pct + # Check exc_pct assert 0 <= self.exc_pct <= 1, "exc_pct must be between 0 and 1" - # check if inter_area_connections is list of 4 boolean + # Check if inter_area_connections is list of 4 boolean assert ( isinstance(self.inter_area_connections, list) and len(self.inter_area_connections) == 4 ), "inter_area_connections must be list of 4 boolean" + # Four elements are ‘exc-exc’, ‘exc-inh’, ‘inh-exc’, and ‘inh-inh’ for i in range(4): assert isinstance( self.inter_area_connections[i], bool ), "inter_area_connections must be list of 4 boolean" - def _generate_mask(self): + def _generate_masks(self): """ Generate the mask for the multi-area network """ @@ -118,6 +119,14 @@ def _masks_to_ei(self): for i in range(self.n_areas): self.readout_mask[:, self.inhibitory_neurons[i]] = 0 + def get_sparsity_masks(self): + masks = self.get_masks() + # The sparsity masks will be binary version of the current masks, all 1 and -1 are 1, 0 remains 0 + sparsity_masks = [] + for mask in masks: + sparsity_masks.append((mask != 0).astype(int)) + return sparsity_masks + def get_specs(self): """ Return the specifications of the network diff --git a/nn4n/mask/multi_io.py b/nn4n/mask/multi_io.py index 9176594..665d7da 100644 --- a/nn4n/mask/multi_io.py +++ b/nn4n/mask/multi_io.py @@ -1,6 +1,6 @@ import numpy as np from nn4n.mask.base_mask import BaseMask -from nn4n.utils.help_functions import print_dict +from nn4n.utils.helper_functions import print_dict class MultiIO(BaseMask): diff --git a/nn4n/mask/random_input.py b/nn4n/mask/random_input.py index bcac6f9..74bf05c 100644 --- a/nn4n/mask/random_input.py +++ b/nn4n/mask/random_input.py @@ -1,6 +1,6 @@ import numpy as np from nn4n.mask.base_mask import BaseMask -from nn4n.utils.help_functions import print_dict +from nn4n.utils.helper_functions import print_dict class RandomInput(BaseMask): diff --git a/nn4n/model/ctrnn.py b/nn4n/model/ctrnn.py index a95c080..9865e20 100644 --- a/nn4n/model/ctrnn.py +++ b/nn4n/model/ctrnn.py @@ -52,7 +52,6 @@ class CTRNN(BaseNN): - activation: activation function, default: "relu", can be "relu", "sigmoid", "tanh", "retanh" - dt: time step, default: 10 - tau: time constant, default: 100 - - batch_first: whether the input is batch first or not, default: True - biases: use bias or not for each layer, a list of 3 values or a single value if a single value is passed, it will be broadcasted to a list of 3 values, it can be: - None: no bias @@ -99,7 +98,6 @@ def _initialize(self, **kwargs): self.dims = kwargs.pop("dims", [1, 100, 1]) self.biases = kwargs.pop("biases", None) self.weights = kwargs.pop("weights", "uniform") - self.batch_first = kwargs.pop("batch_first", True) # network dynamics parameters self.sparsity_masks = kwargs.pop("sparsity_masks", None) @@ -441,8 +439,6 @@ def _build_structures(self, kwargs): } return rc_struct, out_struct - # ====================================================================================== - # FORWARD # ====================================================================================== def to(self, device): @@ -459,19 +455,12 @@ def forward(self, x: torch.Tensor, init_state: torch.Tensor = None) -> torch.Ten Inputs: - x: input, shape: (batch_size, n_timesteps, input_dim) """ - if not self.batch_first: - x = x.transpose(0, 1) - # skip constraints if the model is not in training mode if self.training: self._enforce_constraints() hidden_states = self.recurrent_layer(x, init_state) output = self.readout_layer(hidden_states.float()) - if not self.batch_first: - output = output.transpose(0, 1) - hidden_states = hidden_states.transpose(0, 1) - return output, {"h": hidden_states} def train(self): diff --git a/nn4n/utils/__init__.py b/nn4n/utils/__init__.py index 5b8421e..a0fdb4c 100644 --- a/nn4n/utils/__init__.py +++ b/nn4n/utils/__init__.py @@ -1,2 +1,2 @@ -from .help_functions import * +from .helper_functions import * from .area_manager import * diff --git a/nn4n/utils/area_manager.py b/nn4n/utils/area_manager.py index 427bf43..eef53b9 100644 --- a/nn4n/utils/area_manager.py +++ b/nn4n/utils/area_manager.py @@ -11,21 +11,18 @@ def wrapper(self, *args, **kwargs): class AreaManager: - def __init__(self, area_indices=None, batch_first=True): + def __init__(self, area_indices=None): """ Initialize the AreaManager Inputs: - area_indices: a list of indices (array) denoting a neuron's area assignment - - batch_first: whether the states are batch-first (default: True) """ if area_indices is not None: self.set_area_indices(area_indices) else: self._area_indices = None # Ensure it's None initially - self._batch_first = batch_first - def set_area_indices(self, area_indices): """ Set the area indices diff --git a/nn4n/utils/help_functions.py b/nn4n/utils/helper_functions.py similarity index 80% rename from nn4n/utils/help_functions.py rename to nn4n/utils/helper_functions.py index 2a43466..841ec43 100644 --- a/nn4n/utils/help_functions.py +++ b/nn4n/utils/helper_functions.py @@ -39,15 +39,14 @@ def get_activation(act): raise NotImplementedError -def plot_connectivity_matrix_dist(w, title, colorbar=True, ignore_zeros=False): +def plot_connectivity_matrix(w, title, colorbar=True): """ - Plot the distribution of a connectivity matrix + Plot a connectivity matrix with larger values in blue and smaller values in red. Inputs: - w: connectivity matrix, must be a numpy array or a torch tensor - title: title of the plot - colorbar: whether to show the colorbar (default: True) - - ignore_zeros: whether to ignore zeros in the distribution (needed for sparse matrices) (default: False) """ if type(w) == torch.Tensor: w = w.detach().numpy() @@ -65,11 +64,14 @@ def plot_connectivity_matrix_dist(w, title, colorbar=True, ignore_zeros=False): mat_h = img_width mat_w = img_width * hw_ratio + hist_height + # Reverse the colormap to make blue larger and red smaller + cmap = plt.cm.bwr.reversed() # Reverse 'bwr' colormap + fig, ax = plt.subplots(figsize=(mat_w, mat_h)) - ax.imshow(-w, cmap="bwr", vmin=-r, vmax=r) - ax.set_title(f"{title}" if not ignore_zeros else f"{title} (nonzero)") + cax = ax.imshow(w, cmap=cmap, vmin=-r, vmax=r) # Apply reversed colormap + ax.set_title(title) if colorbar: - fig.colorbar(ax.imshow(-w, cmap="bwr", vmin=-r, vmax=r), ax=ax) + fig.colorbar(cax, ax=ax) if w.shape[1] < 5: ax.set_xticks([]) if w.shape[0] < 5: @@ -77,7 +79,20 @@ def plot_connectivity_matrix_dist(w, title, colorbar=True, ignore_zeros=False): plt.tight_layout() plt.show() - fig, ax = plt.subplots(figsize=(img_width, hist_height)) + +def plot_connectivity_distribution(w, title, ignore_zeros=False): + """ + Plot the distribution of a connectivity matrix + + Inputs: + - w: connectivity matrix, must be a numpy array or a torch tensor + - title: title of the distribution + - ignore_zeros: whether to ignore zeros in the distribution (default: False) + """ + if type(w) == torch.Tensor: + w = w.detach().numpy() + + fig, ax = plt.subplots(figsize=(6, 2)) ax.set_title(f"{title} distribution") if ignore_zeros: mean_nonzero = np.mean(np.abs(w)[np.abs(w) != 0]) @@ -88,28 +103,18 @@ def plot_connectivity_matrix_dist(w, title, colorbar=True, ignore_zeros=False): plt.show() -def plot_connectivity_matrix(w, title, colorbar=True): +def plot_connectivity_matrix_dist(w, title, colorbar=True, ignore_zeros=False): """ - Plot a connectivity matrix + Plot the connectivity matrix and its distribution Inputs: - w: connectivity matrix, must be a numpy array or a torch tensor - title: title of the plot - colorbar: whether to show the colorbar (default: True) + - ignore_zeros: whether to ignore zeros in the distribution (needed for sparse matrices) (default: False) """ - if type(w) == torch.Tensor: - w = w.detach().numpy() - - r = np.max(np.abs(w)) - - fig, ax = plt.subplots(figsize=(6, 6 * w.shape[0] / w.shape[1])) - - ax.imshow(-w, cmap="bwr", vmin=-r, vmax=r) - ax.set_title(title) - if colorbar: - fig.colorbar(ax.imshow(w, cmap="bwr", vmin=-r, vmax=r), ax=ax) - # plt.tight_layout() - plt.show() + plot_connectivity_matrix(w, title, colorbar=colorbar) + plot_connectivity_distribution(w, title, ignore_zeros=ignore_zeros) def plot_eigenvalues(w, title): diff --git a/todo.md b/todo.md index 81354f3..7cf9a00 100644 --- a/todo.md +++ b/todo.md @@ -3,7 +3,7 @@ - [ ] Resolve the transpose issue in the model module and the mask module. - [x] Make the model use `batch_first` by default. All `batch_first` parameters are removed, let user set it in their own usage. - [x] Refactor the RNNLoss part, let it take a dictionary instead of many separate `lambda_*` parameters. --> added the `CompositeLoss` instead. -- [x] Added batch_first parameter. Adjusted to batch_first by default to follow PyTorch standard. +- [x] Adjusted the network to batch_first by default to follow PyTorch standard. - [x] Varying `alpha`. Alpha is now learnable - [x] Make `alpha` can be defined with a vector. - [ ] Need to adjust implementation for `apply_plasticity` as it won't support SSL framework.