Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
trane293 committed Sep 10, 2020
0 parents commit e9b0b67
Show file tree
Hide file tree
Showing 34 changed files with 6,280 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
notebooks/.ipynb_checkpoints
.idea/
modules/__pycache__/
*/*/*.pyc
*/*/.ipynb_checkpoints
*/*/__pycache__/
notebooks/
misc_stuff/
128 changes: 128 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
name: pytorch-CycleGAN-and-pix2pix
channels:
- anaconda
- conda-forge
- defaults
dependencies:
- h5py=2.8.0=py35h989c5e5_3
- six=1.11.0=py35_1
- ca-certificates=2018.11.29=ha4d7672_0
- certifi=2018.8.24=py35_1001
- cloudpickle=0.6.1=py_0
- cycler=0.10.0=py_1
- dask-core=1.0.0=py_0
- decorator=4.3.0=py_0
- matplotlib=2.2.3=py35h8e2386c_0
- networkx=2.2=py_1
- openssl=1.0.2p=h470a237_1
- pyqt=5.6.0=py35h8210e8a_7
- python-dateutil=2.7.5=py_0
- pywavelets=1.0.1=py35h7eb728f_0
- scikit-image=0.14.0=py35hfc679d8_1
- sip=4.18.1=py35hfc679d8_0
- toolz=0.9.0=py_1
- tornado=5.1.1=py35h470a237_0
- blas=1.0=mkl
- cffi=1.11.5=py35he75722e_1
- cudatoolkit=9.0=h13b8566_0
- cudnn=7.1.2=cuda9.0_0
- dbus=1.13.2=h714fa37_1
- expat=2.2.6=he6710b0_0
- fontconfig=2.13.0=h9420a91_0
- freetype=2.9.1=h8a8886c_1
- glib=2.56.2=hd408876_0
- gst-plugins-base=1.14.0=hbbd80ab_1
- gstreamer=1.14.0=hb453b48_1
- hdf5=1.10.2=hba1933b_1
- icu=58.2=h9c2bf20_1
- imageio=2.4.1=py35_0
- intel-openmp=2019.1=144
- jpeg=9b=h024ee3a_2
- kiwisolver=1.0.1=py35hf484d3e_0
- libedit=3.1.20170329=h6b74fdf_2
- libffi=3.2.1=hd88cf55_4
- libgcc-ng=8.2.0=hdf63c60_1
- libgfortran-ng=7.3.0=hdf63c60_0
- libpng=1.6.35=hbc83047_0
- libstdcxx-ng=8.2.0=hdf63c60_1
- libtiff=4.0.9=he85c1e1_2
- libuuid=1.0.3=h1bed415_2
- libxcb=1.13=h1bed415_1
- libxml2=2.9.8=h26e45fe_1
- mkl=2018.0.3=1
- mkl_fft=1.0.6=py35h7dd41cf_0
- mkl_random=1.0.1=py35h4414c95_1
- nccl=1.3.5=cuda9.0_0
- ncurses=6.1=hf484d3e_0
- ninja=1.8.2=py35h6bb024c_1
- numpy=1.15.2=py35h1d66e8a_0
- numpy-base=1.15.2=py35h81de0dd_0
- olefile=0.46=py35_0
- pcre=8.42=h439df22_0
- pillow=5.2.0=py35heded4f4_0
- pip=10.0.1=py35_0
- pycparser=2.19=py35_0
- pyparsing=2.2.1=py35_0
- python=3.5.5=hc3d631a_4
- pytorch=0.4.1=py35ha74772b_0
- pytz=2018.5=py35_0
- qt=5.6.3=h39df351_1
- readline=7.0=h7b6447c_5
- scipy=1.1.0=py35hfa4b5c9_1
- setuptools=40.2.0=py35_0
- sqlite=3.25.3=h7b6447c_0
- tk=8.6.8=hbc83047_0
- wheel=0.31.1=py35_0
- xz=5.2.4=h14c3975_4
- zlib=1.2.11=ha838bed_2
- pip:
- backcall==0.1.0
- bleach==3.0.2
- chardet==3.0.4
- dask==1.0.0
- defusedxml==0.5.0
- dominate==2.3.1
- entrypoints==0.2.3
- idna==2.7
- ipykernel==5.1.0
- ipython==7.2.0
- ipython-genutils==0.2.0
- ipywidgets==7.4.2
- jedi==0.13.1
- jinja2==2.10
- jsonschema==2.6.0
- jupyter==1.0.0
- jupyter-client==5.2.3
- jupyter-console==6.0.0
- jupyter-core==4.4.0
- markupsafe==1.1.0
- mistune==0.8.4
- nbconvert==5.4.0
- nbformat==4.4.0
- notebook==5.7.2
- pandocfilters==1.4.2
- parso==0.3.1
- pexpect==4.6.0
- pickleshare==0.7.5
- prometheus-client==0.4.2
- prompt-toolkit==2.0.7
- ptyprocess==0.6.0
- pygments==2.3.0
- pyzmq==17.1.2
- qtconsole==4.4.3
- requests==2.20.1
- send2trash==1.5.0
- terminado==0.8.1
- testpath==0.4.2
- torch==0.4.1
- torchfile==0.1.0
- torchvision==0.2.1
- tqdm==4.28.1
- traitlets==4.3.2
- urllib3==1.24.1
- visdom==0.1.7
- wcwidth==0.1.7
- webencodings==0.5.1
- widgetsnbextension==3.4.2
prefix: /home/asa224/anaconda3/envs/pytorch-CycleGAN-and-pix2pix

2 changes: 2 additions & 0 deletions mmgan_hgg.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/usr/bin/env bash
python train_mmgan_brats2018.py --grade=HGG --train_patient_idx=200 --test_pats=10 --batch_size=4 --dataset=BRATS2018 --n_epochs=60 --model_name=mmgan_hgg_zeros_cl --log_level=info --n_cpu=4 --c_learning=1 --z_type=zeros
2 changes: 2 additions & 0 deletions mmgan_lgg.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/usr/bin/env bash
python train_mmgan_brats2018.py --grade=LGG --train_patient_idx=70 --test_pats=5 --batch_size=4 --dataset=BRATS2018 --n_epochs=60 --model_name=mmgan_lgg_zeros_cl --log_level=info --n_cpu=8 --c_learning=1 --z_type=zeros
35 changes: 35 additions & 0 deletions modules/advanced_gans/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import glob
import random
import os
import numpy as np

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class ImageDataset(Dataset):
def __init__(self, root, transforms_=None, mode='train'):
self.transform = transforms.Compose(transforms_)

self.files = sorted(glob.glob(os.path.join(root, mode) + '/*.*'))
if mode == 'train':
self.files.extend(sorted(glob.glob(os.path.join(root, 'test') + '/*.*')))

def __getitem__(self, index):

img = Image.open(self.files[index % len(self.files)])
w, h = img.size
img_A = img.crop((0, 0, w/2, h))
img_B = img.crop((w/2, 0, w, h))

if np.random.random() < 0.5:
img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], 'RGB')
img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], 'RGB')

img_A = self.transform(img_A)
img_B = self.transform(img_B)

return {'A': img_A, 'B': img_B}

def __len__(self):
return len(self.files)
169 changes: 169 additions & 0 deletions modules/advanced_gans/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import random
np.random.seed(1337)
torch.manual_seed(1337)
random.seed(1337)
torch.backends.cudnn.deterministic = True
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)

##############################
# U-NET
##############################

class UNetDown(nn.Module):
def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
super(UNetDown, self).__init__()
layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
if normalize:
layers.append(nn.InstanceNorm2d(out_size))
layers.append(nn.LeakyReLU(0.2))
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)

def forward(self, x):
return self.model(x)

class UNetUp(nn.Module):
def __init__(self, in_size, out_size, dropout=0.0):
super(UNetUp, self).__init__()
layers = [ nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
nn.InstanceNorm2d(out_size),
nn.ReLU(inplace=True)]
if dropout:
layers.append(nn.Dropout(dropout))

self.model = nn.Sequential(*layers)

def forward(self, x, skip_input):
x = self.model(x)
x = torch.cat((x, skip_input), 1)

return x

class GeneratorUNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, with_tanh=False, with_relu=False):
super(GeneratorUNet, self).__init__()

# original dropout was 0.5
self.down1 = UNetDown(in_channels, 64, normalize=False)
self.down2 = UNetDown(64, 128)
self.down3 = UNetDown(128, 256)
self.down4 = UNetDown(256, 512, dropout=0.2)
self.down5 = UNetDown(512, 512, dropout=0.2)
self.down6 = UNetDown(512, 512, dropout=0.2)
self.down7 = UNetDown(512, 512, dropout=0.2)
self.down8 = UNetDown(512, 512, normalize=False, dropout=0.2)

self.up1 = UNetUp(512, 512, dropout=0.2)
self.up2 = UNetUp(1024, 512, dropout=0.2)
self.up3 = UNetUp(1024, 512, dropout=0.2)
self.up4 = UNetUp(1024, 512, dropout=0.2)
self.up5 = UNetUp(1024, 256)
self.up6 = UNetUp(512, 128)
self.up7 = UNetUp(256, 64)

if with_tanh:

self.final = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(128, out_channels, 4, padding=1),
nn.Tanh()
)
elif with_relu:
# this is for ISLES2015
self.final = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(128, out_channels, 4, padding=1),
nn.ReLU()
)

else:
self.final = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(128, out_channels, 4, padding=1)
)


def forward(self, x):
# U-Net generator with skip connections from encoder to decoder
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d7 = self.down7(d6)
d8 = self.down8(d7)
u1 = self.up1(d8, d7)
u2 = self.up2(u1, d6)
u3 = self.up3(u2, d5)
u4 = self.up4(u3, d4)
u5 = self.up5(u4, d3)
u6 = self.up6(u5, d2)
u7 = self.up7(u6, d1)

return self.final(u7)


##############################
# Discriminator
##############################

class Discriminator(nn.Module):
def __init__(self, in_channels=3, out_channels=4, dataset='BRATS2018'):
super(Discriminator, self).__init__()

# inp, stride, pad, dil, kernel = (256, 2, 1, 1, 8)
# np.floor(((inp + 2*pad - dil*(kernel - 1) - 1)/stride) + 1)

if 'BRATS' in dataset:
def discriminator_block(in_filters, out_filters, normalization=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalization:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*discriminator_block(in_channels*2, 64, normalization=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, out_channels, 4, padding=1, bias=False)
)
else:
# FOR ISLES2015
def discriminator_block(in_filters, out_filters, normalization=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=4, padding=1)]
if normalization:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*discriminator_block(in_channels * 2, 8, normalization=False),
*discriminator_block(8, 8),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(8, out_channels, 4, padding=1, bias=False)
)

def forward(self, img_A, img_B):
# Concatenate image and condition image by channels to produce input
img_input = torch.cat((img_A, img_B), 1)
return self.model(img_input)
Loading

0 comments on commit e9b0b67

Please sign in to comment.