Skip to content

Commit

Permalink
Add utils tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
tibuch committed Mar 30, 2021
1 parent 931f91c commit 3ff7976
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 0 deletions.
Empty file added fit_tests/__init__.py
Empty file.
Empty file added fit_tests/utils/__init__.py
Empty file.
74 changes: 74 additions & 0 deletions fit_tests/utils/test_tomo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import unittest

from fit.utils.tomo_utils import get_cartesian_rfft_coords_2D, get_polar_rfft_coords_2D, get_polar_rfft_coords_sinogram, \
get_cartesian_rfft_coords_sinogram
import torch

import numpy as np


class TestTomoUtils(unittest.TestCase):

def setUp(self) -> None:
self.img_shape = 27
self.angles = np.array([0, np.pi / 2, np.pi])

def test_cartesian_rfft_2D(self):
x, y, flatten_indices, order = get_cartesian_rfft_coords_2D(self.img_shape)
x_ordered = torch.zeros_like(x)
x_ordered[flatten_indices] = x
x_ordered = x_ordered.reshape(self.img_shape, -1)

y_ordered = torch.zeros_like(y)
y_ordered[flatten_indices] = y
y_ordered = y_ordered.reshape(self.img_shape, -1)
y_ordered = torch.roll(y_ordered, -(self.img_shape // 2 + 1), 0)

y_target, x_target = torch.meshgrid(torch.arange(self.img_shape), torch.arange(self.img_shape // 2 + 1))

self.assertEqual(order[0, 0], 0, 'Top left pixel should have index 0.')
self.assertTrue(torch.all(x_target == x_ordered) and torch.all(y_target == y_ordered),
'rFFT coordinates are wrong.')

def test_polar_rfft_2D(self):
r, phi, flatten_indices, order = get_polar_rfft_coords_2D(img_shape=self.img_shape)

self.assertEqual(order[0, 0], 0, 'Top left pixel should have index 0.')

r_ordered = torch.zeros_like(r)
r_ordered[flatten_indices] = r
r_ordered = r_ordered.reshape(self.img_shape, -1)
self.assertEqual(r_ordered[0, 0], 0, 'Top left pixel does not have radius 0.')

phi_ordered = torch.zeros_like(phi)
phi_ordered[flatten_indices] = phi
phi_ordered = phi_ordered.reshape(self.img_shape, -1)
self.assertEqual(phi_ordered[0, 0], 0, 'Top left pixel angle does not correspond to 0.')
self.assertEqual(phi_ordered[self.img_shape // 2, 0], np.pi / 2, 'Phi component is of (test 1).')
self.assertEqual(phi_ordered[self.img_shape - 1, 0], -np.pi / 2, 'Phi component is of (test 2).')

def test_polar_sinogram(self):
r, phi, flatten_indices = get_polar_rfft_coords_sinogram(self.angles, self.img_shape)
self.assertTrue(torch.all((r[0::3] == r[1::3]) == (r[1::3] == r[2::3])),
'Radii of polar sinogram coords are off.')

phi_ordered = torch.zeros_like(phi)
phi_ordered[flatten_indices] = phi
self.assertTrue(torch.all(phi_ordered[:self.img_shape // 2 + 1] == np.pi / 2.),
'Phi of polar sinogram coords are off (test1).')
self.assertTrue(torch.all(phi_ordered[self.img_shape // 2 + 1:-(self.img_shape // 2 + 1)] == 0),
'Phi of polar sinogram coords are off (test1).')
self.assertTrue(torch.all(phi_ordered[-(self.img_shape // 2 + 1):] == -np.pi / 2.),
'Phi of polar sinogram coords are off (test2).')

def test_cartesian_sinogram(self):
x, y, flatten_indices = get_cartesian_rfft_coords_sinogram(self.angles, self.img_shape)
print(x)
self.assertTrue(torch.all(x <= self.img_shape // 2 + 1))
self.assertTrue(torch.all(x >= 0))
self.assertTrue(torch.all(y <= self.img_shape))
self.assertTrue(torch.all(y >= 0))


if __name__ == '__main__':
unittest.main()
87 changes: 87 additions & 0 deletions fit_tests/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import unittest

from fit.utils import cart2pol, pol2cart
import torch

from fit.utils import normalize, denormalize
from fit.utils import normalize_amp, denormalize_amp

import numpy as np

from fit.utils.utils import log_amplitudes, normalize_phi, denormalize_phi, psf_real, normalize_FC, denormalize_FC, \
convert2DFT


class TestUtils(unittest.TestCase):

def test_cart2pol2cart(self):
x = torch.arange(1, 6, dtype=torch.float32)
y = torch.arange(-2, 3, dtype=torch.float32)

r, phi = cart2pol(x, y)
x_, y_ = pol2cart(r, phi)
self.assertTrue(torch.allclose(x, x_) and torch.allclose(y, y_),
'Cartesian to polar coordinate transformations are broken.')

def test_normlize_denormalize_realspace(self):
data = torch.from_numpy(np.array([-1, 2, 4, 0, -5], dtype=np.float32))
mean = torch.mean(data)
std = torch.std(data)
data_n = normalize(data, mean, std)
self.assertAlmostEqual(torch.mean(data_n).item(), 0, 7)
self.assertAlmostEqual(torch.std(data_n).item(), 1, 7)

data_dn = denormalize(data_n, mean, std)
self.assertTrue(torch.allclose(data, data_dn))

def test_normalize_denormalize_amplitudes(self):
amps = torch.exp(torch.arange(6, dtype=torch.float32))
log_amps = log_amplitudes(amps)
min_amp = log_amps.min()
max_amp = log_amps.max()

n_amps = normalize_amp(amps, amp_min=min_amp, amp_max=max_amp)
amps_ = denormalize_amp(n_amps, amp_min=min_amp, amp_max=max_amp)

self.assertTrue(torch.allclose(amps, amps_))

def test_normalize_denormalize_phases(self):
phases = torch.linspace(-np.pi, np.pi, 10)

phases_n = normalize_phi(phases)
phases_ = denormalize_phi(phases_n)

self.assertTrue(torch.allclose(phases, phases_))

def test_normalize_denormalize_FC(self):
img = psf_real(7, 27)
rfft = torch.fft.rfftn(img)
log_amps = log_amplitudes(rfft.abs())
min_amp = log_amps.min()
max_amp = log_amps.max()

amp_n, phi_n = normalize_FC(rfft, amp_min=min_amp, amp_max=max_amp)
fc_n = torch.stack([amp_n, phi_n], -1)
rfft_ = denormalize_FC(fc_n, amp_min=min_amp, amp_max=max_amp)

self.assertTrue(torch.allclose(rfft, rfft_))

def test_convert2DFT(self):
img = psf_real(7, 27)
rfft = torch.fft.rfftn(img)
log_amps = log_amplitudes(rfft.abs())
min_amp = log_amps.min()
max_amp = log_amps.max()

order = torch.from_numpy(np.random.permutation(27 * 14))
amp_n, phi_n = normalize_FC(rfft, amp_min=min_amp, amp_max=max_amp)
fc_n = torch.stack([amp_n.flatten(), phi_n.flatten()], dim=-1)[order]

dft = convert2DFT(fc_n.unsqueeze(0), amp_min=min_amp, amp_max=max_amp, dst_flatten_order=order, img_shape=27)
img_ = torch.fft.irfftn(dft, s=(27, 27))

self.assertTrue(torch.allclose(img, img_))


if __name__ == '__main__':
unittest.main()

0 comments on commit 3ff7976

Please sign in to comment.