Skip to content

Commit

Permalink
Merge pull request #3 from juglab/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
tibuch authored Jan 11, 2021
2 parents 2c07d53 + d9bfd06 commit dd80473
Show file tree
Hide file tree
Showing 12 changed files with 756 additions and 153 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Build Python package:
`python setup.py bdist_wheel`

Build singularity recipe:
`neurodocker generate singularity -b nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 -p apt --copy /home/tibuch/Gitrepos/FourierImageTransformer/dist/fourier_image_transformers-0.1.1-py3-none-any.whl /fourier_image_transformers-0.1.1-py3-none-any.whl --miniconda create_env=fit conda_install='python=3.7 astra-toolbox pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -c astra-toolbox/label/dev' pip_install='/fourier_image_transformers-0.1.1-py3-none-any.whl' activate=true --entrypoint "/neurodocker/startup.sh python" > v0.1.1.Singularity`
`neurodocker generate singularity -b nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 -p apt --copy /home/tibuch/Gitrepos/FourierImageTransformer/dist/fourier_image_transformers-0.1.2-py3-none-any.whl /fourier_image_transformers-0.1.2-py3-none-any.whl --miniconda create_env=fit conda_install='python=3.7 astra-toolbox pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -c astra-toolbox/label/dev' pip_install='/fourier_image_transformers-0.1.2-py3-none-any.whl' activate=true --entrypoint "/neurodocker/startup.sh python" > v0.1.2.Singularity`

Build singularity container:
`sudo singularity build fit.simg fit.Singularity`
`sudo singularity build fit_v0.1.2.simg v0.1.2.Singularity`
349 changes: 349 additions & 0 deletions examples/DataModule - LoDoPaB TRec.ipynb

Large diffs are not rendered by default.

110 changes: 64 additions & 46 deletions examples/DataModule - MNIST Tomo .ipynb

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions fit/datamodules/tomo_rec/FCDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ def __getitem__(self, item):

mag_min, mag_max = sino_mag.min(), sino_mag.max()

sino_mag = (sino_mag - mag_min) / (mag_max - mag_min)
img_mag = (img_mag - mag_min) / (mag_max - mag_min)
sino_mag = 2 * (sino_mag - mag_min) / (mag_max - mag_min) - 1
img_mag = 2 * (img_mag - mag_min) / (mag_max - mag_min) - 1

sino_phi = sino_phi / (2 * np.pi)
img_phi = img_phi / (2 * np.pi)
sino_phi = 2 * sino_phi / (2 * np.pi) - 1
img_phi = 2 * img_phi / (2 * np.pi) - 1

sino_fft = torch.stack([sino_mag.flatten(), sino_phi.flatten()], dim=-1)
img_fft = torch.stack([img_mag.flatten(), img_phi.flatten()], dim=-1)
return sino_fft, img_fft, img, (mag_min.unsqueeze(-1), mag_max.unsqueeze(-1))

136 changes: 130 additions & 6 deletions fit/datamodules/tomo_rec/TRecDataModule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Union, List

import dival
import numpy as np
import torch
from pytorch_lightning import LightningDataModule
Expand All @@ -8,7 +9,64 @@

from fit.datamodules.tomo_rec.FCDataset import FourierCoefficientDataset
from fit.datamodules.tomo_rec.GroundTruthDataset import GroundTruthDataset
from fit.utils.tomo_utils import get_projection_dataset
import odl
from skimage.transform import resize

from fit.utils.tomo_utils import get_detector_length
from fit.utils.utils import normalize


