Skip to content

Commit

Permalink
refactored models into modules and models
Browse files Browse the repository at this point in the history
added notebooks for testing components & integrating foolbox
moved individual loaders into single utils file
fixed gradient leaking issue with ensemble models
fixed checkpointing errors and added checkpoint helper funcs
changed double quotes to single quotes
  • Loading branch information
dpaiton committed Apr 24, 2020
1 parent e9dfdcf commit 771556c
Show file tree
Hide file tree
Showing 30 changed files with 973 additions and 427 deletions.
Empty file removed models/__init__.py
Empty file.
73 changes: 31 additions & 42 deletions models/base.py → models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from DeepSparseCoding.utils.file_utils import Logger


class BaseModel(nn.Module):
def __init__(self):
super(BaseModel, self).__init__()
self.params_loaded = False

class BaseModel(object):
def setup(self, params, logger=None):
"""
Setup required model components
Expand All @@ -25,31 +20,34 @@ def setup(self, params, logger=None):
self.log_params()
else:
self.logger = logger
self.setup_model()
self.setup_module(params)
self.setup_optimizer()

def load_params(self, params):
"""
Calculates a few extra parameters
Sets parameters as member variable
"""
params.cp_latest_filename = "latest_checkpoint_v"+params.version
if not hasattr(params, "model_out_dir"):
if not hasattr(params, 'model_out_dir'):
params.model_out_dir = os.path.join(params.out_dir, params.model_name)
params.cp_save_dir = os.path.join(params.model_out_dir, "checkpoints")
params.log_dir = os.path.join(params.model_out_dir, "logfiles")
params.save_dir = os.path.join(params.model_out_dir, "savefiles")
params.disp_dir = os.path.join(params.model_out_dir, "vis")
params.cp_save_dir = os.path.join(params.model_out_dir, 'checkpoints')
params.log_dir = os.path.join(params.model_out_dir, 'logfiles')
params.save_dir = os.path.join(params.model_out_dir, 'savefiles')
params.disp_dir = os.path.join(params.model_out_dir, 'vis')
params.batches_per_epoch = params.epoch_size / params.batch_size
params.num_batches = params.num_epochs * params.batches_per_epoch
if not hasattr(params, "cp_latest_filename"):
params.cp_latest_filename = os.path.join(params.cp_save_dir,
f'{params.model_name}_latest_checkpoint_v{params.version}.pt')
self.params = params
self.params_loaded = True

def check_params(self):
"""
Check parameters with assertions
"""
assert self.params.num_pixels == int(np.prod(self.params.data_shape))
if self.params.device is torch.device('cpu'):
print('WARNING: Model is running on the CPU')

def get_param(self, param_name):
"""
Expand Down Expand Up @@ -79,7 +77,7 @@ def init_logging(self, log_filename=None):
if self.params.log_to_file:
if log_filename is None:
log_filename = os.path.join(self.params.log_dir,
self.params.model_name+"_v"+self.params.version+".log")
self.params.model_name+'_v'+self.params.version+'.log')
self.logger = Logger(filename=log_filename, overwrite=True)
else:
self.logger = Logger(filename=None)
Expand All @@ -100,42 +98,35 @@ def log_info(self, string):
"""Log input string"""
self.logger.log_info(string)

def write_checkpoint(self, session):
def write_checkpoint(self):
"""Write checkpoints"""
base_save_path = os.path.join(self.params.cp_save_dir,
self.params.model_name+"_v"+self.params.version)
full_save_path = base_save_path+self.params.cp_latest_filename
torch.save(self.state_dict(), full_save_path)
self.logger.log_info("Full model saved in file %s"%full_save_path)
return base_save_path
torch.save(self.state_dict(), self.params.cp_latest_filename)
self.log_info('Full model saved in file %s'%self.params.cp_latest_filename)

def load_checkpoint(self, model_dir):
def load_checkpoint(self, cp_file=None):
"""
Load checkpoint model into session.
Load checkpoint
Inputs:
model_dir: String specifying the path to the checkpoint
"""
assert self.params.cp_load == True, ("cp_load must be set to true to load a checkpoint")
cp_file = os.path.join(model_dir, self.params.cp_latest_filename)
return torch.load(cp_file)

def setup_model(self):
raise NotImplementedError
if cp_file is None:
cp_file = self.params.cp_latest_filename
return self.load_state_dict(torch.load(cp_file))

def get_optimizer(self, optimizer_params, trainable_variables):
optimizer_name = optimizer_params.optimizer.name
if(optimizer_name == "sgd"):
if(optimizer_name == 'sgd'):
optimizer = torch.optim.SGD(
trainable_variables,
lr=optimizer_params.weight_lr,
weight_decay=optimizer_params.weight_decay)
elif optimizer_name == "adam":
elif optimizer_name == 'adam':
optimizer = torch.optim.Adam(
trainable_variables,
lr=optimizer_params.weight_lr,
weight_decay=optimizer_params.weight_decay)
else:
assert False, ("optimizer name must be 'sgd' or 'adam', not %s"%(optimizer_name))
assert False, ('optimizer name must be "sgd" or "adam", not %s'%(optimizer_name))
return optimizer

def setup_optimizer(self):
Expand All @@ -147,9 +138,6 @@ def setup_optimizer(self):
milestones=self.params.optimizer.milestones,
gamma=self.params.optimizer.lr_decay_rate)

def get_encodings(self):
raise NotImplementedError

