Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Masking for CNN #14

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions args.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def parse_args():
help='which checkpoint to resume from possible values ["latest", "best", epoch]')
parser.add_argument('--pretrained', action='store_true',
default=False, help='use pre-trained model')
parser.add_argument('--masking', default=False, type=bool, const=True, nargs='?',
help='Whether to use masking when training the models')

# data settings
parser.add_argument('--num-classes', default=5, type=int)
Expand Down
14 changes: 11 additions & 3 deletions commander.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from toolbox import utils, logger, metrics, losses, optimizers
import trainer
from args import parse_args
from models.cnn1d import add_mask_to_vector

from torch.utils.tensorboard import SummaryWriter

Expand Down Expand Up @@ -100,7 +101,14 @@ def main():

# init data loaders
loader = get_loader(args)
train_data = loader(data_dir=args.data_dir, split='train', phase='train')
loader_args = {}
if args.masking:
# calculate masking in the data loader phase
loader_args['custom_transforms'] = add_mask_to_vector

train_data = loader(data_dir=args.data_dir, split='train',
phase='train', num_classes=args.num_classes, **loader_args)

sample_method, cb_weights, sample_weights = None, None, None
if args.sampler:
sample_weights = torch.tensor(
Expand All @@ -115,7 +123,7 @@ def main():
shuffle=False if args.sampler else True, num_workers=args.workers, pin_memory=True,
sampler=sample_method)
val_loader = torch.utils.data.DataLoader(loader(data_dir=args.data_dir, split='val',
phase='test'), batch_size=args.batch_size,
phase='test', **loader_args), batch_size=args.batch_size,
shuffle=False, num_workers=args.workers, pin_memory=True)

exp_logger, lr = None, None
Expand Down Expand Up @@ -145,7 +153,7 @@ def main():

if args.test:
test_loader = torch.utils.data.DataLoader(loader(data_dir=args.data_dir, split='test',
phase='test', num_classes=args.num_classes), batch_size=args.batch_size,
phase='test'), batch_size=args.batch_size,
shuffle=False, num_workers=args.workers, pin_memory=True)
trainer.test(args, test_loader, model, criterion, args.start_epoch,
eval_score=metrics.accuracy_classif, output_dir=args.out_pred_dir, has_gt=True, tb_writer=tb_writer)
Expand Down
9 changes: 5 additions & 4 deletions loaders/ecg_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ def __getitem__(self, index: int):
ecg = self.data.iloc[index, :-
1].values.astype(np.float32).reshape((1, 187))
label = self.labels[index]

ecg = torch.tensor(ecg).float()
label = torch.tensor(label).long()

if self.transform is not None:
ecg = self.transform(ecg)
label = self.transform(label)
else:
ecg = torch.tensor(ecg).float()
label = torch.tensor(label).long()

return tuple([ecg, label, index])

def __len__(self):
Expand Down
2 changes: 1 addition & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def get_model(args):
print('Fetching model %s - %s ' % (arch, args.model_name))

model_generator = get_generator(arch)
model = model_generator(args.model_name, num_classes=args.num_classes)
model = model_generator(args.model_name, num_classes=args.num_classes, masking=args.masking)

return model

Expand Down
37 changes: 32 additions & 5 deletions models/cnn1d.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,37 @@
import torch.nn as nn
import torch.nn.functional as F
import torch

import numpy as np

def get_vector_mask(vector):
mask_length = np.trim_zeros(vector.numpy(), 'b').shape[0]
return torch.cat([torch.ones(mask_length), torch.zeros(vector.shape[0] - mask_length)])


def get_mask(input_batch):
result = torch.ones(input_batch.shape)

for i in range(input_batch.shape[0]):
result[i, 0, :] = get_vector_mask(input_batch[i, 0, :])

return result

def add_mask_to_vector(x):
x = x.squeeze()
mask = get_vector_mask(x)
return torch.stack([x.unsqueeze(0), mask.unsqueeze(0)], axis=0).squeeze()


class Cnn1d(nn.Module):
def __init__(self, num_classes=5, input_channels=1):
def __init__(self, num_classes=5, masking=False):
super(Cnn1d, self).__init__()

self.conv1 = nn.Conv1d(in_channels=input_channels,
out_channels=8, kernel_size=7, stride=1)
self.masking = masking
if masking:
self.conv1 = nn.Conv1d(in_channels=2, out_channels=8, kernel_size=7, stride=1)
else:
self.conv1 = nn.Conv1d(in_channels=1, out_channels=8, kernel_size=7, stride=1)
self.bn1 = nn.BatchNorm1d(8)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool1d(kernel_size=3, stride=3)
Expand All @@ -22,6 +46,9 @@ def __init__(self, num_classes=5, input_channels=1):
self.softmax = nn.LogSoftmax(dim=1)

def forward(self, x):
# masking = self.conv1.in_channels == 2
# if we're masking we don't have to do anything, DataLoader takes care of that

out = self.conv1(x)
out = self.relu(out)
out = self.maxpool(out)
Expand All @@ -39,7 +66,7 @@ def cnn1d_3(**kwargs):
return model


def cnn1d(model_name, num_classes):
def cnn1d(model_name, num_classes, **kwargs):
return{
'cnn1d_3': cnn1d_3(num_classes=num_classes, input_channels=1),
'cnn1d_3': cnn1d_3(num_classes=num_classes, masking=kwargs.get("masking", False)),
}[model_name]
Empty file added test/__init__.py
Empty file.
22 changes: 22 additions & 0 deletions test/test_cnn1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import unittest
import torch

from models.cnn1d import get_mask

class MaskingCNNTest(unittest.TestCase):
def test_mask(self):
# data is assumed to be of size (batch, 1, seq_len)
# in this case we generate seq_len == 5
input_data = torch.tensor([[[1, 2, 3, 0, 0]], [[3, 0, 1, 0, 0]], [[0, 0, 1, 2, 0]]])
print(input_data.shape)
self.assertTrue(input_data.shape[0] > 1)
self.assertTrue(input_data.shape[1] == 1)
self.assertTrue(input_data.shape[2] == 5)

result = get_mask(input_data)
self.assertEquals(result.shape, input_data.shape)

result = result.squeeze()
self.assertEqual(result[0, :].tolist(), [1, 1, 1, 0, 0])
self.assertEqual(result[1, :].tolist(), [1, 1, 1, 0, 0])
self.assertEqual(result[2, :].tolist(), [1, 1, 1, 1, 0])