def get_projection_dataset(dataset, num_angles, im_shape=70, impl='astra_cpu', inner_circle=True):
assert isinstance(dataset, GroundTruthDataset)
reco_space = dataset.space
if inner_circle:
space = odl.uniform_discr(min_pt=reco_space.min_pt,
max_pt=reco_space.max_pt,
shape=(im_shape, im_shape), dtype=np.float32)
min_pt = reco_space.min_pt
max_pt = reco_space.max_pt
proj_space = odl.uniform_discr(min_pt, max_pt, 2 * (2 * int(reco_space.max_pt[0]) - 1,), dtype=np.float32)
detector_length = get_detector_length(proj_space)
det_partition = odl.uniform_partition(-np.sqrt((reco_space.shape[0] / 2.) ** 2 / 2),
np.sqrt((reco_space.shape[0] / 2.) ** 2 / 2),
detector_length)
else:
space = odl.uniform_discr(min_pt=reco_space.min_pt,
max_pt=reco_space.max_pt,
shape=(im_shape, im_shape), dtype=np.float32)
min_pt = reco_space.min_pt
max_pt = reco_space.max_pt
proj_space = odl.uniform_discr(min_pt, max_pt, 2 * (reco_space.shape[0],), dtype=np.float32)
detector_length = get_detector_length(proj_space)
det_partition = odl.uniform_partition(-reco_space.shape[0] / 2., reco_space.shape[0] / 2., detector_length)

angle_partition = odl.uniform_partition(0, np.pi, num_angles)
reco_geometry = odl.tomo.Parallel2dGeometry(angle_partition, det_partition)

ray_trafo = odl.tomo.RayTransform(space, reco_geometry, impl=impl)

def get_reco_ray_trafo(**kwargs):
return odl.tomo.RayTransform(reco_space, reco_geometry, **kwargs)

reco_ray_trafo = get_reco_ray_trafo(impl=impl)

class _ResizeOperator(odl.Operator):
def __init__(self):
super().__init__(reco_space, space)

def _call(self, x, out, **kwargs):
out.assign(space.element(resize(x, (im_shape, im_shape), order=1)))

# forward operator
resize_op = _ResizeOperator()
forward_op = ray_trafo * resize_op

ds = dataset.create_pair_dataset(
forward_op=forward_op, noise_type=None)

ds.get_ray_trafo = get_reco_ray_trafo
ds.ray_trafo = reco_ray_trafo
return ds


class MNISTTomoFourierTargetDataModule(LightningDataModule):
Expand All @@ -26,15 +84,18 @@ def __init__(self, root_dir, batch_size, num_angles=15, inner_circle=True):
self.num_angles = num_angles
self.inner_circle = inner_circle
self.gt_ds = None
self.mean = None
self.std = None

def setup(self, stage: Optional[str] = None):
mnist_test = MNIST(self.root_dir, train=False, download=True).data
mnist_train_test = MNIST(self.root_dir, train=True, download=True).data
mnist_test = MNIST(self.root_dir, train=False, download=True).data.type(torch.float32)
mnist_train_val = MNIST(self.root_dir, train=True, download=True).data.type(torch.float32)
np.random.seed(1612)
perm = np.random.permutation(mnist_train_test.shape[0])
mnist_train = mnist_train_test[perm[:55000], 1:, 1:]
mnist_val = mnist_train_test[perm[55000:], 1:, 1:]
perm = np.random.permutation(mnist_train_val.shape[0])
mnist_train = mnist_train_val[perm[:55000], 1:, 1:]
mnist_val = mnist_train_val[perm[55000:], 1:, 1:]
mnist_test = mnist_test[:, 1:, 1:]

