Skip to content

Commit

Permalink
Merge pull request #1 from juglab/dev
Browse files Browse the repository at this point in the history
v0.1.0
  • Loading branch information
tibuch authored Dec 29, 2020
2 parents 20e290a + b578117 commit 8aeb0ec
Show file tree
Hide file tree
Showing 24 changed files with 1,909 additions and 0 deletions.
13 changes: 13 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,16 @@ dmypy.json

# Pyre type checker
.pyre/

# lightning_logs
lightning_logs

# pycharm
.idea

# Singularity
*.simg
*results.json

# lightning_logs
lightning_logs
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
astra-toolbox requires cuda 10.2: conda install -c astra-toolbox/label/dev astra-toolbox

conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch

Build Python package:
`python setup.py bdist_wheel`

Build singularity recipe:
`neurodocker generate singularity -b nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 -p apt --copy /home/tibuch/Gitrepos/FourierImageTransformer/dist/fourier_image_transformers-0.1.0-py3-none-any.whl /fourier_image_transformers-0.1.0-py3-none-any.whl --miniconda create_env=fit conda_install='python=3.7 astra-toolbox -c astra-toolbox/label/dev pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch' pip_install='/fourier_image_transformers-0.1.0-py3-none-any.whl' activate=true --entrypoint "/neurodocker/startup.sh python" > v0.1.0.Singularity`

Build singularity container:
`sudo singularity build fit.simg fit.Singularity`
216 changes: 216 additions & 0 deletions examples/DataModule - MNIST Tomo .ipynb

Large diffs are not rendered by default.

515 changes: 515 additions & 0 deletions examples/MNIST - TRec Example.ipynb

Large diffs are not rendered by default.

Empty file added fit/__init__.py
Empty file.
Empty file added fit/datamodules/__init__.py
Empty file.
50 changes: 50 additions & 0 deletions fit/datamodules/tomo_rec/FCDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np
import torch
import torch.fft
from torch.utils.data import Dataset


class FCDataset(Dataset):
def __init__(self, ds, part='train', img_shape=42):
self.ds = ds.create_torch_dataset(part=part)
self.img_shape = img_shape
self.angles = ds.ray_trafo.geometry.angles

def __getitem__(self, i):
raise NotImplementedError()

def __len__(self):
return len(self.ds)


class FourierCoefficientDataset(FCDataset):
def __init__(self, ds, part='train', img_shape=42):
super().__init__(ds, part=part, img_shape=img_shape)

