-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactored models into modules and models
added notebooks for testing components & integrating foolbox moved individual loaders into single utils file fixed gradient leaking issue with ensemble models fixed checkpointing errors and added checkpoint helper funcs changed double quotes to single quotes
- Loading branch information
Showing
30 changed files
with
973 additions
and
427 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
|
||
import DeepSparseCoding.utils.loaders as loaders | ||
from DeepSparseCoding.models.base_model import BaseModel | ||
from DeepSparseCoding.modules.ensemble_module import EnsembleModule | ||
|
||
|
||
class EnsembleModel(BaseModel, EnsembleModule): | ||
def setup_module(self, params): | ||
for subparams in params.ensemble_params: | ||
subparams.epoch_size = params.epoch_size | ||
subparams.batches_per_epoch = params.batches_per_epoch | ||
subparams.num_batches = params.num_batches | ||
subparams.num_val_images = params.num_val_images | ||
subparams.num_test_images = params.num_test_images | ||
subparams.data_shape = params.data_shape | ||
super(EnsembleModel, self).setup_ensemble_module(params) | ||
self.submodel_classes = [] | ||
for submodel_params in self.params.ensemble_params: | ||
self.submodel_classes.append(loaders.load_model_class( | ||
submodel_params.model_type, | ||
self.params.lib_root_dir)) | ||
|
||
def setup_optimizer(self): | ||
for module in self: | ||
module.optimizer = self.get_optimizer( | ||
optimizer_params=module.params, | ||
trainable_variables=module.parameters()) | ||
module.scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||
module.optimizer, | ||
milestones=module.params.optimizer.milestones, | ||
gamma=module.params.optimizer.lr_decay_rate) | ||
|
||
def preprocess_data(self, data): | ||
""" | ||
We assume that only the first submodel will be preprocessing the input data | ||
""" | ||
submodule = self.__getitem__(0) | ||
return self.submodel_classes[0].preprocess_data(submodule, data) | ||
|
||
def get_total_loss(self, input_tuple, ensemble_index): | ||
submodule = self.__getitem__(ensemble_index) | ||
submodel_class = self.submodel_classes[ensemble_index] | ||
return submodel_class.get_total_loss(submodule, input_tuple) | ||
|
||
def generate_update_dict(self, input_data, input_labels=None, batch_step=0): | ||
update_dict = super(EnsembleModel, self).generate_update_dict(input_data, | ||
input_labels, batch_step) | ||
x = input_data.clone() # TODO: Do I need to clone it? If not then don't. | ||
for ensemble_index, submodel_class in enumerate(self.submodel_classes): | ||
submodule = self.__getitem__(ensemble_index) | ||
submodel_update_dict = submodel_class.generate_update_dict(submodule, x, | ||
input_labels, batch_step, update_dict=dict()) | ||
for key, value in submodel_update_dict.items(): | ||
if key not in ['epoch', 'batch_step']: | ||
key = submodule.params.model_type+'_'+key | ||
update_dict[key] = value | ||
x = submodule.get_encodings(x) | ||
return update_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import numpy as np | ||
import torch | ||
|
||
from DeepSparseCoding.models.base_model import BaseModel | ||
from DeepSparseCoding.modules.lca_module import LcaModule | ||
import DeepSparseCoding.modules.losses as losses | ||
|
||
|
||
class LcaModel(BaseModel, LcaModule): | ||
def get_total_loss(self, input_tuple): | ||
input_tensor, input_labels = input_tuple | ||
latents = self.get_encodings(input_tensor) | ||
recon = self.get_recon_from_latents(latents) | ||
recon_loss = losses.half_squared_l2(input_tensor, recon) | ||
sparse_loss = self.params.sparse_mult * losses.l1_norm(latents) | ||
total_loss = recon_loss + sparse_loss | ||
return total_loss | ||
|
||
def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None): | ||
if update_dict is None: | ||
update_dict = super(LcaModel, self).generate_update_dict(input_data, input_labels, batch_step) | ||
epoch = batch_step / self.params.batches_per_epoch | ||
stat_dict = { | ||
'epoch':int(epoch), | ||
'batch_step':batch_step, | ||
'train_progress':np.round(batch_step/self.params.num_batches, 3), | ||
'weight_lr':self.scheduler.get_lr()[0]} | ||
latents = self.get_encodings(input_data) | ||
recon = self.get_recon_from_latents(latents) | ||
recon_loss = losses.half_squared_l2(input_data, recon).item() | ||
sparse_loss = self.params.sparse_mult * losses.l1_norm(latents).item() | ||
stat_dict['loss_recon'] = recon_loss | ||
stat_dict['loss_sparse'] = sparse_loss | ||
stat_dict['loss_total'] = recon_loss + sparse_loss | ||
stat_dict['input_max_mean_min'] = [ | ||
input_data.max().item(), input_data.mean().item(), input_data.min().item()] | ||
stat_dict['recon_max_mean_min'] = [ | ||
recon.max().item(), recon.mean().item(), recon.min().item()] | ||
latent_nnz = torch.sum(latents != 0).item() # TODO: github issue 23907 requests torch.count_nonzero | ||
stat_dict['latents_fraction_active'] = latent_nnz / latents.numel() | ||
update_dict.update(stat_dict) | ||
return update_dict |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.