assert mnist_train.shape[1] == MNISTTomoFourierTargetDataModule.IMG_SHAPE
assert mnist_train.shape[2] == MNISTTomoFourierTargetDataModule.IMG_SHAPE
x, y = torch.meshgrid(torch.arange(-MNISTTomoFourierTargetDataModule.IMG_SHAPE // 2 + 1,
Expand All @@ -45,6 +106,13 @@ def setup(self, stage: Optional[str] = None):
mnist_train = circle * np.clip(mnist_train, 50, 255)
mnist_val = circle * np.clip(mnist_val, 50, 255)
mnist_test = circle * np.clip(mnist_test, 50, 255)

self.mean = mnist_train.mean()
self.std = mnist_train.std()

mnist_train = normalize(mnist_train, self.mean, self.std)
mnist_val = normalize(mnist_val, self.mean, self.std)
mnist_test = normalize(mnist_test, self.mean, self.std)
self.gt_ds = get_projection_dataset(
GroundTruthDataset(mnist_train, mnist_val, mnist_test),
num_angles=self.num_angles, im_shape=70, impl='astra_cpu', inner_circle=self.inner_circle)
Expand All @@ -64,3 +132,59 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]
return DataLoader(
FourierCoefficientDataset(self.gt_ds, part='test', img_shape=MNISTTomoFourierTargetDataModule.IMG_SHAPE),
batch_size=1)


class LoDoPaBFourierTargetDataModule(LightningDataModule):
IMG_SHAPE = 361

def __init__(self, batch_size, num_angles=15):
"""
:param root_dir:
:param batch_size:
:param num_angles:
"""
super().__init__()
self.batch_size = batch_size
self.num_angles = num_angles
self.inner_circle = True
self.gt_ds = None

def setup(self, stage: Optional[str] = None):
lodopab = dival.get_standard_dataset('lodopab', impl='astra_cpu')
gt_train = np.array([lodopab.get_sample(i, part='train', out=(False, True))[1][1:, 1:] for i in range(1000)])
gt_val = np.array([lodopab.get_sample(i, part='validation', out=(False, True))[1][1:, 1:] for i in range(100)])
gt_test = np.array([lodopab.get_sample(i, part='test', out=(False, True))[1][1:, 1:] for i in range(1000)])

gt_train = torch.from_numpy(gt_train)
gt_val = torch.from_numpy(gt_val)
gt_test = torch.from_numpy(gt_test)

assert gt_train.shape[1] == LoDoPaBFourierTargetDataModule.IMG_SHAPE
assert gt_train.shape[2] == LoDoPaBFourierTargetDataModule.IMG_SHAPE
x, y = torch.meshgrid(torch.arange(-LoDoPaBFourierTargetDataModule.IMG_SHAPE // 2 + 1,
LoDoPaBFourierTargetDataModule.IMG_SHAPE // 2 + 1),
torch.arange(-LoDoPaBFourierTargetDataModule.IMG_SHAPE // 2 + 1,
LoDoPaBFourierTargetDataModule.IMG_SHAPE // 2 + 1))
circle = torch.sqrt(x ** 2. + y ** 2.) <= LoDoPaBFourierTargetDataModule.IMG_SHAPE // 2
gt_train *= circle
gt_val *= circle
gt_test *= circle
self.gt_ds = get_projection_dataset(
GroundTruthDataset(gt_train, gt_val, gt_test),
num_angles=self.num_angles, im_shape=450, impl='astra_cpu', inner_circle=self.inner_circle)

def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(
FourierCoefficientDataset(self.gt_ds, part='train', img_shape=LoDoPaBFourierTargetDataModule.IMG_SHAPE),
batch_size=self.batch_size, num_workers=2)

def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
FourierCoefficientDataset(self.gt_ds, part='validation',
img_shape=LoDoPaBFourierTargetDataModule.IMG_SHAPE),
batch_size=self.batch_size, num_workers=2)

def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
FourierCoefficientDataset(self.gt_ds, part='test', img_shape=LoDoPaBFourierTargetDataModule.IMG_SHAPE),
batch_size=1)
37 changes: 22 additions & 15 deletions fit/modules/TRecTransformerModule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.core.step_result import TrainResult, EvalResult
from torch.optim.lr_scheduler import ReduceLROnPlateau

from fit.datamodules.tomo_rec import MNISTTomoFourierTargetDataModule
from fit.transformers.TRecTransformer import TRecTransformer
Expand All @@ -12,14 +13,15 @@
from torch.nn import functional as F
import torch.fft

from fit.utils.utils import denormalize


class TRecTransformerModule(LightningModule):
def __init__(self, d_model, y_coords_proj, x_coords_proj, y_coords_img, x_coords_img, src_flatten_coords,
dst_flatten_coords, dst_order, angles, img_shape=27, detector_len=27, init_bin_factor=4,
alpha=1.5, bin_factor_cd=10,
lr=0.0001,
weight_decay=0.01,
loss_switch=0.5,
attention_type="linear", n_layers=4, n_heads=4, d_query=4, dropout=0.1, attention_dropout=0.1):
super().__init__()

Expand All @@ -31,7 +33,6 @@ def __init__(self, d_model, y_coords_proj, x_coords_proj, y_coords_img, x_coords
"detector_len",
"lr",
"weight_decay",
"loss_switch",
"attention_type",
"n_layers",
"n_heads",
Expand Down Expand Up @@ -68,9 +69,6 @@ def __init__(self, d_model, y_coords_proj, x_coords_proj, y_coords_img, x_coords
dropout=self.hparams.dropout,
attention_dropout=self.hparams.attention_dropout)

self.criterion = self._fc_loss
self.using_real_loss = False

x, y = torch.meshgrid(torch.arange(-MNISTTomoFourierTargetDataModule.IMG_SHAPE // 2 + 1,
MNISTTomoFourierTargetDataModule.IMG_SHAPE // 2 + 1),
torch.arange(-MNISTTomoFourierTargetDataModule.IMG_SHAPE // 2 + 1,
Expand All @@ -82,7 +80,12 @@ def forward(self, x, out_pos_emb):

def configure_optimizers(self):
optimizer = RAdam(self.trec.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
return optimizer
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, verbose=True)
return {
'optimizer': optimizer,
'lr_scheduler': scheduler,
'monitor': 'Train/avg_val_mse'
}

def _real_loss(self, pred_fc, target_fc, target_real, mag_min, mag_max):
if self.bin_factor == 1:
Expand All @@ -104,9 +107,15 @@ def _real_loss(self, pred_fc, target_fc, target_real, mag_min, mag_max):
2 * (self.hparams.img_shape // 2,), (1, 2))
return F.mse_loss(y_hat, y_target)

def _fc_loss(self, pred_fc, target_fc, target_real, mag_min, mag_max):
def _fc_loss(self, pred_fc, target_fc):
return F.mse_loss(pred_fc, target_fc)

def criterion(self, pred_fc, target_fc, target_real, mag_min, mag_max):
fc_loss = self._fc_loss(pred_fc=pred_fc, target_fc=target_fc)
real_loss = self._real_loss(pred_fc=pred_fc, target_fc=target_fc, target_real=target_real, mag_min=mag_min,
mag_max=mag_max)
return fc_loss + real_loss

def _bin_data(self, x_fc, y_fc):
shells = (self.hparams.detector_len // 2 + 1) / self.bin_factor
num_sino_fcs = np.clip(self.num_angles * int(shells + 1), 1, x_fc.shape[1])
Expand All @@ -127,12 +136,6 @@ def training_step(self, batch, batch_idx):
loss = self.criterion(pred, y_fc_, y_real, mag_min, mag_max)
return loss

def on_train_epoch_start(self):
if not self.using_real_loss and self.current_epoch >= (self.trainer.max_epochs * self.hparams.loss_switch):
self.criterion = self._real_loss
print('Epoch {}/{}: Switched to real loss.'.format(self.current_epoch, self.trainer.max_epochs - 1))
self.using_real_loss = True

def training_epoch_end(self, outputs):
loss = [d['loss'] for d in outputs]
self.log('Train/loss', torch.mean(torch.stack(loss)), logger=True, on_epoch=True)
Expand Down Expand Up @@ -206,7 +209,8 @@ def validation_epoch_end(self, outputs):
bin_mse = [o['bin_mse'] for o in outputs]
mean_val_mse = torch.mean(torch.stack(val_mse))
mean_bin_mse = torch.mean(torch.stack(bin_mse))
if self.bin_count > self.hparams.bin_factor_cd and mean_val_mse < (self.hparams.alpha * mean_bin_mse) and self.bin_factor > 1:
if self.bin_count > self.hparams.bin_factor_cd and mean_val_mse < (
self.hparams.alpha * mean_bin_mse) and self.bin_factor > 1:
self.bin_count = 0
self.bin_factor = max(1, self.bin_factor - 1)
self.register_buffer('mask', psfft(self.bin_factor, pixel_res=self.hparams.img_shape).to(self.device))
Expand All @@ -233,7 +237,10 @@ def test_step(self, batch, batch_idx):
pred_img = torch.roll(torch.fft.irfftn(pred_dft[0], s=2 * (self.hparams.img_shape,)),
2 * (self.hparams.img_shape // 2,), (0, 1))

return PSNR(self.circle * y_real[0], self.circle * pred_img, drange=torch.tensor(255., dtype=torch.float32))
gt = denormalize(y_real[0], self.trainer.datamodule.mean, self.trainer.datamodule.std)
pred_img = denormalize(pred_img, self.trainer.datamodule.mean, self.trainer.datamodule.std)

return PSNR(self.circle * gt, self.circle * pred_img, drange=torch.tensor(255., dtype=torch.float32))

def test_epoch_end(self, outputs):
outputs = torch.stack(outputs)
Expand Down
57 changes: 0 additions & 57 deletions fit/utils/tomo_utils.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,5 @@
import numpy as np
import odl
import torch
from skimage.transform import resize

from ..datamodules.tomo_rec.GroundTruthDataset import GroundTruthDataset


def get_projection_dataset(dataset, num_angles, im_shape=70, impl='astra_cpu', inner_circle=True):
assert isinstance(dataset, GroundTruthDataset)
reco_space = dataset.space
if inner_circle:
space = odl.uniform_discr(min_pt=reco_space.min_pt,
max_pt=reco_space.max_pt,
shape=(im_shape, im_shape), dtype=np.float32)
min_pt = reco_space.min_pt
max_pt = reco_space.max_pt
proj_space = odl.uniform_discr(min_pt, max_pt, 2 * (2 * int(reco_space.max_pt[0]) - 1,), dtype=np.float32)
detector_length = get_detector_length(proj_space)
det_partition = odl.uniform_partition(-np.sqrt((reco_space.shape[0] / 2.) ** 2 / 2),
np.sqrt((reco_space.shape[0] / 2.) ** 2 / 2),
detector_length)
else:
space = odl.uniform_discr(min_pt=reco_space.min_pt,
max_pt=reco_space.max_pt,
shape=(im_shape, im_shape), dtype=np.float32)
min_pt = reco_space.min_pt
max_pt = reco_space.max_pt
proj_space = odl.uniform_discr(min_pt, max_pt, 2 * (reco_space.shape[0],), dtype=np.float32)
detector_length = get_detector_length(proj_space)
det_partition = odl.uniform_partition(-reco_space.shape[0] / 2., reco_space.shape[0] / 2., detector_length)

angle_partition = odl.uniform_partition(0, np.pi, num_angles)
reco_geometry = odl.tomo.Parallel2dGeometry(angle_partition, det_partition)

ray_trafo = odl.tomo.RayTransform(space, reco_geometry, impl=impl)

def get_reco_ray_trafo(**kwargs):
return odl.tomo.RayTransform(reco_space, reco_geometry, **kwargs)

reco_ray_trafo = get_reco_ray_trafo(impl=impl)

class _ResizeOperator(odl.Operator):
def __init__(self):
super().__init__(reco_space, space)

def _call(self, x, out, **kwargs):
out.assign(space.element(resize(x, (im_shape, im_shape), order=1)))

# forward operator
resize_op = _ResizeOperator()
forward_op = ray_trafo * resize_op

ds = dataset.create_pair_dataset(
forward_op=forward_op, noise_type=None)

ds.get_ray_trafo = get_reco_ray_trafo
ds.ray_trafo = reco_ray_trafo
return ds


def get_detector_length(proj_space):
Expand Down
Loading

0 comments on commit dd80473

Please sign in to comment.