Skip to content

Commit

Permalink
complete modularized rnn
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaozewang committed Nov 25, 2024
1 parent 6d02eff commit 71024ac
Show file tree
Hide file tree
Showing 18 changed files with 423 additions and 394 deletions.
19 changes: 3 additions & 16 deletions nn4n/criterion/firing_rate_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F


class CustomLoss(nn.Module):
def __init__(self, batch_first=True):
super().__init__()
self.batch_first = batch_first

def forward(self, **kwargs):
pass


class FiringRateLoss(CustomLoss):
class FiringRateLoss(nn.Module):
def __init__(self, metric="l2", **kwargs):
super().__init__(**kwargs)
assert metric in ["l1", "l2"], "metric must be either l1 or l2"
Expand All @@ -29,7 +19,7 @@ def forward(self, states, **kwargs):
return F.mse_loss(mean_fr, torch.zeros_like(mean_fr), reduction="mean")


class FiringRateDistLoss(CustomLoss):
class FiringRateDistLoss(nn.Module):
def __init__(self, metric="sd", **kwargs):
super().__init__(**kwargs)
valid_metrics = ["sd", "cv", "mean_ad", "max_ad"]
Expand Down Expand Up @@ -63,15 +53,12 @@ def forward(self, states, **kwargs):
return torch.max(torch.abs(mean_fr - avg_mean_fr))


class StatePredictionLoss(CustomLoss):
class StatePredictionLoss(nn.Module):
def __init__(self, tau=1, **kwargs):
super().__init__(**kwargs)
self.tau = tau

def forward(self, states, **kwargs):
if not self.batch_first:
states = states.transpose(0, 1)

