Skip to content

Commit

Permalink
adds preprocessing option 'rescale_data_to_one'
Browse files Browse the repository at this point in the history
  • Loading branch information
dpaiton committed Jun 17, 2020
1 parent b510b5f commit 3121d50
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 19 deletions.
1 change: 1 addition & 0 deletions params/base_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(self):

def set_params(self):
self.standardize_data = False
self.rescale_data_to_one = False
self.model_type = None
self.log_to_file = True
self.train_logs_per_epoch = None
Expand Down
1 change: 1 addition & 0 deletions params/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self):
self.num_val_images = 0
self.num_test_images = 0
self.standardize_data = False
self.rescale_data_to_one = False
self.num_epochs = 3
self.train_logs_per_epoch = 1

Expand Down
2 changes: 1 addition & 1 deletion tests/test_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def rescale_data_to_one(self):
for samplewise in samplewise_options:
err_msg = (f'\ninput_shape={shape}\neps={eps_val}\nsamplewise={samplewise}')
random_tensor = torch.rand(shape)
func_output = dp.standardize(random_tensor, eps=eps_val, samplewise=samplewise)
func_output = dp.rescale_data_to_one(random_tensor, eps=eps_val, samplewise=samplewise)
norm_tensor = func_output[0].numpy()
orig_min = func_output[1]
orig_max = func_output[2]
Expand Down
33 changes: 18 additions & 15 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,24 @@ def test_mnist(self):
except:
return 0
standardize_data_list = [True, False]
for standardize_data in standardize_data_list:
params = types.SimpleNamespace()
params.standardize_data = standardize_data
if(params.standardize_data):
params.eps = 1e-8
params.data_dir = self.data_dir
params.dataset = 'mnist'
params.shuffle_data = True
params.batch_size = 10000
train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)
for key, value in data_params.items():
setattr(params, key, value)
assert len(train_loader.dataset) == params.epoch_size
(data, target) = next(iter(train_loader))
assert data.numpy().shape == (params.batch_size, 28, 28, 1)
rescale_data_list = [True, False]
for standardize_data in [True, False]:
for rescale_data_to_one in [True, False]:
params = types.SimpleNamespace()
params.standardize_data = standardize_data
params.rescale_data_to_one = rescale_data_to_one
if(params.standardize_data or params.rescale_data_to_one):
params.eps = 1e-8
params.data_dir = self.data_dir
params.dataset = 'mnist'
params.shuffle_data = True
params.batch_size = 10000
train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)
for key, value in data_params.items():
setattr(params, key, value)
assert len(train_loader.dataset) == params.epoch_size
(data, target) = next(iter(train_loader))
assert data.numpy().shape == (params.batch_size, 28, 28, 1)

def test_synthetic(self):
epoch_size_list = [20, 50]
Expand Down
6 changes: 4 additions & 2 deletions utils/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,10 @@ def rescale_data_to_one(data, eps=None, samplewise=True):
if(eps is None):
eps = 1.0 / np.sqrt(data[0,...].numel())
if(samplewise):
data_min = data.view(-1, np.prod(list(range(data.ndim)[1:]))).min(axis=1, keepdims=True)
data_max = data.view(-1, np.prod(list(range(data.ndim)[1:]))).max(axis=1, keepdims=True)
data_min = torch.min(data.view(-1, np.prod(data.shape[1:])),
axis=1, keepdims=False)[0].view(-1, *[1]*(data.ndim-1))
data_max = torch.max(data.view(-1, np.prod(data.shape[1:])),
axis=1, keepdims=False)[0].view(-1, *[1]*(data.ndim-1))
else:
data_min = torch.min(data)
data_max = torch.max(data)
Expand Down
7 changes: 6 additions & 1 deletion utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import torch
from torchvision import datasets, transforms

ROOT_DIR = os.path.dirname(os.getcwd())
ROOT_DIR = os.getcwd()
while 'DeepSparseCoding' in ROOT_DIR:
ROOT_DIR = os.path.dirname(ROOT_DIR)
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)

import DeepSparseCoding.utils.data_processing as dp
Expand All @@ -23,6 +25,9 @@ def load_dataset(params):
if params.standardize_data:
preprocessing_pipeline.append(
transforms.Lambda(lambda x: dp.standardize(x, eps=params.eps)[0]))
if params.rescale_data_to_one:
preprocessing_pipeline.append(
transforms.Lambda(lambda x: dp.rescale_data_to_one(x, eps=params.eps, samplewise=True)[0]))
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(root=params.data_dir, train=True, download=True,
transform=transforms.Compose(preprocessing_pipeline)),
Expand Down

0 comments on commit 3121d50

Please sign in to comment.