From bb9ea43080725bc43102357f400665f5821b033a Mon Sep 17 00:00:00 2001 From: tibuch Date: Thu, 1 Apr 2021 18:41:04 +0200 Subject: [PATCH] Add image prediction function to TRec models. --- fit/modules/TRecTransformerModule.py | 33 ++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/fit/modules/TRecTransformerModule.py b/fit/modules/TRecTransformerModule.py index 0186a56..41d22fc 100644 --- a/fit/modules/TRecTransformerModule.py +++ b/fit/modules/TRecTransformerModule.py @@ -12,7 +12,7 @@ from torch.nn import functional as F import torch.fft -from fit.utils.utils import denormalize, denormalize_amp, denormalize_phi +from fit.utils.utils import denormalize, denormalize_amp, denormalize_phi, denormalize_FC class TRecTransformerModule(LightningModule): @@ -27,7 +27,6 @@ def __init__(self, d_model, sinogram_coords, target_coords, src_flatten_coords, attention_type="linear", n_layers=4, n_heads=4, d_query=4, dropout=0.1, attention_dropout=0.1, only_FBP=False, only_convblock=False, - no_convblock=False, d_conv=8): super().__init__() @@ -46,7 +45,8 @@ def __init__(self, d_model, sinogram_coords, target_coords, src_flatten_coords, "d_query", "dropout", "attention_dropout", - "convblock_only", + "only_FBP", + "only_convblock", "d_conv") self.sinogram_coords = sinogram_coords self.target_coords = target_coords @@ -145,7 +145,7 @@ def _fc_loss(self, pred_fc, target_fc, amp_min, amp_max): return torch.mean(amp_loss + phi_loss), torch.mean(amp_loss), torch.mean(phi_loss) def criterion(self, pred_fc, pred_img, target_fc, amp_min, amp_max): - if self.hparams.convblock_only: + if self.hparams.only_convblock: return self._real_loss(pred_img=pred_img, target_fc=target_fc, amp_min=amp_min, amp_max=amp_max), torch.tensor(0.0), torch.tensor(0.0) else: @@ -337,3 +337,28 @@ def test_epoch_end(self, outputs): self.log('Mean PSNR', torch.mean(outputs).detach().cpu().numpy(), logger=True) self.log('SEM PSNR', torch.std(outputs / np.sqrt(len(outputs))).detach().cpu().numpy(), logger=True) + + def get_imgs(self, x, fbp, y, amp_min, amp_max): + self.eval() + + self.bin_factor = 1 + self.register_buffer('mask', psf_rfft(self.bin_factor, pixel_res=self.hparams.img_shape).to(self.device)) + + x_fc_, fbp_fc_, y_fc_ = self._bin_data(x, fbp, y) + + pred_fc, pred_img = self.trec.forward(x_fc_, fbp_fc_, amp_min=amp_min, amp_max=amp_max, + dst_flatten_coords=self.dst_flatten_order, + img_shape=self.hparams.img_shape, + attenuation=self.mask) + + tmp = denormalize_FC(pred_fc, amp_min=amp_min, amp_max=amp_max) + pred_fc_ = torch.ones(x.shape[0], self.hparams.img_shape * (self.hparams.img_shape // 2 + 1), dtype=x.dtype, + device=x.device) + pred_fc_[:, :tmp.shape[1]] = tmp + + dft_pred_fc = convert2DFT(x=pred_fc, amp_min=amp_min, amp_max=amp_max, + dst_flatten_order=self.dst_flatten_order, img_shape=self.hparams.img_shape) + img_pred_before_conv = torch.roll(torch.fft.irfftn(dft_pred_fc, dim=[1, 2], s=2 * (self.hparams.img_shape,)), + 2 * (self.hparams.img_shape // 2,), (1, 2)) + + return pred_img, img_pred_before_conv