def print_update(self, input_data, input_labels=None, batch_step=0):
"""
Log train progress information
Expand All @@ -160,20 +148,21 @@ def print_update(self, input_data, input_labels=None, batch_step=0):
NOTE: For the analysis code to parse update statistics, the self.js_dumpstring() call
must receive a dict object. Additionally, the self.js_dumpstring() output must be
logged with <stats> </stats> tags.
For example: logging.info("<stats>"+self.js_dumpstring(output_dictionary)+"</stats>")
For example: logging.info('<stats>'+self.js_dumpstring(output_dictionary)+'</stats>')
"""
update_dict = self.generate_update_dict(input_data, input_labels, batch_step)
js_str = self.js_dumpstring(update_dict)
self.log_info("<stats>"+js_str+"</stats>")
self.log_info('<stats>'+js_str+'</stats>')

def generate_update_dict(self, input_data, input_labels=None, batch_step=0):
def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None):
"""
Generates a dictionary to be logged in the print_update function
"""
update_dict = dict()
if update_dict is None:
update_dict = dict()
for param_name, param_var in self.named_parameters():
grad = param_var.grad
update_dict[param_name+"_grad_max_mean_min"] = [
update_dict[param_name+'_grad_max_mean_min'] = [
grad.max().item(), grad.mean().item(), grad.min().item()]
return update_dict

Expand Down
63 changes: 0 additions & 63 deletions models/ensemble.py

This file was deleted.

59 changes: 59 additions & 0 deletions models/ensemble_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch

import DeepSparseCoding.utils.loaders as loaders
from DeepSparseCoding.models.base_model import BaseModel
from DeepSparseCoding.modules.ensemble_module import EnsembleModule


class EnsembleModel(BaseModel, EnsembleModule):
def setup_module(self, params):
for subparams in params.ensemble_params:
subparams.epoch_size = params.epoch_size
subparams.batches_per_epoch = params.batches_per_epoch
subparams.num_batches = params.num_batches
subparams.num_val_images = params.num_val_images
subparams.num_test_images = params.num_test_images
subparams.data_shape = params.data_shape
super(EnsembleModel, self).setup_ensemble_module(params)
self.submodel_classes = []
for submodel_params in self.params.ensemble_params:
self.submodel_classes.append(loaders.load_model_class(
submodel_params.model_type,
self.params.lib_root_dir))

def setup_optimizer(self):
for module in self:
module.optimizer = self.get_optimizer(
optimizer_params=module.params,
trainable_variables=module.parameters())
module.scheduler = torch.optim.lr_scheduler.MultiStepLR(
module.optimizer,
milestones=module.params.optimizer.milestones,
gamma=module.params.optimizer.lr_decay_rate)

def preprocess_data(self, data):
"""
We assume that only the first submodel will be preprocessing the input data
"""
submodule = self.__getitem__(0)
return self.submodel_classes[0].preprocess_data(submodule, data)

def get_total_loss(self, input_tuple, ensemble_index):
submodule = self.__getitem__(ensemble_index)
submodel_class = self.submodel_classes[ensemble_index]
return submodel_class.get_total_loss(submodule, input_tuple)

def generate_update_dict(self, input_data, input_labels=None, batch_step=0):
update_dict = super(EnsembleModel, self).generate_update_dict(input_data,
input_labels, batch_step)
x = input_data.clone() # TODO: Do I need to clone it? If not then don't.
for ensemble_index, submodel_class in enumerate(self.submodel_classes):
submodule = self.__getitem__(ensemble_index)
submodel_update_dict = submodel_class.generate_update_dict(submodule, x,
input_labels, batch_step, update_dict=dict())
for key, value in submodel_update_dict.items():
if key not in ['epoch', 'batch_step']:
key = submodule.params.model_type+'_'+key
update_dict[key] = value
x = submodule.get_encodings(x)
return update_dict
42 changes: 42 additions & 0 deletions models/lca_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
import torch

from DeepSparseCoding.models.base_model import BaseModel
from DeepSparseCoding.modules.lca_module import LcaModule
import DeepSparseCoding.modules.losses as losses


class LcaModel(BaseModel, LcaModule):
def get_total_loss(self, input_tuple):
input_tensor, input_labels = input_tuple
latents = self.get_encodings(input_tensor)
recon = self.get_recon_from_latents(latents)
recon_loss = losses.half_squared_l2(input_tensor, recon)
sparse_loss = self.params.sparse_mult * losses.l1_norm(latents)
total_loss = recon_loss + sparse_loss
return total_loss

def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None):
if update_dict is None:
update_dict = super(LcaModel, self).generate_update_dict(input_data, input_labels, batch_step)
epoch = batch_step / self.params.batches_per_epoch
stat_dict = {
'epoch':int(epoch),
'batch_step':batch_step,
'train_progress':np.round(batch_step/self.params.num_batches, 3),
'weight_lr':self.scheduler.get_lr()[0]}
latents = self.get_encodings(input_data)
recon = self.get_recon_from_latents(latents)
recon_loss = losses.half_squared_l2(input_data, recon).item()
sparse_loss = self.params.sparse_mult * losses.l1_norm(latents).item()
stat_dict['loss_recon'] = recon_loss
stat_dict['loss_sparse'] = sparse_loss
stat_dict['loss_total'] = recon_loss + sparse_loss
stat_dict['input_max_mean_min'] = [
input_data.max().item(), input_data.mean().item(), input_data.min().item()]
stat_dict['recon_max_mean_min'] = [
recon.max().item(), recon.mean().item(), recon.min().item()]
latent_nnz = torch.sum(latents != 0).item() # TODO: github issue 23907 requests torch.count_nonzero
stat_dict['latents_fraction_active'] = latent_nnz / latents.numel()
update_dict.update(stat_dict)
return update_dict
60 changes: 0 additions & 60 deletions models/mlp.py

This file was deleted.

Loading

0 comments on commit 771556c

Please sign in to comment.