def __getitem__(self, item):
sino, img = self.ds[item]
sino_fft = torch.fft.rfftn(torch.roll(sino, sino.shape[1] // 2, 1), dim=[-1])
img_fft = torch.fft.rfftn(torch.roll(img, 2 * (img.shape[0] // 2,), (0, 1)), dim=[0, 1])

sino_mag = sino_fft.abs()
sino_mag[sino_mag == 0] = 1.
sino_mag = torch.log(sino_mag)
sino_phi = sino_fft.angle()

img_mag = img_fft.abs()
img_mag[img_mag == 0] = 1.
img_mag = torch.log(img_mag)
img_phi = img_fft.angle()

mag_min, mag_max = sino_mag.min(), sino_mag.max()

sino_mag = (sino_mag - mag_min) / (mag_max - mag_min)
img_mag = (img_mag - mag_min) / (mag_max - mag_min)

sino_phi = sino_phi / (2 * np.pi)
img_phi = img_phi / (2 * np.pi)

sino_fft = torch.stack([sino_mag.flatten(), sino_phi.flatten()], dim=-1)
img_fft = torch.stack([img_mag.flatten(), img_phi.flatten()], dim=-1)
return sino_fft, img_fft, img, (mag_min.unsqueeze(-1), mag_max.unsqueeze(-1))

79 changes: 79 additions & 0 deletions fit/datamodules/tomo_rec/GroundTruthDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import torch
from dival import Dataset
from dival.datasets.dataset import ObservationGroundTruthPairDataset
from odl import uniform_discr


class GroundTruthDataset(Dataset):
def __init__(self, train_gt_images, val_gt_images, test_gt_images):
self.train_gt_images = train_gt_images
self.val_gt_images = val_gt_images
self.test_gt_images = test_gt_images

self.shape = (self.train_gt_images.shape[1], self.train_gt_images.shape[2])
min_pt = [-self.shape[0] / 2, -self.shape[1] / 2]
max_pt = [self.shape[0] / 2, self.shape[1] / 2]
space = uniform_discr(min_pt, max_pt, self.shape, dtype=np.float32)

self.train_len = self.train_gt_images.shape[0]
self.val_len = self.val_gt_images.shape[0]
self.test_len = self.test_gt_images.shape[0]
self.random_access = True
super().__init__(space=space)

def create_pair_dataset(self, forward_op, post_processor=None,
noise_type=None, noise_kwargs=None,
noise_seeds=None):

dataset = ObservationGroundTruthPairDataset(
self.generator, forward_op, post_processor=post_processor,
train_len=self.train_len, validation_len=self.val_len,
test_len=self.test_len, noise_type=noise_type,
noise_kwargs=noise_kwargs, noise_seeds=noise_seeds)
return dataset

def generator(self, part='train'):
if part == 'train':
gen = self._train_generator()
elif part == 'validation':
gen = self._val_generator()
elif part == 'test':
gen = self._test_generator()
else:
raise NotImplementedError

for gt in gen:
yield gt

def _train_generator(self):
for i in range(self.train_len):
yield (self.train_gt_images[i].type(torch.float32))

def _test_generator(self):
for i in range(self.test_len):
yield (self.test_gt_images[i].type(torch.float32))

def _val_generator(self):
for i in range(self.val_len):
yield (self.val_gt_images[i].type(torch.float32))

def get_sample(self, index, part='train', out=None):
if out == None:
if part == 'train':
return self.train_gt_images[index].type(torch.float32)
elif part == ' validation':
return self.val_gt_images[index].type(torch.float32)
elif part == 'test':
return self.test_gt_images[index].type(torch.float32)
else:
raise NotImplementedError
else:
if part == 'train':
out = self.train_gt_images[index].type(torch.float32)
elif part == 'validation':
out = self.train_gt_images[index].type(torch.float32)
elif part == 'test':
out = self.test_gt_images[index].type(torch.float32)
else:
raise NotImplementedError
51 changes: 51 additions & 0 deletions fit/datamodules/tomo_rec/TRecDataModule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Optional, Union, List

import numpy as np
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

from fit.datamodules.tomo_rec.FCDataset import FourierCoefficientDataset
from fit.datamodules.tomo_rec.GroundTruthDataset import GroundTruthDataset
from fit.utils.tomo_utils import get_projection_dataset


class MNISTTomoFourierTargetDataModule(LightningDataModule):
IMG_SHAPE = 28

def __init__(self, root_dir, batch_size, num_angles=15):
"""
:param root_dir:
:param batch_size:
:param num_angles:
"""
super().__init__()
self.root_dir = root_dir
self.batch_size = batch_size
self.num_angles = num_angles
self.gt_ds = None

def setup(self, stage: Optional[str] = None):
mnist_test = MNIST(self.root_dir, train=False, download=True).data
mnist_train_test = MNIST(self.root_dir, train=True, download=True).data
np.random.seed(1612)
perm = np.random.permutation(mnist_train_test.shape[0])
mnist_train = mnist_train_test[perm[:55000]]
mnist_val = mnist_train_test[perm[55000:]]
self.gt_ds = get_projection_dataset(GroundTruthDataset(mnist_train, mnist_val, mnist_test),
num_angles=self.num_angles, IM_SHAPE=(70, 70), impl='astra_cpu')

def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(
FourierCoefficientDataset(self.gt_ds, part='train', img_shape=MNISTTomoFourierTargetDataModule.IMG_SHAPE),
batch_size=self.batch_size, num_workers=2)

def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
FourierCoefficientDataset(self.gt_ds, part='validation', img_shape=MNISTTomoFourierTargetDataModule.IMG_SHAPE),
batch_size=self.batch_size, num_workers=2)

def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
FourierCoefficientDataset(self.gt_ds, part='test', img_shape=MNISTTomoFourierTargetDataModule.IMG_SHAPE),
batch_size=1)
2 changes: 2 additions & 0 deletions fit/datamodules/tomo_rec/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .FCDataset import FourierCoefficientDataset
from .TRecDataModule import MNISTTomoFourierTargetDataModule
163 changes: 163 additions & 0 deletions fit/modules/TRecTransformerModule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.core.step_result import TrainResult, EvalResult

from fit.transformers.TRecTransformer import TRecTransformer
from fit.utils import convert2FC, fft_interpolate
from fit.utils.RAdam import RAdam

import numpy as np

from torch.nn import functional as F
import torch.fft


class TRecTransformerModule(LightningModule):
def __init__(self, d_model, y_coords_proj, x_coords_proj, y_coords_img, x_coords_img, angles, img_shape=362,
lr=0.0001,
weight_decay=0.01,
loss_switch=0.5,
attention_type="linear", n_layers=4, n_heads=4, d_query=4, dropout=0.1, attention_dropout=0.1):
super().__init__()

self.save_hyperparameters("d_model",
"img_shape",
"lr",
"weight_decay",
"loss_switch",
"attention_type",
"n_layers",
"n_heads",
"d_query",
"dropout",
"attention_dropout")
self.y_coords_proj = y_coords_proj
self.x_coords_proj = x_coords_proj
self.y_coords_img = y_coords_img
self.x_coords_img = x_coords_img
self.angles = angles
self.dft_shape = (img_shape, img_shape // 2 + 1)

self.trec = TRecTransformer(d_model=self.hparams.d_model,
y_coords_proj=y_coords_proj, x_coords_proj=x_coords_proj,
y_coords_img=y_coords_img, x_coords_img=x_coords_img,
attention_type=self.hparams.attention_type,
n_layers=self.hparams.n_layers,
n_heads=self.hparams.n_heads,
d_query=self.hparams.d_query,
dropout=self.hparams.dropout,
attention_dropout=self.hparams.attention_dropout)

self.criterion = self._fc_loss
self.using_real_loss = False

def forward(self, x):
return self.trec.forward(x)

def configure_optimizers(self):
optimizer = RAdam(self.trec.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
return optimizer

def _real_loss(self, pred_fc, target_fc, target_real, mag_min, mag_max):
mag = pred_fc[..., 0]
phi = pred_fc[..., 1]
mag = (mag * (mag_max - mag_min)) + mag_min
mag = torch.exp(mag)

phi = phi * 2 * np.pi
dft = torch.complex(mag * torch.cos(phi), mag * torch.sin(phi))
dft = dft.reshape(-1, *self.dft_shape)
y_hat = torch.roll(torch.fft.irfftn(dft, dim=[1, 2]),
(self.hparams.img_shape // 2, self.hparams.img_shape // 2), (1, 2))
return F.mse_loss(y_hat, target_real)

def _fc_loss(self, pred_fc, target_fc, target_real, mag_min, mag_max):
return F.mse_loss(pred_fc, target_fc)

def training_step(self, batch, batch_idx):
x_fc, y_fc, y_real, (mag_min, mag_max) = batch
pred = self(x_fc)
loss = self.criterion(pred, y_fc, y_real, mag_min, mag_max)
return loss

def on_train_epoch_start(self):
if not self.using_real_loss and self.current_epoch >= (self.trainer.max_epochs * self.hparams.loss_switch):
self.criterion = self._real_loss
print('Epoch {}/{}: Switched to real loss.'.format(self.current_epoch, self.trainer.max_epochs - 1))
self.using_real_loss = True

def training_epoch_end(self, outputs):
loss = [d['loss'] for d in outputs]
self.log('Train/loss', torch.mean(torch.stack(loss)), logger=True, on_epoch=True)

def validation_step(self, batch, batch_idx):
x_fc, y_fc, y_real, (mag_min, mag_max) = batch
pred = self(x_fc)
val_loss = self.criterion(pred, y_fc, y_real, mag_min, mag_max)
val_mse = self._real_loss(pred, y_fc, y_real, mag_min, mag_max)
self.log_dict({'val_loss': val_loss})
self.log_dict({'val_mse': val_mse})
if batch_idx == 0:
self.log_val_images(pred, x_fc, y_real, mag_min, mag_max)
return {'val_loss': val_loss, 'val_mse': val_mse}

def log_val_images(self, pred, x, y_real, mag_min, mag_max):
x_fc = convert2FC(x, mag_min, mag_max)
pred_fc = convert2FC(pred, mag_min, mag_max)

for i in range(3):
x_dft = fft_interpolate(self.x_coords_proj.cpu().numpy(), self.y_coords_proj.cpu().numpy(),
self.x_coords_img.cpu().numpy(), self.y_coords_img.cpu().numpy(),
x_fc[i].cpu().numpy(), target_shape=self.dft_shape)
x_img = torch.roll(torch.fft.irfftn(torch.from_numpy(x_dft)),
2 * (self.hparams.img_shape // 2,), (0, 1))
x_img = torch.clamp((x_img - x_img.min()) / (x_img.max() - x_img.min()), 0, 1)

pred_img = torch.roll(torch.fft.irfftn(pred_fc[i].reshape(self.dft_shape)),
2 * (self.hparams.img_shape // 2,), (0, 1))
pred_img = torch.clamp((pred_img - pred_img.min()) / (pred_img.max() - pred_img.min()), 0, 1)

y_img = y_real[i]
y_img = torch.clamp((y_img - y_img.min()) / (y_img.max() - y_img.min()), 0, 1)

self.trainer.logger.experiment.add_image('inputs/img_{}'.format(i), x_img.unsqueeze(0),
global_step=self.trainer.global_step)
self.trainer.logger.experiment.add_image('predcitions/img_{}'.format(i), pred_img.unsqueeze(0),
global_step=self.trainer.global_step)
self.trainer.logger.experiment.add_image('targets/img_{}'.format(i), y_img.unsqueeze(0),
global_step=self.trainer.global_step)

def validation_epoch_end(self, outputs):
val_loss = [o['val_loss'] for o in outputs]
val_mse = [o['val_mse'] for o in outputs]
self.log('Train/avg_val_loss', torch.mean(torch.stack(val_loss)), logger=True, on_epoch=True)
self.log('Train/avg_val_mse', torch.mean(torch.stack(val_mse)), logger=True, on_epoch=True)

def test_step(self, batch, batch_idx):
x, y, y_real, (mag_min, mag_max) = batch
assert len(x) == 1, 'Test images have to be evaluated independently.'

pred = self(x)

pred_fc = convert2FC(pred, mag_min, mag_max)
pred_img = torch.roll(torch.fft.irfftn(pred_fc[0].reshape(self.dft_shape)),
2 * (self.hparams.img_shape // 2,), (0, 1))
return self.PSNR(y_real[0], pred_img, drange=torch.tensor(255., dtype=torch.float32))

def test_epoch_end(self, outputs):
outputs = torch.stack(outputs)
self.log('Mean PSNR', torch.mean(outputs).detach().cpu().numpy(), logger=True)
self.log('SEM PSNR', torch.std(outputs/ np.sqrt(len(outputs))).detach().cpu().numpy(),
logger=True)

def normalize_minmse(self, x, target):
"""Affine rescaling of x, such that the mean squared error to target is minimal."""
cov = np.cov(x.detach().cpu().numpy().flatten(), target.detach().cpu().numpy().flatten())
alpha = cov[0, 1] / (cov[0, 0] + 1e-10)
beta = target.mean() - alpha * x.mean()
return alpha * x + beta

def PSNR(self, gt, img, drange):
img = self.normalize_minmse(img, gt)
mse = torch.mean(torch.square(gt - img))
return 20 * torch.log10(drange) - 10 * torch.log10(mse)
1 change: 1 addition & 0 deletions fit/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .TRecTransformerModule import TRecTransformerModule
Loading

0 comments on commit 8aeb0ec

Please sign in to comment.