diff --git a/args.py b/args.py index 4cfdf5e..c4f1f77 100644 --- a/args.py +++ b/args.py @@ -80,6 +80,8 @@ def parse_args(): help='which checkpoint to resume from possible values ["latest", "best", epoch]') parser.add_argument('--pretrained', action='store_true', default=False, help='use pre-trained model') + parser.add_argument('--masking', default=False, type=bool, const=True, nargs='?', + help='Whether to use masking when training the models') # data settings parser.add_argument('--num-classes', default=5, type=int) diff --git a/commander.py b/commander.py index 2b6a2c2..1009a2c 100644 --- a/commander.py +++ b/commander.py @@ -14,6 +14,7 @@ from toolbox import utils, logger, metrics, losses, optimizers import trainer from args import parse_args +from models.cnn1d import add_mask_to_vector from torch.utils.tensorboard import SummaryWriter @@ -100,7 +101,14 @@ def main(): # init data loaders loader = get_loader(args) - train_data = loader(data_dir=args.data_dir, split='train', phase='train') + loader_args = {} + if args.masking: + # calculate masking in the data loader phase + loader_args['custom_transforms'] = add_mask_to_vector + + train_data = loader(data_dir=args.data_dir, split='train', + phase='train', num_classes=args.num_classes, **loader_args) + sample_method, cb_weights, sample_weights = None, None, None if args.sampler: sample_weights = torch.tensor( @@ -115,7 +123,7 @@ def main(): shuffle=False if args.sampler else True, num_workers=args.workers, pin_memory=True, sampler=sample_method) val_loader = torch.utils.data.DataLoader(loader(data_dir=args.data_dir, split='val', - phase='test'), batch_size=args.batch_size, + phase='test', **loader_args), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) exp_logger, lr = None, None @@ -145,7 +153,7 @@ def main(): if args.test: test_loader = torch.utils.data.DataLoader(loader(data_dir=args.data_dir, split='test', - phase='test', num_classes=args.num_classes), batch_size=args.batch_size, + phase='test'), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) trainer.test(args, test_loader, model, criterion, args.start_epoch, eval_score=metrics.accuracy_classif, output_dir=args.out_pred_dir, has_gt=True, tb_writer=tb_writer) diff --git a/loaders/ecg_loader.py b/loaders/ecg_loader.py index 2ad2368..5009402 100644 --- a/loaders/ecg_loader.py +++ b/loaders/ecg_loader.py @@ -43,12 +43,13 @@ def __getitem__(self, index: int): ecg = self.data.iloc[index, :- 1].values.astype(np.float32).reshape((1, 187)) label = self.labels[index] + + ecg = torch.tensor(ecg).float() + label = torch.tensor(label).long() + if self.transform is not None: ecg = self.transform(ecg) - label = self.transform(label) - else: - ecg = torch.tensor(ecg).float() - label = torch.tensor(label).long() + return tuple([ecg, label, index]) def __len__(self): diff --git a/models/__init__.py b/models/__init__.py index d46647a..47e505e 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -12,7 +12,7 @@ def get_model(args): print('Fetching model %s - %s ' % (arch, args.model_name)) model_generator = get_generator(arch) - model = model_generator(args.model_name, num_classes=args.num_classes) + model = model_generator(args.model_name, num_classes=args.num_classes, masking=args.masking) return model diff --git a/models/cnn1d.py b/models/cnn1d.py index df0bf1f..6b4c822 100644 --- a/models/cnn1d.py +++ b/models/cnn1d.py @@ -1,13 +1,37 @@ import torch.nn as nn import torch.nn.functional as F +import torch + +import numpy as np + +def get_vector_mask(vector): + mask_length = np.trim_zeros(vector.numpy(), 'b').shape[0] + return torch.cat([torch.ones(mask_length), torch.zeros(vector.shape[0] - mask_length)]) + + +def get_mask(input_batch): + result = torch.ones(input_batch.shape) + + for i in range(input_batch.shape[0]): + result[i, 0, :] = get_vector_mask(input_batch[i, 0, :]) + + return result + +def add_mask_to_vector(x): + x = x.squeeze() + mask = get_vector_mask(x) + return torch.stack([x.unsqueeze(0), mask.unsqueeze(0)], axis=0).squeeze() class Cnn1d(nn.Module): - def __init__(self, num_classes=5, input_channels=1): + def __init__(self, num_classes=5, masking=False): super(Cnn1d, self).__init__() - self.conv1 = nn.Conv1d(in_channels=input_channels, - out_channels=8, kernel_size=7, stride=1) + self.masking = masking + if masking: + self.conv1 = nn.Conv1d(in_channels=2, out_channels=8, kernel_size=7, stride=1) + else: + self.conv1 = nn.Conv1d(in_channels=1, out_channels=8, kernel_size=7, stride=1) self.bn1 = nn.BatchNorm1d(8) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool1d(kernel_size=3, stride=3) @@ -22,6 +46,9 @@ def __init__(self, num_classes=5, input_channels=1): self.softmax = nn.LogSoftmax(dim=1) def forward(self, x): + # masking = self.conv1.in_channels == 2 + # if we're masking we don't have to do anything, DataLoader takes care of that + out = self.conv1(x) out = self.relu(out) out = self.maxpool(out) @@ -39,7 +66,7 @@ def cnn1d_3(**kwargs): return model -def cnn1d(model_name, num_classes): +def cnn1d(model_name, num_classes, **kwargs): return{ - 'cnn1d_3': cnn1d_3(num_classes=num_classes, input_channels=1), + 'cnn1d_3': cnn1d_3(num_classes=num_classes, masking=kwargs.get("masking", False)), }[model_name] diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_cnn1d.py b/test/test_cnn1d.py new file mode 100644 index 0000000..1803b0f --- /dev/null +++ b/test/test_cnn1d.py @@ -0,0 +1,22 @@ +import unittest +import torch + +from models.cnn1d import get_mask + +class MaskingCNNTest(unittest.TestCase): + def test_mask(self): + # data is assumed to be of size (batch, 1, seq_len) + # in this case we generate seq_len == 5 + input_data = torch.tensor([[[1, 2, 3, 0, 0]], [[3, 0, 1, 0, 0]], [[0, 0, 1, 2, 0]]]) + print(input_data.shape) + self.assertTrue(input_data.shape[0] > 1) + self.assertTrue(input_data.shape[1] == 1) + self.assertTrue(input_data.shape[2] == 5) + + result = get_mask(input_data) + self.assertEquals(result.shape, input_data.shape) + + result = result.squeeze() + self.assertEqual(result[0, :].tolist(), [1, 1, 1, 0, 0]) + self.assertEqual(result[1, :].tolist(), [1, 1, 1, 0, 0]) + self.assertEqual(result[2, :].tolist(), [1, 1, 1, 1, 0]) \ No newline at end of file