From e9b0b670a4b7fdf92c80bd07f2e16656e1ec09a4 Mon Sep 17 00:00:00 2001 From: Anmol Sharma Date: Wed, 9 Sep 2020 20:58:31 -0700 Subject: [PATCH] initial commit --- .gitignore | 8 + environment.yml | 128 ++ mmgan_hgg.sh | 2 + mmgan_lgg.sh | 2 + modules/advanced_gans/datasets.py | 35 + modules/advanced_gans/models.py | 169 +++ modules/advanced_gans/pix2pix.py | 186 +++ modules/configfile.py | 120 ++ modules/create_hdf5_file.py | 155 +++ modules/dataloader.py | 256 ++++ modules/helpers.py | 1262 ++++++++++++++++++ modules/mischelpers.py | 239 ++++ modules/models.py | 144 ++ modules/preprocess.py | 155 +++ modules/pytorch_msssim/__init__.py | 133 ++ prep_BRATS2015/cedar_gen_hdf5.sh | 14 + prep_BRATS2015/create_hdf5_file.py | 170 +++ prep_BRATS2015/modules/__init__.py | 0 prep_BRATS2015/modules/configfile.py | 102 ++ prep_BRATS2015/modules/dataloader.py | 292 ++++ prep_BRATS2015/modules/mischelpers.py | 239 ++++ prep_BRATS2015/prepare_data_for_synthesis.py | 56 + prep_BRATS2015/preprocess.py | 152 +++ prep_ISLES2015/cedar_gen_hdf5.sh | 14 + prep_ISLES2015/create_hdf5_file.py | 154 +++ prep_ISLES2015/modules/__init__.py | 0 prep_ISLES2015/modules/configfile.py | 81 ++ prep_ISLES2015/modules/dataloader.py | 292 ++++ prep_ISLES2015/modules/mischelpers.py | 239 ++++ prep_ISLES2015/prepare_data_for_synthesis.py | 56 + prep_ISLES2015/preprocess.py | 140 ++ readme.MD | 88 ++ train_mmgan_brats2018.py | 583 ++++++++ train_mmgan_brats2018_single.py | 614 +++++++++ 34 files changed, 6280 insertions(+) create mode 100755 .gitignore create mode 100755 environment.yml create mode 100755 mmgan_hgg.sh create mode 100755 mmgan_lgg.sh create mode 100755 modules/advanced_gans/datasets.py create mode 100755 modules/advanced_gans/models.py create mode 100755 modules/advanced_gans/pix2pix.py create mode 100755 modules/configfile.py create mode 100755 modules/create_hdf5_file.py create mode 100755 modules/dataloader.py create mode 100755 modules/helpers.py create mode 100755 modules/mischelpers.py create mode 100755 modules/models.py create mode 100755 modules/preprocess.py create mode 100755 modules/pytorch_msssim/__init__.py create mode 100755 prep_BRATS2015/cedar_gen_hdf5.sh create mode 100755 prep_BRATS2015/create_hdf5_file.py create mode 100755 prep_BRATS2015/modules/__init__.py create mode 100755 prep_BRATS2015/modules/configfile.py create mode 100755 prep_BRATS2015/modules/dataloader.py create mode 100755 prep_BRATS2015/modules/mischelpers.py create mode 100755 prep_BRATS2015/prepare_data_for_synthesis.py create mode 100755 prep_BRATS2015/preprocess.py create mode 100755 prep_ISLES2015/cedar_gen_hdf5.sh create mode 100755 prep_ISLES2015/create_hdf5_file.py create mode 100755 prep_ISLES2015/modules/__init__.py create mode 100755 prep_ISLES2015/modules/configfile.py create mode 100755 prep_ISLES2015/modules/dataloader.py create mode 100755 prep_ISLES2015/modules/mischelpers.py create mode 100755 prep_ISLES2015/prepare_data_for_synthesis.py create mode 100755 prep_ISLES2015/preprocess.py create mode 100755 readme.MD create mode 100755 train_mmgan_brats2018.py create mode 100755 train_mmgan_brats2018_single.py diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..e51eb99 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +notebooks/.ipynb_checkpoints +.idea/ +modules/__pycache__/ +*/*/*.pyc +*/*/.ipynb_checkpoints +*/*/__pycache__/ +notebooks/ +misc_stuff/ diff --git a/environment.yml b/environment.yml new file mode 100755 index 0000000..ae74f53 --- /dev/null +++ b/environment.yml @@ -0,0 +1,128 @@ +name: pytorch-CycleGAN-and-pix2pix +channels: + - anaconda + - conda-forge + - defaults +dependencies: + - h5py=2.8.0=py35h989c5e5_3 + - six=1.11.0=py35_1 + - ca-certificates=2018.11.29=ha4d7672_0 + - certifi=2018.8.24=py35_1001 + - cloudpickle=0.6.1=py_0 + - cycler=0.10.0=py_1 + - dask-core=1.0.0=py_0 + - decorator=4.3.0=py_0 + - matplotlib=2.2.3=py35h8e2386c_0 + - networkx=2.2=py_1 + - openssl=1.0.2p=h470a237_1 + - pyqt=5.6.0=py35h8210e8a_7 + - python-dateutil=2.7.5=py_0 + - pywavelets=1.0.1=py35h7eb728f_0 + - scikit-image=0.14.0=py35hfc679d8_1 + - sip=4.18.1=py35hfc679d8_0 + - toolz=0.9.0=py_1 + - tornado=5.1.1=py35h470a237_0 + - blas=1.0=mkl + - cffi=1.11.5=py35he75722e_1 + - cudatoolkit=9.0=h13b8566_0 + - cudnn=7.1.2=cuda9.0_0 + - dbus=1.13.2=h714fa37_1 + - expat=2.2.6=he6710b0_0 + - fontconfig=2.13.0=h9420a91_0 + - freetype=2.9.1=h8a8886c_1 + - glib=2.56.2=hd408876_0 + - gst-plugins-base=1.14.0=hbbd80ab_1 + - gstreamer=1.14.0=hb453b48_1 + - hdf5=1.10.2=hba1933b_1 + - icu=58.2=h9c2bf20_1 + - imageio=2.4.1=py35_0 + - intel-openmp=2019.1=144 + - jpeg=9b=h024ee3a_2 + - kiwisolver=1.0.1=py35hf484d3e_0 + - libedit=3.1.20170329=h6b74fdf_2 + - libffi=3.2.1=hd88cf55_4 + - libgcc-ng=8.2.0=hdf63c60_1 + - libgfortran-ng=7.3.0=hdf63c60_0 + - libpng=1.6.35=hbc83047_0 + - libstdcxx-ng=8.2.0=hdf63c60_1 + - libtiff=4.0.9=he85c1e1_2 + - libuuid=1.0.3=h1bed415_2 + - libxcb=1.13=h1bed415_1 + - libxml2=2.9.8=h26e45fe_1 + - mkl=2018.0.3=1 + - mkl_fft=1.0.6=py35h7dd41cf_0 + - mkl_random=1.0.1=py35h4414c95_1 + - nccl=1.3.5=cuda9.0_0 + - ncurses=6.1=hf484d3e_0 + - ninja=1.8.2=py35h6bb024c_1 + - numpy=1.15.2=py35h1d66e8a_0 + - numpy-base=1.15.2=py35h81de0dd_0 + - olefile=0.46=py35_0 + - pcre=8.42=h439df22_0 + - pillow=5.2.0=py35heded4f4_0 + - pip=10.0.1=py35_0 + - pycparser=2.19=py35_0 + - pyparsing=2.2.1=py35_0 + - python=3.5.5=hc3d631a_4 + - pytorch=0.4.1=py35ha74772b_0 + - pytz=2018.5=py35_0 + - qt=5.6.3=h39df351_1 + - readline=7.0=h7b6447c_5 + - scipy=1.1.0=py35hfa4b5c9_1 + - setuptools=40.2.0=py35_0 + - sqlite=3.25.3=h7b6447c_0 + - tk=8.6.8=hbc83047_0 + - wheel=0.31.1=py35_0 + - xz=5.2.4=h14c3975_4 + - zlib=1.2.11=ha838bed_2 + - pip: + - backcall==0.1.0 + - bleach==3.0.2 + - chardet==3.0.4 + - dask==1.0.0 + - defusedxml==0.5.0 + - dominate==2.3.1 + - entrypoints==0.2.3 + - idna==2.7 + - ipykernel==5.1.0 + - ipython==7.2.0 + - ipython-genutils==0.2.0 + - ipywidgets==7.4.2 + - jedi==0.13.1 + - jinja2==2.10 + - jsonschema==2.6.0 + - jupyter==1.0.0 + - jupyter-client==5.2.3 + - jupyter-console==6.0.0 + - jupyter-core==4.4.0 + - markupsafe==1.1.0 + - mistune==0.8.4 + - nbconvert==5.4.0 + - nbformat==4.4.0 + - notebook==5.7.2 + - pandocfilters==1.4.2 + - parso==0.3.1 + - pexpect==4.6.0 + - pickleshare==0.7.5 + - prometheus-client==0.4.2 + - prompt-toolkit==2.0.7 + - ptyprocess==0.6.0 + - pygments==2.3.0 + - pyzmq==17.1.2 + - qtconsole==4.4.3 + - requests==2.20.1 + - send2trash==1.5.0 + - terminado==0.8.1 + - testpath==0.4.2 + - torch==0.4.1 + - torchfile==0.1.0 + - torchvision==0.2.1 + - tqdm==4.28.1 + - traitlets==4.3.2 + - urllib3==1.24.1 + - visdom==0.1.7 + - wcwidth==0.1.7 + - webencodings==0.5.1 + - widgetsnbextension==3.4.2 +prefix: /home/asa224/anaconda3/envs/pytorch-CycleGAN-and-pix2pix + diff --git a/mmgan_hgg.sh b/mmgan_hgg.sh new file mode 100755 index 0000000..b83c2ac --- /dev/null +++ b/mmgan_hgg.sh @@ -0,0 +1,2 @@ +#!/usr/bin/env bash +python train_mmgan_brats2018.py --grade=HGG --train_patient_idx=200 --test_pats=10 --batch_size=4 --dataset=BRATS2018 --n_epochs=60 --model_name=mmgan_hgg_zeros_cl --log_level=info --n_cpu=4 --c_learning=1 --z_type=zeros diff --git a/mmgan_lgg.sh b/mmgan_lgg.sh new file mode 100755 index 0000000..b2dc886 --- /dev/null +++ b/mmgan_lgg.sh @@ -0,0 +1,2 @@ +#!/usr/bin/env bash +python train_mmgan_brats2018.py --grade=LGG --train_patient_idx=70 --test_pats=5 --batch_size=4 --dataset=BRATS2018 --n_epochs=60 --model_name=mmgan_lgg_zeros_cl --log_level=info --n_cpu=8 --c_learning=1 --z_type=zeros diff --git a/modules/advanced_gans/datasets.py b/modules/advanced_gans/datasets.py new file mode 100755 index 0000000..513188d --- /dev/null +++ b/modules/advanced_gans/datasets.py @@ -0,0 +1,35 @@ +import glob +import random +import os +import numpy as np + +from torch.utils.data import Dataset +from PIL import Image +import torchvision.transforms as transforms + +class ImageDataset(Dataset): + def __init__(self, root, transforms_=None, mode='train'): + self.transform = transforms.Compose(transforms_) + + self.files = sorted(glob.glob(os.path.join(root, mode) + '/*.*')) + if mode == 'train': + self.files.extend(sorted(glob.glob(os.path.join(root, 'test') + '/*.*'))) + + def __getitem__(self, index): + + img = Image.open(self.files[index % len(self.files)]) + w, h = img.size + img_A = img.crop((0, 0, w/2, h)) + img_B = img.crop((w/2, 0, w, h)) + + if np.random.random() < 0.5: + img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], 'RGB') + img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], 'RGB') + + img_A = self.transform(img_A) + img_B = self.transform(img_B) + + return {'A': img_A, 'B': img_B} + + def __len__(self): + return len(self.files) diff --git a/modules/advanced_gans/models.py b/modules/advanced_gans/models.py new file mode 100755 index 0000000..974ea6e --- /dev/null +++ b/modules/advanced_gans/models.py @@ -0,0 +1,169 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch +import numpy as np +import random +np.random.seed(1337) +torch.manual_seed(1337) +random.seed(1337) +torch.backends.cudnn.deterministic = True +def weights_init_normal(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + torch.nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) + torch.nn.init.constant_(m.bias.data, 0.0) + +############################## +# U-NET +############################## + +class UNetDown(nn.Module): + def __init__(self, in_size, out_size, normalize=True, dropout=0.0): + super(UNetDown, self).__init__() + layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)] + if normalize: + layers.append(nn.InstanceNorm2d(out_size)) + layers.append(nn.LeakyReLU(0.2)) + if dropout: + layers.append(nn.Dropout(dropout)) + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + +class UNetUp(nn.Module): + def __init__(self, in_size, out_size, dropout=0.0): + super(UNetUp, self).__init__() + layers = [ nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False), + nn.InstanceNorm2d(out_size), + nn.ReLU(inplace=True)] + if dropout: + layers.append(nn.Dropout(dropout)) + + self.model = nn.Sequential(*layers) + + def forward(self, x, skip_input): + x = self.model(x) + x = torch.cat((x, skip_input), 1) + + return x + +class GeneratorUNet(nn.Module): + def __init__(self, in_channels=3, out_channels=3, with_tanh=False, with_relu=False): + super(GeneratorUNet, self).__init__() + + # original dropout was 0.5 + self.down1 = UNetDown(in_channels, 64, normalize=False) + self.down2 = UNetDown(64, 128) + self.down3 = UNetDown(128, 256) + self.down4 = UNetDown(256, 512, dropout=0.2) + self.down5 = UNetDown(512, 512, dropout=0.2) + self.down6 = UNetDown(512, 512, dropout=0.2) + self.down7 = UNetDown(512, 512, dropout=0.2) + self.down8 = UNetDown(512, 512, normalize=False, dropout=0.2) + + self.up1 = UNetUp(512, 512, dropout=0.2) + self.up2 = UNetUp(1024, 512, dropout=0.2) + self.up3 = UNetUp(1024, 512, dropout=0.2) + self.up4 = UNetUp(1024, 512, dropout=0.2) + self.up5 = UNetUp(1024, 256) + self.up6 = UNetUp(512, 128) + self.up7 = UNetUp(256, 64) + + if with_tanh: + + self.final = nn.Sequential( + nn.Upsample(scale_factor=2), + nn.ZeroPad2d((1, 0, 1, 0)), + nn.Conv2d(128, out_channels, 4, padding=1), + nn.Tanh() + ) + elif with_relu: + # this is for ISLES2015 + self.final = nn.Sequential( + nn.Upsample(scale_factor=2), + nn.ZeroPad2d((1, 0, 1, 0)), + nn.Conv2d(128, out_channels, 4, padding=1), + nn.ReLU() + ) + + else: + self.final = nn.Sequential( + nn.Upsample(scale_factor=2), + nn.ZeroPad2d((1, 0, 1, 0)), + nn.Conv2d(128, out_channels, 4, padding=1) + ) + + + def forward(self, x): + # U-Net generator with skip connections from encoder to decoder + d1 = self.down1(x) + d2 = self.down2(d1) + d3 = self.down3(d2) + d4 = self.down4(d3) + d5 = self.down5(d4) + d6 = self.down6(d5) + d7 = self.down7(d6) + d8 = self.down8(d7) + u1 = self.up1(d8, d7) + u2 = self.up2(u1, d6) + u3 = self.up3(u2, d5) + u4 = self.up4(u3, d4) + u5 = self.up5(u4, d3) + u6 = self.up6(u5, d2) + u7 = self.up7(u6, d1) + + return self.final(u7) + + +############################## +# Discriminator +############################## + +class Discriminator(nn.Module): + def __init__(self, in_channels=3, out_channels=4, dataset='BRATS2018'): + super(Discriminator, self).__init__() + + # inp, stride, pad, dil, kernel = (256, 2, 1, 1, 8) + # np.floor(((inp + 2*pad - dil*(kernel - 1) - 1)/stride) + 1) + + if 'BRATS' in dataset: + def discriminator_block(in_filters, out_filters, normalization=True): + """Returns downsampling layers of each discriminator block""" + layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] + if normalization: + layers.append(nn.InstanceNorm2d(out_filters)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + self.model = nn.Sequential( + *discriminator_block(in_channels*2, 64, normalization=False), + *discriminator_block(64, 128), + *discriminator_block(128, 256), + *discriminator_block(256, 512), + nn.ZeroPad2d((1, 0, 1, 0)), + nn.Conv2d(512, out_channels, 4, padding=1, bias=False) + ) + else: + # FOR ISLES2015 + def discriminator_block(in_filters, out_filters, normalization=True): + """Returns downsampling layers of each discriminator block""" + layers = [nn.Conv2d(in_filters, out_filters, 4, stride=4, padding=1)] + if normalization: + layers.append(nn.InstanceNorm2d(out_filters)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + self.model = nn.Sequential( + *discriminator_block(in_channels * 2, 8, normalization=False), + *discriminator_block(8, 8), + nn.ZeroPad2d((1, 0, 1, 0)), + nn.Conv2d(8, out_channels, 4, padding=1, bias=False) + ) + + def forward(self, img_A, img_B): + # Concatenate image and condition image by channels to produce input + img_input = torch.cat((img_A, img_B), 1) + return self.model(img_input) diff --git a/modules/advanced_gans/pix2pix.py b/modules/advanced_gans/pix2pix.py new file mode 100755 index 0000000..fe5f79a --- /dev/null +++ b/modules/advanced_gans/pix2pix.py @@ -0,0 +1,186 @@ +import argparse +import os +import numpy as np +import math +import itertools +import time +import datetime +import sys + +import torchvision.transforms as transforms +from torchvision.utils import save_image + +from torch.utils.data import DataLoader +from torchvision import datasets +from torch.autograd import Variable + +from models import * +from datasets import * + +import torch.nn as nn +import torch.nn.functional as F +import torch + +parser = argparse.ArgumentParser() +parser.add_argument('--epoch', type=int, default=0, help='epoch to start training from') +parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training') +parser.add_argument('--dataset_name', type=str, default="facades", help='name of the dataset') +parser.add_argument('--batch_size', type=int, default=1, help='size of the batches') +parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate') +parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient') +parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') +parser.add_argument('--decay_epoch', type=int, default=100, help='epoch from which to start lr decay') +parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') +parser.add_argument('--img_height', type=int, default=256, help='size of image height') +parser.add_argument('--img_width', type=int, default=256, help='size of image width') +parser.add_argument('--channels', type=int, default=3, help='number of image channels') +parser.add_argument('--sample_interval', type=int, default=500, help='interval between sampling of images from generators') +parser.add_argument('--checkpoint_interval', type=int, default=-1, help='interval between model checkpoints') +opt = parser.parse_args() +print(opt) + +os.makedirs('images/%s' % opt.dataset_name, exist_ok=True) +os.makedirs('saved_models/%s' % opt.dataset_name, exist_ok=True) + +cuda = True if torch.cuda.is_available() else False + +# Loss functions +criterion_GAN = torch.nn.MSELoss() +criterion_pixelwise = torch.nn.L1Loss() + +# Loss weight of L1 pixel-wise loss between translated image and real image +lambda_pixel = 100 + +# Calculate output of image discriminator (PatchGAN) +patch = (1, opt.img_height//2**4, opt.img_width//2**4) + +# Initialize generator and discriminator +generator = GeneratorUNet() +discriminator = Discriminator() + +if cuda: + generator = generator.cuda() + discriminator = discriminator.cuda() + criterion_GAN.cuda() + criterion_pixelwise.cuda() + +if opt.epoch != 0: + # Load pretrained models + generator.load_state_dict(torch.load('saved_models/%s/generator_%d.pth' % (opt.dataset_name, opt.epoch))) + discriminator.load_state_dict(torch.load('saved_models/%s/discriminator_%d.pth' % (opt.dataset_name, opt.epoch))) +else: + # Initialize weights + generator.apply(weights_init_normal) + discriminator.apply(weights_init_normal) + +# Optimizers +optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) +optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) + +# Configure dataloaders +transforms_ = [ transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ] + +dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_), + batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu) + +val_dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, mode='val'), + batch_size=10, shuffle=True, num_workers=1) + +# Tensor type +Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor + +def sample_images(batches_done): + """Saves a generated sample from the validation set""" + imgs = next(iter(val_dataloader)) + real_A = Variable(imgs['B'].type(Tensor)) + real_B = Variable(imgs['A'].type(Tensor)) + fake_B = generator(real_A) + img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2) + save_image(img_sample, 'images/%s/%s.png' % (opt.dataset_name, batches_done), nrow=5, normalize=True) + +# ---------- +# Training +# ---------- + +prev_time = time.time() + +for epoch in range(opt.epoch, opt.n_epochs): + for i, batch in enumerate(dataloader): + + # Model inputs + real_A = Variable(batch['B'].type(Tensor)) + real_B = Variable(batch['A'].type(Tensor)) + + # Adversarial ground truths + valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False) + fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False) + + # ------------------ + # Train Generators + # ------------------ + + optimizer_G.zero_grad() + + # GAN loss + fake_B = generator(real_A) + pred_fake = discriminator(fake_B, real_A) + loss_GAN = criterion_GAN(pred_fake, valid) + # Pixel-wise loss + loss_pixel = criterion_pixelwise(fake_B, real_B) + + # Total loss + loss_G = loss_GAN + lambda_pixel * loss_pixel + + loss_G.backward() + + optimizer_G.step() + + # --------------------- + # Train Discriminator + # --------------------- + + optimizer_D.zero_grad() + + # Real loss + pred_real = discriminator(real_B, real_A) + loss_real = criterion_GAN(pred_real, valid) + + # Fake loss + pred_fake = discriminator(fake_B.detach(), real_A) + loss_fake = criterion_GAN(pred_fake, fake) + + # Total loss + loss_D = 0.5 * (loss_real + loss_fake) + + loss_D.backward() + optimizer_D.step() + + # -------------- + # Log Progress + # -------------- + + # Determine approximate time left + batches_done = epoch * len(dataloader) + i + batches_left = opt.n_epochs * len(dataloader) - batches_done + time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) + prev_time = time.time() + + # Print log + sys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s" % + (epoch, opt.n_epochs, + i, len(dataloader), + loss_D.item(), loss_G.item(), + loss_pixel.item(), loss_GAN.item(), + time_left)) + + # If at sample interval save image + if batches_done % opt.sample_interval == 0: + sample_images(batches_done) + + + if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0: + # Save model checkpoints + torch.save(generator.state_dict(), 'saved_models/%s/generator_%d.pth' % (opt.dataset_name, epoch)) + torch.save(discriminator.state_dict(), 'saved_models/%s/discriminator_%d.pth' % (opt.dataset_name, epoch)) diff --git a/modules/configfile.py b/modules/configfile.py new file mode 100755 index 0000000..a4aad35 --- /dev/null +++ b/modules/configfile.py @@ -0,0 +1,120 @@ +""" +========================================================== + Config File to set Parameters +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: This file is solely created for the purpose of + managing parameters in a global setting. All the + database loading and generation parameters reside + here, and are inherited by create_hdf5_file.py + to generate the HDF5 data store. + + The parameters are also used in the test_database.py + script to test the created database. +LICENCE: Proprietary for now. +""" +import os +import platform + +# WE CAN USE THIS TO CHANGE IMAGE_DATA_FORMAT on the fly +# keras.backend.common._IMAGE_DATA_FORMAT='channels_first' + +# to make the code portable even on cedar,you need to add conditions here +node_name = platform.node() +if node_name == 'XPS15': + # this is my laptop, so the cedar-rm directory is at a different place + mount_path_prefix = '/home/anmol/mounts/cedar-rm/' +elif 'computecanada' in node_name: # we're in compute canada, maybe in an interactive node, or a scheduler node. + mount_path_prefix = '/home/asa224/' # home directory +else: + # this is probably my workstation or TS server + mount_path_prefix = '/local-scratch/anmol/data' + +config = {} +# set the data directory and output hdf5 file path. +# data_dir is the top level path containing both training and validation sets of the brats dataset. +config['data_dir_prefix'] = os.path.join(mount_path_prefix, 'BRATS2018_Full/') # this should be top level path +config['hdf5_filepath_prefix'] = os.path.join(mount_path_prefix, 'BRATS2018/HDF5_Datasets/BRATS2018_Unprocessed.h5') # top level path +config['hdf5_filepath_prefix_2017'] = os.path.join(mount_path_prefix, 'scratch/asa224/Datasets/BRATS2017/HDF5_Datasets/BRATS.h5') # top level path +config['hdf5_combined'] = os.path.join(os.sep.join(config['hdf5_filepath_prefix'].split(os.sep)[0:-1]), 'BRATS_Combined_Unprocessed.h5') +config['hdf5_filepath_cropped'] = os.path.join(mount_path_prefix, 'BRATS2018/HDF5_Datasets/BRATS2018_Cropped_Normalized_Unprocessed.h5') # top level path +config['saveMeanVarFilepathHGG'] = os.path.join(os.sep.join(config['hdf5_filepath_prefix'].split(os.sep)[0:-1]), 'BRATS2018_HDF5_Datasetstraining_data_hgg_mean_var.p') +config['saveMeanVarFilepathLGG'] = os.path.join(os.sep.join(config['hdf5_filepath_prefix'].split(os.sep)[0:-1]), 'BRATS2018_HDF5_Datasetstraining_data_lgg_mean_var.p') +config['saveMeanVarCombinedData'] = os.path.join(os.sep.join(config['hdf5_filepath_prefix'].split(os.sep)[0:-1]), 'combined_data_mean_var.p') + +config['model_snapshot_location'] = os.path.join(mount_path_prefix, 'scratch/asa224/model-snapshots/') +config['model_checkpoint_location'] = os.path.join(mount_path_prefix, 'scratch/asa224/model-checkpoints/') +config['model_prediction_location'] = os.path.join(mount_path_prefix, 'scratch/asa224/model-predictions/') +# # IF YOU PERFORM PREPROCESSING, THESE VARIABLES ARE TO BE CHANGED. DEFAULT VALUES ARE: +# config['spatial_size_for_training'] = (240, 240) # If any preprocessing is done, then this needs to change. This is the shape of data that you want to train with. If you are changing this that means you did some preprocessing. +# config['num_slices'] = 155 # number of slices in input data. THIS SHOULD CHANGE AS WELL WHEN PERFORMING PREPROCESSING + +# IF YOU PERFORM PREPROCESSING, THESE VARIABLES ARE TO BE CHANGED. +config['spatial_size_for_training'] = (240, 240) # If any preprocessing is done, then this needs to change. This is the shape of data that you want to train with. If you are changing this that means you did some preprocessing. +config['num_slices'] = 155 # number of slices in input data. THIS SHOULD CHANGE AS WELL WHEN PERFORMING PREPROCESSING +config['volume_size'] = list(config['spatial_size_for_training']) + [config['num_slices']] +config['seed'] = 1338 +config['data_order'] = 'th' # what order should the indices be to store in hdf5 file +config['train_hgg_patients'] = 210 # number of HGG patients in training +config['train_lgg_patients'] = 75 # number of LGG patients in training +config['validation_patients'] = 66 # number of patients in validation + +config['batch_size'] = 1 # how many images to load at once in the generator + +config['cropping_coords'] = [29, 223, 41, 196, 0, 148] # coordinates used to crop the volumes, this is generated using the notebook checkLargestCropSize.ipynb +config['size_after_cropping'] = [194, 155, 148] # set this if you set the above variable. Calculate this using the notebook again. + +config['data_split'] = {'train': 98, + 'test': 2} + +config['std_scale_range'] = [6] # [4,6,8,10] scale the standard deviation for path generation process to allow patches from far off regions +config['num_patches_per_patient'] = 50 # number of patches to generate for a single patient +config['patch_size'] = [64, 64, 64] # size of patch to extract +config['patch_input_shape'] = [4] + config['patch_size'] +config['gen_patches_from'] = 'original' # generate patches from the cropped version of the database or original. +config['validate_on'] = 'original' # Perform validation on original images or cropped images +config['num_labels'] = 3 # number of labels in the segmentation mask, except background +config['max_label_val'] = 4 + +config['val_shape_after_prediction'] = [] + +# check the order of data and chose proper data shape to save images +if config['data_order'] == 'th': + config['train_shape_hgg'] = (config['train_hgg_patients'], 4, config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], config['num_slices']) + config['train_shape_lgg'] = (config['train_lgg_patients'], 4, config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], config['num_slices']) + config['train_segmasks_shape_hgg'] = (config['train_hgg_patients'], config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], config['num_slices']) + config['train_segmasks_shape_lgg'] = (config['train_lgg_patients'], config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], config['num_slices']) + config['val_shape'] = (config['validation_patients'], 4, config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], config['num_slices']) + + config['train_shape_hgg_crop'] = (config['train_hgg_patients'], 4, config['size_after_cropping'][0], config['size_after_cropping'][1], config['size_after_cropping'][2]) + config['train_shape_lgg_crop'] = (config['train_lgg_patients'], 4, config['size_after_cropping'][0], config['size_after_cropping'][1], config['size_after_cropping'][2]) + config['train_segmasks_shape_hgg_crop'] = (config['train_hgg_patients'], config['size_after_cropping'][0],config['size_after_cropping'][1], config['size_after_cropping'][2]) + config['train_segmasks_shape_lgg_crop'] = (config['train_lgg_patients'], config['size_after_cropping'][0], config['size_after_cropping'][1], config['size_after_cropping'][2]) + config['val_shape_crop'] = (config['validation_patients'], 4, config['size_after_cropping'][0], config['size_after_cropping'][1], config['size_after_cropping'][2]) + config['numpy_patch_size'] = (config['num_patches_per_patient'], 4, config['patch_size'][0], config['patch_size'][1], + config['patch_size'][2]) +elif config['data_order'] == 'tf': + config['train_shape_hgg'] = (config['train_hgg_patients'], config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], config['num_slices'], 4) + config['train_shape_lgg'] = (config['train_lgg_patients'], config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], config['num_slices'], 4) + config['train_segmasks_shape_hgg'] = (config['train_hgg_patients'], config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], config['num_slices']) + config['train_segmasks_shape_lgg'] = (config['train_lgg_patients'], config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], config['num_slices']) + config['val_shape'] = (config['validation_patients'], config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], config['num_slices'], 4) + + config['train_shape_hgg_crop'] = (config['train_hgg_patients'], config['size_after_cropping'][0], config['size_after_cropping'][1], config['size_after_cropping'][2], 4) + config['train_shape_lgg_crop'] = (config['train_lgg_patients'], config['size_after_cropping'][0], config['size_after_cropping'][1], config['size_after_cropping'][2], 4) + config['train_segmasks_shape_hgg_crop'] = (config['train_hgg_patients'], config['size_after_cropping'][0],config['size_after_cropping'][1], config['size_after_cropping'][2]) + config['train_segmasks_shape_lgg_crop'] = (config['train_lgg_patients'], config['size_after_cropping'][0], config['size_after_cropping'][1], config['size_after_cropping'][2]) + config['val_shape_crop'] = (config['validation_patients'], config['size_after_cropping'][0], config['size_after_cropping'][1], config['size_after_cropping'][2], 4) + config['numpy_patch_size'] = (config['num_patches_per_patient'], config['patch_size'][0], config['patch_size'][1], config['patch_size'][2], 4) + +tmp = list(config['val_shape']) +tmp[1] = config['num_labels'] +config['val_shape_after_prediction'] = tuple(tmp) \ No newline at end of file diff --git a/modules/create_hdf5_file.py b/modules/create_hdf5_file.py new file mode 100755 index 0000000..ae5fa5d --- /dev/null +++ b/modules/create_hdf5_file.py @@ -0,0 +1,155 @@ +""" +========================================================== + Prepare BRATS 2017 Data +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: This file is used to generate an HDF5 dataset, + which is easy to load and manipulate compared + to working directly with raw data all the time. + Loading and working with HDF5 files is much + faster and efficient due to its asynchronous loading + system. + + The HDF5 file generated can be hosted on a remote server + (like CEDAR) and then accessed over SSHFS. Practically, + this is very effective and does not hinder the performance + by a large margin. + + This script generates a simple HDF5 data store, + which contains the original numpy arrays of the + data store. To perform any preprocessing, implement + the preprocessData()function in dataloader.py to + work directly on nibabel objects, instead of + numpy objects. +LICENCE: Proprietary for now. +""" + +import os +import glob +from modules import dataloader +import logging +import numpy as np +import h5py +import sys +sys.path.append('../') +from modules.configfile import config + +logging.basicConfig(level=logging.INFO) +try: + logger = logging.getLogger(__file__.split('/')[-1]) +except: + logger = logging.getLogger(__name__) + +# whether or not to preprocess the data before creating the HDF5 file? Check the preprocess function in dataloader to +# know exactly what preprocessing is being performed. +PREPROCESS_DATA = False + +logger.info('[IMPORTANT] This will create a new HDF5 file in SAFE MODE. It will NOT OVERWRITE A PREVIOUS HDF5 FILE ' + 'IF ITS PRESENT') +def createHDF5File(config): + """ + Function to create a new HDF5 File to hold the BRATS 2017 data. The function will fail if there's already a file + present with the same name (SAFE OPERATION) + + :param config: The config variable defined in configfile.py + :return: hdf5_file object + """ + + # w- mode fails when there is a file already. + hdf5_file = h5py.File(config['hdf5_filepath_prefix'], mode='w') + + # create a new parent directory to hold the data inside it + grp = hdf5_file.create_group("original_data") + + # the dataset is int16 originally, checked using nibabel, however we create float32 containers to make the dataset + # compatible with further preprocessing. + # HGG Data + grp.create_dataset("training_data_hgg", config['train_shape_hgg'], np.float32) + grp.create_dataset("training_data_hgg_pat_name", (config['train_shape_hgg'][0],), dtype='S100') + grp.create_dataset("training_data_segmasks_hgg", config['train_segmasks_shape_hgg'], np.int16) + + # LGG Data + grp.create_dataset("training_data_lgg", config['train_shape_lgg'], np.float32) + grp.create_dataset("training_data_lgg_pat_name", (config['train_shape_lgg'][0],), dtype='S100') + grp.create_dataset("training_data_segmasks_lgg", config['train_segmasks_shape_lgg'], np.int16) + + # Validation Data, with no segmentation masks + grp.create_dataset("validation_data", config['val_shape'], np.float32) + grp.create_dataset("validation_data_pat_name", (config['val_shape'][0],), dtype='S100') + return hdf5_file + +def main(): + hdf5_file_main = createHDF5File(config) + # hdf5_file_main = h5py.File(config['hdf5_filepath_prefix'], mode='w') + # Go inside the "original_data" parent directory. + # we need to create the validation data dataset again since the shape has changed. + hdf5_file = hdf5_file_main['original_data'] + del hdf5_file['validation_data'] + del hdf5_file['validation_data_pat_name'] + # Validation Data, with no segmentation masks + hdf5_file.create_dataset("validation_data", config['val_shape'], np.float32) + hdf5_file.create_dataset("validation_data_pat_name", (config['val_shape'][0],), dtype="S100") + + for dataset_splits in glob.glob(os.path.join(config['data_dir_prefix'], '*')): # Training/Validation data? + if os.path.isdir(dataset_splits) and 'Validation' in dataset_splits: # make sure its a directory + # VALIDATION data handler + logger.info('currently loading Validation data.') + count = 0 + # validation data does not have HGG and LGG distinctions + for images, pats in dataloader.loadDataGenerator(dataset_splits, + batch_size=config['batch_size'], loadSurvival=False, csvFilePath=None, + loadSeg=False, preprocess=PREPROCESS_DATA): + hdf5_file['validation_data'][count:count+config['batch_size'],...] = images + t = 0 + + for i in range(count, count + config['batch_size']): + hdf5_file['validation_data_pat_name'][i] = pats[t].split('/')[-1].encode('utf-8') + t += 1 + + # logger.debug('array equal?: {}'.format(np.array_equal(hdf5_file['validation_data'][count:count+config['batch_size'],...], images))) + logger.info('loaded {} patient(s) from {}'.format(count + config['batch_size'], dataset_splits)) + count += config['batch_size'] + + else: + # TRAINING data handler + if os.path.isdir(dataset_splits) and 'Training' in dataset_splits: + for grade_type in glob.glob(os.path.join(dataset_splits, '*')): + # there may be other files in there (like the survival data), ignore them. + if os.path.isdir(grade_type): + count = 0 + logger.info('currently loading Training data.') + for images, segmasks, pats in dataloader.loadDataGenerator(grade_type, + batch_size=config['batch_size'], loadSurvival=False, + csvFilePath=None, loadSeg=True, + preprocess=PREPROCESS_DATA): + logger.info('loading patient {} from {}'.format(count, grade_type)) + if 'HGG' in grade_type: + hdf5_file['training_data_hgg'][count:count+config['batch_size'],...] = images + hdf5_file['training_data_segmasks_hgg'][count:count+config['batch_size'], ...] = segmasks + t = 0 + for i in range(count, count + config['batch_size']): + hdf5_file['training_data_hgg_pat_name'][i] = pats[t].split('/')[-1].encode('utf-8') + t += 1 + elif 'LGG' in grade_type: + hdf5_file['training_data_lgg'][count:count+config['batch_size'], ...] = images + hdf5_file['training_data_segmasks_lgg'][count:count+config['batch_size'], ...] = segmasks + t = 0 + for i in range(count, count + config['batch_size']): + hdf5_file['training_data_lgg_pat_name'][i] = pats[t].split('/')[-1].encode('utf-8') + t += 1 + + logger.info('loaded {} patient(s) from {}'.format(count + config['batch_size'], grade_type)) + count += config['batch_size'] + # close the HDF5 file + hdf5_file_main.close() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/modules/dataloader.py b/modules/dataloader.py new file mode 100755 index 0000000..a6b9441 --- /dev/null +++ b/modules/dataloader.py @@ -0,0 +1,256 @@ +""" +========================================================== + Load BRATS 2017 Data +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: The script has multiple functions to load, + preprocess, and standardize the BRATS + 2017 dataset, along with its survival annotations. + Main function is the loadDataGenerator which loads + the data using a generator, and doesn't hog memory. + + The loadDataGenerator is capable of applying + arbitrary preprocessing steps to the data. This can be + achieved by implementing the function preprocessData. +LICENCE: Proprietary for now. +""" + +from __future__ import print_function +import glob as glob +import numpy as np +import pickle +import sys as sys +from pandas import read_csv +import os +import logging +from configfile import config +import SimpleITK as sitk + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +def preprocessData(img_obj, process=False): + """ + Perform preprocessing on the original nibabel object. + Use this function to: + 1) Resize/Resample the 3D Volume + 2) Crop the brain region + 3) Do (2) then (1). + + When you do preprocessing, especially something that + changes the spatial size of the volume, make sure you + update config['spatial_size_for_training'] = (240, 240) + value in the config file. + + :param img_obj: + :param process: + :return: + """ + if process == False: + return img_obj + else: + maskImage = sitk.OtsuThreshold(img_obj, 0, 1, 200) + image = sitk.Cast(img_obj, sitk.sitkFloat32) + corrector = sitk.N4BiasFieldCorrectionImageFilter() + numberFilltingLevels = 4 + corrector.SetMaximumNumberOfIterations([4] * numberFilltingLevels) + output = corrector.Execute(image, maskImage) + return output + + +def loadDataGenerator(data_dir, batch_size=1, preprocess=False, loadSurvival=False, + csvFilePath=None, loadSeg=True): + """ + Main function to load BRATS 2017 dataset. + + :param data_dir: path to the folder where patient data resides, needs individual paths for HGG and LGG + :param batch_size: size of batch to load (default=1) + :param loadSurvival: load survival data (True/False) (default=False) + :param csvFilePath: If loadSurvival is True, provide path to survival data (default=False) + :param loadSeg: load segmentations (True/False) (default=True) + :return: + """ + + patID = 0 # used to keep count of how many patients loaded already. + num_sequences = 4 # number of sequences in the data. BRATS has 4. + num_slices = config['num_slices'] + running_pats = [] + out_shape = config['spatial_size_for_training'] # shape of the training data + + # create placeholders, currently only supports theano type convention (num_eg, channels, x, y, z) + images = np.empty((batch_size, num_sequences, out_shape[0], out_shape[1], num_slices)).astype(np.int16) + labels = np.empty((batch_size, 1)).astype(np.int16) + slices = None # not used anymore + + if loadSeg == True: + # create placeholder for the segmentation mask + seg_masks = np.empty((batch_size, out_shape[0], out_shape[1], num_slices)).astype(np.int16) + + csv_flag = 0 + + if loadSurvival == True: + logger.info('Loading annotations as well..') + if csvFilePath != None: + logger.info('trying to load CSV file...') + csv_file = read_csv(csvFilePath) + csv_flag = 1 + logger.info('opened CSV file successfully') + else: + logger.debug('loadSurvival is True but no csvFilePath provided!') + raise Exception + + batch_id = 1 # counter for batches loaded + logger.info('starting to load images..') + for patients in glob.glob(data_dir + '/*'): + if os.path.isdir(patients): + logger.debug('{} is a directory.'.format(patients)) + + # save the name of the patient + running_pats.append(patients) + if csv_flag == 1: + patName = patients.split('/')[-1] + try: + labels[patID] = csv_file[csv_file['Brats17ID'] == patName]['Survival'].tolist()[0] + logger.debug('Added survival data..') + except IndexError: + labels[patID] = None + + # this hacky piece of code is to reorder the filenames, so that segmentation file is always at the end. + # get all the filepaths + files = glob.glob(patients + '/*') + + # create list without the "seg" filepath inside it + files_new = [x for x in files if 'seg' not in x] + + # create another list with only seg folder inside it + seg_filename = [x for x in files if 'seg' in x] + if seg_filename != []: + # concatenate the list, now the seg filepath is at the end + files_new = files_new + seg_filename + + for imagefile in files_new: # get the filepath of the image (nii.gz) + if 'seg' in imagefile: + if loadSeg == True: + logger.debug('loading segmentation for this patient..') + + # open using SimpleITK + # SimpleITK would allow me to add number of preprocessing steps that are well defined and + # implemented in SITK for their own object type. We can leverage those functions if we preserve + # the image object. + + img_obj = sitk.ReadImage(imagefile) + pix_data = sitk.GetArrayViewFromImage(img_obj) + + # check Practice - SimpleiTK.ipynb notebook for more info on why this swapaxes operation is req + pix_data_swapped = np.swapaxes(pix_data, 0, 1) + pix_data_swapped = np.swapaxes(pix_data_swapped, 1, 2) + + seg_masks[patID, :, :, :] = pix_data_swapped + else: + continue + else: + # this is to ensure that each channel stays at the same place + if 't1.' in imagefile: + i = 0 + seq_name = 't1' + elif 't2.' in imagefile: + i = 1 + seq_name = 't2' + elif 't1ce.' in imagefile: + i = 2 + seq_name = 't1ce' + elif 'flair.' in imagefile: + i = 3 + seq_name = 'flair' + + img_obj = sitk.ReadImage(imagefile) + if preprocess == True: + logger.debug('performing N4ITK Bias Field Correction on {} modality'.format(seq_name)) + img_obj = preprocessData(img_obj, process=preprocess) + + pix_data = sitk.GetArrayViewFromImage(img_obj) + + pix_data_swapped = np.swapaxes(pix_data, 0, 1) + pix_data_swapped = np.swapaxes(pix_data_swapped, 1, 2) + + images[patID, i, :, :, :] = pix_data_swapped + + patID += 1 + + if batch_id % batch_size == 0: + patID = 0 + if csv_flag == 1 and loadSeg == True: + yield images, labels, seg_masks, running_pats + elif csv_flag == 0 and loadSeg == True: + yield images, seg_masks, running_pats + elif csv_flag == 0 and loadSeg == False: + yield images, running_pats + + running_pats = [] + + batch_id += 1 + + +def standardize(images, findMeanVarOnly=True, saveDump=None, applyToTest=None): + """ + This function standardizes the input data to zero mean and unit variance. It is capable of calculating the + mean and std values from the input data, or can also apply user specified mean/std values to the images. + + :param images: numpy ndarray of shape (num_qg, channels, x, y, z) to apply mean/std normalization to + :param findMeanVarOnly: only find the mean and variance of the input data, do not normalize + :param saveDump: if True, saves the calculated mean/variance values to the disk in pickle form + :param applyToTest: apply user specified mean/var values to given images. checkLargestCropSize.ipynb has more info + :return: standardized images, and vals (if mean/val was calculated by the function + """ + + # takes a dictionary + if applyToTest != None: + logger.info('Applying to test data using provided values') + from training_helpers import apply_mean_std + images = apply_mean_std(images, applyToTest) + return images + + logger.info('Calculating mean value..') + vals = { + 'mn': [], + 'var': [] + } + for i in range(4): + vals['mn'].append(np.mean(images[:, i, :, :, :])) + + logger.info('Calculating variance..') + for i in range(4): + vals['var'].append(np.var(images[:, i, :, :, :])) + + if findMeanVarOnly == False: + logger.info('Starting standardization process..') + + for i in range(4): + images[:, i, :, :, :] = ((images[:, i, :, :, :] - vals['mn'][i]) / float(vals['var'][i])) + + logger.info('Data standardized!') + + if saveDump != None: + logger.info('Dumping mean and var values to disk..') + pickle.dump(vals, open(saveDump, 'wb')) + logger.info('Done!') + + return images, vals + + +if __name__ == "__main__": + """ + Only for testing purpose, DO NOT ATTEMPT TO RUN THIS SCRIPT. ONLY IMPORT AS MODULE + """ + data_dir = '/local-scratch/cedar-rm/scratch/asa224/Datasets/BRATS2017/MICCAI_BraTS17_Data_Training/HGG/' + images, segmasks = loadDataGenerator(data_dir, batch_size=2, loadSurvival=False, + csvFilePath=None, loadSeg=True) \ No newline at end of file diff --git a/modules/helpers.py b/modules/helpers.py new file mode 100755 index 0000000..9f50ce5 --- /dev/null +++ b/modules/helpers.py @@ -0,0 +1,1262 @@ +import h5py as h5py +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms, utils +import platform +from math import log10 +import csv +import shutil +import scipy.misc +if 'cdr' in platform.node() or 'ts' in platform.node(): + import matplotlib.pyplot + matplotlib.pyplot.switch_backend('agg') +from skimage.measure import compare_ssim +from matplotlib import pylab +import matplotlib.pyplot as plt +import os +import numpy.ma as ma +import pickle as pickle +import numpy as np +from skimage.transform import resize +import itertools +import torch.nn as nn +import copy +import logging +import modules.pytorch_msssim as pyt_ssim + +try: + logger = logging.getLogger(__file__.split('/')[-1]) +except: + logger = logging.getLogger(__name__) + + +def show_volume(sample, slice_idx=3, modality=1): + """ + params: + sample: Sample dictionary generated by the DataLoader class + slice_idx: The dimension along which to index the 3D volume for display[1,3] + modality: Index of modality to display [0,4] + """ + data = np.array(sample['image']).transpose(1, 2, 3, 0) + pat_name = sample['name'] + seg = np.array(sample['seg']).transpose(1, 2, 0) + total_slices = data.shape[slice_idx] + offset = total_slices // 10 + fig, ax = plt.subplots(nrows=2, ncols=5, squeeze=False, figsize=(20, 10)) + ax = [i for ls in ax for i in ls] + c = 0 + for a in ax: + if c < total_slices: + if slice_idx == 1: + image_mask = ma.masked_where(seg > 0, seg) + a.imshow(data[modality, c, :, :], cmap='gray') + a.imshow(image_mask[c, :, :], cmap='rainbow', alpha=0.4) + elif slice_idx == 2: + image_mask = ma.masked_where(seg > 0, seg) + a.imshow(data[modality, :, c, :], cmap='gray') + a.imshow(image_mask[:, c, :], cmap='rainbow', alpha=0.4) + elif slice_idx == 3: + image_mask = ma.masked_where(seg > 0, seg) + a.imshow(data[modality, :, :, c], cmap='gray') + a.imshow(image_mask[:, :, c], cmap='rainbow', alpha=0.4) + a.axis('off') + c += offset + else: + break + plt.tight_layout() + plt.suptitle(pat_name) + plt.show() + + +class BrainMRIData(Dataset): + def __init__(self, h5path, mean_var_path, parent_name, dataset_name, + load_pat_names, load_seg, + transform=None, apply_normalization=False, which_normalization=None, + dataset='BRATS2018', load_indices=None, train_range=None): + """ + params: + h5path: Path to the HDF5 file from which to import data + mean_var_path: Path to the mean/variance file for the HDf5 + file + parent_name: Parent name of the group in HDF5 file + dataset_name: Name of dataset in HDF5 file to import data + from. + load_pat_names: Boolean. Whether or not to load patient + names + load_seg: Whether or not to load segmentations + transform: Whether to apply a transformation on data + """ + self.valid_parent_names = ['original_data', + 'combined', + 'preprocessed'] + + self.valid_dataset_names = ['training_data_hgg', + 'training_data_hgg_pat_name', + 'training_data_lgg', + 'training_data_lgg_pat_name', + 'training_data_segmasks_hgg', + 'training_data_segmasks_lgg', + 'validation_data', + 'validation_data_pat_name', + # for combined h5py file + 'training_data', + 'training_data_pat_name', + 'training_data_segmasks'] + + if h5path is None: + raise ValueError("Please specify the path for the HDF5 file") + + if mean_var_path is None: + raise ValueError("Please specify the path for the mean/var file") + + if parent_name is None or \ + parent_name not in self.valid_parent_names: + raise ValueError("Invalid parent group name, please check") + + if dataset_name is None or \ + dataset_name not in self.valid_dataset_names: + raise ValueError("Invalid dataset name, please check") + + self.h5path = h5path + self.h5file = h5py.File(h5path, 'r') + # file was pickled in Python 2, so to open in Python 3 + # we need to set encoding + self.mean_var_file = pickle.load(open(mean_var_path, 'rb'), + encoding='latin1') + self.parent_name = parent_name + + # the actual dataset that we load + self.xdataset_name = dataset_name + # corresponding segmentation mask dataset + self.ydataset_name = dataset_name + '_segmasks' + # corresponding name dataset + self.name_dataset_name = dataset_name + "_pat_name" + self.apply_normalization = apply_normalization + + self.dataset = dataset + self.load_indices = load_indices + + try: + logger.info("Loading datasets to memory to make the training faster") + self.xdataset = self.h5file[self.parent_name][self.xdataset_name] + self.x_max = [] + if train_range is not None: + # only calculate max value from the training_range, not from testing patients + # Problem is, three patients have an error intensity value of 32767.0 as max, which messes up + # normalization. Hence find those patients, and remove them from intensity max calculation + # curr_max = {0: 0, 1: 0, 2: 0, 3: 0} + # for i in train_range: + # for m in range(0, 4): + # if (not self.xdataset[i, m].max() > 30000.0) and (self.xdataset[i, m].max() >= curr_max[m]): + # curr_max[m] = self.xdataset[i, m].max() + + # typically BRATS uses 12 bits (all dicoms do), but there are like + self.x_max = {0: 4096.0, 1: 4096.0, 2: 4096.0, 3: 4096.0} + + if 'validation_data' not in self.xdataset_name and load_seg: + self.ydataset = self.h5file[self.parent_name][self.ydataset_name] + self.name_dataset = self.h5file[self.parent_name][self.name_dataset_name] + except KeyError as e: + text = 'Attempted keys: {}, {}, {}'.format(self.xdataset_name, + self.ydataset_name, + self.name_dataset) + raise KeyError(text) + + self.load_pat_names = load_pat_names + self.load_seg = load_seg + self.transform = transform + + if self.apply_normalization: + if self.dataset == 'BRATS2018' or self.dataset == 'BRATS2015': + if which_normalization is None: + # self.normalization_function = self.apply_mean_std + self.normalization_function = self.apply_normalization_for_ISLES + elif which_normalization == 'tanh': + print('Using TANH normalization') + self.normalization_function = self.apply_tanh_normalization + # self.normalization_function = self.apply_mean_std + elif self.dataset == 'ISLES2015': + self.normalization_function = self.apply_normalization_for_ISLES + + def __len__(self): + try: + return len(self.h5file[self.parent_name][self.xdataset_name]) + except ValueError: + return 0 + + def getitem_via_index(self, idx): + sample = {} + + if self.apply_normalization: + # if its ISLES2015 normalization, it will ignore the mean_var_file. Otherwise it uses it. + sample['image'] = self.normalization_function(self.xdataset[idx], self.mean_var_file) + else: + sample['image'] = self.xdataset[idx] + + if self.load_seg: + sample['seg'] = self.ydataset[idx] + + if self.load_pat_names: + sample['name'] = self.name_dataset[idx] + + if self.transform: + sample = self.transform(sample) + + return sample + + def __getitem__(self, idx): + sample = {} + + if self.load_indices is not None: + # convert the idx into the relative index in load_indices train/test index generated for cross validation + idx = self.load_indices[idx] + + logger.debug("value of idx in dataloader's __getitem__: {}".format(idx)) + if self.apply_normalization: + # if its ISLES2015 normalization, it will ignore the mean_var_file. Otherwise it uses it. + sample['image'] = self.normalization_function(self.xdataset[idx], self.mean_var_file) + else: + sample['image'] = self.xdataset[idx] + + if self.load_seg: + sample['seg'] = self.ydataset[idx] + + if self.load_pat_names: + sample['name'] = self.name_dataset[idx] + + if self.transform: + sample = self.transform(sample) + + return sample + + def apply_normalization_for_ISLES(self, im, mean_var_file): + """ + In order to not break compatibility with anothe rnormaization function for BRATS, we have an unused argument here. + :param im: + :param mean_var_file: + :return: + """ + # remove all negative values + im[im < 0] = 0.0 + for m in range(0, 4): + if len(im.shape) > 4: + for k in range(0, im.shape[0]): + im[k, m, ...] = (im[k, m, ...] / np.mean(im[k, m, ...])) + else: + # remove clipping + # if im[m, ...].min() != 0.0: + # im[m, ...] = np.clip(im[m, ...], a_min=0.0, a_max=np.max(im[m, ...])) + + im[m, ...] = (im[m, ...] / np.mean(im[m, ...])) + + return im + + def apply_tanh_normalization(self, im, max_val=None): + im = self.apply_mean_std(im, self.mean_var_file) + # make sure there are no negative entries + # im[im < 0] = 0.0 + + for m in range(0, 4): + if len(im.shape) > 4: + for k in range(0, im.shape[0]): + den = self.x_max[m] - im[k, m, ...].min() # max of this modality across the dataset + im[k, m, ...] = 2 * ((im[k, m, ...] - im[k, m, ...].min()) / den) - 1 + else: + den = im[m, ...].max() - im[m, ...].min() # max of this modality across the dataset + im[m, ...] = 2 * ((im[m, ...] - im[m, ...].min()) / den) - 1 + + final = np.clip(im, -1, 1) + return final + + def apply_tanh_normalization_old(self, im, dummy): + # make sure there are no negative entries + im[im < 0] = 0.0 + + for m in range(0, 4): + if len(im.shape) > 4: + for k in range(0, im.shape[0]): + den = im[k, m, ...].max() - im[k, m, ...].min() + if den == 0.0: + im[k, m, ...] = -1.0 + else: + im[k, m, ...] = 2 * ((im[k, m, ...] - im[k, m, ...].min()) / den) - 1 + else: + den = im[m, ...].min() - im[m, ...].min() + + if den == 0.0: + im[m, ...] = -1.0 + else: + im[m, ...] = 2 * ((im[m, ...] - im[m, ...].min()) / den) - 1 + return im + + def apply_mean_std(self, im, mean_var): + """ + Supercedes the standardize function. Takes the mean/var file generated during preprocessed data generation and + applies the normalization step to the patch. + :param im: patch of size (num_egs, channels, x, y, z) or (channels, x, y, z) + :param mean_var: dictionary containing mean/var value calculated in preprocess.py + :return: normalized patch + """ + + # expects a dictionary of means and VARIANCES, NOT STD + for m in range(0, 4): + if len(im.shape) > 4: + im[:, m, ...] = (im[:, m, ...] - mean_var['mn'][m]) / np.sqrt(mean_var['var'][m]) + else: + im[m, ...] = (im[m, ...] - mean_var['mn'][m]) / np.sqrt(mean_var['var'][m]) + + return im + + def close_connection(self): + self.h5file.close() + + def reconnect(self): + try: + self.h5file = h5py.File(self.h5path, 'r') + return True + except: + raise + +def revert_mean_std(im, mean_var): + """ + Inverse function of apply_mean_std to get back original intensities + :param im: + :param mean_var: + :return: + """ + for m in range(0, 4): + if len(im.shape) > 4: + im[:, m, ...] = (im[:, m, ...] * np.sqrt(mean_var['var'][m]) + mean_var['mn'][m]) + else: + im[m, ...] = (im[m, ...] * np.sqrt(mean_var['var'][m]) + mean_var['mn'][m]) + + return im + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + image = sample['image'] + seg = sample.get('seg') + name = sample.get('name') + final = {} + + # swap color axis because + # numpy image: H x W x Z x C + # torch image: C X H X W x Z + final['image'] = image.transpose((3, 0, 1, 2)) + final['image'] = torch.from_numpy(final['image']).type(torch.FloatTensor) + + if seg is not None: + final['seg'] = seg.transpose((2, 0, 1)) + final['seg'] = torch.from_numpy(final['seg']) + + if name is not None: + final['name'] = name + + return final + + +class Resize(object): + """Convert ndarrays in sample to Tensors.""" + + def __init__(self, size): + self.size = size + + def __call__(self, sample): + image = sample['image'] + seg = sample.get('seg') + name = sample.get('name') + final = {} + sh = image.shape + size = self.size + + dummy_im = np.empty((sh[0], size[0], size[1], sh[-1])) + + for curr_slice in range(0, image.shape[-1]): + for curr_seq in range(0, 4): + dummy_im[curr_seq, :, :, curr_slice] = resize(image[curr_seq, :, :, curr_slice], output_shape=size, + preserve_range=True) + + sample['image'] = dummy_im + + return sample + + +def my_collate(batch): + image_arrays = (curr_batch['image'] for curr_batch in batch) + seg_arrays = (curr_batch['seg'] for curr_batch in batch if 'seg' in curr_batch) + name_arrays = (curr_batch['name'] for curr_batch in batch if 'name' in curr_batch) + + final_batch = {} + collated_image_array = torch.cat(tuple(image_arrays), dim=0) + final_batch['image'] = collated_image_array + + if 'seg' in batch[0].keys(): + collated_seg_array = torch.cat(tuple(seg_arrays), dim=0) + final_batch['seg'] = collated_seg_array + + if 'name' in batch[0].keys(): + collated_name_array = np.vstack(name_arrays) + final_batch['name'] = collated_name_array + + return final_batch + + +def generate_training_strategy(dataset_name, curr_epoch, total_epochs): + if dataset_name == 'ISLES2015': + FIRST_BRACKET = 50 # X% of the epochs for easy scenarios + if curr_epoch <= ((total_epochs * FIRST_BRACKET) / 100): + logger.debug('First Bracket') + rand_val = torch.randint(low=3, high=7, size=(1,)) + + SECOND_BRACKET = 80 + # if the curr_epoch is above X%%, but less than T%, train ONLY with difficult examples + if (curr_epoch > ((total_epochs * FIRST_BRACKET) / 100)) and \ + (curr_epoch <= ((total_epochs * SECOND_BRACKET) / 100)): + logger.debug('Second Bracket') + rand_val = torch.randint(low=0, high=3, size=(1,)) + + if (curr_epoch > ((total_epochs * SECOND_BRACKET) / 100)): + logger.debug('Third Bracket') + rand_val = torch.randint(low=0, high=7, size=(1,)) + + elif dataset_name == 'BRATS2018': + FIRST_BRACKET = 30 # X% of the epochs for easy scenarios + if curr_epoch <= ((total_epochs * FIRST_BRACKET) / 100): + logger.debug('First Bracket') + rand_val = torch.randint(low=10, high=14, size=(1,)) + + SECOND_BRACKET = 70 + # if the curr_epoch is above X%%, but less than T%, train ONLY with difficult examples + if (curr_epoch > ((total_epochs * FIRST_BRACKET) / 100)) and \ + (curr_epoch <= ((total_epochs * SECOND_BRACKET) / 100)): + logger.debug('Second Bracket') + rand_val = torch.randint(low=4, high=10, size=(1,)) + + THIRD_BRACKET = 90 + if (curr_epoch > ((total_epochs * SECOND_BRACKET) / 100)) and \ + (curr_epoch <= ((total_epochs * THIRD_BRACKET) / 100)): + logger.debug('Third Bracket') + rand_val = torch.randint(low=0, high=4, size=(1,)) + + if (curr_epoch > ((total_epochs * THIRD_BRACKET) / 100)): + logger.debug('Third Bracket') + rand_val = torch.randint(low=0, high=14, size=(1,)) + + return rand_val + +def show_intermediate_results_BRATS(G, test_patient, save_path, all_scenarios, epoch, curr_scenario_range=None, batch_size_to_test=5): + if isinstance(test_patient, list): + patients = test_patient + else: + patients = [test_patient] + # put the generator in EVAL mode. + G.eval() + mse = nn.MSELoss() + # mse_total_1_missing = [] + # mse_total_2_missing = [] + # mse_total_3_missing = [] + # + # psnr_total_1_missing = [] + # psnr_total_2_missing = [] + # psnr_total_3_missing = [] + # + # ssim_total_1_missing = [] + # ssim_total_2_missing = [] + # ssim_total_3_missing = [] + + sh = patients[0]['image'].shape + G_all_slices = np.empty((sh[0], sh[1], sh[2], sh[3])) + + for patient in patients: + pat_name = patient['name'].decode('UTF-8') + logger.info('Testing Patient: {}'.format(pat_name)) + + curr_saving_directory = os.path.join(save_path, pat_name, "epoch_" + str(epoch)) + + if not os.path.isdir(curr_saving_directory): + os.makedirs(curr_saving_directory) + + patient_copy = patient['image'].clone() + patient_numpy = patient_copy.detach().cpu().numpy() + scenarios = copy.deepcopy(all_scenarios) + + sh = patient_numpy.shape + + batch_size = batch_size_to_test + + if curr_scenario_range is not None: + scenarios = scenarios[curr_scenario_range[0]:curr_scenario_range[1]] + + logger.info('Testing on scenarios: {}'.format(scenarios)) + # what's the current scenario + for curr_scenario in scenarios: + + logger.info('Testing on scenario: {}'.format(curr_scenario)) + # get the batch indices + batch_indices = range(0, sh[0], batch_size) + + G_mse_loss = [] + G_ssim_val = [] + G_psnr_val = [] + # for each batch + for _num, batch_idx in enumerate(batch_indices): + x_test_r = patient['image'][batch_idx:batch_idx + batch_size, ...].cuda() + x_test_z = x_test_r.clone().cuda().type(torch.cuda.FloatTensor) + + for idx_, k in enumerate(curr_scenario): + if k == 0: + x_test_z[:, idx_, ...] = torch.randn((sh[-2], sh[-1])) + + G_result = G(x_test_z) + + for idx_curr_label, i in enumerate(curr_scenario): + if i == 0: + G_mse_loss.append(mse( + G_result[:, idx_curr_label, ...], + x_test_r[:, idx_curr_label, ...]).item()) + G_ssim_val.append(pyt_ssim.ssim( + G_result[:, idx_curr_label, ...].unsqueeze(1), + x_test_r[:, idx_curr_label, ...].unsqueeze(1), + val_range=2).item()) + G_psnr_val.append(psnr_torch( + G_result[:, idx_curr_label, ...].unsqueeze(1), + x_test_r[:, idx_curr_label, ...].unsqueeze(1), + val_range=2).item()) + + g_out_np = G_result.detach().cpu().numpy() + G_all_slices[batch_idx:batch_idx + batch_size] = g_out_np + + slice_offset = 10 + start_slice = 0 + end_slice = 150 if sh[0] >= 150 else sh[0] + num_synthesized = curr_scenario.count(0) + rows = 4 + num_synthesized + cols = (end_slice - start_slice) // slice_offset + + f, axes = plt.subplots(rows, cols, sharey=True, figsize=(30, 10)) + + # plot the original real patient, row by row until 4th row + # the order of sequences in HDF5 is as follows: (Not T1, T1c, T2, T2F). + seq_names = ['T1', 'T2', 'T1c', 'T2F'] + + for curr_row in range(0, 4): # curr_row also controls the sequence to show + for curr_col, curr_slice_idx in zip(range(0, cols), range(start_slice, end_slice, slice_offset)): + # we use patient['image'] as it contains ALL slices for the patient, not batch wise + axes[curr_row, curr_col].imshow(patient_numpy[curr_slice_idx, curr_row, ...], cmap='gray') + axes[curr_row, curr_col].axis('off') + axes[curr_row, 0].axis('on') + axes[curr_row, 0].get_xaxis().set_ticks([]) + axes[curr_row, 0].get_yaxis().set_ticks([]) + axes[curr_row, 0].set_ylabel(seq_names[curr_row], rotation=90, size='large') + + # now we plot the synthesized one + indices = [i for i, x in enumerate(curr_scenario) if x == 0] + for curr_row, curr_sequence_index in zip(range(4, rows), + indices): # curr_row also controls the sequence to show + for curr_col, curr_slice_idx in zip(range(0, cols), range(start_slice, end_slice, slice_offset)): + # we use patient['image'] as it contains ALL slices for the patient, not batch wise + axes[curr_row, curr_col].imshow(G_all_slices[curr_slice_idx, curr_sequence_index,...], cmap='gray') + axes[curr_row, curr_col].axis('off') + axes[curr_row, 0].axis('on') + axes[curr_row, 0].get_xaxis().set_ticks([]) + axes[curr_row, 0].get_yaxis().set_ticks([]) + axes[curr_row, 0].set_ylabel(seq_names[curr_sequence_index], rotation=90, size='large') + + # Question: Why is the ordering like this? + # Answer: Check the original dataloader in dataloader.py. The "i" values are this way. + plt.suptitle('Epoch: {}\nT1 T2 T1c T2F\n{} {} {} {}'.format(epoch, + curr_scenario[0], + curr_scenario[1], + curr_scenario[2], + curr_scenario[3])) + + f.text(0.65, 0.95, pat_name) + f.text(0.35, 0.95, "MSE: %.5f" % np.mean(G_mse_loss)) + f.text(0.35, 0.935, "PSNR: %.5f" % np.mean(G_psnr_val)) + f.text(0.35, 0.92, "SSIM: %.5f" % np.mean(G_ssim_val)) + + plt.savefig(os.path.join(curr_saving_directory, "".join([str(x) for x in curr_scenario]) + ".png")) + pylab.close(f) + del patient_numpy, patient_copy, patient + # plt.savefig(result_pathname + "_" + "{}".format(pat_name) + "_" + "".join([str(x) for x in curr_scenario]) + ".png", bbox_inches='tight') + + # put the generator back to train mode + G.train() + return 0 + + +def show_intermediate_results(G, test_patient, save_path, all_scenarios, epoch, curr_scenario_range=None, + batch_size_to_test=5, seq_type='T1', dataset='ISLES2015'): + + if isinstance(test_patient, list): + patients = test_patient + else: + patients = [test_patient] + # put the generator in EVAL mode. + mse = nn.MSELoss() + # mse_total_1_missing = [] + # mse_total_2_missing = [] + # mse_total_3_missing = [] + # + # psnr_total_1_missing = [] + # psnr_total_2_missing = [] + # psnr_total_3_missing = [] + # + # ssim_total_1_missing = [] + # ssim_total_2_missing = [] + # ssim_total_3_missing = [] + + sh = patients[0]['image'].shape + G_all_slices = np.empty((sh[0], 1, sh[-2], sh[-1])) + + for patient in patients: + pat_name = patient['name'].decode('UTF-8') + logger.debug('Testing Patient: {}'.format(pat_name)) + + curr_saving_directory = os.path.join(save_path, pat_name, "epoch_" + str(epoch)) + + if not os.path.isdir(curr_saving_directory): + os.makedirs(curr_saving_directory) + + patient_copy = patient['image'].clone() + patient_numpy = patient_copy.detach().cpu().numpy() + scenarios = all_scenarios + + if seq_type == "T1": + SEQ_IDX = 0 + elif seq_type == 'T2': + SEQ_IDX = 1 + else: # this is for ISLES2015 + SEQ_IDX = 3 + + sh = patient_numpy.shape + + batch_size = batch_size_to_test + + if curr_scenario_range is not None: + scenarios = scenarios[curr_scenario_range[0]:curr_scenario_range[1]] + + logger.debug('Testing on scenarios: {}'.format(scenarios)) + # what's the current scenario + for curr_scenario in scenarios: + + logger.debug('Testing on scenario: {}'.format(curr_scenario)) + # get the batch indices + batch_indices = range(0, sh[0], batch_size) + + # for each batch + G_mse_loss = [] + G_ssim_val = [] + G_psnr_val = [] + for _num, batch_idx in enumerate(batch_indices): + x_test_r = patient['image'][batch_idx:batch_idx + batch_size, ...].cuda() + x_test_z = x_test_r.clone().cuda().type(torch.cuda.FloatTensor) + + for idx_, k in enumerate(curr_scenario): + if k == 0: + x_test_z[:, idx_, ...] = torch.ones((sh[-2], sh[-1])) * -1.0 + + G_result = G(x_test_z) + + for i in range(batch_size): + G_mse_loss.append(mse(G_result[i, 0, ...], x_test_r[i, SEQ_IDX, ...]).item()) + + G_ssim_val.append(pyt_ssim.ssim( + G_result[:, 0].unsqueeze(1) / (torch.max(x_test_r[:, SEQ_IDX]) + 0.0001), + x_test_r[:, SEQ_IDX].unsqueeze(1) / (torch.max(x_test_r[:, SEQ_IDX]) + 0.0001), + val_range=1).item()) + + G_psnr_val.append( + psnr_torch(G_result, x_test_r[:, SEQ_IDX, ...].unsqueeze(1)).item()) + + g_out_np = G_result.detach().cpu().numpy() + G_all_slices[batch_idx:batch_idx + batch_size] = g_out_np + + slice_offset = 10 + start_slice = 0 + end_slice = 150 if sh[0] >= 150 else sh[0] + + num_synthesized = 1 # WE ONLY SYNTHESIZE FLAIR + + rows = 4 + num_synthesized + cols = (end_slice - start_slice) // slice_offset + + f, axes = plt.subplots(rows, cols, sharey=True, figsize=(30, 10)) + + # plot the original real patient, row by row until 4th row + # the order of sequences in HDF5 is as follows: (Not T1, T1c, T2, T2F). + + if dataset == 'ISLES2015': + seq_names = ['T1', 'T2', 'DWI', 'T2F'] + else: + seq_names = ['T1', 'T2', 'T1c', 'T2F'] + + for curr_row in range(0, 4): # curr_row also controls the sequence to show + for curr_col, curr_slice_idx in zip(range(0, cols), range(start_slice, end_slice, slice_offset)): + # we use patient['image'] as it contains ALL slices for the patient, not batch wise + axes[curr_row, curr_col].imshow(np.clip(patient_numpy[curr_slice_idx, curr_row, ...] / + np.max(patient_numpy[curr_slice_idx, curr_row, ...]), 0.0, 1.0), + cmap='gray', vmin=0, vmax=1) + axes[curr_row, curr_col].axis('off') + axes[curr_row, 0].axis('on') + axes[curr_row, 0].get_xaxis().set_ticks([]) + axes[curr_row, 0].get_yaxis().set_ticks([]) + axes[curr_row, 0].set_ylabel(seq_names[curr_row], rotation=90, size='large') + + # now we plot the synthesized one + indices = [0] # 3 is the index for T2F, but we only have 1 index coming out of generator, so 0. + flair_index = 3 + for curr_row, curr_sequence_index in zip(range(4, rows), + indices): # curr_row also controls the sequence to show + for curr_col, curr_slice_idx in zip(range(0, cols), range(start_slice, end_slice, slice_offset)): + # we use patient['image'] as it contains ALL slices for the patient, not batch wise + axes[curr_row, curr_col].imshow(np.clip(G_all_slices[curr_slice_idx, curr_sequence_index,...] / + np.max(G_all_slices[curr_slice_idx, curr_sequence_index,...]), 0.0, 1.0), + cmap='gray', vmin=0, vmax=1) + axes[curr_row, curr_col].axis('off') + axes[curr_row, 0].axis('on') + axes[curr_row, 0].get_xaxis().set_ticks([]) + axes[curr_row, 0].get_yaxis().set_ticks([]) + axes[curr_row, 0].set_ylabel(seq_names[curr_sequence_index], rotation=90, size='large') + + # Question: Why is the ordering like this? + # Answer: Check the original dataloader in dataloader.py. The "i" values are this way. + if dataset == 'ISLES2015': + plt.suptitle('Epoch: {}\nT1 T2 DWI T2F\n{} {} {} {}'.format(epoch, + curr_scenario[0], + curr_scenario[1], + curr_scenario[2], + curr_scenario[3])) + else: + plt.suptitle('Epoch: {}\nT1 T2 T1c T2F\n{} {} {} {}'.format(epoch, + curr_scenario[0], + curr_scenario[1], + curr_scenario[2], + curr_scenario[3])) + + f.text(0.65, 0.95, pat_name) + f.text(0.35, 0.95, "MSE: %.5f" % np.mean(G_mse_loss)) + f.text(0.35, 0.935, "PSNR: %.5f" % np.mean(G_psnr_val)) + f.text(0.35, 0.92, "SSIM: %.5f" % np.mean(G_ssim_val)) + plt.subplots_adjust(wspace=0, hspace=0) + plt.savefig(os.path.join(curr_saving_directory, "".join([str(x) for x in curr_scenario]) + ".png")) + pylab.close(f) + del patient_numpy, patient_copy, patient + # put the generator back to train mode + del G_all_slices + return 0 + +# ========================================================================================= +# Function used by original cDCGAN implementation that I based my code upon +# ========================================================================================= +def show_result(num_epoch, show = False, save = False, path = 'result.png'): + + G.eval() + test_images = G(fixed_z_, fixed_y_label_) + G.train() + + size_figure_grid = 10 + fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5)) + for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)): + ax[i, j].get_xaxis().set_visible(False) + ax[i, j].get_yaxis().set_visible(False) + + for k in range(10*10): + i = k // 10 + j = k % 10 + ax[i, j].cla() + ax[i, j].imshow(test_images[k, 0].cpu().data.numpy(), cmap='gray') + + label = 'Epoch {0}'.format(num_epoch) + fig.text(0.5, 0.04, label, ha='center') + plt.savefig(path) + + if show: + plt.show() + else: + plt.close() + +def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'): + x = range(len(hist['D_losses'])) + + y1 = hist['D_losses'] + y2 = hist['G_losses'] + + plt.plot(x, y1, label='D_loss') + plt.plot(x, y2, label='G_loss') + + plt.xlabel('Epoch') + plt.ylabel('Loss') + + plt.legend(loc=4) + plt.grid(True) + plt.tight_layout() + + if save: + plt.savefig(path) + + if show: + plt.show() + else: + plt.close() + + +def create_dataloaders(parent_path='/scratch/asa224/asa224/Datasets/BRATS2018/HDF5_Datasets/', + parent_name='preprocessed', + dataset_name='training_data_hgg', + load_pat_names=True, + dataset_type='cropped', + load_seg=False, + transform_fn=[Resize(size=(256, 256)), ToTensor()], + apply_normalization = True, + which_normalization=None, + train_range=None, + resize_slices=148, + get_viz_dataloader=True, + num_workers=0, + dataset='BRATS2018', + load_indices=None, + shuffle=False): + + logger.info("Setting paths") + # train_h5path = os.path.join(parent_path, 'BRATS_Combined.h5') + if dataset == 'BRATS2018': + if dataset_type == 'cropped': + train_h5path = os.path.join(parent_path, 'BRATS2018_Cropped.h5') + else: + train_h5path = os.path.join(parent_path, 'BRATS2018.h5') + + if 'lgg' in dataset_name: + mean_var_path = os.path.join(parent_path, 'BRATS2018_HDF5_Datasetstraining_data_lgg_mean_var.p') + else: + mean_var_path = os.path.join(parent_path, 'BRATS2018_HDF5_Datasetstraining_data_hgg_mean_var.p') + + + elif dataset == 'BRATS2015': + if dataset_type == 'cropped': + train_h5path = os.path.join(parent_path, 'BRATS2015_Cropped.h5') + else: + train_h5path = os.path.join(parent_path, 'BRATS2015.h5') + + # Only use the LGG data as per the experiment design. It's actually mean and var, not std. Ignore the name + mean_var_path = os.path.join(parent_path, 'HDF5_Datasetstraining_data_lgg_mean_std.p') + + elif dataset == 'ISLES2015': + if dataset_type == 'cropped': + train_h5path = os.path.join(parent_path, 'ISLES2015_Cropped.h5') + else: + train_h5path = os.path.join(parent_path, 'ISLES2015.h5') + + # only using SISS data. It's actually mean and var, not std. Ignore the name + mean_var_path = os.path.join(parent_path, 'HDF5_Datasetstraining_data_mean_std.p') + + logger.debug('Using the following paths:') + logger.debug('train_h5_path: {}'.format(train_h5path)) + logger.debug('mean_var_path: {}'.format(mean_var_path)) + + transform = transforms.Compose(transform_fn) + + # build loader object + logger.info("Build loader object") + loader = BrainMRIData(train_h5path, mean_var_path, parent_name, + dataset_name, load_pat_names, + load_seg, transform, apply_normalization=apply_normalization, + which_normalization=which_normalization, + dataset=dataset, load_indices=load_indices, train_range=train_range) + + logger.info("Getting DataLoader instance using the loader object") + # batch size here means patients. How many patients to load at once? + dataloader = DataLoader(loader, batch_size=1, + shuffle=shuffle, num_workers=num_workers, collate_fn=my_collate) + + if get_viz_dataloader: + dataloader_for_viz = loader + return dataloader, dataloader_for_viz + + return dataloader + + +def impute_reals_into_fake(x_z, fake_x, label_scenario): + for idx, k in enumerate(label_scenario): + if k == 1: # THIS IS A REAL AVAILABLE SEQUENCE + fake_x[:, idx, ...] = x_z[:, idx, ...].clone().cuda() + + return fake_x + + +def save_checkpoint(state, filename, pickle_module): + torch.save(state, filename, pickle_module=pickle_module) + +def load_checkpoint(model, optimizer, filename, pickle_module): + if os.path.isfile(filename): + logger.info("Loading checkpoint '{}'".format(filename)) + checkpoint = torch.load(filename, pickle_module=pickle_module) + # args.start_epoch = checkpoint['epoch'] + # best_prec1 = checkpoint['best_prec1'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + logger.info("Loaded checkpoint '{}' (epoch {})" + .format(filename, checkpoint['epoch'])) + else: + logger.critical('Checkpoint {} does not exist.'.format(filename)) + + return model, optimizer + + + +def psnr_torch(pred, gt): + # normalize images between [0, 1] + epsilon = 0.00001 + epsilon2 = torch.from_numpy(np.array(0.01, dtype=np.float32)) + # always use ground truth + gt_n = gt / (gt.max() + epsilon) + pred_n = pred / (pred.max() + epsilon) + + PIXEL_MAX = 1.0 + + mse = torch.mean((gt_n - pred_n) ** 2) + if mse.item() == 0.0: + psnr = 20 * torch.log10(PIXEL_MAX / epsilon2) + else: + psnr = 20 * torch.log10(PIXEL_MAX / torch.sqrt(mse)) + + return psnr + + +def l2_torch(a, b): + return torch.mean((a - b) ** 2) + +def calculate_metrics(G, patient_list, + save_path, all_scenarios, + epoch, curr_scenario_range=None, + batch_size_to_test=2, + impute_type=None, + dataset = 'ISLES2015', + convert_normalization=False, + save_stats=False, + mean_var_file=None, + use_pytorch_ssim=False, seq_type='T1'): + + """ + For ISLES2015 + all_scenarios: scenarios + curr_scenario_range: None + batch_to_test: 2 + """ + + if isinstance(patient_list, list): + patients = patient_list + else: + patients = [patient_list] + # put the generator in EVAL mode. + mse = nn.MSELoss() + save_im_path = os.path.join(save_path, 'all_slices', 'epoch_{}'.format(epoch)) + + if not os.path.isdir(save_im_path): + os.makedirs(save_im_path) + + + # contains metrics for EACH slice from EACH OF THE SCENARIO. Basically everything. This is what we need for + # ISLES2015 + running_mse = {} + running_psnr = {} + running_ssim = {} + + for (pat_ind, patient) in enumerate(patients): + pat_name = patient['name'].decode('UTF-8') + logger.debug('Testing Patient: {}'.format(pat_name)) + patient_image = patient['image'] + + patient_copy = patient['image'].clone() + + patient_numpy = patient_copy.detach().cpu().numpy() + + scenarios = all_scenarios + all_minus_1_g = torch.ones((batch_size_to_test,1,256,256)).cuda() * -1.0 + all_minus_x_test_r = torch.ones((batch_size_to_test, 256, 256)).cuda() * -1.0 + + sh = patient_numpy.shape + + batch_size = batch_size_to_test + + # this will store output for ALL patients + + if curr_scenario_range is not None: + scenarios = scenarios[curr_scenario_range[0]:curr_scenario_range[1]] + + logger.debug('Testing on scenarios: {}'.format(scenarios)) + for curr_scenario in scenarios: + + curr_scenario_str = ''.join([str(x) for x in curr_scenario]) + + running_mse[curr_scenario_str] = [] + running_psnr[curr_scenario_str] = [] + running_ssim[curr_scenario_str] = [] + + logger.debug('Testing on scenario: {}'.format(curr_scenario)) + + # get the batch indices + batch_indices = range(0, sh[0], batch_size) + + # for each batch + for _num, batch_idx in enumerate(batch_indices): + x_test_r = patient_image[batch_idx:batch_idx + batch_size, ...].cuda() + x_test_z = x_test_r.clone().cuda().type(torch.cuda.FloatTensor) + + if impute_type == 'noise': + impute_tensor = torch.randn((batch_size, + 256, + 256), device='cuda') + + elif impute_type == 'average': + avail_indx = [i for i, x in enumerate(curr_scenario) if x == 1] + impute_tensor = torch.mean(x_test_r[:, avail_indx, ...], dim=1) + elif impute_type == 'zeros': + impute_tensor = torch.zeros((batch_size, + 256, + 256), device='cuda') + else: + impute_tensor = torch.ones((sh[-2], sh[-1])) * -1.0 + # print('Imputing with -1') + + # print('Imputing with {}'.format(impute_type)) + for idx_, k in enumerate(curr_scenario): + if k == 0: + x_test_z[:, idx_, ...] = impute_tensor + + G_result = G(x_test_z) + # save all images + if dataset == 'ISLES2015' or dataset == 'BRATS2015': + if 'ISLES' in dataset: + SEQ_IND = 3 + ssim = compare_ssim + else: + if seq_type == "T1": + SEQ_IND = 0 + elif seq_type == "T2": + SEQ_IND = 1 + ssim = pyt_ssim.ssim + + if dataset == 'BRATS2015': + + # den_G_result = G_result.max() - G_result.min() + # if den_G_result != 0.0: + # G_result_norm = 2*((G_result - G_result.min())/ (den_G_result)) -1 + # else: + # G_result_norm = all_minus_1_g + # + # den_x_test_r = x_test_r[:, SEQ_IND, ...].max() - x_test_r[:, SEQ_IND, ...].min() + # if den_x_test_r != 0.0: + # x_test_r_norm = 2*((x_test_r[:, SEQ_IND, ...] - x_test_r[:, SEQ_IND, ...].min())/ (den_x_test_r)) - 1 + # else: + # x_test_r_norm = all_minus_x_test + + # calculate metrics + running_mse[curr_scenario_str].append( + mse(G_result, + x_test_r[:, SEQ_IND, ...].unsqueeze(1)).item()) + + running_ssim[curr_scenario_str].append(ssim( + G_result[:, 0].unsqueeze(1) / (torch.max(x_test_r[:, SEQ_IND]) + 0.0001), + x_test_r[:, SEQ_IND].unsqueeze(1) / (torch.max(x_test_r[:, SEQ_IND]) + 0.0001), + val_range=1).item()) + + running_psnr[curr_scenario_str].append( + psnr_torch(G_result, x_test_r[:, SEQ_IND, ...].unsqueeze(1)).item()) + + # running_mse[curr_scenario_str].append( + # mse(G_result, x_test_r[:,SEQ_IND].unsqueeze(1)).item()) + # + # s = ssim(G_result, x_test_r[:,SEQ_IND].unsqueeze(1), val_range=2).item() + # if s > 0: + # running_ssim[curr_scenario_str].append(s) + # else: + # running_ssim[curr_scenario_str].append(0.0) + # + # p = psnr_torch(G_result, x_test_r[:, SEQ_IND].unsqueeze(1), val_range=2).item() + # if p > 0: + # running_psnr[curr_scenario_str].append(p) + # else: + # running_psnr[curr_scenario_str].append(0) + + real_filepath = os.path.join(save_im_path, '{}-{}_real.png'.format(pat_ind, _num)) + fake_filepath = os.path.join(save_im_path, '{}-{}_fake.png'.format(pat_ind, _num)) + + scipy.misc.toimage(G_result[0, 0].detach().cpu().numpy(), cmin=-1.0, cmax=1.0).save( + fake_filepath) + scipy.misc.toimage(x_test_r[0, SEQ_IND].detach().cpu().numpy(), cmin=-1.0, cmax=1.0).save( + real_filepath) + + else: + running_mse[curr_scenario_str].append( + mse(G_result, + x_test_r[:, SEQ_IND, ...].unsqueeze(1)).item()) + + running_ssim[curr_scenario_str].append(ssim( + G_result[:, 0].unsqueeze(1) / (torch.max(x_test_r[:, SEQ_IND]) + 0.0001), + x_test_r[:, SEQ_IND].unsqueeze(1) / (torch.max(x_test_r[:, SEQ_IND]) + 0.0001), + val_range=1).item()) + + # running_ssim[curr_scenario_str].append(ssim(G_result[:, 0].unsqueeze(1), + # x_test_r[:, SEQ_IND].unsqueeze(1)).item()) + + running_psnr[curr_scenario_str].append( + psnr_torch(G_result, x_test_r[:, SEQ_IND, ...].unsqueeze(1)).item()) + + else: + for idx_curr_label, j in enumerate(curr_scenario): + if j == 0: + running_mse[curr_scenario_str].append( + mse(G_result[:, idx_curr_label] / + (torch.max(G_result[:, idx_curr_label]) + 0.0001), + x_test_r[:, idx_curr_label] / + (torch.max(x_test_r[:, idx_curr_label]) + 0.0001)).item()) + + running_ssim[curr_scenario_str].append(pyt_ssim.ssim( + G_result[:, idx_curr_label].unsqueeze(1) / + (torch.max(G_result[:, idx_curr_label]) + 0.0001), + x_test_r[:, idx_curr_label].unsqueeze(1) / + (torch.max(x_test_r[:, idx_curr_label]) + 0.0001), + val_range=1).item()) + + running_psnr[curr_scenario_str].append( + psnr_torch(G_result[:, idx_curr_label], + x_test_r[:, idx_curr_label]).item()) + + + + num_dict = {} + all_mean_mse = [] + all_mean_psnr = [] + all_mean_ssim = [] + + for (mse_key, mse_list, psnr_key, psnr_list, ssim_key, ssim_list) in zip(running_mse.keys(), running_mse.values(), + running_psnr.keys(), running_psnr.values(), + running_ssim.keys(), running_ssim.values()): + + assert mse_key == ssim_key == psnr_key + num_dict[mse_key] = { + 'mse': np.mean(mse_list), + 'psnr': np.mean(psnr_list), + 'ssim': np.mean(ssim_list) + } + + all_mean_mse += mse_list + all_mean_psnr += psnr_list + all_mean_ssim += ssim_list + + num_dict['mean'] = { + 'mse': np.mean(all_mean_mse), + 'psnr': np.mean(all_mean_psnr), + 'ssim': np.mean(all_mean_ssim) + } + if save_stats: + stat_folder = os.path.join(save_path, "stats/".format()) + if not os.path.isdir(stat_folder): + os.makedirs(stat_folder) + print('Saving running statistics to folder: {}'.format(stat_folder)) + # save mse, psnr and ssim + pickle.dump(running_mse, open(os.path.join(stat_folder, 'mse.p'), 'wb')) + pickle.dump(running_psnr, open(os.path.join(stat_folder, 'psnr.p'), 'wb')) + pickle.dump(running_ssim, open(os.path.join(stat_folder, 'ssim.p'), 'wb')) + + return num_dict, running_mse, running_psnr, running_ssim + + return num_dict + + +def calculate_metrics_pgan(G, patient_list, + save_path, all_scenarios, + epoch, curr_scenario_range=None, + batch_size_to_test=2, + dataset = 'ISLES2015', + convert_normalization=False, + mean_var_file=None, + use_pytorch_ssim=False, seq_type='T1'): + + """ + For ISLES2015 + all_scenarios: scenarios + curr_scenario_range: None + batch_to_test: 2 + """ + + if isinstance(patient_list, list): + patients = patient_list + else: + patients = [patient_list] + # put the generator in EVAL mode. + mse = nn.MSELoss() + + # contains metrics for EACH slice from EACH OF THE SCENARIO. Basically everything. This is what we need for + # ISLES2015 + running_mse = [] + running_psnr = [] + running_ssim = [] + + for (pat_ind, patient) in enumerate(patients): + pat_name = patient['name'].decode('UTF-8') + logger.debug('Testing Patient: {}'.format(pat_name)) + patient_image = patient['image'] + + patient_copy = patient['image'].clone() + + patient_numpy = patient_copy.detach().cpu().numpy() + + sh = patient_numpy.shape + + batch_size = batch_size_to_test + + # get the batch indices + batch_indices = range(0, sh[0], batch_size) + + if seq_type == 'T1': + SEQ_IDX = 1 # WE TRAIN WITH T2 + SEQ_IDX_SYNTH = 0 + else: + SEQ_IDX = 0 # WE TRAIN WITH T1 + SEQ_IDX_SYNTH = 1 + + # for each batch + for _num, batch_idx in enumerate(batch_indices): + x_test_r = patient_image[batch_idx:batch_idx + batch_size, SEQ_IDX_SYNTH,...].unsqueeze(1).cuda() + x_test_z =patient_image[batch_idx:batch_idx + batch_size, SEQ_IDX,...].unsqueeze(1).cuda() + + G_result = G(x_test_z) + ssim = pyt_ssim.ssim + + running_mse.append( + mse(G_result, + x_test_r).item()) + + running_ssim.append(ssim( + G_result / (torch.max(x_test_r) + 0.0001), + x_test_r / (torch.max(x_test_r) + 0.0001), + val_range=1).item()) + + # running_ssim[curr_scenario_str].append(ssim(G_result[:, 0].unsqueeze(1), + # x_test_r[:, SEQ_IND].unsqueeze(1)).item()) + + running_psnr.append( + psnr_torch(G_result, x_test_r).item()) + + num_dict = {} + num_dict['mean'] = { + 'mse': np.mean(running_mse), + 'psnr': np.mean(running_psnr), + 'ssim': np.mean(running_ssim) + } + + return num_dict + + +def printTable(result_dict): + print("{:<10} {:<10} {:<10} {:<10}".format('Scenario', 'MSE', 'PSNR', 'SSIM')) + result_keys = sorted(result_dict.keys()) + for k, v in zip(result_keys, result_dict.values()): + print("{:<10} {:<10.4f} {:<10.4f} {:<10.4f}".format(k, result_dict[k]['mse'], + result_dict[k]['psnr'], + result_dict[k]['ssim'])) diff --git a/modules/mischelpers.py b/modules/mischelpers.py new file mode 100755 index 0000000..d411af5 --- /dev/null +++ b/modules/mischelpers.py @@ -0,0 +1,239 @@ +""" +========================================================== + Misc Helper Classes/Functions +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: The module has various helper classes/functions + that can be used throughout the pipeline, and + don't fit exactly in either data loading or + visualization operations. +LICENCE: Proprietary for now. +""" + +import numpy as np +from configfile import config +import h5py +from nilearn._utils import check_niimg +from nilearn.image import new_img_like +from nilearn.image import reorder_img, resample_img + + + +class Rect3D: + """ + Class to encapsulate the Rectangle coordinates. This prevents future + issues when the coordinates need to be standardized. + """ + def __init__(self, coord_list): + if len(coord_list) < 6: + print('Coordinate list shape is incorrect, creating empty object!') + coord_list = [0, 0, 0, 0, 0, 0] + self.empty = True + else: + self.empty = False + + self.rmin = coord_list[0] + self.rmax = coord_list[1] + self.cmin = coord_list[2] + self.cmax = coord_list[3] + self.zmin = coord_list[4] + self.zmax = coord_list[5] + self.list_view = coord_list + + def show(self): + return self.list_view + +class Rect2D: + """ + Class to encapsulate the Rectangle coordinates. This prevents future + issues when the coordinates need to be standardized. + """ + def __init__(self, coord_list): + if len(coord_list) < 4: + print('Coordinate list shape is incorrect, creating empty object!') + coord_list = [0, 0, 0, 0] + self.empty = True + else: + self.empty = False + + self.rmin = coord_list[0] + self.rmax = coord_list[1] + self.cmin = coord_list[2] + self.cmax = coord_list[3] + self.list_view = coord_list + + def show(self): + return self.list_view + +def bbox_3D(img, tol=0.5): + """ + TOL = argument used when dark regions are >0 + (usually after some preprocessing, like + rescaling). + """ + r, c, z = np.where(img > tol) + rmin, rmax, cmin, cmax, zmin, zmax = np.min(r), np.max(r), np.min(c), np.max(c), np.min(z), np.max(z) + rect_obj = Rect3D([rmin, rmax, cmin, cmax, zmin, zmax]) + return rect_obj + +def bbox_2D(img, tol=0.5): + """ + TOL = argument used when dark regions are >0 + (usually after some preprocessing, like + rescaling). + """ + r, c = np.where(img > tol) + if r.size == 0 or c.size == 0: + return Rect2D([-1, -1, -1, -1]) + else: + rmin, rmax, cmin, cmax = np.min(r), np.max(r), np.min(c), np.max(c) + rect_obj = Rect2D([rmin, rmax, cmin, cmax]) + return rect_obj + +def open_hdf5(filepath=None, mode='r'): + if filepath == None: + filepath = config['hdf5_filepath_prefix'] + + return h5py.File(filepath, mode=mode) + +def get_data_splits_bbox(hdf5_filepath, train_start=0, train_end=190, test_start=190, test_end=None): + """ + + :param hdf5_filepath: + :param train_start: Start index to slice to get the training data. For 10 instances starting from 0, choose 0. + :param train_end: End index for training. Remember this index is 'exclusive', so if you want 10 instances, choose this as 10 + :param test_start: Start index to slice to get the testing data. Same comment as above. + :param test_end: End index for testing. + :return: Keras instances to slice into x_train, y_train, x_test, y_test. + """ + import keras + filepath = config['hdf5_filepath_prefix'] if hdf5_filepath is None else hdf5_filepath + + x_train = keras.utils.io_utils.HDF5Matrix(filepath, "training_data_hgg", start=train_start, end=train_end, + normalizer=None) + y_train = keras.utils.io_utils.HDF5Matrix(filepath, "bounding_box_hgg", start=train_start, end=train_end, + normalizer=None) + + x_test = keras.utils.io_utils.HDF5Matrix(filepath, "training_data_hgg", start=test_start, end=test_end, + normalizer=None) + y_test = keras.utils.io_utils.HDF5Matrix(filepath, "bounding_box_hgg", start=test_start, end=test_end, + normalizer=None) + + return x_train, y_train, x_test, y_test + +def createDense(bbox, im): + box = np.zeros(im.shape) + box[bbox[0]:bbox[1], bbox[2]:bbox[3], bbox[4]:bbox[5]] = 1 + return box + + +def _crop_img_to(img, slices, copy=True): + """Crops image to a smaller size + Crop img to size indicated by slices and adjust affine + accordingly + Parameters + ---------- + img: Niimg-like object + See http://nilearn.github.io/manipulating_images/input_output.html + Img to be cropped. If slices has less entries than img + has dimensions, the slices will be applied to the first len(slices) + dimensions + slices: list of slices + Defines the range of the crop. + E.g. [slice(20, 200), slice(40, 150), slice(0, 100)] + defines a 3D cube + copy: boolean + Specifies whether cropped data is to be copied or not. + Default: True + Returns + ------- + cropped_img: Niimg-like object + See http://nilearn.github.io/manipulating_images/input_output.html + Cropped version of the input image + """ + + img = check_niimg(img) + + data = img.get_data() + affine = img.affine + + cropped_data = data[slices] + if copy: + cropped_data = cropped_data.copy() + + linear_part = affine[:3, :3] + old_origin = affine[:3, 3] + new_origin_voxel = np.array([s.start for s in slices]) + new_origin = old_origin + linear_part.dot(new_origin_voxel) + + new_affine = np.eye(4) + new_affine[:3, :3] = linear_part + new_affine[:3, 3] = new_origin + + return new_img_like(img, cropped_data, new_affine) + + +def crop_img_custom(img, slices=None, rtol=1e-8, copy=True): + """Crops img as much as possible + Will crop img, removing as many zero entries as possible + without touching non-zero entries. Will leave one voxel of + zero padding around the obtained non-zero area in order to + avoid sampling issues later on. + Parameters + ---------- + img: Niimg-like object + See http://nilearn.github.io/manipulating_images/input_output.html + img to be cropped. + rtol: float + relative tolerance (with respect to maximal absolute + value of the image), under which values are considered + negligeable and thus croppable. + copy: boolean + Specifies whether cropped data is copied or not. + Returns + ------- + cropped_img: image + Cropped version of the input image + """ + + img = check_niimg(img) + data = img.get_data() + + if slices is not None: + return _crop_img_to(img, slices, copy=copy), slices + else: + infinity_norm = max(-data.min(), data.max()) + passes_threshold = np.logical_or(data < -rtol * infinity_norm, + data > rtol * infinity_norm) + + if data.ndim == 4: + passes_threshold = np.any(passes_threshold, axis=-1) + coords = np.array(np.where(passes_threshold)) + start = coords.min(axis=1) + end = coords.max(axis=1) + 1 + + # pad with one voxel to avoid resampling problems + start = np.maximum(start - 1, 0) + end = np.minimum(end + 1, data.shape[:3]) + + slices = [slice(s, e) for s, e in zip(start, end)] + + return _crop_img_to(img, slices, copy=copy), slices + + +def resize(image, new_shape, interpolation="continuous"): + input_shape = np.asarray(image.shape, dtype=np.float16) + ras_image = reorder_img(image, resample=interpolation) + output_shape = np.asarray(new_shape) + new_spacing = input_shape/output_shape + new_affine = np.copy(ras_image.affine) + new_affine[:3, :3] = ras_image.affine[:3, :3] * np.diag(new_spacing) + return resample_img(ras_image, target_affine=new_affine, target_shape=output_shape, interpolation=interpolation, clip=True) \ No newline at end of file diff --git a/modules/models.py b/modules/models.py new file mode 100755 index 0000000..5739950 --- /dev/null +++ b/modules/models.py @@ -0,0 +1,144 @@ +import torch.nn as nn +import torch.nn.functional as F +import os, time +import matplotlib.pyplot as plt +import itertools +import pickle +import imageio +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +# ================================================================================ +# Conditional version of Pix2Pix GAN +# ================================================================================ + +# G(z) +class cPix2PixGenerator(nn.Module): + def __init__(self): + super(cPix2PixGenerator, self).__init__() + self.conv1_1_input = nn.Conv2d(4, 64, 3, 1, 1) + self.batch_norm1_1_input = nn.BatchNorm2d(64) + + self.conv1_1_label = nn.Conv2d(4, 64, 3, 1, 1) + self.batch_norm1_1_label = nn.BatchNorm2d(64) + + self.conv2_1 = nn.Conv2d(128, 128, 3, 1, 1) + self.batch_norm2_1 = nn.BatchNorm2d(128) + + self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1) + self.batch_norm3_1 = nn.BatchNorm2d(256) + + self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1) + self.batch_norm4_1 = nn.BatchNorm2d(512) + + self.conv5_1 = nn.Conv2d(512, 256, 3, 1, 1) + self.batch_norm5_1 = nn.BatchNorm2d(256) + + self.conv6_1 = nn.Conv2d(256, 128, 3, 1, 1) + self.batch_norm6_1 = nn.BatchNorm2d(128) + + self.conv7_1 = nn.Conv2d(128, 4, 3, 1, 1) + + # weight_init + def weight_init(self, mean, std): + for m in self._modules: + normal_init(self._modules[m], mean, std) + + # forward method + def forward(self, input, label): + inp_image = F.relu(self.batch_norm1_1_input(self.conv1_1_input(input))) + inp_label = F.relu(self.batch_norm1_1_label(self.conv1_1_label(label))) + x = torch.cat([inp_image, inp_label], 1) + x = F.relu(self.batch_norm2_1(self.conv2_1(x))) + x = F.relu(self.batch_norm3_1(self.conv3_1(x))) + x = F.relu(self.batch_norm4_1(self.conv4_1(x))) + x = F.relu(self.batch_norm5_1(self.conv5_1(x))) + x = F.relu(self.batch_norm6_1(self.conv6_1(x))) + x = self.conv7_1(x) + + return x + +class cPix2PixDiscriminator(nn.Module): + def __init__(self): + super(cPix2PixDiscriminator, self).__init__() + self.conv1_1_input = nn.Conv2d(4, 64, 5, 1, 0) + self.batch_norm1_1_input = nn.BatchNorm2d(64) + self.maxpool1_1_input = nn.MaxPool2d(2, 2, 0) + + self.conv1_1_label = nn.Conv2d(4, 64, 5, 1, 0) + self.batch_norm1_1_label = nn.BatchNorm2d(64) + self.maxpool1_1_label = nn.MaxPool2d(2, 2, 0) + + self.conv2_1 = nn.Conv2d(128, 256, 5, 1, 0) + self.batch_norm2_1 = nn.BatchNorm2d(256) + self.maxpool2_1 = nn.MaxPool2d(2, 2, 0) + + self.conv3_1 = nn.Conv2d(256, 128, 5, 1, 0) + self.batch_norm3_1 = nn.BatchNorm2d(128) + self.maxpool3_1 = nn.MaxPool2d(2, 2, 0) + + self.conv4_1 = nn.Conv2d(128, 1, 5, 1, 0) + + self.linear1 = nn.Linear(1 * 24 * 24, 128) + self.linear2 = nn.Linear(128, 256) + self.linear3 = nn.Linear(256, 4) + + # weight_init + def weight_init(self, mean, std): + for m in self._modules: + normal_init(self._modules[m], mean, std) + + # forward method + def forward(self, input, label): + inp_image = F.leaky_relu(self.maxpool1_1_input(self.batch_norm1_1_input(self.conv1_1_input(input))), 0.2) + inp_label = F.leaky_relu(self.maxpool1_1_label(self.batch_norm1_1_label(self.conv1_1_label(label))), 0.2) + x = torch.cat([inp_image, inp_label], 1) + x = F.leaky_relu(self.maxpool2_1(self.batch_norm2_1(self.conv2_1(x))), 0.2) + x = F.leaky_relu(self.maxpool3_1(self.batch_norm3_1(self.conv3_1(x))), 0.2) + x = F.leaky_relu(self.conv4_1(x), 0.2) + x = F.leaky_relu(self.linear1(x.view(-1, self.num_flat_features(x))), 0.2) + x = F.leaky_relu(self.linear2(x), 0.2) + x = F.sigmoid(self.linear3(x)) + + return x + + def num_flat_features(self, x): + size = x.size()[1:] # all dimensions except the batch dimension + num_features = 1 + for s in size: + num_features *= s + return num_features + +# Weight initialization function used by the above cPix2Pix + +def normal_init(m, mean, std): + if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): + m.weight.data.normal_(mean, std) + m.bias.data.zero_() + +# ================================================================================ +# Sample Neural Network to test stuff +# ================================================================================ +class SampleNet(nn.Module): + + def __init__(self): + super().__init__() + # 1 input image channel, 6 output channels, 5x5 square convolution + # kernel + self.conv1 = nn.Conv2d(4, 1, 5, padding=2) + + def forward(self, x): + # Max pooling over a (2, 2) window + x = F.relu(self.conv1(x)) + return x + + def num_flat_features(self, x): + size = x.size()[1:] # all dimensions except the batch dimension + num_features = 1 + for s in size: + num_features *= s + return num_features + + diff --git a/modules/preprocess.py b/modules/preprocess.py new file mode 100755 index 0000000..2cb36c8 --- /dev/null +++ b/modules/preprocess.py @@ -0,0 +1,155 @@ +""" +========================================================== + Preprocess BRATS Data +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: This file uses the previously generated data + (using create_hdf5_file.py) and generates a + new file with similar structure, but after + applying a couple of preprocessing steps. + More specifically, the script applies the + following operations on the data: + + 1) Crop out the dark margins in the scans + to only leave a concise brain area. For + this a generous estimate of bounding box + generated from the whole dataset is used. + For more information, see checkLargestCropSize + notebook. + + The code DOES NOT APPLY MEAN/VAR normalization, + but simply calculates the values and saves on disk. + Check lines 140-143 for more information. + + The saved mean/var files are to be used before + the training process. + +LICENCE: Proprietary for now. +""" + +import h5py +from modules.configfile import config +import numpy as np +import SimpleITK as sitk +import optparse +import logging +# from modules.mischelpers import * +from modules.dataloader import standardize +import os + +logging.basicConfig(level=logging.DEBUG) + +try: + logger = logging.getLogger(__file__.split('/')[-1]) +except: + logger = logging.getLogger(__name__) + +logger.warning('[IMPORTANT] The code DOES NOT APPLY mean/var normalization, rather it calculates it and saves to disk') +# ------------------------------------------------------------------------------------ +# open existing datafile +# ------------------------------------------------------------------------------------ +logger.info('opening previously generated HDF5 file.') + +# open the existing datafile +hdf5_file_main = h5py.File(config['hdf5_filepath_prefix'], 'r') + +logger.info('opened HDF5 file at {}'.format(config['hdf5_filepath_prefix'])) + +# get the group identifier for original dataset +hdf5_file = hdf5_file_main['original_data'] + +# ==================================================================================== + +# ------------------------------------------------------------------------------------ +# create new HDF5 file to hold cropped data. +# ------------------------------------------------------------------------------------ +logger.info('creating new HDF5 dataset to hold cropped/normalized data') +filename = os.path.join(os.sep.join(config['hdf5_filepath_prefix'].split(os.sep)[0:-1]), 'BRATS_Cropped_Normalized_Unprocessed.h5') +new_hdf5 = h5py.File(filename, mode='w') +logger.info('created new database at {}'.format(filename)) + +# create a folder group to hold the datasets. The schema is similar to original one except for the name of the folder +# group +new_group_preprocessed = new_hdf5.create_group('preprocessed') + +# create similar datasets in this file. +new_group_preprocessed.create_dataset("training_data_hgg", config['train_shape_hgg_crop'], np.float32) +new_group_preprocessed.create_dataset("training_data_hgg_pat_name", (config['train_shape_hgg_crop'][0],), dtype="S100") +new_group_preprocessed.create_dataset("training_data_segmasks_hgg", config['train_segmasks_shape_hgg_crop'], np.int16) + +new_group_preprocessed.create_dataset("training_data_lgg", config['train_shape_lgg_crop'], np.float32) +new_group_preprocessed.create_dataset("training_data_lgg_pat_name", (config['train_shape_lgg_crop'][0],), dtype="S100") +new_group_preprocessed.create_dataset("training_data_segmasks_lgg", config['train_segmasks_shape_lgg_crop'], np.int16) + +new_group_preprocessed.create_dataset("validation_data", config['val_shape_crop'], np.float32) +new_group_preprocessed.create_dataset("validation_data_pat_name", (config['val_shape_crop'][0],), dtype="S100") +# ==================================================================================== + +# just copy the patient names directly +new_group_preprocessed['training_data_hgg_pat_name'][:] = hdf5_file['training_data_hgg_pat_name'][:] +new_group_preprocessed['training_data_lgg_pat_name'][:] = hdf5_file['training_data_lgg_pat_name'][:] +new_group_preprocessed['validation_data_pat_name'][:] = hdf5_file['validation_data_pat_name'][:] + +# ------------------------------------------------------------------------------------ +# start cropping process and standardization process +# ------------------------------------------------------------------------------------ + +# get the file where mean/var values are stored +# TODO: Use the config file global path, not this one. + +saveMeanVarFilename = os.sep.join(config['hdf5_filepath_prefix'].split(os.sep)[0:-1]) +logging.info('starting the Cropping/Normalization process.') + +# only run thecropping steps on these datasets +run_on_list = ['training_data_segmasks_hgg', 'training_data_hgg', 'training_data_lgg', 'training_data_segmasks_lgg', 'validation_data'] + +#only run the mean/var normalization on these datasets +std_list = ['training_data_hgg', 'training_data_lgg'] +for run_on in run_on_list: + + # we define the final shape after cropping in the config file to make it easy to access. More information available in + # checkLargestCropSize.ipynb notebook. + if run_on == 'training_data_hgg': + im_np = np.empty(config['train_shape_hgg_crop']) + elif run_on == 'training_data_lgg': + im_np = np.empty(config['train_shape_lgg_crop']) + elif run_on == 'validation_data': + im_np = np.empty(config['val_shape_crop']) + + logger.info('Running on {}'.format(run_on)) + for i in range(0, hdf5_file[run_on].shape[0]): + # cropping operation + logger.debug('{}:- Patient {}'.format(run_on, i+1)) + im = hdf5_file[run_on][i] + m = config['cropping_coords'] + if 'segmasks' in run_on: + # there are no channels for segmasks + k = im[m[0]:m[1], m[2]:m[3], m[4]:m[5]] + else: + k = im[:, m[0]:m[1], m[2]:m[3], m[4]:m[5]] + + if run_on in std_list: + # save the image to this numpy array + im_np[i] = k + new_group_preprocessed[run_on][i] = k + # find mean and standard deviation, and apply to data. Also write the mean/std values to disk + if run_on in std_list: + logger.info('The dataset {} needs standardization'.format(run_on)) + _tmp, vals = standardize(im_np, findMeanVarOnly=True, saveDump=saveMeanVarFilename + run_on + '_mean_std.p') + logging.info('Calculated normalization values for {}:\n{}'.format(run_on, vals)) + del im_np + +# ==================================================================================== + +hdf5_file_main.close() +new_hdf5.close() + + diff --git a/modules/pytorch_msssim/__init__.py b/modules/pytorch_msssim/__init__.py new file mode 100755 index 0000000..3032b46 --- /dev/null +++ b/modules/pytorch_msssim/__init__.py @@ -0,0 +1,133 @@ +import torch +import torch.nn.functional as F +from math import exp +import numpy as np + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + + +def create_window(window_size, channel=1): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() + return window + + +def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + if torch.max(img1) > 128: + max_val = 255 + else: + max_val = 1 + + if torch.min(img1) < -0.5: + min_val = -1 + else: + min_val = 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, channel, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window(real_size, channel=channel).to(img1.device) + + mu1 = F.conv2d(img1, window, padding=padd, groups=channel) + mu2 = F.conv2d(img2, window, padding=padd, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + if size_average: + ret = ssim_map.mean() + else: + ret = ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret + + +def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): + device = img1.device + weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) + levels = weights.size()[0] + mssim = [] + mcs = [] + for _ in range(levels): + sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) + mssim.append(sim) + mcs.append(cs) + + img1 = F.avg_pool2d(img1, (2, 2)) + img2 = F.avg_pool2d(img2, (2, 2)) + + mssim = torch.stack(mssim) + mcs = torch.stack(mcs) + + # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) + if normalize: + mssim = (mssim + 1) / 2 + mcs = (mcs + 1) / 2 + + pow1 = mcs ** weights + pow2 = mssim ** weights + # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ + output = torch.prod(pow1[:-1] * pow2[-1]) + return output + + +# Classes to re-use window +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, val_range=None): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.val_range = val_range + + # Assume 1 channel for SSIM + self.channel = 1 + self.window = create_window(window_size) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.dtype == img1.dtype: + window = self.window + else: + window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) + self.window = window + self.channel = channel + + return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) + +class MSSSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, channel=3): + super(MSSSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = channel + + def forward(self, img1, img2): + # TODO: store window between calls if possible + return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) diff --git a/prep_BRATS2015/cedar_gen_hdf5.sh b/prep_BRATS2015/cedar_gen_hdf5.sh new file mode 100755 index 0000000..764d53a --- /dev/null +++ b/prep_BRATS2015/cedar_gen_hdf5.sh @@ -0,0 +1,14 @@ +#!/bin/bash +#SBATCH --nodes=1 # number of nodes +#SBATCH --ntasks=1 # number of MPI processes +#SBATCH --cpus-per-task=2 # 24 cores on cedar nodes +#SBATCH --account=rrg-hamarneh +#SBATCH --mem=16G # give all memory you have in the node +#SBATCH --time=3-05:00 # time (DD-HH:MM) +#SBATCH --job-name=GenerateHDF5File +#SBATCH --output=GenerateHDF5File.out +#SBATCH --mail-user=asa224@sfu.ca +#SBATCH --mail-type=ALL + +# run the command +~/.virtualenvs/mm_synthesis/bin/python create_hdf5_file.py diff --git a/prep_BRATS2015/create_hdf5_file.py b/prep_BRATS2015/create_hdf5_file.py new file mode 100755 index 0000000..d148c4f --- /dev/null +++ b/prep_BRATS2015/create_hdf5_file.py @@ -0,0 +1,170 @@ +""" +========================================================== + Prepare BRATS 2015 Data +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: This file is used to generate an HDF5 dataset, + which is easy to load and manipulate compared + to working directly with raw data all the time. + Loading and working with HDF5 files is much + faster and efficient due to its asynchronous loading + system. + + The HDF5 file generated can be hosted on a remote server + (like CEDAR) and then accessed over SSHFS. Practically, + this is very effective and does not hinder the performance + by a large margin. + + This script generates a simple HDF5 data store, + which contains the original numpy arrays of the + data store. To perform any preprocessing, implement + the preprocessData()function in dataloader.py to + work directly on nibabel objects, instead of + numpy objects. +LICENCE: Proprietary for now. +""" + +import os +import glob +from modules import dataloader +import logging +import numpy as np +import h5py +import sys +# sys.path.append('../') +from modules.configfile import config + +logging.basicConfig(level=logging.INFO) +try: + logger = logging.getLogger(__file__.split('/')[-1]) +except: + logger = logging.getLogger(__name__) + +try: + logger.info('Dataloader file in use: {}'.format(dataloader.__file__)) +except: + logger.info('Cannot determine the dataloader class being used, check carefully before proceeding') + + +# whether or not to preprocess the data before creating the HDF5 file? Check the preprocess function in dataloader to +# know exactly what preprocessing is being performed. +PREPROCESS_DATA = True + +logger.info('[IMPORTANT] This will create a new HDF5 file in SAFE MODE. It will NOT OVERWRITE A PREVIOUS HDF5 FILE ' + 'IF ITS PRESENT') +def createHDF5File(config): + """ + Function to create a new HDF5 File to hold the BRATS 2015 data. The function will fail if there's already a file + present with the same name (SAFE OPERATION) + + :param config: The config variable defined in configfile.py + :return: hdf5_file object + """ + + # w- mode fails when there is a file already. + hdf5_file = h5py.File(config['hdf5_filepath_prefix'], mode='w') + + # create a new parent directory to hold the data inside it + grp = hdf5_file.create_group("original_data") + + # the dataset is int16 originally, checked using nibabel, however we create float32 containers to make the dataset + # compatible with further preprocessing. + # HGG Data + + grp.create_dataset("training_data_hgg", config['train_shape_hgg'], np.float32) + grp.create_dataset("training_data_hgg_pat_name", (config['train_shape_hgg'][0],), dtype="S100") + grp.create_dataset("training_data_segmasks_hgg", config['train_segmasks_shape_hgg'], np.int16) + + # LGG Data + grp.create_dataset("training_data_lgg", config['train_shape_lgg'], np.float32) + grp.create_dataset("training_data_lgg_pat_name", (config['train_shape_lgg'][0],), dtype="S100") + grp.create_dataset("training_data_segmasks_lgg", config['train_segmasks_shape_lgg'], np.int16) + + grp.create_dataset("testing_hgglgg_patients", config['testing_hgglgg_patients_shape'], np.uint16) + + grp.create_dataset("testing_hgglgg_patients_pat_name", (config['testing_hgglgg_patients_shape'][0],), dtype="S100") + + return hdf5_file + +def main(): + hdf5_file_main = createHDF5File(config) + # hdf5_file_main = h5py.File(config['hdf5_filepath_prefix'], mode='a') + # Go inside the "original_data" parent directory. + # we need to create the validation data dataset again since the shape has changed. + hdf5_file = hdf5_file_main['original_data'] + contents = glob.glob(os.path.join(config['data_dir_prefix'], '*')) + + # for debugging, making sure Training set is loaded first not Testing, since that is tested already. + contents.reverse() + for dataset_splits in contents: # Challenge/LeaderBoard data? + if os.path.isdir(dataset_splits): # make sure its a directory + for grade_type in glob.glob(os.path.join(dataset_splits, '*')): + # there may be other files in there (like the survival data), ignore them. + if os.path.isdir(grade_type): + count = 0 + if 'Testing' in dataset_splits: + logger.info('currently loading Testing -> {} data.'.format(os.path.basename(grade_type))) + ty = 'Testing' + + for images, pats in dataloader.loadDataGenerator(grade_type, + batch_size=config['batch_size'], loadSurvival=False, + csvFilePath=None, loadSeg=False, + preprocess=PREPROCESS_DATA, dataset='2013'): + logger.info('loading patient {} from {}'.format(count, grade_type)) + if 'HGG_LGG' in grade_type: + if ty == 'Testing': + main_data_name = 'testing_hgglgg_patients' + main_data_pat_name = 'testing_hgglgg_patients_pat_name' + + hdf5_file[main_data_name][count:count+config['batch_size'],...] = images + t = 0 + for i in range(count, count + config['batch_size']): + hdf5_file[main_data_pat_name][i] = pats[t].split('.')[-2] + t += 1 + + logger.info('loaded {} patient(s) from {}'.format(count + config['batch_size'], grade_type)) + count += config['batch_size'] + else: + # TRAINING data handler + if os.path.isdir(dataset_splits) and 'Training' in dataset_splits: + for grade_type in glob.glob(os.path.join(dataset_splits, '*')): + # there may be other files in there (like the survival data), ignore them. + if os.path.isdir(grade_type): + count = 0 + logger.info('currently loading Training data.') + for images, segmasks, pats in dataloader.loadDataGenerator(grade_type, + batch_size=config['batch_size'], loadSurvival=False, + csvFilePath=None, loadSeg=True, + preprocess=PREPROCESS_DATA): + logger.info('loading patient {} from {}'.format(count, grade_type)) + if 'HGG' in grade_type: + hdf5_file['training_data_hgg'][count:count+config['batch_size'],...] = images + hdf5_file['training_data_segmasks_hgg'][count:count+config['batch_size'], ...] = segmasks + t = 0 + for i in range(count, count + config['batch_size']): + hdf5_file['training_data_hgg_pat_name'][i] = pats[t].split('/')[-1] + t += 1 + elif 'LGG' in grade_type: + hdf5_file['training_data_lgg'][count:count+config['batch_size'], ...] = images + hdf5_file['training_data_segmasks_lgg'][count:count+config['batch_size'], ...] = segmasks + t = 0 + for i in range(count, count + config['batch_size']): + hdf5_file['training_data_lgg_pat_name'][i] = pats[t].split('/')[-1] + t += 1 + + logger.info('loaded {} patient(s) from {}'.format(count + config['batch_size'], grade_type)) + count += config['batch_size'] + # close the HDF5 file + # close the HDF5 file + hdf5_file_main.close() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/prep_BRATS2015/modules/__init__.py b/prep_BRATS2015/modules/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/prep_BRATS2015/modules/configfile.py b/prep_BRATS2015/modules/configfile.py new file mode 100755 index 0000000..e820073 --- /dev/null +++ b/prep_BRATS2015/modules/configfile.py @@ -0,0 +1,102 @@ +""" +========================================================== + Config File to set Parameters +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: This file is solely created for the purpose of + managing parameters in a global setting. All the + database loading and generation parameters reside + here, and are inherited by create_hdf5_file.py + to generate the HDF5 data store. + + The parameters are also used in the test_database.py + script to test the created database. +LICENCE: Proprietary for now. +""" +import os +import platform + +# WE CAN USE THIS TO CHANGE IMAGE_DATA_FORMAT on the fly +# keras.backend.common._IMAGE_DATA_FORMAT='channels_first' + +# to make the code portable even on cedar,you need to add conditions here +node_name = platform.node() +if node_name == 'XPS15': + # this is my laptop, so the cedar-rm directory is at a different place + mount_path_prefix = '/home/anmol/mounts/cedar-rm/' +elif 'computecanada' in node_name: # we're in compute canada, maybe in an interactive node, or a scheduler node. + # mount_path_prefix = '/home/asa224/' # home directory + mount_path_prefix = '' # home directory +else: + # this is probably my workstation or TS server + # mount_path_prefix = '/local-scratch/asa224_new/mounts/cedar-rm/' + mount_path_prefix = '' + +config = {} +# set the data directory and output hdf5 file path. +# data_dir is the top level path containing both training and validation sets of the brats dataset. +# config['data_dir_prefix'] = os.path.join(mount_path_prefix, 'rrg_proj_dir/scratch_files_globus/Datasets/BRATS2015/') # this should +config['data_dir_prefix'] = os.path.join(mount_path_prefix, '/local-scratch/anmol/data/BRATS2015') # this should# be top level path +config['hdf5_filepath_prefix'] = os.path.join(mount_path_prefix, '/local-scratch/anmol/data/BRATS2015/HDF5_Datasets/BRATS2015.h5') # top level path + +config['spatial_size_for_training'] = (240, 240) # If any preprocessing is done, then this needs to change. This is the shape of data that you want to train with. If you are changing this that means you did some preprocessing. +config['num_slices'] = 155 # number of slices in input data. THIS SHOULD CHANGE AS WELL WHEN PERFORMING PREPROCESSING +config['volume_size'] = list(config['spatial_size_for_training']) + [config['num_slices']] +config['seed'] = 1338 +config['data_order'] = 'th' # what order should the indices be to store in hdf5 file +config['train_hgg_patients'] = 220 # number of HGG patients in training +config['train_lgg_patients'] = 54 # number of LGG patients in training +config['testing_hgglgg_patients'] = 110 # number of HGG patients in training + +config['cropping_coords'] = [29, 223, 41, 196, 0, 148] # coordinates used to crop the volumes, this is generated using the notebook checkLargestCropSize.ipynb +config['size_after_cropping'] = [194, 155, 148] # set this if you set the above variable. Calculate this using the notebook again. + +config['batch_size'] = 1 # how many images to load at once in the generator + +# check the order of data and chose proper data shape to save images +if config['data_order'] == 'th': + + + config['train_shape_hgg'] = ( + config['train_hgg_patients'], 4, config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], + config['num_slices']) + config['train_shape_lgg'] = ( + config['train_lgg_patients'], 4, config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], + config['num_slices']) + config['train_segmasks_shape_hgg'] = ( + config['train_hgg_patients'], config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], + config['num_slices']) + config['train_segmasks_shape_lgg'] = ( + config['train_lgg_patients'], config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], + config['num_slices']) + config['testing_hgglgg_patients_shape'] = (config['testing_hgglgg_patients'], 4, + config['spatial_size_for_training'][0], + config['spatial_size_for_training'][1], + config['num_slices']) + + + config['train_shape_hgg_crop'] = ( + config['train_hgg_patients'], 4, config['size_after_cropping'][0], config['size_after_cropping'][1], + config['size_after_cropping'][2]) + config['train_shape_lgg_crop'] = ( + config['train_lgg_patients'], 4, config['size_after_cropping'][0], config['size_after_cropping'][1], + config['size_after_cropping'][2]) + config['train_segmasks_shape_hgg_crop'] = ( + config['train_hgg_patients'], config['size_after_cropping'][0], config['size_after_cropping'][1], + config['size_after_cropping'][2]) + config['train_segmasks_shape_lgg_crop'] = ( + config['train_lgg_patients'], config['size_after_cropping'][0], config['size_after_cropping'][1], + config['size_after_cropping'][2]) + config['testing_hgglgg_patients_shape_crop'] = (config['testing_hgglgg_patients'], 4, + config['size_after_cropping'][0], + config['size_after_cropping'][1], + config['size_after_cropping'][2]) + diff --git a/prep_BRATS2015/modules/dataloader.py b/prep_BRATS2015/modules/dataloader.py new file mode 100755 index 0000000..fcae5af --- /dev/null +++ b/prep_BRATS2015/modules/dataloader.py @@ -0,0 +1,292 @@ +""" +========================================================== + Load BRATS 2017 Data +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: The script has multiple functions to load, + preprocess, and standardize the BRATS + 2017 dataset, along with its survival annotations. + Main function is the loadDataGenerator which loads + the data using a generator, and doesn't hog memory. + + The loadDataGenerator is capable of applying + arbitrary preprocessing steps to the data. This can be + achieved by implementing the function preprocessData. +LICENCE: Proprietary for now. +""" + +from __future__ import print_function +import glob as glob +import numpy as np +import pickle +import sys as sys +# from pandas import read_csv +import os +import logging +from configfile import config +import SimpleITK as sitk + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +def preprocessData(img_obj, process=False): + """ + Perform preprocessing on the original nibabel object. + Use this function to: + 1) Resize/Resample the 3D Volume + 2) Crop the brain region + 3) Do (2) then (1). + + When you do preprocessing, especially something that + changes the spatial size of the volume, make sure you + update config['spatial_size_for_training'] = (240, 240) + value in the config file. + + :param img_obj: + :param process: + :return: + """ + if process == False: + return img_obj + else: + maskImage = sitk.OtsuThreshold(img_obj, 0, 1, 200) + image = sitk.Cast(img_obj, sitk.sitkFloat32) + corrector = sitk.N4BiasFieldCorrectionImageFilter() + numberFilltingLevels = 4 + corrector.SetMaximumNumberOfIterations([4] * numberFilltingLevels) + output = corrector.Execute(image, maskImage) + return output + +def resize_mha_volume(image, spacing=[1,1,1], size=[240,240,155]): + + # Create the reference image + reference_origin = image.GetOrigin() + reference_direction = np.identity(image.GetDimension()).flatten() + reference_image = sitk.Image(size, image.GetPixelIDValue()) + reference_image.SetOrigin(reference_origin) + reference_image.SetSpacing(spacing) + reference_image.SetDirection(reference_direction) + + # Transform which maps from the reference_image to the current image (output-to-input) + transform = sitk.AffineTransform(image.GetDimension()) + transform.SetMatrix(image.GetDirection()) + transform.SetTranslation(np.array(image.GetOrigin()) - reference_origin) + + # Modify the transformation to align the centers of the original and reference image + reference_center = np.array( + reference_image.TransformContinuousIndexToPhysicalPoint(np.array(reference_image.GetSize()) / 2.0)) + centering_transform = sitk.TranslationTransform(image.GetDimension()) + img_center = np.array(image.TransformContinuousIndexToPhysicalPoint(np.array(image.GetSize()) / 2.0)) + centering_transform.SetOffset(np.array(transform.GetInverse().TransformPoint(img_center) - reference_center)) + centered_transform = sitk.Transform(transform) + centered_transform.AddTransform(centering_transform) + + # Using the linear interpolator + image_rs = sitk.Resample(image, reference_image, transform, sitk.sitkLinear, 0.0) + return image_rs + + +def loadDataGenerator(data_dir, batch_size=1, preprocess=False, loadSurvival=False, + csvFilePath=None, loadSeg=True, dataset=2018): + """ + Main function to load BRATS 2017 dataset. + + :param data_dir: path to the folder where patient data resides, needs individual paths for HGG and LGG + :param batch_size: size of batch to load (default=1) + :param loadSurvival: load survival data (True/False) (default=False) + :param csvFilePath: If loadSurvival is True, provide path to survival data (default=False) + :param loadSeg: load segmentations (True/False) (default=True) + :return: + """ + + patID = 0 # used to keep count of how many patients loaded already. + num_sequences = 4 # number of sequences in the data. BRATS has 4. + num_slices = config['num_slices'] + running_pats = [] + out_shape = config['spatial_size_for_training'] # shape of the training data + + # create placeholders, currently only supports theano type convention (num_eg, channels, x, y, z) + images = np.empty((batch_size, num_sequences, out_shape[0], out_shape[1], num_slices)).astype(np.int16) + labels = np.empty((batch_size, 1)).astype(np.int16) + + if loadSeg == True: + # create placeholder for the segmentation mask + seg_masks = np.empty((batch_size, out_shape[0], out_shape[1], num_slices)).astype(np.int16) + + csv_flag = 0 + + batch_id = 1 # counter for batches loaded + logger.info('starting to load images..') + for patient in glob.glob(data_dir + '/*'): + if os.path.isdir(patient): + logger.debug('{} is a directory.'.format(patient)) + + # this hacky piece of code is to reorder the filenames, so that segmentation file is always at the end. + # get all the filepaths + sequence_folders = glob.glob(patient + '/*') + + vsd_id = [] + for curr_seq in sequence_folders: # get the filepath of the image (nii.gz) + imagefile = [x for x in glob.glob(os.path.join(curr_seq, '*')) if '.txt' not in x][0] + # save the name of the patient + if '.OT.' in imagefile: + if loadSeg == True: + logger.debug('loading segmentation for this patient..') + + # open using SimpleITK + # SimpleITK would allow me to add number of preprocessing steps that are well defined and + # implemented in SITK for their own object type. We can leverage those functions if we preserve + # the image object. + + img_obj = sitk.ReadImage(imagefile) + pix_data = sitk.GetArrayViewFromImage(img_obj) + + # check Practice - SimpleiTK.ipynb notebook for more info on why this swapaxes operation is req + pix_data_swapped = np.swapaxes(pix_data, 0, 1) + pix_data_swapped = np.swapaxes(pix_data_swapped, 1, 2) + + seg_masks[patID, :, :, :] = pix_data_swapped + else: + continue + else: + # this is to ensure that each channel stays at the same place + if 'isles' in dataset.lower(): + if 'T1.' in imagefile: + i = 0 + seq_name = 't1' + elif 'T2.' in imagefile: + i = 1 + seq_name = 't2' + elif 'DWI.' in imagefile: + i = 2 + seq_name = 'dwi' + elif 'Flair.' in imagefile: + i = 3 + seq_name = 'flair' + vsd_id.append(os.path.basename(imagefile)) + else: + if 'T1.' in imagefile: + i = 0 + seq_name = 't1' + elif 'T2.' in imagefile: + i = 1 + seq_name = 't2' + elif 'T1c.' in imagefile: + i = 2 + seq_name = 't1c' + elif 'Flair.' in imagefile: + i = 3 + seq_name = 'flair' + vsd_id.append(os.path.basename(imagefile)) + + img_obj = sitk.ReadImage(imagefile) + if preprocess == True: + logger.debug('performing N4ITK Bias Field Correction on {} modality'.format(seq_name)) + + if 'isles' not in dataset.lower(): + img_obj_res = resize_mha_volume(img_obj, spacing=[1, 1, 1], size=[240, 240, 155]) + + img_obj_res = preprocessData(img_obj_res, process=preprocess) + + pix_data = sitk.GetArrayViewFromImage(img_obj_res) + + pix_data_swapped = np.swapaxes(pix_data, 0, 1) + pix_data_swapped = np.swapaxes(pix_data_swapped, 1, 2) + + images[patID, i, :, :, :] = pix_data_swapped + + patID += 1 + + if batch_id % batch_size == 0: + patID = 0 + if loadSeg == True: + yield images, seg_masks, vsd_id + elif loadSeg == False: + yield images, vsd_id + + vsd_id = [] + + batch_id += 1 + + + +def apply_mean_std(im, mean_var): + """ + Supercedes the standardize function. Takes the mean/var file generated during preprocessed data generation and + applies the normalization step to the patch. + :param im: patch of size (num_egs, channels, x, y, z) or (channels, x, y, z) + :param mean_var: dictionary containing mean/var value calculated in preprocess.py + :return: normalized patch + """ + + # expects a dictionary of means and VARIANCES, NOT STD + for m in range(0, 4): + if len(np.shape(im)) > 4: + im[:, m, ...] = (im[:, m, ...] - mean_var['mn'][m]) / np.sqrt(mean_var['var'][m]) + else: + im[m, ...] = (im[m, ...] - mean_var['mn'][m]) / np.sqrt(mean_var['var'][m]) + + return im + + +def standardize(images, findMeanVarOnly=True, saveDump=None, applyToTest=None): + """ + This function standardizes the input data to zero mean and unit variance. It is capable of calculating the + mean and std values from the input data, or can also apply user specified mean/std values to the images. + + :param images: numpy ndarray of shape (num_qg, channels, x, y, z) to apply mean/std normalization to + :param findMeanVarOnly: only find the mean and variance of the input data, do not normalize + :param saveDump: if True, saves the calculated mean/variance values to the disk in pickle form + :param applyToTest: apply user specified mean/var values to given images. checkLargestCropSize.ipynb has more info + :return: standardized images, and vals (if mean/val was calculated by the function + """ + + # takes a dictionary + if applyToTest != None: + logger.info('Applying to test data using provided values') + images = apply_mean_std(images, applyToTest) + return images + + logger.info('Calculating mean value..') + vals = { + 'mn': [], + 'var': [] + } + for i in range(4): + vals['mn'].append(np.mean(images[:, i, :, :, :])) + + logger.info('Calculating variance..') + for i in range(4): + vals['var'].append(np.var(images[:, i, :, :, :])) + + if findMeanVarOnly == False: + logger.info('Starting standardization process..') + + for i in range(4): + images[:, i, :, :, :] = ((images[:, i, :, :, :] - vals['mn'][i]) / float(vals['var'][i])) + + logger.info('Data standardized!') + + if saveDump != None: + logger.info('Dumping mean and var values to disk..') + pickle.dump(vals, open(saveDump, 'wb')) + logger.info('Done!') + + return images, vals + + +if __name__ == "__main__": + """ + Only for testing purpose, DO NOT ATTEMPT TO RUN THIS SCRIPT. ONLY IMPORT AS MODULE + """ + data_dir = '/local-scratch/cedar-rm/scratch/asa224/Datasets/BRATS2017/MICCAI_BraTS17_Data_Training/HGG/' + images, segmasks = loadDataGenerator(data_dir, batch_size=2, loadSurvival=False, + csvFilePath=None, loadSeg=True) diff --git a/prep_BRATS2015/modules/mischelpers.py b/prep_BRATS2015/modules/mischelpers.py new file mode 100755 index 0000000..d411af5 --- /dev/null +++ b/prep_BRATS2015/modules/mischelpers.py @@ -0,0 +1,239 @@ +""" +========================================================== + Misc Helper Classes/Functions +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: The module has various helper classes/functions + that can be used throughout the pipeline, and + don't fit exactly in either data loading or + visualization operations. +LICENCE: Proprietary for now. +""" + +import numpy as np +from configfile import config +import h5py +from nilearn._utils import check_niimg +from nilearn.image import new_img_like +from nilearn.image import reorder_img, resample_img + + + +class Rect3D: + """ + Class to encapsulate the Rectangle coordinates. This prevents future + issues when the coordinates need to be standardized. + """ + def __init__(self, coord_list): + if len(coord_list) < 6: + print('Coordinate list shape is incorrect, creating empty object!') + coord_list = [0, 0, 0, 0, 0, 0] + self.empty = True + else: + self.empty = False + + self.rmin = coord_list[0] + self.rmax = coord_list[1] + self.cmin = coord_list[2] + self.cmax = coord_list[3] + self.zmin = coord_list[4] + self.zmax = coord_list[5] + self.list_view = coord_list + + def show(self): + return self.list_view + +class Rect2D: + """ + Class to encapsulate the Rectangle coordinates. This prevents future + issues when the coordinates need to be standardized. + """ + def __init__(self, coord_list): + if len(coord_list) < 4: + print('Coordinate list shape is incorrect, creating empty object!') + coord_list = [0, 0, 0, 0] + self.empty = True + else: + self.empty = False + + self.rmin = coord_list[0] + self.rmax = coord_list[1] + self.cmin = coord_list[2] + self.cmax = coord_list[3] + self.list_view = coord_list + + def show(self): + return self.list_view + +def bbox_3D(img, tol=0.5): + """ + TOL = argument used when dark regions are >0 + (usually after some preprocessing, like + rescaling). + """ + r, c, z = np.where(img > tol) + rmin, rmax, cmin, cmax, zmin, zmax = np.min(r), np.max(r), np.min(c), np.max(c), np.min(z), np.max(z) + rect_obj = Rect3D([rmin, rmax, cmin, cmax, zmin, zmax]) + return rect_obj + +def bbox_2D(img, tol=0.5): + """ + TOL = argument used when dark regions are >0 + (usually after some preprocessing, like + rescaling). + """ + r, c = np.where(img > tol) + if r.size == 0 or c.size == 0: + return Rect2D([-1, -1, -1, -1]) + else: + rmin, rmax, cmin, cmax = np.min(r), np.max(r), np.min(c), np.max(c) + rect_obj = Rect2D([rmin, rmax, cmin, cmax]) + return rect_obj + +def open_hdf5(filepath=None, mode='r'): + if filepath == None: + filepath = config['hdf5_filepath_prefix'] + + return h5py.File(filepath, mode=mode) + +def get_data_splits_bbox(hdf5_filepath, train_start=0, train_end=190, test_start=190, test_end=None): + """ + + :param hdf5_filepath: + :param train_start: Start index to slice to get the training data. For 10 instances starting from 0, choose 0. + :param train_end: End index for training. Remember this index is 'exclusive', so if you want 10 instances, choose this as 10 + :param test_start: Start index to slice to get the testing data. Same comment as above. + :param test_end: End index for testing. + :return: Keras instances to slice into x_train, y_train, x_test, y_test. + """ + import keras + filepath = config['hdf5_filepath_prefix'] if hdf5_filepath is None else hdf5_filepath + + x_train = keras.utils.io_utils.HDF5Matrix(filepath, "training_data_hgg", start=train_start, end=train_end, + normalizer=None) + y_train = keras.utils.io_utils.HDF5Matrix(filepath, "bounding_box_hgg", start=train_start, end=train_end, + normalizer=None) + + x_test = keras.utils.io_utils.HDF5Matrix(filepath, "training_data_hgg", start=test_start, end=test_end, + normalizer=None) + y_test = keras.utils.io_utils.HDF5Matrix(filepath, "bounding_box_hgg", start=test_start, end=test_end, + normalizer=None) + + return x_train, y_train, x_test, y_test + +def createDense(bbox, im): + box = np.zeros(im.shape) + box[bbox[0]:bbox[1], bbox[2]:bbox[3], bbox[4]:bbox[5]] = 1 + return box + + +def _crop_img_to(img, slices, copy=True): + """Crops image to a smaller size + Crop img to size indicated by slices and adjust affine + accordingly + Parameters + ---------- + img: Niimg-like object + See http://nilearn.github.io/manipulating_images/input_output.html + Img to be cropped. If slices has less entries than img + has dimensions, the slices will be applied to the first len(slices) + dimensions + slices: list of slices + Defines the range of the crop. + E.g. [slice(20, 200), slice(40, 150), slice(0, 100)] + defines a 3D cube + copy: boolean + Specifies whether cropped data is to be copied or not. + Default: True + Returns + ------- + cropped_img: Niimg-like object + See http://nilearn.github.io/manipulating_images/input_output.html + Cropped version of the input image + """ + + img = check_niimg(img) + + data = img.get_data() + affine = img.affine + + cropped_data = data[slices] + if copy: + cropped_data = cropped_data.copy() + + linear_part = affine[:3, :3] + old_origin = affine[:3, 3] + new_origin_voxel = np.array([s.start for s in slices]) + new_origin = old_origin + linear_part.dot(new_origin_voxel) + + new_affine = np.eye(4) + new_affine[:3, :3] = linear_part + new_affine[:3, 3] = new_origin + + return new_img_like(img, cropped_data, new_affine) + + +def crop_img_custom(img, slices=None, rtol=1e-8, copy=True): + """Crops img as much as possible + Will crop img, removing as many zero entries as possible + without touching non-zero entries. Will leave one voxel of + zero padding around the obtained non-zero area in order to + avoid sampling issues later on. + Parameters + ---------- + img: Niimg-like object + See http://nilearn.github.io/manipulating_images/input_output.html + img to be cropped. + rtol: float + relative tolerance (with respect to maximal absolute + value of the image), under which values are considered + negligeable and thus croppable. + copy: boolean + Specifies whether cropped data is copied or not. + Returns + ------- + cropped_img: image + Cropped version of the input image + """ + + img = check_niimg(img) + data = img.get_data() + + if slices is not None: + return _crop_img_to(img, slices, copy=copy), slices + else: + infinity_norm = max(-data.min(), data.max()) + passes_threshold = np.logical_or(data < -rtol * infinity_norm, + data > rtol * infinity_norm) + + if data.ndim == 4: + passes_threshold = np.any(passes_threshold, axis=-1) + coords = np.array(np.where(passes_threshold)) + start = coords.min(axis=1) + end = coords.max(axis=1) + 1 + + # pad with one voxel to avoid resampling problems + start = np.maximum(start - 1, 0) + end = np.minimum(end + 1, data.shape[:3]) + + slices = [slice(s, e) for s, e in zip(start, end)] + + return _crop_img_to(img, slices, copy=copy), slices + + +def resize(image, new_shape, interpolation="continuous"): + input_shape = np.asarray(image.shape, dtype=np.float16) + ras_image = reorder_img(image, resample=interpolation) + output_shape = np.asarray(new_shape) + new_spacing = input_shape/output_shape + new_affine = np.copy(ras_image.affine) + new_affine[:3, :3] = ras_image.affine[:3, :3] * np.diag(new_spacing) + return resample_img(ras_image, target_affine=new_affine, target_shape=output_shape, interpolation=interpolation, clip=True) \ No newline at end of file diff --git a/prep_BRATS2015/prepare_data_for_synthesis.py b/prep_BRATS2015/prepare_data_for_synthesis.py new file mode 100755 index 0000000..b720575 --- /dev/null +++ b/prep_BRATS2015/prepare_data_for_synthesis.py @@ -0,0 +1,56 @@ +""" +======================================================================== + Prepare BRATS 2013 Validation data for MM_Synthesis +======================================================================== +AUTHOR: Anmol Sharma +Description: This file prepares the BRATS 2013 validation set, which is + divided into challenge and leaderboard, and further divided + into the grade types, into NPZ format that the mm_synthesis + module expects. + + This is similar to how BRATS 2018 data was prepared. Check + the original mm_synthesis codebase for more info. + +""" + +import os, sys, h5py +sys.path.append('..') +from modules.configfile import config, mount_path_prefix +import numpy as np + +def saveNPZ(data, save_path, pat_names): + np.save(open(os.path.join(save_path, 'pat_names_validation.npz'), 'wb'), pat_names) + t1 = data[:,0,...] + t1 = np.swapaxes(t1, 3, 2) + t1 = np.swapaxes(t1, 2, 1) + np.save(open(save_path + 'T1.npz', 'wb'), t1) + del t1 + + t2 = data[:,1,...] + t2 = np.swapaxes(t2, 3, 2) + t2 = np.swapaxes(t2, 2, 1) + np.save(open(save_path + 'T2.npz', 'wb'), t2) + del t2 + + t1ce = data[:,2,...] + t1ce = np.swapaxes(t1ce, 3, 2) + t1ce = np.swapaxes(t1ce, 2, 1) + np.save(open(save_path + 'T1CE.npz', 'wb'), t1ce) + del t1ce + + t2flair = data[:,3,...] + t2flair = np.swapaxes(t2flair, 3, 2) + t2flair = np.swapaxes(t2flair, 2, 1) + np.save(open(save_path + 'T2FLAIR.npz', 'wb'), t2flair) + del t2flair + + print('Done!') + +hf = h5py.File(config['hdf5_filepath_prefix'], 'r') +hf = hf['validation_data'] + +# SAVE CHALLENGE HGG Data +save_path_c_hg = os.path.join(mount_path_prefix, "scratch/asa224/Datasets/BRATS2015/mm_synthesis/validation_data/testing_hgglgg_patients/HGG_LGG/") +pat_names_c_hg = hf['testing_hgglgg_patients_pat_name'] +challenge_data_hgg = hf['testing_hgglgg_patients'] +saveNPZ(challenge_data_hgg, save_path_c_hg, pat_names_c_hg) diff --git a/prep_BRATS2015/preprocess.py b/prep_BRATS2015/preprocess.py new file mode 100755 index 0000000..cc8e5c3 --- /dev/null +++ b/prep_BRATS2015/preprocess.py @@ -0,0 +1,152 @@ +""" +========================================================== + Preprocess BRATS Data +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: This file uses the previously generated data + (using create_hdf5_file.py) and generates a + new file with similar structure, but after + applying a couple of preprocessing steps. + More specifically, the script applies the + following operations on the data: + + 1) Crop out the dark margins in the scans + to only leave a concise brain area. For + this a generous estimate of bounding box + generated from the whole dataset is used. + For more information, see checkLargestCropSize + notebook. + + The code DOES NOT APPLY MEAN/VAR normalization, + but simply calculates the values and saves on disk. + Check lines 140-143 for more information. + + The saved mean/var files are to be used before + the training process. + +LICENCE: Proprietary for now. +""" + +import h5py +from modules.configfile import config +import numpy as np +import logging +from modules.dataloader import standardize +import os + +logging.basicConfig(level=logging.DEBUG) + +try: + logger = logging.getLogger(__file__.split('/')[-1]) +except: + logger = logging.getLogger(__name__) + +logger.warning('[IMPORTANT] The code DOES NOT APPLY mean/var normalization, rather it calculates it and saves to disk') +# ------------------------------------------------------------------------------------ +# open existing datafile +# ------------------------------------------------------------------------------------ +logger.info('opening previously generated HDF5 file.') + +# open the existing datafile +hdf5_file_main = h5py.File(config['hdf5_filepath_prefix'], 'r') + +logger.info('opened HDF5 file at {}'.format(config['hdf5_filepath_prefix'])) + +# get the group identifier for original dataset +hdf5_file = hdf5_file_main['original_data'] + +# ==================================================================================== + +# ------------------------------------------------------------------------------------ +# create new HDF5 file to hold cropped data. +# ------------------------------------------------------------------------------------ +logger.info('creating new HDF5 dataset to hold cropped/normalized data') +filename = os.path.join(os.sep.join(config['hdf5_filepath_prefix'].split(os.sep)[0:-1]), 'BRATS2015_Cropped_Normalized.h5') +new_hdf5 = h5py.File(filename, mode='w') +logger.info('created new database at {}'.format(filename)) + +# create a folder group to hold the datasets. The schema is similar to original one except for the name of the folder +# group +new_group_preprocessed = new_hdf5.create_group('preprocessed') + +# create similar datasets in this file. +new_group_preprocessed.create_dataset("training_data_hgg", config['train_shape_hgg_crop'], np.float32) +new_group_preprocessed.create_dataset("training_data_hgg_pat_name", (config['train_shape_hgg_crop'][0],), dtype="S100") +new_group_preprocessed.create_dataset("training_data_segmasks_hgg", config['train_segmasks_shape_hgg_crop'], np.int16) + +new_group_preprocessed.create_dataset("training_data_lgg", config['train_shape_lgg_crop'], np.float32) +new_group_preprocessed.create_dataset("training_data_lgg_pat_name", (config['train_shape_lgg_crop'][0],), dtype="S100") +new_group_preprocessed.create_dataset("training_data_segmasks_lgg", config['train_segmasks_shape_lgg_crop'], np.int16) + +new_group_preprocessed.create_dataset("testing_hgglgg_patients", config['testing_hgglgg_patients_shape_crop'], np.float32) +new_group_preprocessed.create_dataset("testing_hgglgg_patients_pat_name", (config['testing_hgglgg_patients_shape'][0],), dtype="S100") +# ==================================================================================== + +# just copy the patient names directly +new_group_preprocessed['training_data_hgg_pat_name'][:] = hdf5_file['training_data_hgg_pat_name'][:] +new_group_preprocessed['training_data_lgg_pat_name'][:] = hdf5_file['training_data_lgg_pat_name'][:] +new_group_preprocessed['testing_hgglgg_patients_pat_name'][:] = hdf5_file['testing_hgglgg_patients_pat_name'][:] + +# ------------------------------------------------------------------------------------ +# start cropping process and standardization process +# ------------------------------------------------------------------------------------ + +# get the file where mean/var values are stored +# TODO: Use the config file global path, not this one. + +saveMeanVarFilename = os.sep.join(config['hdf5_filepath_prefix'].split(os.sep)[0:-1]) +logging.info('starting the Cropping/Normalization process.') + +# only run thecropping steps on these datasets +run_on_list = ['training_data_segmasks_hgg', 'training_data_hgg', 'training_data_lgg', 'training_data_segmasks_lgg', 'testing_hgglgg_patients'] + +#only run the mean/var normalization on these datasets +std_list = ['training_data_hgg', 'training_data_lgg'] +for run_on in run_on_list: + + # we define the final shape after cropping in the config file to make it easy to access. More information available in + # checkLargestCropSize.ipynb notebook. + if run_on == 'training_data_hgg': + im_np = np.empty(config['train_shape_hgg_crop']) + elif run_on == 'training_data_lgg': + im_np = np.empty(config['train_shape_lgg_crop']) + elif run_on == 'testing_hgglgg_patients': + im_np = np.empty(config['testing_hgglgg_patients_shape_crop']) + + logger.info('Running on {}'.format(run_on)) + for i in range(0, hdf5_file[run_on].shape[0]): + # cropping operation + logger.debug('{}:- Patient {}'.format(run_on, i+1)) + im = hdf5_file[run_on][i] + m = config['cropping_coords'] + if 'segmasks' in run_on: + # there are no channels for segmasks + k = im[m[0]:m[1], m[2]:m[3], m[4]:m[5]] + else: + k = im[:, m[0]:m[1], m[2]:m[3], m[4]:m[5]] + + if run_on in std_list: + # save the image to this numpy array + im_np[i] = k + new_group_preprocessed[run_on][i] = k + # find mean and standard deviation, and apply to data. Also write the mean/std values to disk + if run_on in std_list: + logger.info('The dataset {} needs standardization'.format(run_on)) + _tmp, vals = standardize(im_np, findMeanVarOnly=True, saveDump=saveMeanVarFilename + run_on + '_mean_std.p') + logging.info('Calculated normalization values for {}:\n{}'.format(run_on, vals)) + del im_np + +# ==================================================================================== + +hdf5_file_main.close() +new_hdf5.close() + + diff --git a/prep_ISLES2015/cedar_gen_hdf5.sh b/prep_ISLES2015/cedar_gen_hdf5.sh new file mode 100755 index 0000000..764d53a --- /dev/null +++ b/prep_ISLES2015/cedar_gen_hdf5.sh @@ -0,0 +1,14 @@ +#!/bin/bash +#SBATCH --nodes=1 # number of nodes +#SBATCH --ntasks=1 # number of MPI processes +#SBATCH --cpus-per-task=2 # 24 cores on cedar nodes +#SBATCH --account=rrg-hamarneh +#SBATCH --mem=16G # give all memory you have in the node +#SBATCH --time=3-05:00 # time (DD-HH:MM) +#SBATCH --job-name=GenerateHDF5File +#SBATCH --output=GenerateHDF5File.out +#SBATCH --mail-user=asa224@sfu.ca +#SBATCH --mail-type=ALL + +# run the command +~/.virtualenvs/mm_synthesis/bin/python create_hdf5_file.py diff --git a/prep_ISLES2015/create_hdf5_file.py b/prep_ISLES2015/create_hdf5_file.py new file mode 100755 index 0000000..351412f --- /dev/null +++ b/prep_ISLES2015/create_hdf5_file.py @@ -0,0 +1,154 @@ +""" +========================================================== + Prepare BRATS 2015 Data +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: This file is used to generate an HDF5 dataset, + which is easy to load and manipulate compared + to working directly with raw data all the time. + Loading and working with HDF5 files is much + faster and efficient due to its asynchronous loading + system. + + The HDF5 file generated can be hosted on a remote server + (like CEDAR) and then accessed over SSHFS. Practically, + this is very effective and does not hinder the performance + by a large margin. + + This script generates a simple HDF5 data store, + which contains the original numpy arrays of the + data store. To perform any preprocessing, implement + the preprocessData()function in dataloader.py to + work directly on nibabel objects, instead of + numpy objects. +LICENCE: Proprietary for now. +""" + +import os +import glob +from modules import dataloader +import logging +import numpy as np +import h5py +from modules.configfile import config + +logging.basicConfig(level=logging.INFO) +try: + logger = logging.getLogger(__file__.split('/')[-1]) +except: + logger = logging.getLogger(__name__) + +try: + logger.info('Dataloader file in use: {}'.format(dataloader.__file__)) +except: + logger.info('Cannot determine the dataloader class being used, check carefully before proceeding') + + +# whether or not to preprocess the data before creating the HDF5 file? Check the preprocess function in dataloader to +# know exactly what preprocessing is being performed. +# NO BIAS FIELD CORRECTION FOR ISLES DATASET FOR COMPARISON +PREPROCESS_DATA = False + +logger.info('[IMPORTANT] This will create a new HDF5 file in SAFE MODE. It will NOT OVERWRITE A PREVIOUS HDF5 FILE ' + 'IF ITS PRESENT') +def createHDF5File(config): + """ + Function to create a new HDF5 File to hold the BRATS 2015 data. The function will fail if there's already a file + present with the same name (SAFE OPERATION) + + :param config: The config variable defined in configfile.py + :return: hdf5_file object + """ + + # w- mode fails when there is a file already. + hdf5_file = h5py.File(config['hdf5_filepath_prefix'], mode='w') + + # create a new parent directory to hold the data inside it + grp = hdf5_file.create_group("original_data") + + # the dataset is int16 originally, checked using nibabel, however we create float32 containers to make the dataset + # compatible with further preprocessing. + # HGG Data + + grp.create_dataset("training_data", config['train_shape'], np.float32) + grp.create_dataset("training_data_pat_name", (config['train_shape'][0],), dtype="S100") + grp.create_dataset("training_data_segmasks", config['train_segmasks_shape'], np.int16) + + return hdf5_file + +def main(): + hdf5_file_main = createHDF5File(config) + # hdf5_file_main = h5py.File(config['hdf5_filepath_prefix'], mode='a') + # Go inside the "original_data" parent directory. + # we need to create the validation data dataset again since the shape has changed. + hdf5_file = hdf5_file_main['original_data'] + contents = glob.glob(os.path.join(config['data_dir_prefix'], '*')) + + # for debugging, making sure Training set is loaded first not Testing, since that is tested already. + contents.reverse() + for dataset_splits in contents: # Challenge/LeaderBoard data? + if os.path.isdir(dataset_splits): # make sure its a directory + for grade_type in glob.glob(os.path.join(dataset_splits, '*')): + # there may be other files in there (like the survival data), ignore them. + if os.path.isdir(grade_type): + count = 0 + if 'Testing' in dataset_splits: + logger.info('currently loading Testing -> {} data.'.format(os.path.basename(grade_type))) + ty = 'Testing' + + for images, pats in dataloader.loadDataGenerator(grade_type, + batch_size=config['batch_size'], loadSurvival=False, + csvFilePath=None, loadSeg=False, + preprocess=PREPROCESS_DATA, dataset='ISLES'): + logger.info('loading patient {} from {}'.format(count, grade_type)) + if 'HGG_LGG' in grade_type: + if ty == 'Testing': + main_data_name = 'testing_hgglgg_patients' + main_data_pat_name = 'testing_hgglgg_patients_pat_name' + + hdf5_file[main_data_name][count:count+config['batch_size'],...] = images + t = 0 + for i in range(count, count + config['batch_size']): + hdf5_file[main_data_pat_name][i] = pats[t].split('.')[-2] + t += 1 + + logger.info('loaded {} patient(s) from {}'.format(count + config['batch_size'], grade_type)) + count += config['batch_size'] + else: + # TRAINING data handler + if os.path.isdir(dataset_splits) and 'Training' in dataset_splits: + for grade_type in glob.glob(os.path.join(dataset_splits, '*')): + # there may be other files in there (like the survival data), ignore them. + if os.path.isdir(grade_type): + count = 0 + logger.info('currently loading Training data.') + for images, segmasks, pats in dataloader.loadDataGenerator(grade_type, + batch_size=config['batch_size'], loadSurvival=False, + csvFilePath=None, loadSeg=True, + preprocess=PREPROCESS_DATA, + dataset='ISLES'): + logger.info('loading patient {} from {}'.format(count, grade_type)) + + hdf5_file['training_data'][count:count+config['batch_size'],...] = images + hdf5_file['training_data_segmasks'][count:count+config['batch_size'], ...] = segmasks + t = 0 + for i in range(count, count + config['batch_size']): + hdf5_file['training_data_pat_name'][i] = pats[t].split('/')[-1] + t += 1 + + logger.info('loaded {} patient(s) from {}'.format(count + config['batch_size'], grade_type)) + count += config['batch_size'] + # close the HDF5 file + # close the HDF5 file + hdf5_file_main.close() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/prep_ISLES2015/modules/__init__.py b/prep_ISLES2015/modules/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/prep_ISLES2015/modules/configfile.py b/prep_ISLES2015/modules/configfile.py new file mode 100755 index 0000000..5427102 --- /dev/null +++ b/prep_ISLES2015/modules/configfile.py @@ -0,0 +1,81 @@ +""" +========================================================== + Config File to set Parameters +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: This file is solely created for the purpose of + managing parameters in a global setting. All the + database loading and generation parameters reside + here, and are inherited by create_hdf5_file.py + to generate the HDF5 data store. + + The parameters are also used in the test_database.py + script to test the created database. +LICENCE: Proprietary for now. +""" +import os +import platform + +# WE CAN USE THIS TO CHANGE IMAGE_DATA_FORMAT on the fly +# keras.backend.common._IMAGE_DATA_FORMAT='channels_first' + +# to make the code portable even on cedar,you need to add conditions here +node_name = platform.node() +if node_name == 'XPS15': + # this is my laptop, so the cedar-rm directory is at a different place + mount_path_prefix = '/home/anmol/mounts/cedar-rm/' +elif 'computecanada' in node_name: # we're in compute canada, maybe in an interactive node, or a scheduler node. + # mount_path_prefix = '/home/asa224/' # home directory + mount_path_prefix = '' # home directory +else: + # this is probably my workstation or TS server + # mount_path_prefix = '/local-scratch/asa224_new/mounts/cedar-rm/' + mount_path_prefix = '' + +config = {} +# set the data directory and output hdf5 file path. +# data_dir is the top level path containing both training and validation sets of the brats dataset. +# config['data_dir_prefix'] = os.path.join(mount_path_prefix, 'rrg_proj_dir/scratch_files_globus/Datasets/BRATS2015/') # this should +config['data_dir_prefix'] = os.path.join(mount_path_prefix, '/local-scratch/anmol/data/ISLES2015') # this should# be top level path +config['hdf5_filepath_prefix'] = os.path.join(mount_path_prefix, '/local-scratch/anmol/data/ISLES2015/HDF5_Datasets/ISLES2015.h5') # top level path + +config['spatial_size_for_training'] = (230, 230) # If any preprocessing is done, then this needs to change. This is the shape of data that you want to train with. If you are changing this that means you did some preprocessing. +config['num_slices'] = 154 # number of slices in input data. THIS SHOULD CHANGE AS WELL WHEN PERFORMING PREPROCESSING +config['volume_size'] = list(config['spatial_size_for_training']) + [config['num_slices']] +config['seed'] = 1338 +config['data_order'] = 'th' # what order should the indices be to store in hdf5 file +config['train_patients'] = 28 + +# SET LATER +config['cropping_coords'] = [8, 204, 39, 189, 0, 158] # coordinates used to crop the volumes, this is generated using the notebook checkLargestCropSize.ipynb +config['size_after_cropping'] = (196, 150, 154) # set this if you set the above variable. Calculate this using the notebook again. + +config['batch_size'] = 1 # how many images to load at once in the generator + +# check the order of data and chose proper data shape to save images +if config['data_order'] == 'th': + + config['train_shape'] = ( + config['train_patients'], 4, config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], + config['num_slices']) + + config['train_segmasks_shape'] = ( + config['train_patients'], config['spatial_size_for_training'][0], config['spatial_size_for_training'][1], + config['num_slices']) + + + config['train_shape_crop'] = ( + config['train_patients'], 4, config['size_after_cropping'][0], config['size_after_cropping'][1], + config['size_after_cropping'][2]) + + config['train_segmasks_shape_crop'] = ( + config['train_patients'], config['size_after_cropping'][0], config['size_after_cropping'][1], + config['size_after_cropping'][2]) diff --git a/prep_ISLES2015/modules/dataloader.py b/prep_ISLES2015/modules/dataloader.py new file mode 100755 index 0000000..a18ce11 --- /dev/null +++ b/prep_ISLES2015/modules/dataloader.py @@ -0,0 +1,292 @@ +""" +========================================================== + Load BRATS 2017 Data +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: The script has multiple functions to load, + preprocess, and standardize the BRATS + 2017 dataset, along with its survival annotations. + Main function is the loadDataGenerator which loads + the data using a generator, and doesn't hog memory. + + The loadDataGenerator is capable of applying + arbitrary preprocessing steps to the data. This can be + achieved by implementing the function preprocessData. +LICENCE: Proprietary for now. +""" + +from __future__ import print_function +import glob as glob +import numpy as np +import pickle +import sys as sys +# from pandas import read_csv +import os +import logging +from configfile import config +import SimpleITK as sitk + +logger = logging.getLogger(__name__) + +def preprocessData(img_obj, process=False): + """ + Perform preprocessing on the original nibabel object. + Use this function to: + 1) Resize/Resample the 3D Volume + 2) Crop the brain region + 3) Do (2) then (1). + + When you do preprocessing, especially something that + changes the spatial size of the volume, make sure you + update config['spatial_size_for_training'] = (240, 240) + value in the config file. + + :param img_obj: + :param process: + :return: + """ + if process == False: + return img_obj + else: + maskImage = sitk.OtsuThreshold(img_obj, 0, 1, 200) + image = sitk.Cast(img_obj, sitk.sitkFloat32) + corrector = sitk.N4BiasFieldCorrectionImageFilter() + numberFilltingLevels = 4 + corrector.SetMaximumNumberOfIterations([4] * numberFilltingLevels) + output = corrector.Execute(image, maskImage) + return output + +def resize_mha_volume(image, spacing=[1,1,1], size=[240,240,155]): + + # Create the reference image + reference_origin = image.GetOrigin() + reference_direction = np.identity(image.GetDimension()).flatten() + reference_image = sitk.Image(size, image.GetPixelIDValue()) + reference_image.SetOrigin(reference_origin) + reference_image.SetSpacing(spacing) + reference_image.SetDirection(reference_direction) + + # Transform which maps from the reference_image to the current image (output-to-input) + transform = sitk.AffineTransform(image.GetDimension()) + transform.SetMatrix(image.GetDirection()) + transform.SetTranslation(np.array(image.GetOrigin()) - reference_origin) + + # Modify the transformation to align the centers of the original and reference image + reference_center = np.array( + reference_image.TransformContinuousIndexToPhysicalPoint(np.array(reference_image.GetSize()) / 2.0)) + centering_transform = sitk.TranslationTransform(image.GetDimension()) + img_center = np.array(image.TransformContinuousIndexToPhysicalPoint(np.array(image.GetSize()) / 2.0)) + centering_transform.SetOffset(np.array(transform.GetInverse().TransformPoint(img_center) - reference_center)) + centered_transform = sitk.Transform(transform) + centered_transform.AddTransform(centering_transform) + + # Using the linear interpolator + image_rs = sitk.Resample(image, reference_image, transform, sitk.sitkLinear, 0.0) + return image_rs + + +def loadDataGenerator(data_dir, batch_size=1, preprocess=False, loadSurvival=False, + csvFilePath=None, loadSeg=True, dataset=2018): + """ + Main function to load BRATS 2017 dataset. + + :param data_dir: path to the folder where patient data resides, needs individual paths for HGG and LGG + :param batch_size: size of batch to load (default=1) + :param loadSurvival: load survival data (True/False) (default=False) + :param csvFilePath: If loadSurvival is True, provide path to survival data (default=False) + :param loadSeg: load segmentations (True/False) (default=True) + :return: + """ + + patID = 0 # used to keep count of how many patients loaded already. + num_sequences = 4 # number of sequences in the data. BRATS has 4. + num_slices = config['num_slices'] + running_pats = [] + out_shape = config['spatial_size_for_training'] # shape of the training data + + # create placeholders, currently only supports theano type convention (num_eg, channels, x, y, z) + images = np.empty((batch_size, num_sequences, out_shape[0], out_shape[1], num_slices)).astype(np.int16) + labels = np.empty((batch_size, 1)).astype(np.int16) + + if loadSeg == True: + # create placeholder for the segmentation mask + seg_masks = np.empty((batch_size, out_shape[0], out_shape[1], num_slices)).astype(np.int16) + + csv_flag = 0 + + batch_id = 1 # counter for batches loaded + logger.info('starting to load images..') + for patient in glob.glob(data_dir + '/*'): + if os.path.isdir(patient): + logger.debug('{} is a directory.'.format(patient)) + + # this hacky piece of code is to reorder the filenames, so that segmentation file is always at the end. + # get all the filepaths + sequence_folders = glob.glob(patient + '/*') + + vsd_id = [] + for curr_seq in sequence_folders: # get the filepath of the image (nii.gz) + imagefile = [x for x in glob.glob(os.path.join(curr_seq, '*')) if '.txt' not in x][0] + # save the name of the patient + if '.OT.' in imagefile: + if loadSeg == True: + logger.debug('loading segmentation for this patient..') + + # open using SimpleITK + # SimpleITK would allow me to add number of preprocessing steps that are well defined and + # implemented in SITK for their own object type. We can leverage those functions if we preserve + # the image object. + + img_obj = sitk.ReadImage(imagefile) + img_obj = resize_mha_volume(img_obj, spacing=[1, 1, 1], + size=[out_shape[0], out_shape[1], num_slices]) + + pix_data = sitk.GetArrayViewFromImage(img_obj) + + # check Practice - SimpleiTK.ipynb notebook for more info on why this swapaxes operation is req + pix_data_swapped = np.swapaxes(pix_data, 0, 1) + pix_data_swapped = np.swapaxes(pix_data_swapped, 1, 2) + + seg_masks[patID, :, :, :] = pix_data_swapped + else: + continue + else: + # this is to ensure that each channel stays at the same place + if 'isles' in dataset.lower(): + if 'T1.' in imagefile: + i = 0 + seq_name = 't1' + elif 'T2.' in imagefile: + i = 1 + seq_name = 't2' + elif 'DWI.' in imagefile: + i = 2 + seq_name = 'dwi' + elif 'Flair.' in imagefile: + i = 3 + seq_name = 'flair' + vsd_id.append(os.path.basename(imagefile)) + else: + if 'T1.' in imagefile: + i = 0 + seq_name = 't1' + elif 'T2.' in imagefile: + i = 1 + seq_name = 't2' + elif 'T1c.' in imagefile: + i = 2 + seq_name = 't1c' + elif 'Flair.' in imagefile: + i = 3 + seq_name = 'flair' + vsd_id.append(os.path.basename(imagefile)) + + img_obj = sitk.ReadImage(imagefile) + if preprocess == True: + logger.debug('performing N4ITK Bias Field Correction on {} modality'.format(seq_name)) + img_obj = preprocessData(img_obj, process=preprocess) + + img_obj = resize_mha_volume(img_obj, spacing=[1, 1, 1], size=[out_shape[0], out_shape[1], num_slices]) + + pix_data = sitk.GetArrayViewFromImage(img_obj) + + pix_data_swapped = np.swapaxes(pix_data, 0, 1) + pix_data_swapped = np.swapaxes(pix_data_swapped, 1, 2) + + images[patID, i, :, :, :] = pix_data_swapped + + patID += 1 + + if batch_id % batch_size == 0: + patID = 0 + if loadSeg == True: + yield images, seg_masks, vsd_id + elif loadSeg == False: + yield images, vsd_id + + vsd_id = [] + + batch_id += 1 + + + +def apply_mean_std(im, mean_var): + """ + Supercedes the standardize function. Takes the mean/var file generated during preprocessed data generation and + applies the normalization step to the patch. + :param im: patch of size (num_egs, channels, x, y, z) or (channels, x, y, z) + :param mean_var: dictionary containing mean/var value calculated in preprocess.py + :return: normalized patch + """ + + # expects a dictionary of means and VARIANCES, NOT STD + for m in range(0, 4): + if len(np.shape(im)) > 4: + im[:, m, ...] = (im[:, m, ...] - mean_var['mn'][m]) / np.sqrt(mean_var['var'][m]) + else: + im[m, ...] = (im[m, ...] - mean_var['mn'][m]) / np.sqrt(mean_var['var'][m]) + + return im + + +def standardize(images, findMeanVarOnly=True, saveDump=None, applyToTest=None): + """ + This function standardizes the input data to zero mean and unit variance. It is capable of calculating the + mean and std values from the input data, or can also apply user specified mean/std values to the images. + + :param images: numpy ndarray of shape (num_qg, channels, x, y, z) to apply mean/std normalization to + :param findMeanVarOnly: only find the mean and variance of the input data, do not normalize + :param saveDump: if True, saves the calculated mean/variance values to the disk in pickle form + :param applyToTest: apply user specified mean/var values to given images. checkLargestCropSize.ipynb has more info + :return: standardized images, and vals (if mean/val was calculated by the function + """ + + # takes a dictionary + if applyToTest != None: + logger.info('Applying to test data using provided values') + images = apply_mean_std(images, applyToTest) + return images + + logger.info('Calculating mean value..') + vals = { + 'mn': [], + 'var': [] + } + for i in range(4): + vals['mn'].append(np.mean(images[:, i, :, :, :])) + + logger.info('Calculating variance..') + for i in range(4): + vals['var'].append(np.var(images[:, i, :, :, :])) + + if findMeanVarOnly == False: + logger.info('Starting standardization process..') + + for i in range(4): + images[:, i, :, :, :] = ((images[:, i, :, :, :] - vals['mn'][i]) / float(vals['var'][i])) + + logger.info('Data standardized!') + + if saveDump != None: + logger.info('Dumping mean and var values to disk..') + pickle.dump(vals, open(saveDump, 'wb')) + logger.info('Done!') + + return images, vals + + +if __name__ == "__main__": + """ + Only for testing purpose, DO NOT ATTEMPT TO RUN THIS SCRIPT. ONLY IMPORT AS MODULE + """ + data_dir = '/local-scratch/cedar-rm/scratch/asa224/Datasets/BRATS2017/MICCAI_BraTS17_Data_Training/HGG/' + images, segmasks = loadDataGenerator(data_dir, batch_size=2, loadSurvival=False, + csvFilePath=None, loadSeg=True) diff --git a/prep_ISLES2015/modules/mischelpers.py b/prep_ISLES2015/modules/mischelpers.py new file mode 100755 index 0000000..d411af5 --- /dev/null +++ b/prep_ISLES2015/modules/mischelpers.py @@ -0,0 +1,239 @@ +""" +========================================================== + Misc Helper Classes/Functions +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: The module has various helper classes/functions + that can be used throughout the pipeline, and + don't fit exactly in either data loading or + visualization operations. +LICENCE: Proprietary for now. +""" + +import numpy as np +from configfile import config +import h5py +from nilearn._utils import check_niimg +from nilearn.image import new_img_like +from nilearn.image import reorder_img, resample_img + + + +class Rect3D: + """ + Class to encapsulate the Rectangle coordinates. This prevents future + issues when the coordinates need to be standardized. + """ + def __init__(self, coord_list): + if len(coord_list) < 6: + print('Coordinate list shape is incorrect, creating empty object!') + coord_list = [0, 0, 0, 0, 0, 0] + self.empty = True + else: + self.empty = False + + self.rmin = coord_list[0] + self.rmax = coord_list[1] + self.cmin = coord_list[2] + self.cmax = coord_list[3] + self.zmin = coord_list[4] + self.zmax = coord_list[5] + self.list_view = coord_list + + def show(self): + return self.list_view + +class Rect2D: + """ + Class to encapsulate the Rectangle coordinates. This prevents future + issues when the coordinates need to be standardized. + """ + def __init__(self, coord_list): + if len(coord_list) < 4: + print('Coordinate list shape is incorrect, creating empty object!') + coord_list = [0, 0, 0, 0] + self.empty = True + else: + self.empty = False + + self.rmin = coord_list[0] + self.rmax = coord_list[1] + self.cmin = coord_list[2] + self.cmax = coord_list[3] + self.list_view = coord_list + + def show(self): + return self.list_view + +def bbox_3D(img, tol=0.5): + """ + TOL = argument used when dark regions are >0 + (usually after some preprocessing, like + rescaling). + """ + r, c, z = np.where(img > tol) + rmin, rmax, cmin, cmax, zmin, zmax = np.min(r), np.max(r), np.min(c), np.max(c), np.min(z), np.max(z) + rect_obj = Rect3D([rmin, rmax, cmin, cmax, zmin, zmax]) + return rect_obj + +def bbox_2D(img, tol=0.5): + """ + TOL = argument used when dark regions are >0 + (usually after some preprocessing, like + rescaling). + """ + r, c = np.where(img > tol) + if r.size == 0 or c.size == 0: + return Rect2D([-1, -1, -1, -1]) + else: + rmin, rmax, cmin, cmax = np.min(r), np.max(r), np.min(c), np.max(c) + rect_obj = Rect2D([rmin, rmax, cmin, cmax]) + return rect_obj + +def open_hdf5(filepath=None, mode='r'): + if filepath == None: + filepath = config['hdf5_filepath_prefix'] + + return h5py.File(filepath, mode=mode) + +def get_data_splits_bbox(hdf5_filepath, train_start=0, train_end=190, test_start=190, test_end=None): + """ + + :param hdf5_filepath: + :param train_start: Start index to slice to get the training data. For 10 instances starting from 0, choose 0. + :param train_end: End index for training. Remember this index is 'exclusive', so if you want 10 instances, choose this as 10 + :param test_start: Start index to slice to get the testing data. Same comment as above. + :param test_end: End index for testing. + :return: Keras instances to slice into x_train, y_train, x_test, y_test. + """ + import keras + filepath = config['hdf5_filepath_prefix'] if hdf5_filepath is None else hdf5_filepath + + x_train = keras.utils.io_utils.HDF5Matrix(filepath, "training_data_hgg", start=train_start, end=train_end, + normalizer=None) + y_train = keras.utils.io_utils.HDF5Matrix(filepath, "bounding_box_hgg", start=train_start, end=train_end, + normalizer=None) + + x_test = keras.utils.io_utils.HDF5Matrix(filepath, "training_data_hgg", start=test_start, end=test_end, + normalizer=None) + y_test = keras.utils.io_utils.HDF5Matrix(filepath, "bounding_box_hgg", start=test_start, end=test_end, + normalizer=None) + + return x_train, y_train, x_test, y_test + +def createDense(bbox, im): + box = np.zeros(im.shape) + box[bbox[0]:bbox[1], bbox[2]:bbox[3], bbox[4]:bbox[5]] = 1 + return box + + +def _crop_img_to(img, slices, copy=True): + """Crops image to a smaller size + Crop img to size indicated by slices and adjust affine + accordingly + Parameters + ---------- + img: Niimg-like object + See http://nilearn.github.io/manipulating_images/input_output.html + Img to be cropped. If slices has less entries than img + has dimensions, the slices will be applied to the first len(slices) + dimensions + slices: list of slices + Defines the range of the crop. + E.g. [slice(20, 200), slice(40, 150), slice(0, 100)] + defines a 3D cube + copy: boolean + Specifies whether cropped data is to be copied or not. + Default: True + Returns + ------- + cropped_img: Niimg-like object + See http://nilearn.github.io/manipulating_images/input_output.html + Cropped version of the input image + """ + + img = check_niimg(img) + + data = img.get_data() + affine = img.affine + + cropped_data = data[slices] + if copy: + cropped_data = cropped_data.copy() + + linear_part = affine[:3, :3] + old_origin = affine[:3, 3] + new_origin_voxel = np.array([s.start for s in slices]) + new_origin = old_origin + linear_part.dot(new_origin_voxel) + + new_affine = np.eye(4) + new_affine[:3, :3] = linear_part + new_affine[:3, 3] = new_origin + + return new_img_like(img, cropped_data, new_affine) + + +def crop_img_custom(img, slices=None, rtol=1e-8, copy=True): + """Crops img as much as possible + Will crop img, removing as many zero entries as possible + without touching non-zero entries. Will leave one voxel of + zero padding around the obtained non-zero area in order to + avoid sampling issues later on. + Parameters + ---------- + img: Niimg-like object + See http://nilearn.github.io/manipulating_images/input_output.html + img to be cropped. + rtol: float + relative tolerance (with respect to maximal absolute + value of the image), under which values are considered + negligeable and thus croppable. + copy: boolean + Specifies whether cropped data is copied or not. + Returns + ------- + cropped_img: image + Cropped version of the input image + """ + + img = check_niimg(img) + data = img.get_data() + + if slices is not None: + return _crop_img_to(img, slices, copy=copy), slices + else: + infinity_norm = max(-data.min(), data.max()) + passes_threshold = np.logical_or(data < -rtol * infinity_norm, + data > rtol * infinity_norm) + + if data.ndim == 4: + passes_threshold = np.any(passes_threshold, axis=-1) + coords = np.array(np.where(passes_threshold)) + start = coords.min(axis=1) + end = coords.max(axis=1) + 1 + + # pad with one voxel to avoid resampling problems + start = np.maximum(start - 1, 0) + end = np.minimum(end + 1, data.shape[:3]) + + slices = [slice(s, e) for s, e in zip(start, end)] + + return _crop_img_to(img, slices, copy=copy), slices + + +def resize(image, new_shape, interpolation="continuous"): + input_shape = np.asarray(image.shape, dtype=np.float16) + ras_image = reorder_img(image, resample=interpolation) + output_shape = np.asarray(new_shape) + new_spacing = input_shape/output_shape + new_affine = np.copy(ras_image.affine) + new_affine[:3, :3] = ras_image.affine[:3, :3] * np.diag(new_spacing) + return resample_img(ras_image, target_affine=new_affine, target_shape=output_shape, interpolation=interpolation, clip=True) \ No newline at end of file diff --git a/prep_ISLES2015/prepare_data_for_synthesis.py b/prep_ISLES2015/prepare_data_for_synthesis.py new file mode 100755 index 0000000..b720575 --- /dev/null +++ b/prep_ISLES2015/prepare_data_for_synthesis.py @@ -0,0 +1,56 @@ +""" +======================================================================== + Prepare BRATS 2013 Validation data for MM_Synthesis +======================================================================== +AUTHOR: Anmol Sharma +Description: This file prepares the BRATS 2013 validation set, which is + divided into challenge and leaderboard, and further divided + into the grade types, into NPZ format that the mm_synthesis + module expects. + + This is similar to how BRATS 2018 data was prepared. Check + the original mm_synthesis codebase for more info. + +""" + +import os, sys, h5py +sys.path.append('..') +from modules.configfile import config, mount_path_prefix +import numpy as np + +def saveNPZ(data, save_path, pat_names): + np.save(open(os.path.join(save_path, 'pat_names_validation.npz'), 'wb'), pat_names) + t1 = data[:,0,...] + t1 = np.swapaxes(t1, 3, 2) + t1 = np.swapaxes(t1, 2, 1) + np.save(open(save_path + 'T1.npz', 'wb'), t1) + del t1 + + t2 = data[:,1,...] + t2 = np.swapaxes(t2, 3, 2) + t2 = np.swapaxes(t2, 2, 1) + np.save(open(save_path + 'T2.npz', 'wb'), t2) + del t2 + + t1ce = data[:,2,...] + t1ce = np.swapaxes(t1ce, 3, 2) + t1ce = np.swapaxes(t1ce, 2, 1) + np.save(open(save_path + 'T1CE.npz', 'wb'), t1ce) + del t1ce + + t2flair = data[:,3,...] + t2flair = np.swapaxes(t2flair, 3, 2) + t2flair = np.swapaxes(t2flair, 2, 1) + np.save(open(save_path + 'T2FLAIR.npz', 'wb'), t2flair) + del t2flair + + print('Done!') + +hf = h5py.File(config['hdf5_filepath_prefix'], 'r') +hf = hf['validation_data'] + +# SAVE CHALLENGE HGG Data +save_path_c_hg = os.path.join(mount_path_prefix, "scratch/asa224/Datasets/BRATS2015/mm_synthesis/validation_data/testing_hgglgg_patients/HGG_LGG/") +pat_names_c_hg = hf['testing_hgglgg_patients_pat_name'] +challenge_data_hgg = hf['testing_hgglgg_patients'] +saveNPZ(challenge_data_hgg, save_path_c_hg, pat_names_c_hg) diff --git a/prep_ISLES2015/preprocess.py b/prep_ISLES2015/preprocess.py new file mode 100755 index 0000000..1e1ebf6 --- /dev/null +++ b/prep_ISLES2015/preprocess.py @@ -0,0 +1,140 @@ +""" +========================================================== + Preprocess BRATS Data +========================================================== +AUTHOR: Anmol Sharma +AFFILIATION: Simon Fraser University + Burnaby, BC, Canada +PROJECT: Analysis of Brain MRI Scans for Management of + Malignant Tumors +COLLABORATORS: Anmol Sharma (SFU) + Prof. Ghassan Hamarneh (SFU) + Dr. Brian Toyota (VGH) + Dr. Mostafa Fatehi (VGH) +DESCRIPTION: This file uses the previously generated data + (using create_hdf5_file.py) and generates a + new file with similar structure, but after + applying a couple of preprocessing steps. + More specifically, the script applies the + following operations on the data: + + 1) Crop out the dark margins in the scans + to only leave a concise brain area. For + this a generous estimate of bounding box + generated from the whole dataset is used. + For more information, see checkLargestCropSize + notebook. + + The code DOES NOT APPLY MEAN/VAR normalization, + but simply calculates the values and saves on disk. + Check lines 140-143 for more information. + + The saved mean/var files are to be used before + the training process. + +LICENCE: Proprietary for now. +""" +import h5py +from modules.configfile import config +import logging +# from modules.mischelpers import * +from modules.dataloader import standardize +import os +import numpy as np + +logging.basicConfig(level=logging.DEBUG) + +try: + logger = logging.getLogger(__file__.split('/')[-1]) +except: + logger = logging.getLogger(__name__) + +logger.warning('[IMPORTANT] The code DOES NOT APPLY mean/var normalization, rather it calculates it and saves to disk') +# ------------------------------------------------------------------------------------ +# open existing datafile +# ------------------------------------------------------------------------------------ +logger.info('opening previously generated HDF5 file.') + +# open the existing datafile +hdf5_file_main = h5py.File(config['hdf5_filepath_prefix'], 'r') + +logger.info('opened HDF5 file at {}'.format(config['hdf5_filepath_prefix'])) + +# get the group identifier for original dataset +hdf5_file = hdf5_file_main['original_data'] + +# ==================================================================================== + +# ------------------------------------------------------------------------------------ +# create new HDF5 file to hold cropped data. +# ------------------------------------------------------------------------------------ +logger.info('creating new HDF5 dataset to hold cropped/normalized data') +filename = os.path.join(os.sep.join(config['hdf5_filepath_prefix'].split(os.sep)[0:-1]), 'ISLES_Cropped_Normalized.h5') +new_hdf5 = h5py.File(filename, mode='w') +logger.info('created new database at {}'.format(filename)) + +# create a folder group to hold the datasets. The schema is similar to original one except for the name of the folder +# group +new_group_preprocessed = new_hdf5.create_group('preprocessed') + +# create similar datasets in this file. +new_group_preprocessed.create_dataset("training_data", config['train_shape_crop'], np.float32) +new_group_preprocessed.create_dataset("training_data_pat_name", (config['train_shape_crop'][0],), dtype="S100") +new_group_preprocessed.create_dataset("training_data_segmasks", config['train_segmasks_shape_crop'], np.int16) + +# ==================================================================================== + +# just copy the patient names directly +new_group_preprocessed['training_data_pat_name'][:] = hdf5_file['training_data_pat_name'][:] + +# ------------------------------------------------------------------------------------ +# start cropping process and standardization process +# ------------------------------------------------------------------------------------ + +# get the file where mean/var values are stored +# TODO: Use the config file global path, not this one. + +saveMeanVarFilename = os.sep.join(config['hdf5_filepath_prefix'].split(os.sep)[0:-1]) +logging.info('starting the Cropping/Normalization process.') + +# only run thecropping steps on these datasets +run_on_list = ['training_data_segmasks', 'training_data'] + +#only run the mean/var normalization on these datasets +std_list = ['training_data'] +for run_on in run_on_list: + + # we define the final shape after cropping in the config file to make it easy to access. More information available in + # checkLargestCropSize.ipynb notebook. + if run_on == 'training_data': + im_np = np.empty(config['train_shape_crop']) + + logger.info('Running on {}'.format(run_on)) + for i in range(0, hdf5_file[run_on].shape[0]): + # cropping operation + logger.debug('{}:- Patient {}'.format(run_on, i+1)) + im = hdf5_file[run_on][i] + m = config['cropping_coords'] + if 'segmasks' in run_on: + # there are no channels for segmasks + k = im[m[0]:m[1], m[2]:m[3], m[4]:m[5]] + else: + k = im[:, m[0]:m[1], m[2]:m[3], m[4]:m[5]] + + if run_on in std_list: + # save the image to this numpy array + im_np[i] = k + new_group_preprocessed[run_on][i] = k + # find mean and standard deviation, and apply to data. Also write the mean/std values to disk + if run_on in std_list: + logger.info('The dataset {} needs standardization'.format(run_on)) + _tmp, vals = standardize(im_np, findMeanVarOnly=True, saveDump=saveMeanVarFilename + run_on + '_mean_std.p') + logging.info('Calculated normalization values for {}:\n{}'.format(run_on, vals)) + del im_np + +# ==================================================================================== + +hdf5_file_main.close() +new_hdf5.close() + + diff --git a/readme.MD b/readme.MD new file mode 100755 index 0000000..a7c27ce --- /dev/null +++ b/readme.MD @@ -0,0 +1,88 @@ +# MM-GAN: Missing MRI Pulse Sequence Synthesis using Multi-Modal Generative Adversarial Network + +This repository contains source code for MM-GAN: [Missing MRI Pulse Sequence Synthesis using Multi-Modal Generative Adversarial Network](https://ieeexplore.ieee.org/document/8859286). MM-GAN is a novel GAN-based approach that allows synthezing missing pulse sequences (modalities) for an MRI scan. For more details, please refer to the paper. + +To cite: +``` +@ARTICLE{Sharma20, + author={A. {Sharma} and G. {Hamarneh}}, + journal={IEEE Transactions on Medical Imaging}, + title={Missing MRI Pulse Sequence Synthesis Using Multi-Modal Generative Adversarial Network}, + year={2020}, + volume={39}, + number={4}, + pages={1170-1183},} +``` + +To cite this repository: +``` +@misc{Sharma20Code, + author = {A. {Sharma}}, + title = {MM-GAN: Missing MRI Pulse Sequence Synthesis using Multi-Modal Generative Adversarial Network}, + year = {2020}, + publisher = {GitHub}, + journal = {GitHub Repository}, + howpublished = {\url{https://github.com/trane293/mm-gan}}, + commit = {4f57d6a0e4c030202a07a60bc1bb1ed1544bf679} +} +``` + +## How to Run +In order to run the code, we recommend that you use [Anaconda](https://docs.anaconda.com/anaconda/install/) distribution. We provivde an environment.yml file that can be used to recreate the exact same Python environment that can be used to run this code. + +Once Anaconda is installed, simply do: +```sh +conda env create -f environment.yml +``` + +Download BRATS2018 dataset from the official website https://www.med.upenn.edu/sbia/brats2018/data.html and extract both the Training and Validation zip files to a folder, say: + +``` +/local-scratch/data/BRATS2018/Training/ + /local-scratch/data/BRATS2018/Training/HGG + /local-scratch/data/BRATS2018/Training/LGG +``` +Validation patients do not have HGG/LGG labels. +``` +/local-scratch/data/BRATS2018/Validation/ +``` + +Open `modules/configfile.py` and change the following paths to match your folder structure: + +``` +config['data_dir_prefix'] = os.path.join(mount_path_prefix, 'BRATS2018_Full/') # this should be top level path +config['hdf5_filepath_prefix'] = os.path.join(mount_path_prefix, 'BRATS2018/HDF5_Datasets/BRATS2018_Unprocessed.h5') # top level path +config['hdf5_filepath_prefix_2017'] = os.path.join(mount_path_prefix, 'scratch/asa224/Datasets/BRATS2017/HDF5_Datasets/BRATS.h5') # top level path +config['hdf5_combined'] = os.path.join(os.sep.join(config['hdf5_filepath_prefix'].split(os.sep)[0:-1]), 'BRATS_Combined_Unprocessed.h5') +config['hdf5_filepath_cropped'] = os.path.join(mount_path_prefix, 'BRATS2018/HDF5_Datasets/BRATS2018_Cropped_Normalized_Unprocessed.h5') # top level path +``` + +Once done, switch to the conda virtualenv, and run: + +`python modules/create_hdf5_file.py` + +Once it's finished, run: + +`python preprocess.py` + +You should have two HDF5 datasets now, one of which is raw BRATS2018 dataset, and one that is cropped to coordinates. + +In order to train MM-GAN on HGG data do: +```sh +./mmgan_hgg.sh +``` + +Similarly for training on LGG data: +```sh +./mmgan_lgg.sh +``` +And it should start the network training process. + +## Directory Structure +| Folder Name | Purpose | +| ------ | ------ | +| modules | contains modules and helper functions, along with network architectures| +| modules/advanced_gans | contains the main Pix2Pix implementation| +| notebooks | contains rough notebooks for quick prototyping| +|prep_BRATS2015| contains code to prepare BRATS2015 data for training/evaluation| +|prep_ISLES2015| contains code to prepare ISLES2015 data for training/validation| \ No newline at end of file diff --git a/train_mmgan_brats2018.py b/train_mmgan_brats2018.py new file mode 100755 index 0000000..b0c302f --- /dev/null +++ b/train_mmgan_brats2018.py @@ -0,0 +1,583 @@ +import os +import argparse +from modules.advanced_gans.models import * +from torch.autograd import Variable +from modules.models import cPix2PixDiscriminator +import time +import itertools +import pickle, gc +from modules.helpers import (ToTensor, + torch, + show_intermediate_results_BRATS, + Resize, + create_dataloaders, + impute_reals_into_fake, + save_checkpoint, + load_checkpoint, + generate_training_strategy, + calculate_metrics, + printTable) +import logging +import numpy as np +import copy, sys + +try: + logger = logging.getLogger(__file__.split('/')[-1]) +except: + logger = logging.getLogger(__name__) + +# Ignore warnings +import warnings +warnings.filterwarnings("ignore") + +parser = argparse.ArgumentParser() +parser.add_argument('--epoch', type=int, default=0, help='epoch to start training from') +parser.add_argument('--n_epochs', type=int, default=3, help='number of epochs of training') +parser.add_argument('--dataset', type=str, default="BRATS2018", help='name of the dataset') +parser.add_argument('--grade', type=str, default="LGG", help='grade of tumor to train on') +parser.add_argument('--path_prefix', type=str, default="", help='path prefix to choose') +parser.add_argument('--batch_size', type=int, default=4, help='size of the batches') +parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate') +parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient') +parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') +parser.add_argument('--decay_epoch', type=int, default=100, help='epoch from which to start lr decay') +parser.add_argument('--n_cpu', type=int, default=0, help='number of cpu threads to use during batch generation') +parser.add_argument('--img_height', type=int, default=256, help='size of image height') +parser.add_argument('--img_width', type=int, default=256, help='size of image width') +parser.add_argument('--channels', type=int, default=4, help='number of image channels') +parser.add_argument('--out_channels', type=int, default=4, help='number of output channels') +parser.add_argument('--sample_interval', type=int, default=500, help='interval between sampling of images from generators') +parser.add_argument('--train_patient_idx', type=int, default=3, help='number of patients to train with') +parser.add_argument('--checkpoint_interval', type=int, default=-1, help='interval between model checkpoints') +parser.add_argument('--discrim_type', type=int, default=1, help='discriminator type to use, 0 for normal, 1 for PatchGAN') +parser.add_argument('--test_pats', type=int, default=1, help='number of test patients') +parser.add_argument('--model_name', type=str, default='model_pycharm_test', help='name of mode') +parser.add_argument('--log_level', type=str, default='info', help='logging level to choose') +parser.add_argument('--c_learning', type=int, default=1, help='whether or not use curriculum learning framework') +parser.add_argument('--use_tanh', action='store_true', help='use tanh normalization throughout') +parser.add_argument('--z_type', type=str, default='noise', help='what type of imputation method to use') +parser.add_argument('--ic', type=int, default=1, help='whether to use implicit conditioning (1) or not (0)') + +opt = parser.parse_args() +print(opt) + +if 'info' in opt.log_level: + logging.basicConfig(level=logging.INFO) +elif 'debug' in opt.log_level: + logging.basicConfig(level=logging.DEBUG) + +# ============================================================================= +# Create Training and Validation data loaders +# ============================================================================= +if opt.path_prefix == "": + # parent_path = '/scratch/asa224/asa224/Datasets/BRATS2018/HDF5_Datasets/' + parent_path = '/local-scratch/anmol/data/{}/HDF5_Datasets/'.format(opt.dataset) +else: + # notice there's one less asa224 here + parent_path = os.path.join(opt.path_prefix, 'scratch/asa224/Datasets/BRATS2018/HDF5_Datasets/') + +if opt.dataset == 'BRATS2018': + if opt.grade == 'HGG': + logger.info('Running on HGG Dataset') + parent_name = 'preprocessed' + dataset_name = 'training_data_hgg' + dataset_type = 'cropped' + ALL_PATS = 210 + TRAINING_PATS = 190 + VALIDATION_PATS = 10 + TESTING_PATS = 10 + resize_slices = 148 + elif opt.grade == 'LGG': + logger.info('Running on LGG Dataset') + parent_name = 'preprocessed' + dataset_name = 'training_data_lgg' + dataset_type = 'cropped' + ALL_PATS = 75 + TRAINING_PATS = 70 + resize_slices = 148 +else: + logger.critical("Invalid dataset name: {}".format(opt.dataset)) + sys.exit(-1) + +logger.debug('\tparent_path: \t\t{}'.format(parent_path)) +logger.debug('\tparent_name: \t\t{}'.format(parent_name)) +logger.debug('\tdataset_name: \t\t{}'.format(dataset_name)) +logger.debug('\tdataset_type: \t\t{}'.format(dataset_type)) + +logger.info('\tTraining with CL \t=\t {}'.format(opt.c_learning)) +logger.info('\tImputing Tensor with \t=\t {}'.format(opt.z_type)) +logger.info('\tImplicit Conditioning\t=\t {}'.format(opt.ic)) + +if resize_slices % opt.batch_size != 0: + logger.critical("Batch size is not compatible, please change it to be a multiple of {}".format(resize_slices)) + sys.exit(-1) + +if opt.use_tanh: + which_normalization = 'tanh' +else: + which_normalization = None + +n_dataloader, dataloader_for_viz = create_dataloaders(parent_path=parent_path, + parent_name=parent_name, + dataset_name=dataset_name, + dataset_type=dataset_type, + load_pat_names=True, + load_seg=False, + transform_fn=[Resize(size=(opt.img_height, opt.img_width)), ToTensor()], + apply_normalization=True, + which_normalization=which_normalization, + resize_slices=resize_slices, + get_viz_dataloader=True, + num_workers=opt.n_cpu, + load_indices=None, + dataset=opt.dataset, + shuffle=False) + +test_patient = [] +for k in range(0, opt.test_pats): + test_patient.append(dataloader_for_viz.getitem_via_index(opt.train_patient_idx + k)) # tehre should be no +1 + +# if train_pat = 200 +# The testing loop will evaluate at train_idx = 199 since the condition is train_idx + 1 == opt.train_patient_idx +# testing patient should start from 200 until 209. + +# ============================================================================= + +# ============================================================================= +# Initialize Networks +# ============================================================================= +# +# os.makedirs('images/%s' % opt.dataset_name, exist_ok=True) +# os.makedirs('saved_models/%s' % opt.dataset_name, exist_ok=True) + +cuda = True if torch.cuda.is_available() else False + +# ============================================================================= +# Loss functions +# ============================================================================= +criterion_GAN = torch.nn.BCELoss() if opt.discrim_type == 0 else torch.nn.MSELoss() +criterion_pixelwise = torch.nn.L1Loss() +mse_fake_vs_real = torch.nn.MSELoss() +# ============================================================================= + +# Loss weight of L1 pixel-wise loss between translated image and real image +lambda_pixel = 100 + +# Calculate output of image discriminator (PatchGAN) +patch = (opt.out_channels, opt.img_height//2**4, opt.img_width//2**4) + +# Initialize generator and discriminator +if which_normalization == 'tanh': + generator = GeneratorUNet(in_channels=opt.channels, out_channels=opt.out_channels, with_relu=False, with_tanh=True) +else: + generator = GeneratorUNet(in_channels=opt.channels, out_channels=opt.out_channels, with_relu=True, with_tanh=False) +discriminator = Discriminator(in_channels=opt.channels, dataset='BRATS2018') + +# ============================================================================= + +# ============================================================================= +# Where to save results +# ============================================================================= + +if opt.path_prefix == "": + root = '/local-scratch/anmol/results_new/project_880/' +else: # NOT USED + root = os.path.join(opt.path_prefix, 'rrg_proj_dir/Results/project_880_new/mm_synthesis_gan_results/') + logger.warning("root: {}".format(root)) + logger.warning('Possible bad value for opt.path_prefix') + +model = opt.model_name +if not os.path.isdir(root): + os.mkdir(root) +if not os.path.isdir(os.path.join(root, model)): + os.mkdir(os.path.join(root, model)) +if not os.path.isdir(os.path.join(root, model, "{}".format(opt.dataset), 'scenario_results')): + os.makedirs(os.path.join(root, model, "{}".format(opt.dataset), 'scenario_results')) +# ============================================================================= + +# Optimizers +optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) +optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) + +# Send everything to GPU +if cuda: + generator = nn.DataParallel(generator.cuda()) + discriminator = nn.DataParallel(discriminator.cuda()) + criterion_GAN.cuda() + criterion_pixelwise.cuda() + mse_fake_vs_real.cuda() + +# ============================================================================= +# Init networks and optimizers +# ============================================================================= +if opt.epoch != 0: + # Load pretrained models + logger.info('Loading previous checkpoint!') + generator, optimizer_G = load_checkpoint(generator, optimizer_G, os.path.join(root, opt.model_name, + "{}_param_{}_{}.pkl".format( + 'generator', opt.model_name, + opt.epoch)), pickle_module=pickle) + discriminator, optimizer_D = load_checkpoint(discriminator, optimizer_D, os.path.join(root, opt.model_name, + "{}_param_{}_{}.pkl".format( + 'discriminator', + opt.model_name, + opt.epoch)), pickle_module=pickle) + +else: + # Initialize weights + generator.apply(weights_init_normal) + discriminator.apply(weights_init_normal) + +# ============================================================================= + +# Tensor type +Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor + + +# ============================================================================= +# Training +# ============================================================================= + +# Book keeping +train_hist = {} +train_hist['D_losses'] = [] +train_hist['G_losses'] = [] +train_hist['per_epoch_ptimes'] = [] +train_hist['total_ptime'] = [] +train_hist['test_loss'] = { + 'mse': [], + 'psnr': [], + 'ssim': [] +} +# Get the device we're working on. +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +# Create all scenrios: Total will 15, but remove 0000 and 1111 +scenarios = list(map(list, itertools.product([0, 1], repeat=4))) + +# Generate new label placeholders for this particular batch +# This is for the G (Changed below) +# label_map = torch.ones((opt.batch_size, 4, opt.img_height, opt.img_width), requires_grad=False).cuda().type( +# torch.cuda.FloatTensor) + +# This is for D (Changed below) + +label_list = torch.from_numpy(np.ones((opt.batch_size, + patch[0], + patch[1], + patch[2]))).cuda().type(torch.cuda.FloatTensor) + +# remove the empty scenario and all available scenario +scenarios.remove([0,0,0,0]) +scenarios.remove([1,1,1,1]) + +# sort the scenarios according to decreasing difficulty. Easy scenarios last, and difficult ones first. +scenarios.sort(key=lambda x: x.count(1)) + +logger.info("Starting Training") +start_time = time.time() + +for epoch in range(opt.epoch, opt.n_epochs, 1): + D_losses = [] + D_real_losses = [] + D_fake_losses = [] + G_train_l1_losses = [] + G_train_losses = [] + G_losses = [] + synth_losses = [] + + # patient: Whole patient dictionary containing image, seg, name etc. + # x_patient: Just the images of a single patient + # x_r: Batch of images taken from x_patient according to the batch size specified. + # x_z: Batch from x_r where some sequences are imputed with noise for input to G + epoch_start_time = time.time() + for idx_pat, patient in enumerate(n_dataloader): + logger.info("Current idx_pat: {}".format(idx_pat)) + # if idx_pat > opt.train_patient_idx: + # logger.info("Now testing on patient {}".format(opt.train_patient_idx + 1)) + # main_path = os.path.join(root, model, 'scenario_results') + # + # fixed_p = os.path.join(root, model, 'scenario_results', 'viz' + "_" + str(epoch + 1)) + # + # logger.info("Saving result as {}".format(fixed_p)) + # status = show_intermediate_results(generator, test_patient, save_path=main_path, + # all_scenarios=copy.deepcopy(scenarios), epoch=epoch, + # curr_scenario_range=None, + # batch_size_to_test=opt.batch_size) + # break + + # Put the whole patient in GPU to aid quicker training + x_patient = patient['image'] + batch_indices = list(range(0, resize_slices, opt.batch_size)) + + # this shuffles the 2D axial slice batches for efficient training + # tag1 + random.shuffle(batch_indices) + + # create batches out of this patient + for _num, batch_idx in enumerate(batch_indices): + logger.debug("Patient #{}\nBatch #{}".format(idx_pat, _num)) + + logger.debug("\tSplicing batch from x_real") + x_r = x_patient[batch_idx:batch_idx + opt.batch_size, ...].cuda().type(Tensor) + + if opt.c_learning == 1: + # Curriculum Learning: Train with easier cases in the first epochs, then start training on harder ones + if epoch <= 10: + curr_scenario_range = [11, 14] + rand_val = torch.randint(low=10, high=14, size=(1,)) + if epoch > 10 and epoch <= 20: + curr_scenario_range = [7, 14] + rand_val = torch.randint(low=7, high=14, size=(1,)) + if epoch > 20 and epoch <= 30: + curr_scenario_range = [3, 14] + rand_val = torch.randint(low=3, high=14, size=(1,)) + if epoch > 30: + curr_scenario_range = [0, 14] + rand_val = torch.randint(low=0, high=14, size=(1,)) + elif opt.c_learning == 2: + rand_val = torch.randint(low=0, high=14, size=(1,)) + + label_scenario = scenarios[int(rand_val.numpy()[0])] + logger.debug('\tTraining this batch with Scenario: {}'.format(label_scenario)) + + # create a new x_imputed and x_real with this label scenario + x_z = x_r.clone().cuda() + + label_list_r = torch.from_numpy(np.ones((opt.batch_size, + patch[0], + patch[1], + patch[2]))).cuda().type(torch.cuda.FloatTensor) + + if opt.z_type == 'noise': + impute_tensor = torch.randn((opt.batch_size, + opt.img_height, + opt.img_width), device=device) + elif opt.z_type == 'average': + avail_indx = [i for i, x in enumerate(label_scenario) if x == 1] + impute_tensor = torch.mean(x_r[:, avail_indx,...], dim=1) + elif opt.z_type == 'zeros': + impute_tensor = torch.zeros((opt.batch_size, + opt.img_height, + opt.img_width), device=device) + + for idx, k in enumerate(label_scenario): + if k == 0: + x_z[:, idx, ...] = impute_tensor + + # label_map[:, idx, ...] = 0 + + # this works with both discriminator types. + label_list[:, idx] = 0 + + elif k == 1: + # label_map[:, idx, ...] = 1 + + # this works with both discriminator types. + label_list[:, idx] = 1 + + # TRAIN GENERATOR G + logger.debug('\tTraining Generator') + generator.zero_grad() + optimizer_G.zero_grad() + + # G_result have already been computed above, but we need this again in order to backpropagate again + # G_result = generator(x_z, label_map) + + fake_x = generator(x_z) + + # tag1 + if opt.ic == 1: # we're using IC + fake_x = impute_reals_into_fake(x_z, fake_x, label_scenario) + + pred_fake = discriminator(fake_x, x_r) + + # G_train_loss = BCE_loss(D_result, label_list) + # The discriminator should think that the pred_fake is real, so we minimize the loss between pred_fake + # and label_list_r, ie. make the pred_fake look real, and reducing the error that the discriminator makes + # when predicting it. + + if pred_fake.size() != label_list_r.size(): + logger.warning('Error!') + import sys + sys.exit(-1) + + loss_GAN = criterion_GAN(pred_fake, label_list_r) + + # pixel-wise loss + if opt.ic == 1: + loss_pixel = 0 + synth_loss = 0 + count = 0 + for idx_curr_label, i in enumerate(label_scenario): + if i == 0: + loss_pixel += criterion_pixelwise(fake_x[:, idx_curr_label, ...], x_r[:, idx_curr_label, ...]) + + synth_loss += mse_fake_vs_real(fake_x[:, idx_curr_label, ...], x_r[:, idx_curr_label, ...]) + count += 1 + + + loss_pixel /= count + synth_loss /= count + else: # no IC, calculate loss for all output w.r.t all GT. + loss_pixel = criterion_pixelwise(fake_x, x_r) + + synth_loss = mse_fake_vs_real(fake_x, x_r) + + # variable that sets the relative importance to loss_GAN and loss_pixel + lam = 0.9 + G_train_total_loss = (1 - lam) * loss_GAN + lam * loss_pixel + + G_train_total_loss.backward() + optimizer_G.step() + + # save the losses + G_train_l1_losses.append(loss_pixel.data[0]) + G_train_losses.append(loss_GAN.data[0]) + G_losses.append(G_train_total_loss.data[0]) + synth_losses.append(synth_loss.data[0]) + + # TRAIN DISCRIMINATOR D + # this takes in the real x as X-INPUT and real x as Y-INPUT + logger.debug('\tTraining Discriminator') + discriminator.zero_grad() + optimizer_D.zero_grad() + + # real loss + # EDIT: We removed noise addition + # We can add noise to the inputs of the discriminator + pred_real = discriminator(x_r, + x_r) + + loss_real = criterion_GAN(pred_real, label_list_r) + + # fake loss + # fake_x = generator(x_z, label_map) + fake_x = generator(x_z) + + # tag1 + if opt.ic == 1: + fake_x = impute_reals_into_fake(x_z, fake_x, label_scenario) + + # we add noise to the inputs of the discriminator here as well + pred_fake = discriminator(fake_x.detach(), x_r) + # pred_fake = discriminator(fake_x, x_r) + + loss_fake = criterion_GAN(pred_fake, label_list) + + D_train_loss = 0.5 * (loss_real + loss_fake) + + # for printing purposes + D_real_losses.append(loss_real.data[0]) + D_fake_losses.append(loss_fake.data[0]) + D_losses.append(D_train_loss.data[0]) + + D_train_loss.backward() + optimizer_D.step() + + + + logger.info(" E [{}/{}] P #{} ".format(epoch, opt.n_epochs, + idx_pat) + 'B [%d/%d] - loss_d: [real: %.5f, fake: %.5f, comb: %.5f], loss_g: [gan: %.5f, l1: %.5f, comb: %.5f], synth_loss_mse(ut): %.5f' % ( + (_num + 1), resize_slices // opt.batch_size, torch.mean(torch.FloatTensor(D_real_losses)), + torch.mean(torch.FloatTensor(D_fake_losses)), + torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_train_losses)), + torch.mean(torch.FloatTensor(G_train_l1_losses)), torch.mean(torch.FloatTensor(G_losses)), + torch.mean(torch.FloatTensor(synth_losses)))) + # Check if we have trained with exactly opt.train_patient_idx patients (if opt.train_patient_idx is 10, then idx_pat will be 9, so this condition will evaluate to true + if idx_pat + 1 == opt.train_patient_idx: + logger.info('Testing on test set for this fold') + main_path = os.path.join(root, model, "{}".format(opt.dataset), 'scenario_results') + + logger.info("Saving results at {}".format(main_path)) + + generator.eval() + + logger.info("Calculating metric on test set") + result_dict_test, _running_mse, _running_psnr, _running_ssim = calculate_metrics( + generator, test_patient, save_path=main_path, + all_scenarios=copy.deepcopy(scenarios), + epoch=epoch, save_stats=True, + curr_scenario_range=None, + batch_size_to_test=1, + impute_type=opt.z_type, + dataset=opt.dataset) + + logger.info("\t\tTesting Performance Numbers") + printTable(result_dict_test) + gc.collect() + + # logger.info("Writing detailed visualizations for each scenario") + # status = show_intermediate_results_BRATS(generator, test_patient, save_path=main_path, + # all_scenarios=copy.deepcopy(scenarios), epoch=epoch, + # curr_scenario_range=None, + # batch_size_to_test=opt.batch_size) + + generator.train() + gc.collect() + break + + epoch_end_time = time.time() + per_epoch_ptime = epoch_end_time - epoch_start_time + + print( + '[%d/%d] - ptime: %.2f, loss_d: [real: %.5f, fake: %.5f, comb: %.5f], loss_g: [gan: %.5f, l1: %.5f, comb: %.5f], ' + 'synth_loss_mse(ut): %.5f' % ( + (epoch + 1), opt.n_epochs, per_epoch_ptime, torch.mean(torch.FloatTensor(D_real_losses)), + torch.mean(torch.FloatTensor(D_fake_losses)), + torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_train_losses)), + torch.mean(torch.FloatTensor(G_train_l1_losses)), torch.mean(torch.FloatTensor(G_losses)), + torch.mean(torch.FloatTensor(synth_losses)))) + + # Checkpoint the models + + gen_state_checkpoint = { + 'epoch': epoch + 1, + 'arch': opt.model_name, + 'state_dict': generator.state_dict(), + 'optimizer' : optimizer_G.state_dict(), + } + + des_state_checkpoint = { + 'epoch': epoch + 1, + 'arch': opt.model_name, + 'state_dict': discriminator.state_dict(), + 'optimizer': optimizer_D.state_dict(), + } + + save_checkpoint(gen_state_checkpoint, os.path.join(root, model, 'generator_param_{}_{}.pkl'.format(model, epoch + 1)), + pickle_module=pickle) + + save_checkpoint(des_state_checkpoint, + os.path.join(root, model, 'discriminator_param_{}_{}.pkl'.format(model, epoch + 1)), + pickle_module=pickle) + + with open(os.path.join(root, model, "{}".format(opt.dataset), + 'result_dict_test_epoch_{}.pkl'.format(epoch)), 'wb') as f: + pickle.dump(result_dict_test, f) + + logger.info('[Testing] num_pats: {}, mse: {:.5f}, psnr: {:.5f}, ssim: {:.5f}'.format( + opt.test_pats, + result_dict_test['mean']['mse'], + result_dict_test['mean']['psnr'], + result_dict_test['mean']['ssim'] + )) + + train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses))) + train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses))) + + train_hist['test_loss']['mse'].append(result_dict_test['mean']['mse']) + train_hist['test_loss']['psnr'].append(result_dict_test['mean']['psnr']) + train_hist['test_loss']['ssim'].append(result_dict_test['mean']['ssim']) + + train_hist['per_epoch_ptimes'].append(per_epoch_ptime) + + end_time = time.time() + total_ptime = end_time - start_time + train_hist['total_ptime'].append(total_ptime) + + print("Avg one epoch ptime: %.2f, total %d epochs ptime: %.2f" % ( + torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), opt.n_epochs, total_ptime)) + +with open(os.path.join(root, model, 'train_hist.pkl'), 'wb') as f: + pickle.dump(train_hist, f) diff --git a/train_mmgan_brats2018_single.py b/train_mmgan_brats2018_single.py new file mode 100755 index 0000000..b20e5b2 --- /dev/null +++ b/train_mmgan_brats2018_single.py @@ -0,0 +1,614 @@ +import os +import argparse +from modules.advanced_gans.models import * +from torch.autograd import Variable +from modules.models import cPix2PixDiscriminator +import time +import itertools +import pickle, gc +from modules.helpers import (ToTensor, + torch, + show_intermediate_results, + Resize, + create_dataloaders, + impute_reals_into_fake, + save_checkpoint, + load_checkpoint, + generate_training_strategy, + calculate_metrics, + printTable) +import logging +import numpy as np +import copy, sys + +try: + logger = logging.getLogger(__file__.split('/')[-1]) +except: + logger = logging.getLogger(__name__) + +# Ignore warnings +import warnings +warnings.filterwarnings("ignore") + +parser = argparse.ArgumentParser() +parser.add_argument('--epoch', type=int, default=0, help='epoch to start training from') +parser.add_argument('--n_epochs', type=int, default=2, help='number of epochs of training') +parser.add_argument('--dataset', type=str, default="BRATS2018", help='name of the dataset') +parser.add_argument('--grade', type=str, default="LGG", help='grade of tumor to train on') +parser.add_argument('--path_prefix', type=str, default="", help='path prefix to choose') +parser.add_argument('--batch_size', type=int, default=4, help='size of the batches') +parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate') +parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient') +parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient') +parser.add_argument('--decay_epoch', type=int, default=100, help='epoch from which to start lr decay') +parser.add_argument('--n_cpu', type=int, default=0, help='number of cpu threads to use during batch generation') +parser.add_argument('--img_height', type=int, default=256, help='size of image height') +parser.add_argument('--img_width', type=int, default=256, help='size of image width') +parser.add_argument('--channels', type=int, default=4, help='number of image channels') +parser.add_argument('--out_channels', type=int, default=1, help='number of output channels') +parser.add_argument('--sample_interval', type=int, default=500, help='interval between sampling of images from generators') +parser.add_argument('--train_patient_idx', type=int, default=5, help='number of patients to train with') +parser.add_argument('--checkpoint_interval', type=int, default=-1, help='interval between model checkpoints') +parser.add_argument('--discrim_type', type=int, default=1, help='discriminator type to use, 0 for normal, 1 for PatchGAN') +parser.add_argument('--test_pats', type=int, default=2, help='number of test patients') +parser.add_argument('--model_name', type=str, default='pycharm_test', help='name of mode') +parser.add_argument('--log_level', type=str, default='info', help='logging level to choose') +parser.add_argument('--c_learning', type=int, default=2, help='whether or not use curriculum learning framework') +parser.add_argument('--type', type=str, default='T1', help='what sequence to synthesize') +parser.add_argument('--use_tanh', action='store_true', help='use tanh normalization throughout') + +opt = parser.parse_args() +print(opt) + +if 'info' in opt.log_level: + logging.basicConfig(level=logging.INFO) +elif 'debug' in opt.log_level: + logging.basicConfig(level=logging.DEBUG) + +# ============================================================================= +# Create Training and Validation data loaders +# ============================================================================= +if opt.path_prefix == "": + # parent_path = '/scratch/asa224/asa224/Datasets/BRATS2018/HDF5_Datasets/' + parent_path = '/local-scratch/anmol/data/{}/HDF5_Datasets/'.format(opt.dataset) +else: + # notice there's one less asa224 here + parent_path = os.path.join(opt.path_prefix, 'scratch/asa224/Datasets/{}/HDF5_Datasets/'.format(opt.dataset)) + +if opt.dataset == 'BRATS2018': + if opt.grade == 'HGG': + logger.info('Running on HGG Dataset') + parent_name = 'preprocessed' + dataset_name = 'training_data_hgg' + dataset_type = 'cropped' + ALL_PATS = 210 + TRAINING_PATS = 190 + VALIDATION_PATS = 10 + TESTING_PATS = 10 + resize_slices = 148 + elif opt.grade == 'LGG': + logger.info('Running on LGG Dataset') + parent_name = 'preprocessed' + dataset_name = 'training_data_lgg' + dataset_type = 'cropped' + ALL_PATS = 75 + TRAINING_PATS = 70 + resize_slices = 148 +elif opt.dataset == 'BRATS2015': + logger.info("BRATS2015") + if opt.grade == 'HGG': + logger.info('Running on HGG Dataset') + parent_name = 'preprocessed' + dataset_name = 'training_data_hgg' + dataset_type = 'cropped' + ALL_PATS = 220 + TRAINING_PATS = 200 + VALIDATION_PATS = 10 + TESTING_PATS = 10 + resize_slices = 148 + elif opt.grade == 'LGG': + logger.info('Running on LGG Dataset') + parent_name = 'preprocessed' + dataset_name = 'training_data_lgg' + dataset_type = 'cropped' + ALL_PATS = 54 + TRAINING_PATS = 45 + resize_slices = 148 +else: + logger.critical("Invalid dataset name: {}".format(opt.dataset)) + sys.exit(-1) + +logger.debug('\tparent_path: \t\t{}'.format(parent_path)) +logger.debug('\tparent_name: \t\t{}'.format(parent_name)) +logger.debug('\tdataset_name: \t\t{}'.format(dataset_name)) +logger.debug('\tdataset_type: \t\t{}'.format(dataset_type)) + +if resize_slices % opt.batch_size != 0: + logger.critical("Batch size is not compatible, please change it to be a multiple of {}".format(resize_slices)) + sys.exit(-1) + +#DEBUG ONLY +# opt.use_tanh = True + +if opt.use_tanh: + which_normalization = 'tanh' +else: + which_normalization = None + +train_range = list(range(0, opt.train_patient_idx)) + +n_dataloader, dataloader_for_viz = create_dataloaders(parent_path=parent_path, + parent_name=parent_name, + dataset_name=dataset_name, + dataset_type=dataset_type, + load_pat_names=True, + load_seg=False, + transform_fn=[Resize(size=(opt.img_height, opt.img_width)), ToTensor()], + apply_normalization=True, + which_normalization=which_normalization, + train_range=train_range, + resize_slices=resize_slices, + get_viz_dataloader=True, + num_workers=opt.n_cpu, + load_indices=None, + dataset=opt.dataset, + shuffle=False) + +test_patient = [] +for k in range(0, opt.test_pats): + test_patient.append(dataloader_for_viz.getitem_via_index(opt.train_patient_idx + k)) # tehre should be no +1 + +# if train_pat = 200 +# The testing loop will evaluate at train_idx = 199 since the condition is train_idx + 1 == opt.train_patient_idx +# testing patient should start from 200 until 209. + +# ============================================================================= + +# ============================================================================= +# Initialize Networks +# ============================================================================= +# +# os.makedirs('images/%s' % opt.dataset_name, exist_ok=True) +# os.makedirs('saved_models/%s' % opt.dataset_name, exist_ok=True) + +cuda = True if torch.cuda.is_available() else False + +# ============================================================================= +# Loss functions +# ============================================================================= +criterion_GAN = torch.nn.BCELoss() if opt.discrim_type == 0 else torch.nn.MSELoss() +criterion_pixelwise = torch.nn.L1Loss() +mse_fake_vs_real = torch.nn.MSELoss() +# ============================================================================= + +# Loss weight of L1 pixel-wise loss between translated image and real image +lambda_pixel = 100 + +# Calculate output of image discriminator (PatchGAN) +patch = (opt.out_channels, opt.img_height//2**4, opt.img_width//2**4) + +# Initialize generator and discriminator +if which_normalization == 'tanh': + generator = GeneratorUNet(in_channels=opt.channels, out_channels=opt.out_channels, with_relu=False, with_tanh=True) +else: + generator = GeneratorUNet(in_channels=opt.channels, out_channels=opt.out_channels, with_relu=True, with_tanh=False) +discriminator = Discriminator(in_channels=opt.out_channels, out_channels=opt.out_channels, dataset='BRATS2018') + +# ============================================================================= + +# ============================================================================= +# Where to save results +# ============================================================================= + +if opt.path_prefix == "": + root = '/local-scratch/anmol/results_new/project_880/' +else: # NOT USED + root = os.path.join(opt.path_prefix, 'rrg_proj_dir/Results/project_880_new/mm_synthesis_gan_results/') + logger.warning("root: {}".format(root)) + logger.warning('Possible bad value for opt.path_prefix') + +model = opt.model_name +if not os.path.isdir(root): + os.mkdir(root) +if not os.path.isdir(os.path.join(root, model)): + os.mkdir(os.path.join(root, model)) +if not os.path.isdir(os.path.join(root, model, "{}".format(opt.dataset), 'scenario_results')): + os.makedirs(os.path.join(root, model, "{}".format(opt.dataset), 'scenario_results')) +# ============================================================================= + +# Optimizers +optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) +optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) + +# Send everything to GPU +if cuda: + generator = nn.DataParallel(generator.cuda()) + discriminator = nn.DataParallel(discriminator.cuda()) + criterion_GAN.cuda() + criterion_pixelwise.cuda() + mse_fake_vs_real.cuda() + +# ============================================================================= +# Init networks and optimizers +# ============================================================================= +if opt.epoch != 0: + # Load pretrained models + logger.info('Loading previous checkpoint!') + generator, optimizer_G = load_checkpoint(generator, optimizer_G, os.path.join(root, opt.model_name, + "{}_param_{}_{}.pkl".format( + 'generator', opt.model_name, + 1)), pickle_module=pickle) + discriminator, optimizer_D = load_checkpoint(discriminator, optimizer_D, os.path.join(root, opt.model_name, + "{}_param_{}_{}.pkl".format( + 'discriminator', + opt.model_name, + 1)), pickle_module=pickle) + +else: + # Initialize weights + generator.apply(weights_init_normal) + discriminator.apply(weights_init_normal) + +# ============================================================================= + +# Tensor type +Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor + + +# ============================================================================= +# Training +# ============================================================================= + +# Book keeping +train_hist = {} +train_hist['D_losses'] = [] +train_hist['G_losses'] = [] +train_hist['per_epoch_ptimes'] = [] +train_hist['total_ptime'] = [] +train_hist['test_loss'] = { + 'mse': [], + 'psnr': [], + 'ssim': [] +} +# Get the device we're working on. +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +# Create all scenrios: Total will 15, but remove 0000 and 1111 +scenarios = list(map(list, itertools.product([0, 1], repeat=4))) + +# Generate new label placeholders for this particular batch +# This is for the G (Changed below) +# label_map = torch.ones((opt.batch_size, 4, opt.img_height, opt.img_width), requires_grad=False).cuda().type( +# torch.cuda.FloatTensor) + +# This is for D (Changed below) +if opt.type == 'T1': + scenarios = [x for x in scenarios if x[0] == 0] +else: + scenarios = [x for x in scenarios if x[1] == 0] +label_list = torch.from_numpy(np.ones((opt.batch_size, + patch[0], + patch[1], + patch[2]))).cuda().type(torch.cuda.FloatTensor) + +# remove the empty scenario and all available scenario +scenarios.remove([0,0,0,0]) +# scenarios.remove([1,1,1,1]) # this is no longer invalid in this case. + +# sort the scenarios according to decreasing difficulty. Easy scenarios last, and difficult ones first. +scenarios.sort(key=lambda x: x.count(1)) + +logger.info("Starting Training") +start_time = time.time() + +for epoch in range(opt.epoch, opt.n_epochs, 1): + D_losses = [] + D_real_losses = [] + D_fake_losses = [] + G_train_l1_losses = [] + G_train_losses = [] + G_losses = [] + synth_losses = [] + + # patient: Whole patient dictionary containing image, seg, name etc. + # x_patient: Just the images of a single patient + # x_r: Batch of images taken from x_patient according to the batch size specified. + # x_z: Batch from x_r where some sequences are imputed with noise for input to G + epoch_start_time = time.time() + for idx_pat, patient in enumerate(n_dataloader): + logger.info("Current idx_pat: {}".format(idx_pat)) + # if idx_pat > opt.train_patient_idx: + # logger.info("Now testing on patient {}".format(opt.train_patient_idx + 1)) + # main_path = os.path.join(root, model, 'scenario_results') + # + # fixed_p = os.path.join(root, model, 'scenario_results', 'viz' + "_" + str(epoch + 1)) + # + # logger.info("Saving result as {}".format(fixed_p)) + # status = show_intermediate_results(generator, test_patient, save_path=main_path, + # all_scenarios=copy.deepcopy(scenarios), epoch=epoch, + # curr_scenario_range=None, + # batch_size_to_test=opt.batch_size) + # break + + # Put the whole patient in GPU to aid quicker training + x_patient = patient['image'] + batch_indices = list(range(0, resize_slices, opt.batch_size)) + + # this shuffles the 2D axial slice batches for efficient training + # tag1 + random.shuffle(batch_indices) + + # create batches out of this patient + for _num, batch_idx in enumerate(batch_indices): + logger.debug("Patient #{}\nBatch #{}".format(idx_pat, _num)) + + logger.debug("\tSplicing batch from x_real") + x_r = x_patient[batch_idx:batch_idx + opt.batch_size, ...] + # x_r = x_r.cuda().type(Tensor) # Tensor will be either a cuda dtype or cpu dtype + + # Curriculum Learning: Train with easier cases in the first epochs, then start training on harder ones + + if opt.c_learning == 1: + # Curriculum Learning: Train with easier cases in the first epochs, then start training on harder ones + if epoch <= 10: + rand_val = torch.Tensor([6]) # last one is the easier, train with that. + if epoch > 10 and epoch <= 20: + rand_val = torch.randint(low=3, high=len(scenarios), size=(1,)) + if epoch > 20: + rand_val = torch.randint(low=0, high=len(scenarios), size=(1,)) + elif opt.c_learning == 0: + rand_val = torch.randint(low=0, high=14, size=(1,)) + elif opt.c_learning == 2: + # always train with the last scenario, means everything present. + rand_val = torch.Tensor([6]) # last one is the easier, train with that. + + + label_scenario = scenarios[int(rand_val.numpy()[0])] + logger.debug('\tTraining this batch with Scenario: {}'.format(label_scenario)) + + # create a new x_imputed and x_real with this label scenario + x_z = x_r.clone().cuda() + + if opt.discrim_type == 1: + label_list_r = torch.from_numpy( + np.ones((opt.batch_size, patch[0], patch[1], patch[2]))).cuda().type( + torch.cuda.FloatTensor) + else: + # This is for D when training on real images (Unchanged) + label_list_r = torch.from_numpy(np.ones((opt.batch_size, 4))).cuda().type(Tensor) + + # impute noise to the input of G + for idx, k in enumerate(label_scenario): + if k == 0: + x_z[:, idx, ...] = torch.ones((opt.img_height, opt.img_width), device=device) * -1.0 + # label_map[:, idx, ...] = 0 + # + # # this works with both discriminator types. + # if opt.dataset != 'ISLES2015' and opt.dataset != 'BRATS2015': + # label_list[:, idx] = 0 + # + # elif k == 1: + # # label_map[:, idx, ...] = 1 + # + # # this works with both discriminator types. + # if opt.dataset != 'ISLES2015' and opt.dataset != 'BRATS2015': + # label_list[:, idx] = 1 + + # The output is ALWAYS synthesized, so its zero always. ISLES2015 outputs just FLAIR + label_list[:, 0] = 0 + + # TRAIN GENERATOR G + logger.debug('\tTraining Generator') + generator.zero_grad() + optimizer_G.zero_grad() + + + # ADDING A RELU AT THE END OF GENERATOR FOR ISLES2015, to make sure values are positive + fake_x = generator(x_z) + + if opt.type == "T1": + SEQ_IDX = 0 + elif opt.type == "T2": + SEQ_IDX = 1 + + # if opt.dataset != 'ISLES2015' and opt.dataset != "BRATS2015": + # fake_x = impute_reals_into_fake(x_z, fake_x, label_scenario) + # + # if opt.dataset != 'ISLES2015' and opt.dataset != "BRATS2015": + # pred_fake = discriminator(fake_x, x_r) + # else: + # I may have to unsqueeze + pred_fake = discriminator(fake_x, x_r[:, SEQ_IDX, ...].unsqueeze(1)) + + # G_train_loss = BCE_loss(D_result, label_list) + # The discriminator should think that the pred_fake is real, so we minimize the loss between pred_fake + # and label_list_r, ie. make the pred_fake look real, and reducing the error that the discriminator makes + # when predicting it. + + if pred_fake.size() != label_list_r.size(): + logger.warning('Error!') + import sys + + sys.exit(-1) + + # fooling the discriminator. We REDUCE this loss value so that it does WELL in predicting pred_fake, and label_list_r + # not backpropagating into discriminator + # TODO + loss_GAN = criterion_GAN(pred_fake.detach(), label_list_r) + + # pixel-wise loss + loss_pixel = 0 + synth_loss = 0 + + # if opt.dataset != 'ISLES2015' and opt.dataset != "BRATS2015": + # for idx_curr_label, i in enumerate(label_scenario): + # if i == 0: + # loss_pixel += criterion_pixelwise(fake_x[:, idx_curr_label, ...], x_r[:, idx_curr_label, ...]) + # + # synth_loss += mse_fake_vs_real(fake_x[:, idx_curr_label, ...], x_r[:, idx_curr_label, ...]) + # else: + # fake_x is already (B, 1, 256, 256) + loss_pixel += criterion_pixelwise(fake_x, x_r[:, SEQ_IDX, ...].unsqueeze(1).type(Tensor)) + + synth_loss += mse_fake_vs_real(fake_x, x_r[:, SEQ_IDX, ...].unsqueeze(1).type(Tensor)) + # logger.debug("Min: {}".format(x_r[:, SEQ_IDX, ...].min())) + # logger.debug("Max: {}".format(x_r[:, SEQ_IDX, ...].max())) + # logger.debug("Mean: {}".format(x_r[:, SEQ_IDX, ...].mean())) + + # variable that sets the relative importance to loss_GAN and loss_pixel + lam = 0.9 + G_train_total_loss = (1 - lam) * loss_GAN + lam * loss_pixel + + G_train_total_loss.backward() + optimizer_G.step() + + # save the losses + G_train_l1_losses.append(loss_pixel.data[0]) + G_train_losses.append(loss_GAN.data[0]) + G_losses.append(G_train_total_loss.data[0]) + synth_losses.append(synth_loss.data[0]) + + # TRAIN DISCRIMINATOR D + # this takes in the real x as X-INPUT and real x as Y-INPUT + logger.debug('\tTraining Discriminator') + discriminator.zero_grad() + optimizer_D.zero_grad() + + # real loss + # if opt.dataset == 'ISLES2015' or opt.dataset == "BRATS2015": + pred_real = discriminator( + x_r[:, SEQ_IDX, ...].unsqueeze(1), + x_r[:, SEQ_IDX, ...].unsqueeze(1)) + # else: + # pred_real = discriminator(x_r, x_r) + + loss_real = criterion_GAN(pred_real, label_list_r) + + # fake loss + # fake_x = generator(x_z, label_map)s + fake_x = generator(x_z) + + # if opt.dataset != 'ISLES2015' and opt.dataset != "BRATS2015": + # fake_x = impute_reals_into_fake(x_z, fake_x.detach(), label_scenario) + # + # if opt.dataset != 'ISLES2015' and opt.dataset != "BRATS2015": + # pred_fake = discriminator(fake_x.detach(), x_r) + # else: + sh = fake_x.shape + pred_fake = discriminator(fake_x.detach(), + x_r[:, SEQ_IDX, ...].unsqueeze(1)) + + loss_fake = criterion_GAN(pred_fake, label_list) + + D_train_loss = 0.5 * (loss_real + loss_fake) + + # for printing purposes + D_real_losses.append(loss_real.data[0]) + D_fake_losses.append(loss_fake.data[0]) + D_losses.append(D_train_loss.data[0]) + + D_train_loss.backward() + optimizer_D.step() + + logger.info(" E [{}/{}] P #{} ".format(epoch, opt.n_epochs, + idx_pat) + 'B [%d/%d] - loss_d: [real: %.5f, fake: %.5f, comb: %.5f], loss_g: [gan: %.5f, l1: %.5f, comb: %.5f], synth_loss_mse(ut): %.5f' % ( + (_num + 1), resize_slices // opt.batch_size, torch.mean(torch.FloatTensor(D_real_losses)), + torch.mean(torch.FloatTensor(D_fake_losses)), + torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_train_losses)), + torch.mean(torch.FloatTensor(G_train_l1_losses)), torch.mean(torch.FloatTensor(G_losses)), + torch.mean(torch.FloatTensor(synth_losses)))) + # Check if we have trained with exactly opt.train_patient_idx patients (if opt.train_patient_idx is 10, then idx_pat will be 9, so this condition will evaluate to true + if idx_pat + 1 == opt.train_patient_idx: + logger.info('Testing on test set for this fold') + main_path = os.path.join(root, model, "{}".format(opt.dataset), 'scenario_results') + + logger.info("Saving results at {}".format(main_path)) + + generator.eval() + + logger.info("Calculating metric on test set") + result_dict_test = calculate_metrics(generator, test_patient, save_path=main_path, + all_scenarios=copy.deepcopy(scenarios), epoch=epoch, + curr_scenario_range=[6, 7], + batch_size_to_test=1, dataset="BRATS2015", # keep this to B15 to maintain compatibility + seq_type=opt.type) + + logger.info("\t\tTesting Performance Numbers") + printTable(result_dict_test) + gc.collect() + + logger.info("Writing detailed visualizations for each scenario") + status = show_intermediate_results(generator, test_patient, save_path=main_path, + all_scenarios=copy.deepcopy(scenarios), epoch=epoch, + curr_scenario_range=[6, 7], + batch_size_to_test=opt.batch_size, seq_type=opt.type, + dataset="BRATS2015") # keep this to B15 to maintain + + generator.train() + gc.collect() + break + + epoch_end_time = time.time() + per_epoch_ptime = epoch_end_time - epoch_start_time + + print( + '[%d/%d] - ptime: %.2f, loss_d: [real: %.5f, fake: %.5f, comb: %.5f], loss_g: [gan: %.5f, l1: %.5f, comb: %.5f], ' + 'synth_loss_mse(ut): %.5f' % ( + (epoch + 1), opt.n_epochs, per_epoch_ptime, torch.mean(torch.FloatTensor(D_real_losses)), + torch.mean(torch.FloatTensor(D_fake_losses)), + torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_train_losses)), + torch.mean(torch.FloatTensor(G_train_l1_losses)), torch.mean(torch.FloatTensor(G_losses)), + torch.mean(torch.FloatTensor(synth_losses)))) + + # Checkpoint the models + + gen_state_checkpoint = { + 'epoch': epoch + 1, + 'arch': opt.model_name, + 'state_dict': generator.state_dict(), + 'optimizer' : optimizer_G.state_dict(), + } + + des_state_checkpoint = { + 'epoch': epoch + 1, + 'arch': opt.model_name, + 'state_dict': discriminator.state_dict(), + 'optimizer': optimizer_D.state_dict(), + } + + save_checkpoint(gen_state_checkpoint, os.path.join(root, model, 'generator_param_{}_{}.pkl'.format(model, epoch + 1)), + pickle_module=pickle) + + save_checkpoint(des_state_checkpoint, + os.path.join(root, model, 'discriminator_param_{}_{}.pkl'.format(model, epoch + 1)), + pickle_module=pickle) + + with open(os.path.join(root, model, "{}".format(opt.dataset), + 'result_dict_test_epoch_{}.pkl'.format(epoch)), 'wb') as f: + pickle.dump(result_dict_test, f) + + logger.info('[Testing] num_pats: {}, mse: {:.5f}, psnr: {:.5f}, ssim: {:.5f}'.format( + opt.test_pats, + result_dict_test['mean']['mse'], + result_dict_test['mean']['psnr'], + result_dict_test['mean']['ssim'] + )) + + train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses))) + train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses))) + + train_hist['test_loss']['mse'].append(result_dict_test['mean']['mse']) + train_hist['test_loss']['psnr'].append(result_dict_test['mean']['psnr']) + train_hist['test_loss']['ssim'].append(result_dict_test['mean']['ssim']) + + train_hist['per_epoch_ptimes'].append(per_epoch_ptime) + + end_time = time.time() + total_ptime = end_time - start_time + train_hist['total_ptime'].append(total_ptime) + + print("Avg one epoch ptime: %.2f, total %d epochs ptime: %.2f" % ( + torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), opt.n_epochs, total_ptime)) + +with open(os.path.join(root, model, 'train_hist.pkl'), 'wb') as f: + pickle.dump(train_hist, f)