Skip to content

Commit

Permalink
refactor loss; add composite loss constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaozewang committed Oct 6, 2024
1 parent 9945383 commit 495ada9
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 34 deletions.
3 changes: 2 additions & 1 deletion nn4n/criterion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .rnn_loss import RNNLoss
from .mlp_loss import MLPLoss
from .firing_rate import *
from .firing_rate_loss import *
from .composite_loss import CompositeLoss
76 changes: 76 additions & 0 deletions nn4n/criterion/composite_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
import torch.nn as nn
from .firing_rate_loss import *

class CompositeLoss(nn.Module):
def __init__(self, loss_cfg):
"""
Initializes the CompositeLoss module.
Args:
loss_cfg: A dictionary where the keys are unique identifiers for each loss (e.g., 'loss_fr_1') and the values are
dictionaries specifying the loss type, params, and lambda weight. Example:
{
'loss_fr_1': {'type': 'fr', 'params': {'metric': 'l2'}, 'lambda': 1.0},
'loss_mse': {'type': 'mse_loss', 'lambda': 1.0}
}
"""
super().__init__()
self.loss_components = {}

# Mapping of loss types to their respective classes or instances
loss_types = {
'fr': FiringRateLoss,
'fr_dist': FiringRateDistLoss,
'state_pred': StatePredictionLoss,
'mse': nn.MSELoss,
}
torch_losses = ['mse']

# Iterate over the loss_cfg to instantiate and store losses
for loss_name, loss_spec in loss_cfg.items():
loss_type = loss_spec['type']
loss_params = loss_spec.get('params', {})
loss_lambda = loss_spec.get('lambda', 1.0)

# Instantiate the loss function
if loss_type in loss_types:
loss_class = loss_types[loss_type]
if loss_type in torch_losses:
# If torch built-in loss, don't pass the params
loss_instance = loss_class()
else:
# Other losses might need params
loss_instance = loss_class(**loss_params)

# Store the loss instance and its weight in a dictionary
self.loss_components[loss_name] = (loss_instance, loss_lambda)
else:
raise ValueError(f"Invalid loss type '{loss_type}'. Available types are: {list(loss_types.keys())}")

def forward(self, loss_input_dict):
"""
Forward pass that computes the total weighted loss.
Args:
loss_input_dict: A dictionary where keys correspond to the initialized loss identifiers (e.g., 'loss_fr_1'),
and the values are dictionaries containing parameters to pass to the corresponding loss
function during the forward pass (e.g., {'states': <tensor>}).
"""
total_loss = 0
loss_dict = {}
for loss_name, (loss_fn, loss_weight) in self.loss_components.items():
# Retrieve the corresponding input for this loss from the input dictionary
if loss_name in loss_input_dict:
loss_inputs = loss_input_dict[loss_name]
if isinstance(loss_fn, nn.MSELoss):
# For MSELoss, assume the inputs are 'input' and 'target'
loss_value = loss_fn(loss_inputs['input'], loss_inputs['target'])
else:
loss_value = loss_fn(**loss_inputs)
loss_dict[loss_name] = loss_weight * loss_value
total_loss += loss_dict[loss_name]
else:
raise KeyError(f"Loss input for '{loss_name}' not provided in forward.")

return total_loss, loss_dict
20 changes: 0 additions & 20 deletions nn4n/criterion/firing_rate.py

This file was deleted.

78 changes: 78 additions & 0 deletions nn4n/criterion/firing_rate_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class CustomLoss(nn.Module):
def __init__(self, batch_first=True):
super().__init__()
self.batch_first = batch_first

def forward(self, **kwargs):
pass


class FiringRateLoss(CustomLoss):
def __init__(self, metric='l2', **kwargs):
super().__init__(**kwargs)
assert metric in ['l1', 'l2'], "metric must be either l1 or l2"
self.metric = metric

def forward(self, states, **kwargs):
# Calculate the mean firing rate across specified dimensions
mean_fr = torch.mean(states, dim=(0, 1))

# Replace custom norm calculation with PyTorch's built-in norm
if self.metric == 'l1':
return F.l1_loss(mean_fr, torch.zeros_like(mean_fr), reduction='mean')
else:
return F.mse_loss(mean_fr, torch.zeros_like(mean_fr), reduction='mean')


class FiringRateDistLoss(CustomLoss):
def __init__(self, metric='sd', **kwargs):
super().__init__(**kwargs)
valid_metrics = ['sd', 'cv', 'mean_ad', 'max_ad']
assert metric in valid_metrics, (
"metric must be chosen from 'sd' (standard deviation), "
"'cv' (coefficient of variation), 'mean_ad' (mean abs deviation), "
"or 'max_ad' (max abs deviation)."
)
self.metric = metric

def forward(self, states, **kwargs):
mean_fr = torch.mean(states, dim=(0, 1))

# Standard deviation
if self.metric == 'sd':
return torch.std(mean_fr)

# Coefficient of variation
elif self.metric == 'cv':
return torch.std(mean_fr) / torch.mean(mean_fr)

# Mean absolute deviation
elif self.metric == 'mean_ad':
avg_mean_fr = torch.mean(mean_fr)
# Use F.l1_loss for mean absolute deviation
return F.l1_loss(mean_fr, avg_mean_fr.expand_as(mean_fr), reduction='mean')

# Maximum absolute deviation
elif self.metric == 'max_ad':
avg_mean_fr = torch.mean(mean_fr)
return torch.max(torch.abs(mean_fr - avg_mean_fr))


