-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from juglab/dev
v0.1.0
- Loading branch information
Showing
24 changed files
with
1,909 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
astra-toolbox requires cuda 10.2: conda install -c astra-toolbox/label/dev astra-toolbox | ||
|
||
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch | ||
|
||
Build Python package: | ||
`python setup.py bdist_wheel` | ||
|
||
Build singularity recipe: | ||
`neurodocker generate singularity -b nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 -p apt --copy /home/tibuch/Gitrepos/FourierImageTransformer/dist/fourier_image_transformers-0.1.0-py3-none-any.whl /fourier_image_transformers-0.1.0-py3-none-any.whl --miniconda create_env=fit conda_install='python=3.7 astra-toolbox -c astra-toolbox/label/dev pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch' pip_install='/fourier_image_transformers-0.1.0-py3-none-any.whl' activate=true --entrypoint "/neurodocker/startup.sh python" > v0.1.0.Singularity` | ||
|
||
Build singularity container: | ||
`sudo singularity build fit.simg fit.Singularity` |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import numpy as np | ||
import torch | ||
import torch.fft | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class FCDataset(Dataset): | ||
def __init__(self, ds, part='train', img_shape=42): | ||
self.ds = ds.create_torch_dataset(part=part) | ||
self.img_shape = img_shape | ||
self.angles = ds.ray_trafo.geometry.angles | ||
|
||
def __getitem__(self, i): | ||
raise NotImplementedError() | ||
|
||
def __len__(self): | ||
return len(self.ds) | ||
|
||
|
||
class FourierCoefficientDataset(FCDataset): | ||
def __init__(self, ds, part='train', img_shape=42): | ||
super().__init__(ds, part=part, img_shape=img_shape) | ||
|
||
def __getitem__(self, item): | ||
sino, img = self.ds[item] | ||
sino_fft = torch.fft.rfftn(torch.roll(sino, sino.shape[1] // 2, 1), dim=[-1]) | ||
img_fft = torch.fft.rfftn(torch.roll(img, 2 * (img.shape[0] // 2,), (0, 1)), dim=[0, 1]) | ||
|
||
sino_mag = sino_fft.abs() | ||
sino_mag[sino_mag == 0] = 1. | ||
sino_mag = torch.log(sino_mag) | ||
sino_phi = sino_fft.angle() | ||
|
||
img_mag = img_fft.abs() | ||
img_mag[img_mag == 0] = 1. | ||
img_mag = torch.log(img_mag) | ||
img_phi = img_fft.angle() | ||
|
||
mag_min, mag_max = sino_mag.min(), sino_mag.max() | ||
|
||
sino_mag = (sino_mag - mag_min) / (mag_max - mag_min) | ||
img_mag = (img_mag - mag_min) / (mag_max - mag_min) | ||
|
||
sino_phi = sino_phi / (2 * np.pi) | ||
img_phi = img_phi / (2 * np.pi) | ||
|
||
sino_fft = torch.stack([sino_mag.flatten(), sino_phi.flatten()], dim=-1) | ||
img_fft = torch.stack([img_mag.flatten(), img_phi.flatten()], dim=-1) | ||
return sino_fft, img_fft, img, (mag_min.unsqueeze(-1), mag_max.unsqueeze(-1)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import numpy as np | ||
import torch | ||
from dival import Dataset | ||
from dival.datasets.dataset import ObservationGroundTruthPairDataset | ||
from odl import uniform_discr | ||
|
||
|
||
class GroundTruthDataset(Dataset): | ||
def __init__(self, train_gt_images, val_gt_images, test_gt_images): | ||
self.train_gt_images = train_gt_images | ||
self.val_gt_images = val_gt_images | ||
self.test_gt_images = test_gt_images | ||
|
||
self.shape = (self.train_gt_images.shape[1], self.train_gt_images.shape[2]) | ||
min_pt = [-self.shape[0] / 2, -self.shape[1] / 2] | ||
max_pt = [self.shape[0] / 2, self.shape[1] / 2] | ||
space = uniform_discr(min_pt, max_pt, self.shape, dtype=np.float32) | ||
|
||
self.train_len = self.train_gt_images.shape[0] | ||
self.val_len = self.val_gt_images.shape[0] | ||
self.test_len = self.test_gt_images.shape[0] | ||
self.random_access = True | ||
super().__init__(space=space) | ||
|
||
def create_pair_dataset(self, forward_op, post_processor=None, | ||
noise_type=None, noise_kwargs=None, | ||
noise_seeds=None): | ||
|
||
dataset = ObservationGroundTruthPairDataset( | ||
self.generator, forward_op, post_processor=post_processor, | ||
train_len=self.train_len, validation_len=self.val_len, | ||
test_len=self.test_len, noise_type=noise_type, | ||
noise_kwargs=noise_kwargs, noise_seeds=noise_seeds) | ||
return dataset | ||
|
||
def generator(self, part='train'): | ||
if part == 'train': | ||
gen = self._train_generator() | ||
elif part == 'validation': | ||
gen = self._val_generator() | ||
elif part == 'test': | ||
gen = self._test_generator() | ||
else: | ||
raise NotImplementedError | ||
|
||
for gt in gen: | ||
yield gt | ||
|
||
def _train_generator(self): | ||
for i in range(self.train_len): | ||
yield (self.train_gt_images[i].type(torch.float32)) | ||
|
||
def _test_generator(self): | ||
for i in range(self.test_len): | ||
yield (self.test_gt_images[i].type(torch.float32)) | ||
|
||
def _val_generator(self): | ||
for i in range(self.val_len): | ||
yield (self.val_gt_images[i].type(torch.float32)) | ||
|
||
def get_sample(self, index, part='train', out=None): | ||
if out == None: | ||
if part == 'train': | ||
return self.train_gt_images[index].type(torch.float32) | ||
elif part == ' validation': | ||
return self.val_gt_images[index].type(torch.float32) | ||
elif part == 'test': | ||
return self.test_gt_images[index].type(torch.float32) | ||
else: | ||
raise NotImplementedError | ||
else: | ||
if part == 'train': | ||
out = self.train_gt_images[index].type(torch.float32) | ||
elif part == 'validation': | ||
out = self.train_gt_images[index].type(torch.float32) | ||
elif part == 'test': | ||
out = self.test_gt_images[index].type(torch.float32) | ||
else: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from typing import Optional, Union, List | ||
|
||
import numpy as np | ||
from pytorch_lightning import LightningDataModule | ||
from torch.utils.data import DataLoader | ||
from torchvision.datasets import MNIST | ||
|
||
from fit.datamodules.tomo_rec.FCDataset import FourierCoefficientDataset | ||
from fit.datamodules.tomo_rec.GroundTruthDataset import GroundTruthDataset | ||
from fit.utils.tomo_utils import get_projection_dataset | ||
|
||
|
||
class MNISTTomoFourierTargetDataModule(LightningDataModule): | ||
IMG_SHAPE = 28 | ||
|
||
def __init__(self, root_dir, batch_size, num_angles=15): | ||
""" | ||
:param root_dir: | ||
:param batch_size: | ||
:param num_angles: | ||
""" | ||
super().__init__() | ||
self.root_dir = root_dir | ||
self.batch_size = batch_size | ||
self.num_angles = num_angles | ||
self.gt_ds = None | ||
|
||
def setup(self, stage: Optional[str] = None): | ||
mnist_test = MNIST(self.root_dir, train=False, download=True).data | ||
mnist_train_test = MNIST(self.root_dir, train=True, download=True).data | ||
np.random.seed(1612) | ||
perm = np.random.permutation(mnist_train_test.shape[0]) | ||
mnist_train = mnist_train_test[perm[:55000]] | ||
mnist_val = mnist_train_test[perm[55000:]] | ||
self.gt_ds = get_projection_dataset(GroundTruthDataset(mnist_train, mnist_val, mnist_test), | ||
num_angles=self.num_angles, IM_SHAPE=(70, 70), impl='astra_cpu') | ||
|
||
def train_dataloader(self, *args, **kwargs) -> DataLoader: | ||
return DataLoader( | ||
FourierCoefficientDataset(self.gt_ds, part='train', img_shape=MNISTTomoFourierTargetDataModule.IMG_SHAPE), | ||
batch_size=self.batch_size, num_workers=2) | ||
|
||
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: | ||
return DataLoader( | ||
FourierCoefficientDataset(self.gt_ds, part='validation', img_shape=MNISTTomoFourierTargetDataModule.IMG_SHAPE), | ||
batch_size=self.batch_size, num_workers=2) | ||
|
||
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: | ||
return DataLoader( | ||
FourierCoefficientDataset(self.gt_ds, part='test', img_shape=MNISTTomoFourierTargetDataModule.IMG_SHAPE), | ||
batch_size=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .FCDataset import FourierCoefficientDataset | ||
from .TRecDataModule import MNISTTomoFourierTargetDataModule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import torch | ||
from pytorch_lightning import LightningModule | ||
from pytorch_lightning.core.step_result import TrainResult, EvalResult | ||
|
||
from fit.transformers.TRecTransformer import TRecTransformer | ||
from fit.utils import convert2FC, fft_interpolate | ||
from fit.utils.RAdam import RAdam | ||
|
||
import numpy as np | ||
|
||
from torch.nn import functional as F | ||
import torch.fft | ||
|
||
|
||
class TRecTransformerModule(LightningModule): | ||
def __init__(self, d_model, y_coords_proj, x_coords_proj, y_coords_img, x_coords_img, angles, img_shape=362, | ||
lr=0.0001, | ||
weight_decay=0.01, | ||
loss_switch=0.5, | ||
attention_type="linear", n_layers=4, n_heads=4, d_query=4, dropout=0.1, attention_dropout=0.1): | ||
super().__init__() | ||
|
||
self.save_hyperparameters("d_model", | ||
"img_shape", | ||
"lr", | ||
"weight_decay", | ||
"loss_switch", | ||
"attention_type", | ||
"n_layers", | ||
"n_heads", | ||
"d_query", | ||
"dropout", | ||
"attention_dropout") | ||
self.y_coords_proj = y_coords_proj | ||
self.x_coords_proj = x_coords_proj | ||
self.y_coords_img = y_coords_img | ||
self.x_coords_img = x_coords_img | ||
self.angles = angles | ||
self.dft_shape = (img_shape, img_shape // 2 + 1) | ||
|
||
self.trec = TRecTransformer(d_model=self.hparams.d_model, | ||
y_coords_proj=y_coords_proj, x_coords_proj=x_coords_proj, | ||
y_coords_img=y_coords_img, x_coords_img=x_coords_img, | ||
attention_type=self.hparams.attention_type, | ||
n_layers=self.hparams.n_layers, | ||
n_heads=self.hparams.n_heads, | ||
d_query=self.hparams.d_query, | ||
dropout=self.hparams.dropout, | ||
attention_dropout=self.hparams.attention_dropout) | ||
|
||
self.criterion = self._fc_loss | ||
self.using_real_loss = False | ||
|
||
def forward(self, x): | ||
return self.trec.forward(x) | ||
|
||
def configure_optimizers(self): | ||
optimizer = RAdam(self.trec.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) | ||
return optimizer | ||
|
||
def _real_loss(self, pred_fc, target_fc, target_real, mag_min, mag_max): | ||
mag = pred_fc[..., 0] | ||
phi = pred_fc[..., 1] | ||
mag = (mag * (mag_max - mag_min)) + mag_min | ||
mag = torch.exp(mag) | ||
|
||
phi = phi * 2 * np.pi | ||
dft = torch.complex(mag * torch.cos(phi), mag * torch.sin(phi)) | ||
dft = dft.reshape(-1, *self.dft_shape) | ||
y_hat = torch.roll(torch.fft.irfftn(dft, dim=[1, 2]), | ||
(self.hparams.img_shape // 2, self.hparams.img_shape // 2), (1, 2)) | ||
return F.mse_loss(y_hat, target_real) | ||
|
||
def _fc_loss(self, pred_fc, target_fc, target_real, mag_min, mag_max): | ||
return F.mse_loss(pred_fc, target_fc) | ||
|
||
def training_step(self, batch, batch_idx): | ||
x_fc, y_fc, y_real, (mag_min, mag_max) = batch | ||
pred = self(x_fc) | ||
loss = self.criterion(pred, y_fc, y_real, mag_min, mag_max) | ||
return loss | ||
|
||
def on_train_epoch_start(self): | ||
if not self.using_real_loss and self.current_epoch >= (self.trainer.max_epochs * self.hparams.loss_switch): | ||
self.criterion = self._real_loss | ||
print('Epoch {}/{}: Switched to real loss.'.format(self.current_epoch, self.trainer.max_epochs - 1)) | ||
self.using_real_loss = True | ||
|
||
def training_epoch_end(self, outputs): | ||
loss = [d['loss'] for d in outputs] | ||
self.log('Train/loss', torch.mean(torch.stack(loss)), logger=True, on_epoch=True) | ||
|
||
def validation_step(self, batch, batch_idx): | ||
x_fc, y_fc, y_real, (mag_min, mag_max) = batch | ||
pred = self(x_fc) | ||
val_loss = self.criterion(pred, y_fc, y_real, mag_min, mag_max) | ||
val_mse = self._real_loss(pred, y_fc, y_real, mag_min, mag_max) | ||
self.log_dict({'val_loss': val_loss}) | ||
self.log_dict({'val_mse': val_mse}) | ||
if batch_idx == 0: | ||
self.log_val_images(pred, x_fc, y_real, mag_min, mag_max) | ||
return {'val_loss': val_loss, 'val_mse': val_mse} | ||
|
||
def log_val_images(self, pred, x, y_real, mag_min, mag_max): | ||
x_fc = convert2FC(x, mag_min, mag_max) | ||
pred_fc = convert2FC(pred, mag_min, mag_max) | ||
|
||
for i in range(3): | ||
x_dft = fft_interpolate(self.x_coords_proj.cpu().numpy(), self.y_coords_proj.cpu().numpy(), | ||
self.x_coords_img.cpu().numpy(), self.y_coords_img.cpu().numpy(), | ||
x_fc[i].cpu().numpy(), target_shape=self.dft_shape) | ||
x_img = torch.roll(torch.fft.irfftn(torch.from_numpy(x_dft)), | ||
2 * (self.hparams.img_shape // 2,), (0, 1)) | ||
x_img = torch.clamp((x_img - x_img.min()) / (x_img.max() - x_img.min()), 0, 1) | ||
|
||
pred_img = torch.roll(torch.fft.irfftn(pred_fc[i].reshape(self.dft_shape)), | ||
2 * (self.hparams.img_shape // 2,), (0, 1)) | ||
pred_img = torch.clamp((pred_img - pred_img.min()) / (pred_img.max() - pred_img.min()), 0, 1) | ||
|
||
y_img = y_real[i] | ||
y_img = torch.clamp((y_img - y_img.min()) / (y_img.max() - y_img.min()), 0, 1) | ||
|
||
self.trainer.logger.experiment.add_image('inputs/img_{}'.format(i), x_img.unsqueeze(0), | ||
global_step=self.trainer.global_step) | ||
self.trainer.logger.experiment.add_image('predcitions/img_{}'.format(i), pred_img.unsqueeze(0), | ||
global_step=self.trainer.global_step) | ||
self.trainer.logger.experiment.add_image('targets/img_{}'.format(i), y_img.unsqueeze(0), | ||
global_step=self.trainer.global_step) | ||
|
||
def validation_epoch_end(self, outputs): | ||
val_loss = [o['val_loss'] for o in outputs] | ||
val_mse = [o['val_mse'] for o in outputs] | ||
self.log('Train/avg_val_loss', torch.mean(torch.stack(val_loss)), logger=True, on_epoch=True) | ||
self.log('Train/avg_val_mse', torch.mean(torch.stack(val_mse)), logger=True, on_epoch=True) | ||
|
||
def test_step(self, batch, batch_idx): | ||
x, y, y_real, (mag_min, mag_max) = batch | ||
assert len(x) == 1, 'Test images have to be evaluated independently.' | ||
|
||
pred = self(x) | ||
|
||
pred_fc = convert2FC(pred, mag_min, mag_max) | ||
pred_img = torch.roll(torch.fft.irfftn(pred_fc[0].reshape(self.dft_shape)), | ||
2 * (self.hparams.img_shape // 2,), (0, 1)) | ||
return self.PSNR(y_real[0], pred_img, drange=torch.tensor(255., dtype=torch.float32)) | ||
|
||
def test_epoch_end(self, outputs): | ||
outputs = torch.stack(outputs) | ||
self.log('Mean PSNR', torch.mean(outputs).detach().cpu().numpy(), logger=True) | ||
self.log('SEM PSNR', torch.std(outputs/ np.sqrt(len(outputs))).detach().cpu().numpy(), | ||
logger=True) | ||
|
||
def normalize_minmse(self, x, target): | ||
"""Affine rescaling of x, such that the mean squared error to target is minimal.""" | ||
cov = np.cov(x.detach().cpu().numpy().flatten(), target.detach().cpu().numpy().flatten()) | ||
alpha = cov[0, 1] / (cov[0, 0] + 1e-10) | ||
beta = target.mean() - alpha * x.mean() | ||
return alpha * x + beta | ||
|
||
def PSNR(self, gt, img, drange): | ||
img = self.normalize_minmse(img, gt) | ||
mse = torch.mean(torch.square(gt - img)) | ||
return 20 * torch.log10(drange) - 10 * torch.log10(mse) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .TRecTransformerModule import TRecTransformerModule |
Oops, something went wrong.