diff --git a/fit/datamodules/baselines/BaselineDataModule.py b/fit/datamodules/baselines/BaselineDataModule.py deleted file mode 100644 index 8f45108..0000000 --- a/fit/datamodules/baselines/BaselineDataModule.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Optional, Union, List - -import dival -import numpy as np -import torch -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader -from torchvision.datasets import MNIST - -from fit.datamodules.baselines.BaselineDataset import BaselineDataset -from fit.datamodules.tomo_rec.TRecDataModule import get_projection_dataset -from fit.datamodules.tomo_rec.TRecFCDataset import TRecFourierCoefficientDataset -from fit.datamodules.GroundTruthDataset import GroundTruthDataset -import odl -from skimage.transform import resize - -from fit.utils.tomo_utils import get_detector_length -from fit.utils.utils import normalize - - - - -class MNISTBaselineDataModule(LightningDataModule): - IMG_SHAPE = 27 - - def __init__(self, root_dir, batch_size, num_angles=15, inner_circle=True): - """ - :param root_dir: - :param batch_size: - :param num_angles: - """ - super().__init__() - self.root_dir = root_dir - self.batch_size = batch_size - self.num_angles = num_angles - self.inner_circle = inner_circle - self.gt_ds = None - self.mean = None - self.std = None - self.mag_min = None - self.mag_max = None - - def setup(self, stage: Optional[str] = None): - mnist_test = MNIST(self.root_dir, train=False, download=True).data.type(torch.float32) - mnist_train_val = MNIST(self.root_dir, train=True, download=True).data.type(torch.float32) - np.random.seed(1612) - perm = np.random.permutation(mnist_train_val.shape[0]) - mnist_train = mnist_train_val[perm[:55000], 1:, 1:] - mnist_val = mnist_train_val[perm[55000:], 1:, 1:] - mnist_test = mnist_test[:, 1:, 1:] - - assert mnist_train.shape[1] == MNISTBaselineDataModule.IMG_SHAPE - assert mnist_train.shape[2] == MNISTBaselineDataModule.IMG_SHAPE - x, y = torch.meshgrid(torch.arange(-MNISTBaselineDataModule.IMG_SHAPE // 2 + 1, - MNISTBaselineDataModule.IMG_SHAPE // 2 + 1), - torch.arange(-MNISTBaselineDataModule.IMG_SHAPE // 2 + 1, - MNISTBaselineDataModule.IMG_SHAPE // 2 + 1)) - circle = torch.sqrt(x ** 2. + y ** 2.) <= MNISTBaselineDataModule.IMG_SHAPE // 2 - mnist_train = circle * np.clip(mnist_train, 50, 255) - mnist_val = circle * np.clip(mnist_val, 50, 255) - mnist_test = circle * np.clip(mnist_test, 50, 255) - - self.mean = mnist_train.mean() - self.std = mnist_train.std() - - mnist_train = normalize(mnist_train, self.mean, self.std) - mnist_val = normalize(mnist_val, self.mean, self.std) - mnist_test = normalize(mnist_test, self.mean, self.std) - self.gt_ds = get_projection_dataset( - GroundTruthDataset(mnist_train, mnist_val, mnist_test), - num_angles=self.num_angles, im_shape=70, impl='astra_cpu', inner_circle=self.inner_circle) - - tmp_ds = BaselineDataset(self.gt_ds, mean=None, std=None, part='train', - img_shape=MNISTBaselineDataModule.IMG_SHAPE) - self.mag_min = tmp_ds.mean - self.mag_max = tmp_ds.std - - def train_dataloader(self, *args, **kwargs) -> DataLoader: - return DataLoader( - BaselineDataset(self.gt_ds, mean=self.mean, std=self.std, part='train', - img_shape=MNISTBaselineDataModule.IMG_SHAPE), - batch_size=self.batch_size, num_workers=2) - - def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - return DataLoader( - BaselineDataset(self.gt_ds, mean=self.mean, std=self.std, part='validation', - img_shape=MNISTBaselineDataModule.IMG_SHAPE), - batch_size=self.batch_size, num_workers=2) - - def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - return DataLoader( - BaselineDataset(self.gt_ds, mean=self.mean, std=self.std, part='test', - img_shape=MNISTBaselineDataModule.IMG_SHAPE), - batch_size=1) diff --git a/fit/datamodules/baselines/BaselineDataset.py b/fit/datamodules/baselines/BaselineDataset.py deleted file mode 100644 index cba92ff..0000000 --- a/fit/datamodules/baselines/BaselineDataset.py +++ /dev/null @@ -1,37 +0,0 @@ -import numpy as np -import torch -import torch.fft -from skimage.transform import iradon -from torch.utils.data import Dataset - - -class BaselineDataset(Dataset): - def __init__(self, ds, mean, std, part='train', img_shape=42): - self.ds = ds.create_torch_dataset(part=part) - self.img_shape = img_shape - self.angles = ds.ray_trafo.geometry.angles - if mean == None and std == None: - tmp_recos = [] - for i in np.random.permutation(len(self.ds))[:200]: - sino, _ = self.ds[i] - reco = iradon(sino.numpy().T, theta=-np.rad2deg(self.angles), circle=True, - filter_name='cosine').T - tmp_recos.append(torch.from_numpy(reco)) - - tmp_recos = torch.stack(tmp_recos) - self.mean = tmp_recos.mean() - self.std = tmp_recos.std() - else: - self.mean = mean - self.std = std - - def __getitem__(self, item): - sino, img = self.ds[item] - reco = iradon(sino.numpy().T, theta=-np.rad2deg(self.angles), circle=True, - filter_name='cosine').T - reco = torch.from_numpy(reco) - reco = (reco - self.mean)/self.std - return reco.unsqueeze(0), img.unsqueeze(0) - - def __len__(self): - return len(self.ds) diff --git a/fit/datamodules/baselines/__init__.py b/fit/datamodules/baselines/__init__.py deleted file mode 100644 index 6502f31..0000000 --- a/fit/datamodules/baselines/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .BaselineDataModule import MNISTBaselineDataModule \ No newline at end of file diff --git a/fit/modules/ConvBlockBaselineModule.py b/fit/modules/ConvBlockBaselineModule.py deleted file mode 100644 index 5d7728e..0000000 --- a/fit/modules/ConvBlockBaselineModule.py +++ /dev/null @@ -1,113 +0,0 @@ -import torch -from pytorch_lightning import LightningModule -from torch.optim.lr_scheduler import ReduceLROnPlateau - -from fit.baselines.ConvBlockBaseline import ConvBlockBaseline -from fit.datamodules.tomo_rec import MNISTTomoFourierTargetDataModule -from fit.utils import PSNR -from fit.utils.RAdam import RAdam - -import numpy as np - -from torch.nn import functional as F -import torch.fft - -from fit.utils.utils import denormalize - - -class ConvBlockBaselineModule(LightningModule): - def __init__(self, img_shape=27, - lr=0.0001, - weight_decay=0.01, - d_query=4): - super().__init__() - - self.save_hyperparameters("img_shape", - "lr", - "weight_decay", - "d_query") - - self.cblock = ConvBlockBaseline(d_query=self.hparams.d_query) - - x, y = torch.meshgrid(torch.arange(-MNISTTomoFourierTargetDataModule.IMG_SHAPE // 2 + 1, - MNISTTomoFourierTargetDataModule.IMG_SHAPE // 2 + 1), - torch.arange(-MNISTTomoFourierTargetDataModule.IMG_SHAPE // 2 + 1, - MNISTTomoFourierTargetDataModule.IMG_SHAPE // 2 + 1)) - self.register_buffer('circle', torch.sqrt(x ** 2. + y ** 2.) <= MNISTTomoFourierTargetDataModule.IMG_SHAPE // 2) - - def forward(self, x): - return self.cblock.forward(x) - - def configure_optimizers(self): - optimizer = RAdam(self.cblock.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) - scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, verbose=True) - return { - 'optimizer': optimizer, - 'lr_scheduler': scheduler, - 'monitor': 'Train/avg_val_mse' - } - - def training_step(self, batch, batch_idx): - x, y = batch - - pred = self.cblock.forward(x) - - loss = F.mse_loss(pred, y) - return {'loss': loss} - - def training_epoch_end(self, outputs): - loss = [d['loss'] for d in outputs] - self.log('Train/loss', torch.mean(torch.stack(loss)), logger=True, on_epoch=True) - - def validation_step(self, batch, batch_idx): - x, y = batch - - pred = self.cblock.forward(x) - - val_loss = F.mse_loss(pred, y) - if batch_idx == 0: - self.log_val_images(pred, x, y) - return {'val_loss': val_loss, 'val_mse': val_loss} - - def log_val_images(self, pred_img, x, y): - - for i in range(3): - x_img = x[i] - x_img = torch.clamp((x_img - x_img.min()) / (x_img.max() - x_img.min()), 0, 1) - pred_img_ = pred_img[i] - pred_img_ = torch.clamp((pred_img_ - pred_img_.min()) / (pred_img_.max() - pred_img_.min()), 0, 1) - y_img = y[i] - y_img = torch.clamp((y_img - y_img.min()) / (y_img.max() - y_img.min()), 0, 1) - print(x_img.shape, y_img.shape, pred_img_.shape) - self.trainer.logger.experiment.add_image('inputs/img_{}'.format(i), x_img, - global_step=self.trainer.global_step) - self.trainer.logger.experiment.add_image('predcitions/img_{}'.format(i), pred_img_, - global_step=self.trainer.global_step) - self.trainer.logger.experiment.add_image('targets/img_{}'.format(i), y_img, - global_step=self.trainer.global_step) - - def validation_epoch_end(self, outputs): - val_loss = [o['val_loss'] for o in outputs] - val_mse = [o['val_mse'] for o in outputs] - mean_val_mse = torch.mean(torch.stack(val_mse)) - - self.log('Train/avg_val_loss', torch.mean(torch.stack(val_loss)), logger=True, on_epoch=True) - self.log('Train/avg_val_mse', mean_val_mse, logger=True, on_epoch=True) - - def test_step(self, batch, batch_idx): - x, y = batch - assert len(x) == 1, 'Test images have to be evaluated independently.' - - pred_img = self.cblock.forward(x) - - - gt = denormalize(y[0,0], self.trainer.datamodule.mean, self.trainer.datamodule.std) - pred_img = denormalize(pred_img[0,0], self.trainer.datamodule.mean, self.trainer.datamodule.std) - - return PSNR(self.circle * gt, self.circle * pred_img, drange=torch.tensor(255., dtype=torch.float32)) - - def test_epoch_end(self, outputs): - outputs = torch.stack(outputs) - self.log('Mean PSNR', torch.mean(outputs).detach().cpu().numpy(), logger=True) - self.log('SEM PSNR', torch.std(outputs / np.sqrt(len(outputs))).detach().cpu().numpy(), - logger=True) diff --git a/scripts/Baseline_MNIST.py b/scripts/Baseline_MNIST.py deleted file mode 100755 index 30b8dfa..0000000 --- a/scripts/Baseline_MNIST.py +++ /dev/null @@ -1,73 +0,0 @@ -import argparse -import glob -import json -from os.path import exists - -from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.callbacks import ModelCheckpoint - -from fit.datamodules.baselines import MNISTBaselineDataModule -from fit.modules.ConvBlockBaselineModule import ConvBlockBaselineModule - - -def main(): - seed_everything(28122020) - - parser = argparse.ArgumentParser(description="") - parser.add_argument("--exp_config") - - args = parser.parse_args() - - with open(args.exp_config) as f: - conf = json.load(f) - - dm = MNISTBaselineDataModule(root_dir=conf['root_dir'], batch_size=conf['batch_size']) - dm.setup() - - model = ConvBlockBaselineModule(img_shape=dm.IMG_SHAPE, - lr=conf['lr'], weight_decay=0.01, - d_query=conf['d_query']) - - if exists('lightning_logs'): - print('Some experiments already exist. Abort.') - return 0 - - trainer = Trainer(max_epochs=conf['max_epochs'], - gpus=1, - checkpoint_callback=ModelCheckpoint( - filepath=None, - save_top_k=1, - verbose=False, - save_last=True, - monitor='Train/avg_val_mse', - mode='min', - prefix='best_val_loss_' - ), - deterministic=True) - - trainer.fit(model, datamodule=dm); - - model = ConvBlockBaselineModule.load_from_checkpoint('lightning_logs/version_0/checkpoints/best_val_loss_-last.ckpt') - - test_res = trainer.test(model, datamodule=dm)[0] - out_res = { - "Mean PSNR": test_res["Mean PSNR"].item(), - "SEM PSNR": test_res["SEM PSNR"].item() - } - with open('last_ckpt_results.json', 'w') as f: - json.dump(out_res, f) - - best_path = glob.glob('lightning_logs/version_0/checkpoints/best_val_loss_-epoch*')[0] - model = ConvBlockBaselineModule.load_from_checkpoint(best_path) - - test_res = trainer.test(model, datamodule=dm)[0] - out_res = { - "Mean PSNR": test_res["Mean PSNR"].item(), - "SEM PSNR": test_res["SEM PSNR"].item() - } - with open('best_ckpt_results.json', 'w') as f: - json.dump(out_res, f) - - -if __name__ == "__main__": - main() diff --git a/scripts/baseline_config.json b/scripts/baseline_config.json deleted file mode 100644 index e78997a..0000000 --- a/scripts/baseline_config.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "root_dir": "/home/tibuch/Data/mnist", - "batch_size": 32, - "d_query": 16, - "lr": 0.0001, - "max_epochs": 100 -} \ No newline at end of file