class StatePredictionLoss(CustomLoss):
def __init__(self, tau=1, **kwargs):
super().__init__(**kwargs)
self.tau = tau

def forward(self, states, **kwargs):
if not self.batch_first:
states = states.transpose(0, 1)

# Ensure the sequence is long enough for the prediction window
assert states.shape[1] > self.tau, "The sequence length is shorter than the prediction window."

# Use MSE loss instead of manual difference calculation
return F.mse_loss(states[:-self.tau], states[self.tau:], reduction='mean')
11 changes: 7 additions & 4 deletions nn4n/layer/linear_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,13 @@ def set_weight(self, weight):
def plot_layers(self):
""" Plot the weights matrix and distribution of each layer """
weight = self.weight.cpu() if self.weight.device != torch.device('cpu') else self.weight
if weight.size(0) < weight.size(1):
utils.plot_connectivity_matrix_dist(weight.detach().numpy(), "Weight Matrix", False, self.sparsity_mask is not None)
else:
utils.plot_connectivity_matrix_dist(weight.detach().numpy().T, "Weight Matrix (Transposed)", False, self.sparsity_mask is not None)
# if weight.size(0) < weight.size(1):
# utils.plot_connectivity_matrix_dist(weight.detach().numpy(), "Weight Matrix", False, self.sparsity_mask is not None)
# else:
# utils.plot_connectivity_matrix_dist(weight.detach().numpy().T, "Weight Matrix (Transposed)", False, self.sparsity_mask is not None)

# Disable the transpose as it sometimes causes confusion
utils.plot_connectivity_matrix_dist(weight.detach().numpy(), "Weight Matrix", False, self.sparsity_mask is not None)

def print_layers(self):
""" Print the specs of each layer """
Expand Down
9 changes: 7 additions & 2 deletions nn4n/layer/recurrent_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,22 @@ def to(self, device):
self.hidden_state = self.hidden_state.to(device)
return self

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, init_state: torch.Tensor = None) -> torch.Tensor:
"""
Forwardly update network
Inputs:
- x: input, shape: (batch_size, n_timesteps, input_dim)
- init_state: initial state of the network, shape: (batch_size, hidden_size)
Returns:
- stacked_states: hidden states of the network, shape: (batch_size, n_timesteps, hidden_size)
"""
v_t = self._reset_state().to(x.device)
if init_state is not None:
v_t = init_state.to(x.device)
else:
v_t = self._reset_state().to(x.device)

fr_t = self.activation(v_t)
# update hidden state and append to stacked_states
stacked_states = []
Expand Down
4 changes: 2 additions & 2 deletions nn4n/mask/multi_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _check_parameters(self):

# check n_areas
if isinstance(self.n_areas, int):
assert self.hidden_size % self.n_areas == 0, "hidden_size must be devideable by n_areas"
assert self.hidden_size % self.n_areas == 0, "hidden_size must be devideable by n_areas if n_areas is an int"
# create a node assignment list
node_assigment = np.zeros(self.n_areas, dtype=np.ndarray)
for i in range(self.n_areas):
Expand All @@ -41,7 +41,7 @@ def _check_parameters(self):
self.node_assigment = node_assigment
self.n_areas = len(self.n_areas)
else:
assert False, "n_areas must be int or list"
assert False, f"n_areas must be int or list, but got {type(self.n_areas)}"

if self.n_areas == 1:
self.input_areas = np.array([0])
Expand Down
4 changes: 2 additions & 2 deletions nn4n/model/ctrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def to(self, device):
self.readout_layer.to(device)
return self

def forward(self, x):
def forward(self, x: torch.Tensor, init_state: torch.Tensor = None) -> torch.Tensor:
"""
Forwardly update network
Expand All @@ -323,7 +323,7 @@ def forward(self, x):
# skip constraints if the model is not in training mode
if self.training:
self._enforce_constraints()
hidden_states = self.recurrent_layer(x)
hidden_states = self.recurrent_layer(x, init_state)
output = self.readout_layer(hidden_states.float())

if not self.batch_first:
Expand Down
11 changes: 9 additions & 2 deletions nn4n/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,16 @@ def plot_connectivity_matrix_dist(w, title, colorbar=True, ignore_zeros=False):

img_width, hist_height = 6, 2
hw_ratio = w.shape[0] / w.shape[1]
plt_height = img_width * hw_ratio + hist_height
if hw_ratio > 1:
# height > width
mat_h = img_width / hw_ratio + hist_height
mat_w = img_width
else:
# width > height
mat_h = img_width
mat_w = img_width * hw_ratio + hist_height

fig, ax = plt.subplots(figsize=(img_width, plt_height))
fig, ax = plt.subplots(figsize=(mat_w, mat_h))
ax.imshow(-w, cmap='bwr', vmin=-r, vmax=r)
ax.set_title(f'{title}' if not ignore_zeros else f'{title} (nonzero)')
if colorbar:
Expand Down
5 changes: 4 additions & 1 deletion todo.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@
- [ ] The examples need to be updated. Especially on the main branch.
- [ ] Resolve the transpose issue in the model module and the mask module.
- [ ] Make the model use `batch_first` by default.
- [ ] Refactor the RNNLoss part, let it take a dictionary instead of many separate `lambda_*` parameters.
- [ ] Refactor the RNNLoss part, let it take a dictionary instead of many separate `lambda_*` parameters.
- [ ] Varying alpha
- [ ] Need to adjust implementation for `apply_plasticity` as it won't support SSL framework.
- [ ] Change output to readout.

0 comments on commit 495ada9

Please sign in to comment.