# Ensure the sequence is long enough for the prediction window
assert (
states.shape[1] > self.tau
Expand Down
7 changes: 0 additions & 7 deletions nn4n/criterion/rnn_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class RNNLoss(nn.Module):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
self.batch_first = model.batch_first
if type(self.model) != CTRNN:
raise TypeError("model must be CTRNN")
self._init_losses(**kwargs)
Expand Down Expand Up @@ -103,8 +102,6 @@ def _loss_fr(self, states, **kwargs):
This compute the L2 norm (for now) of the hidden states across all timesteps and batch_size
Then take the square of the mean of the norm
"""
if not self.batch_first:
states = states.transpose(0, 1)
mean_fr = torch.mean(states, dim=(0, 1))
# return torch.pow(torch.mean(states, dim=(0, 1)), 2).mean() # this might not be correct
# return torch.norm(states, p='fro')**2/states.numel() # this might not be correct
Expand All @@ -119,8 +116,6 @@ def _loss_fr_sd(self, states, **kwargs):
Parameters:
- states: size=(batch_size, n_timesteps, hidden_size), hidden states of the network
"""
if not self.batch_first:
states = states.transpose(0, 1)
avg_fr = torch.mean(states, dim=(0, 1))
return avg_fr.std()

Expand All @@ -133,8 +128,6 @@ def _loss_fr_cv(self, states, **kwargs):
Parameters:
- states: size=(batch_size, n_timesteps, hidden_size), hidden states of the network
"""
if not self.batch_first:
states = states.transpose(0, 1)
avg_fr = torch.mean(torch.sqrt(torch.square(states)), dim=(0, 1))
return avg_fr.std() / avg_fr.mean()

Expand Down
2 changes: 1 addition & 1 deletion nn4n/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .linear_layer import LinearLayer
from .hidden_layer import HiddenLayer
from .recurrent_layer import RecurrentLayer
from .rnn import RNN
from .rnn import RNN
286 changes: 4 additions & 282 deletions nn4n/layer/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,295 +6,17 @@

class BaseLayer(nn.Module):
"""
Linear Layer with optional sparsity, excitatory/inhibitory, and plasticity constraints.
The layer is initialized by passing specs in layer_struct.
Required keywords in layer_struct:
- input_dim: dimension of input
- output_dim: dimension of output
- weight: weight matrix init method/init weight matrix, default: 'uniform'
- bias: bias vector init method/init bias vector, default: 'uniform'
- sparsity_mask: mask for sparse connectivity
- ei_mask: mask for Dale's law
- plasticity_mask: mask for plasticity
nn4n Layer class
"""

def __init__(
self,
input_dim: int,
output_dim: int,
weight: str = "uniform",
bias: str = "uniform",
ei_mask: torch.Tensor = None,
sparsity_mask: torch.Tensor = None,
plasticity_mask: torch.Tensor = None,
):
def __init__(self):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.weight_dist = weight
self.bias_dist = bias
self.weight = self._generate_weight(self.weight_dist)
self.bias = self._generate_bias(self.bias_dist)
self.ei_mask = ei_mask.T if ei_mask is not None else None
self.sparsity_mask = sparsity_mask.T if sparsity_mask is not None else None
self.plasticity_mask = (
plasticity_mask.T if plasticity_mask is not None else None
)
# All unique plasticity values in the plasticity mask
self.plasticity_scales = (
torch.unique(self.plasticity_mask)
if self.plasticity_mask is not None
else None
)

self._init_trainable()
self._check_layer()

# INITIALIZATION
# ======================================================================================
@staticmethod
def _check_keys(layer_struct):
required_keys = ["input_dim", "output_dim"]
for key in required_keys:
if key not in layer_struct:
raise ValueError(f"Key '{key}' is missing in layer_struct")

valid_keys = ["input_dim", "output_dim", "weight", "bias", "ei_mask", "sparsity_mask", "plasticity_mask"]
for key in layer_struct.keys():
if key not in valid_keys:
raise ValueError(f"Key '{key}' is not a valid key in layer_struct")

@classmethod
def from_dict(cls, layer_struct):
"""
Alternative constructor to initialize LinearLayer from a dictionary.
"""
# Create an instance using the dictionary values
cls._check_keys(layer_struct)
return cls(
input_dim=layer_struct["input_dim"],
output_dim=layer_struct["output_dim"],
weight=layer_struct.get("weight", "uniform"),
bias=layer_struct.get("bias", "uniform"),
ei_mask=layer_struct.get("ei_mask"),
sparsity_mask=layer_struct.get("sparsity_mask"),
plasticity_mask=layer_struct.get("plasticity_mask"),
)

def _check_layer(self):
"""
Check if the layer is initialized properly
"""
# TODO: Implement this
pass

# INIT TRAINABLE
# ======================================================================================
def _init_trainable(self):
# Enfore constraints
self._init_constraints()
# Convert weight and bias to learnable parameters
self.weight = nn.Parameter(
self.weight, requires_grad=self.weight_dist is not None
)
self.bias = nn.Parameter(self.bias, requires_grad=self.bias_dist is not None)

def _init_constraints(self):
"""
Initialize constraints
It will also balance excitatory and inhibitory neurons
"""
if self.sparsity_mask is not None:

self.weight *= self.sparsity_mask
if self.ei_mask is not None:
# Apply Dale's law
self.weight[self.ei_mask == 1] = torch.clamp(
self.weight[self.ei_mask == 1], min=0
) # For excitatory neurons, set negative weights to 0
self.weight[self.ei_mask == -1] = torch.clamp(
self.weight[self.ei_mask == -1], max=0
) # For inhibitory neurons, set positive weights to 0

# Balance excitatory and inhibitory neurons weight magnitudes
self._balance_excitatory_inhibitory()

def _generate_bias(self, bias_init):
"""Generate random bias"""
if bias_init == "uniform":
# If uniform, let b be uniform in [-sqrt(k), sqrt(k)]
sqrt_k = torch.sqrt(torch.tensor(1 / self.input_dim))
b = torch.rand(self.output_dim) * sqrt_k
b = b * 2 - sqrt_k
elif bias_init == "normal":
b = torch.randn(self.output_dim) / torch.sqrt(torch.tensor(self.input_dim))
elif bias_init == "zero" or bias_init == None:
b = torch.zeros(self.output_dim)
elif type(bias_init) == np.ndarray:
b = torch.from_numpy(bias_init)
else:
raise NotImplementedError
return b.float()

def _generate_weight(self, weight_init):
"""Generate random weight"""
if weight_init == "uniform":
# If uniform, let w be uniform in [-sqrt(k), sqrt(k)]
sqrt_k = torch.sqrt(torch.tensor(1 / self.input_dim))
w = torch.rand(self.output_dim, self.input_dim) * sqrt_k
w = w * 2 - sqrt_k
elif weight_init == "normal":
w = torch.randn(self.output_dim, self.input_dim) / torch.sqrt(
torch.tensor(self.input_dim)
)
elif weight_init == "zero":
w = torch.zeros((self.output_dim, self.input_dim))
elif type(weight_init) == np.ndarray:
w = torch.from_numpy(weight_init)
else:
raise NotImplementedError
return w.float()

def _balance_excitatory_inhibitory(self):
"""Balance excitatory and inhibitory weights"""
scale_mat = torch.ones_like(self.weight)
ext_sum = self.weight[self.sparsity_mask == 1].sum()
inh_sum = self.weight[self.sparsity_mask == -1].sum()
if ext_sum == 0 or inh_sum == 0:
# Automatically stop balancing if one of the sums is 0
# devide by 10 to avoid recurrent explosion/decay
self.weight /= 10
else:
if ext_sum > abs(inh_sum):
_scale = abs(inh_sum).item() / ext_sum.item()
scale_mat[self.sparsity_mask == 1] = _scale
elif ext_sum < abs(inh_sum):
_scale = ext_sum.item() / abs(inh_sum).item()
scale_mat[self.sparsity_mask == -1] = _scale
# Apply scaling
self.weight *= scale_mat

# TRAINING
# ======================================================================================
def to(self, device):
"""Move the network to the device (cpu/gpu)"""
super().to(device)
if self.sparsity_mask is not None:
self.sparsity_mask = self.sparsity_mask.to(device)
if self.ei_mask is not None:
self.ei_mask = self.ei_mask.to(device)
if self.bias.requires_grad:
self.bias = self.bias.to(device)
return self

def forward(self, x):
"""
Forwardly update network
Inputs:
- x: input, shape: (batch_size, input_dim)
Returns:
- state: shape: (batch_size, hidden_size)
"""
return x.float() @ self.weight.T + self.bias

def apply_plasticity(self):
"""
Apply plasticity mask to the weight gradient
"""
with torch.no_grad():
# assume the plasticity mask are all valid and being checked in ctrnn class
for scale in self.plasticity_scales:
if self.weight.grad is not None:
self.weight.grad[self.plasticity_mask == scale] *= scale
else:
raise RuntimeError(
"Weight gradient is None, possibly because the forward loop is non-differentiable"
)

def freeze(self):
"""Freeze the layer"""
self.weight.requires_grad = False
self.bias.requires_grad = False

def unfreeze(self):
"""Unfreeze the layer"""
self.weight.requires_grad = True
self.bias.requires_grad = True

# CONSTRAINTS
# ======================================================================================
def enforce_constraints(self):
"""
Enforce constraints
The constraints are:
- sparsity_mask: mask for sparse connectivity
- ei_mask: mask for Dale's law
"""
if self.sparsity_mask is not None:
self._enforce_sparsity()
if self.ei_mask is not None:
self._enforce_ei()

def _enforce_sparsity(self):
"""Enforce sparsity"""
w = self.weight.detach().clone() * self.sparsity_mask
self.weight.data.copy_(torch.nn.Parameter(w))

def _enforce_ei(self):
"""Enforce Dale's law"""
w = self.weight.detach().clone()
w[self.ei_mask == 1] = torch.clamp(w[self.ei_mask == 1], min=0)
w[self.ei_mask == -1] = torch.clamp(w[self.ei_mask == -1], max=0)
self.weight.data.copy_(torch.nn.Parameter(w))

# HELPER FUNCTIONS
# ======================================================================================
def set_weight(self, weight):
"""Set the value of weight"""
assert (
weight.shape == self.weight.shape
), f"Weight shape mismatch, expected {self.weight.shape}, got {weight.shape}"
with torch.no_grad():
self.weight.copy_(weight)

def plot_layer(self):
"""Plot the weights matrix and distribution of each layer"""
weight = (
self.weight.cpu()
if self.weight.device != torch.device("cpu")
else self.weight
)
utils.plot_connectivity_matrix_dist(
weight.detach().numpy(),
f"Weight",
False,
self.sparsity_mask is not None,
)

def get_specs(self):
"""Print the specs of each layer"""
return {
"input_dim": self.input_dim,
"output_dim": self.output_dim,
"weight_learnable": self.weight.requires_grad,
"weight_min": self.weight.min().item(),
"weight_max": self.weight.max().item(),
"bias_learnable": self.bias.requires_grad,
"bias_min": self.bias.min().item(),
"bias_max": self.bias.max().item(),
"sparsity": (
self.sparsity_mask.sum() / self.sparsity_mask.numel()
if self.sparsity_mask is not None
else 1
)
}
pass

def print_layer(self):
"""
Print the specs of the layer
"""
utils.print_dict("Layer Specs", self.get_specs())
utils.print_dict(f"{self.__class__.__name__} layer", self.get_specs())
Loading

0 comments on commit 71024ac

Please sign in to comment.