-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor loss; add composite loss constructor
- Loading branch information
1 parent
9945383
commit 495ada9
Showing
10 changed files
with
187 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .rnn_loss import RNNLoss | ||
from .mlp_loss import MLPLoss | ||
from .firing_rate import * | ||
from .firing_rate_loss import * | ||
from .composite_loss import CompositeLoss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import torch | ||
import torch.nn as nn | ||
from .firing_rate_loss import * | ||
|
||
class CompositeLoss(nn.Module): | ||
def __init__(self, loss_cfg): | ||
""" | ||
Initializes the CompositeLoss module. | ||
Args: | ||
loss_cfg: A dictionary where the keys are unique identifiers for each loss (e.g., 'loss_fr_1') and the values are | ||
dictionaries specifying the loss type, params, and lambda weight. Example: | ||
{ | ||
'loss_fr_1': {'type': 'fr', 'params': {'metric': 'l2'}, 'lambda': 1.0}, | ||
'loss_mse': {'type': 'mse_loss', 'lambda': 1.0} | ||
} | ||
""" | ||
super().__init__() | ||
self.loss_components = {} | ||
|
||
# Mapping of loss types to their respective classes or instances | ||
loss_types = { | ||
'fr': FiringRateLoss, | ||
'fr_dist': FiringRateDistLoss, | ||
'state_pred': StatePredictionLoss, | ||
'mse': nn.MSELoss, | ||
} | ||
torch_losses = ['mse'] | ||
|
||
# Iterate over the loss_cfg to instantiate and store losses | ||
for loss_name, loss_spec in loss_cfg.items(): | ||
loss_type = loss_spec['type'] | ||
loss_params = loss_spec.get('params', {}) | ||
loss_lambda = loss_spec.get('lambda', 1.0) | ||
|
||
# Instantiate the loss function | ||
if loss_type in loss_types: | ||
loss_class = loss_types[loss_type] | ||
if loss_type in torch_losses: | ||
# If torch built-in loss, don't pass the params | ||
loss_instance = loss_class() | ||
else: | ||
# Other losses might need params | ||
loss_instance = loss_class(**loss_params) | ||
|
||
# Store the loss instance and its weight in a dictionary | ||
self.loss_components[loss_name] = (loss_instance, loss_lambda) | ||
else: | ||
raise ValueError(f"Invalid loss type '{loss_type}'. Available types are: {list(loss_types.keys())}") | ||
|
||
def forward(self, loss_input_dict): | ||
""" | ||
Forward pass that computes the total weighted loss. | ||
Args: | ||
loss_input_dict: A dictionary where keys correspond to the initialized loss identifiers (e.g., 'loss_fr_1'), | ||
and the values are dictionaries containing parameters to pass to the corresponding loss | ||
function during the forward pass (e.g., {'states': <tensor>}). | ||
""" | ||
total_loss = 0 | ||
loss_dict = {} | ||
for loss_name, (loss_fn, loss_weight) in self.loss_components.items(): | ||
# Retrieve the corresponding input for this loss from the input dictionary | ||
if loss_name in loss_input_dict: | ||
loss_inputs = loss_input_dict[loss_name] | ||
if isinstance(loss_fn, nn.MSELoss): | ||
# For MSELoss, assume the inputs are 'input' and 'target' | ||
loss_value = loss_fn(loss_inputs['input'], loss_inputs['target']) | ||
else: | ||
loss_value = loss_fn(**loss_inputs) | ||
loss_dict[loss_name] = loss_weight * loss_value | ||
total_loss += loss_dict[loss_name] | ||
else: | ||
raise KeyError(f"Loss input for '{loss_name}' not provided in forward.") | ||
|
||
return total_loss, loss_dict |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class CustomLoss(nn.Module): | ||
def __init__(self, batch_first=True): | ||
super().__init__() | ||
self.batch_first = batch_first | ||
|
||
def forward(self, **kwargs): | ||
pass | ||
|
||
|
||
class FiringRateLoss(CustomLoss): | ||
def __init__(self, metric='l2', **kwargs): | ||
super().__init__(**kwargs) | ||
assert metric in ['l1', 'l2'], "metric must be either l1 or l2" | ||
self.metric = metric | ||
|
||
def forward(self, states, **kwargs): | ||
# Calculate the mean firing rate across specified dimensions | ||
mean_fr = torch.mean(states, dim=(0, 1)) | ||
|
||
# Replace custom norm calculation with PyTorch's built-in norm | ||
if self.metric == 'l1': | ||
return F.l1_loss(mean_fr, torch.zeros_like(mean_fr), reduction='mean') | ||
else: | ||
return F.mse_loss(mean_fr, torch.zeros_like(mean_fr), reduction='mean') | ||
|
||
|
||
class FiringRateDistLoss(CustomLoss): | ||
def __init__(self, metric='sd', **kwargs): | ||
super().__init__(**kwargs) | ||
valid_metrics = ['sd', 'cv', 'mean_ad', 'max_ad'] | ||
assert metric in valid_metrics, ( | ||
"metric must be chosen from 'sd' (standard deviation), " | ||
"'cv' (coefficient of variation), 'mean_ad' (mean abs deviation), " | ||
"or 'max_ad' (max abs deviation)." | ||
) | ||
self.metric = metric | ||
|
||
def forward(self, states, **kwargs): | ||
mean_fr = torch.mean(states, dim=(0, 1)) | ||
|
||
# Standard deviation | ||
if self.metric == 'sd': | ||
return torch.std(mean_fr) | ||
|
||
# Coefficient of variation | ||
elif self.metric == 'cv': | ||
return torch.std(mean_fr) / torch.mean(mean_fr) | ||
|
||
# Mean absolute deviation | ||
elif self.metric == 'mean_ad': | ||
avg_mean_fr = torch.mean(mean_fr) | ||
# Use F.l1_loss for mean absolute deviation | ||
return F.l1_loss(mean_fr, avg_mean_fr.expand_as(mean_fr), reduction='mean') | ||
|
||
# Maximum absolute deviation | ||
elif self.metric == 'max_ad': | ||
avg_mean_fr = torch.mean(mean_fr) | ||
return torch.max(torch.abs(mean_fr - avg_mean_fr)) | ||
|
||
|
||
class StatePredictionLoss(CustomLoss): | ||
def __init__(self, tau=1, **kwargs): | ||
super().__init__(**kwargs) | ||
self.tau = tau | ||
|
||
def forward(self, states, **kwargs): | ||
if not self.batch_first: | ||
states = states.transpose(0, 1) | ||
|
||
# Ensure the sequence is long enough for the prediction window | ||
assert states.shape[1] > self.tau, "The sequence length is shorter than the prediction window." | ||
|
||
# Use MSE loss instead of manual difference calculation | ||
return F.mse_loss(states[:-self.tau], states[self.tau:], reduction='mean') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters