Skip to content

Commit

Permalink
Add image prediction function to TRec models.
Browse files Browse the repository at this point in the history
  • Loading branch information
tibuch committed Apr 1, 2021
1 parent 8f38bfa commit bb9ea43
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions fit/modules/TRecTransformerModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.nn import functional as F
import torch.fft

from fit.utils.utils import denormalize, denormalize_amp, denormalize_phi
from fit.utils.utils import denormalize, denormalize_amp, denormalize_phi, denormalize_FC


class TRecTransformerModule(LightningModule):
Expand All @@ -27,7 +27,6 @@ def __init__(self, d_model, sinogram_coords, target_coords, src_flatten_coords,
attention_type="linear", n_layers=4, n_heads=4, d_query=4, dropout=0.1, attention_dropout=0.1,
only_FBP=False,
only_convblock=False,
no_convblock=False,
d_conv=8):
super().__init__()

Expand All @@ -46,7 +45,8 @@ def __init__(self, d_model, sinogram_coords, target_coords, src_flatten_coords,
"d_query",
"dropout",
"attention_dropout",
"convblock_only",
"only_FBP",
"only_convblock",
"d_conv")
self.sinogram_coords = sinogram_coords
self.target_coords = target_coords
Expand Down Expand Up @@ -145,7 +145,7 @@ def _fc_loss(self, pred_fc, target_fc, amp_min, amp_max):
return torch.mean(amp_loss + phi_loss), torch.mean(amp_loss), torch.mean(phi_loss)

def criterion(self, pred_fc, pred_img, target_fc, amp_min, amp_max):
if self.hparams.convblock_only:
if self.hparams.only_convblock:
return self._real_loss(pred_img=pred_img, target_fc=target_fc, amp_min=amp_min,
amp_max=amp_max), torch.tensor(0.0), torch.tensor(0.0)
else:
Expand Down Expand Up @@ -337,3 +337,28 @@ def test_epoch_end(self, outputs):
self.log('Mean PSNR', torch.mean(outputs).detach().cpu().numpy(), logger=True)
self.log('SEM PSNR', torch.std(outputs / np.sqrt(len(outputs))).detach().cpu().numpy(),
logger=True)

def get_imgs(self, x, fbp, y, amp_min, amp_max):
self.eval()

self.bin_factor = 1
self.register_buffer('mask', psf_rfft(self.bin_factor, pixel_res=self.hparams.img_shape).to(self.device))

x_fc_, fbp_fc_, y_fc_ = self._bin_data(x, fbp, y)

pred_fc, pred_img = self.trec.forward(x_fc_, fbp_fc_, amp_min=amp_min, amp_max=amp_max,
dst_flatten_coords=self.dst_flatten_order,
img_shape=self.hparams.img_shape,
attenuation=self.mask)

tmp = denormalize_FC(pred_fc, amp_min=amp_min, amp_max=amp_max)
pred_fc_ = torch.ones(x.shape[0], self.hparams.img_shape * (self.hparams.img_shape // 2 + 1), dtype=x.dtype,
device=x.device)
pred_fc_[:, :tmp.shape[1]] = tmp

dft_pred_fc = convert2DFT(x=pred_fc, amp_min=amp_min, amp_max=amp_max,
dst_flatten_order=self.dst_flatten_order, img_shape=self.hparams.img_shape)
img_pred_before_conv = torch.roll(torch.fft.irfftn(dft_pred_fc, dim=[1, 2], s=2 * (self.hparams.img_shape,)),
2 * (self.hparams.img_shape // 2,), (1, 2))

return pred_img, img_pred_before_conv

0 comments on commit bb9ea43

Please sign in to comment.