forked from juglab/FourierImageTransformer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
161 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import unittest | ||
|
||
from fit.utils.tomo_utils import get_cartesian_rfft_coords_2D, get_polar_rfft_coords_2D, get_polar_rfft_coords_sinogram, \ | ||
get_cartesian_rfft_coords_sinogram | ||
import torch | ||
|
||
import numpy as np | ||
|
||
|
||
class TestTomoUtils(unittest.TestCase): | ||
|
||
def setUp(self) -> None: | ||
self.img_shape = 27 | ||
self.angles = np.array([0, np.pi / 2, np.pi]) | ||
|
||
def test_cartesian_rfft_2D(self): | ||
x, y, flatten_indices, order = get_cartesian_rfft_coords_2D(self.img_shape) | ||
x_ordered = torch.zeros_like(x) | ||
x_ordered[flatten_indices] = x | ||
x_ordered = x_ordered.reshape(self.img_shape, -1) | ||
|
||
y_ordered = torch.zeros_like(y) | ||
y_ordered[flatten_indices] = y | ||
y_ordered = y_ordered.reshape(self.img_shape, -1) | ||
y_ordered = torch.roll(y_ordered, -(self.img_shape // 2 + 1), 0) | ||
|
||
y_target, x_target = torch.meshgrid(torch.arange(self.img_shape), torch.arange(self.img_shape // 2 + 1)) | ||
|
||
self.assertEqual(order[0, 0], 0, 'Top left pixel should have index 0.') | ||
self.assertTrue(torch.all(x_target == x_ordered) and torch.all(y_target == y_ordered), | ||
'rFFT coordinates are wrong.') | ||
|
||
def test_polar_rfft_2D(self): | ||
r, phi, flatten_indices, order = get_polar_rfft_coords_2D(img_shape=self.img_shape) | ||
|
||
self.assertEqual(order[0, 0], 0, 'Top left pixel should have index 0.') | ||
|
||
r_ordered = torch.zeros_like(r) | ||
r_ordered[flatten_indices] = r | ||
r_ordered = r_ordered.reshape(self.img_shape, -1) | ||
self.assertEqual(r_ordered[0, 0], 0, 'Top left pixel does not have radius 0.') | ||
|
||
phi_ordered = torch.zeros_like(phi) | ||
phi_ordered[flatten_indices] = phi | ||
phi_ordered = phi_ordered.reshape(self.img_shape, -1) | ||
self.assertEqual(phi_ordered[0, 0], 0, 'Top left pixel angle does not correspond to 0.') | ||
self.assertEqual(phi_ordered[self.img_shape // 2, 0], np.pi / 2, 'Phi component is of (test 1).') | ||
self.assertEqual(phi_ordered[self.img_shape - 1, 0], -np.pi / 2, 'Phi component is of (test 2).') | ||
|
||
def test_polar_sinogram(self): | ||
r, phi, flatten_indices = get_polar_rfft_coords_sinogram(self.angles, self.img_shape) | ||
self.assertTrue(torch.all((r[0::3] == r[1::3]) == (r[1::3] == r[2::3])), | ||
'Radii of polar sinogram coords are off.') | ||
|
||
phi_ordered = torch.zeros_like(phi) | ||
phi_ordered[flatten_indices] = phi | ||
self.assertTrue(torch.all(phi_ordered[:self.img_shape // 2 + 1] == np.pi / 2.), | ||
'Phi of polar sinogram coords are off (test1).') | ||
self.assertTrue(torch.all(phi_ordered[self.img_shape // 2 + 1:-(self.img_shape // 2 + 1)] == 0), | ||
'Phi of polar sinogram coords are off (test1).') | ||
self.assertTrue(torch.all(phi_ordered[-(self.img_shape // 2 + 1):] == -np.pi / 2.), | ||
'Phi of polar sinogram coords are off (test2).') | ||
|
||
def test_cartesian_sinogram(self): | ||
x, y, flatten_indices = get_cartesian_rfft_coords_sinogram(self.angles, self.img_shape) | ||
print(x) | ||
self.assertTrue(torch.all(x <= self.img_shape // 2 + 1)) | ||
self.assertTrue(torch.all(x >= 0)) | ||
self.assertTrue(torch.all(y <= self.img_shape)) | ||
self.assertTrue(torch.all(y >= 0)) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import unittest | ||
|
||
from fit.utils import cart2pol, pol2cart | ||
import torch | ||
|
||
from fit.utils import normalize, denormalize | ||
from fit.utils import normalize_amp, denormalize_amp | ||
|
||
import numpy as np | ||
|
||
from fit.utils.utils import log_amplitudes, normalize_phi, denormalize_phi, psf_real, normalize_FC, denormalize_FC, \ | ||
convert2DFT | ||
|
||
|
||
class TestUtils(unittest.TestCase): | ||
|
||
def test_cart2pol2cart(self): | ||
x = torch.arange(1, 6, dtype=torch.float32) | ||
y = torch.arange(-2, 3, dtype=torch.float32) | ||
|
||
r, phi = cart2pol(x, y) | ||
x_, y_ = pol2cart(r, phi) | ||
self.assertTrue(torch.allclose(x, x_) and torch.allclose(y, y_), | ||
'Cartesian to polar coordinate transformations are broken.') | ||
|
||
def test_normlize_denormalize_realspace(self): | ||
data = torch.from_numpy(np.array([-1, 2, 4, 0, -5], dtype=np.float32)) | ||
mean = torch.mean(data) | ||
std = torch.std(data) | ||
data_n = normalize(data, mean, std) | ||
self.assertAlmostEqual(torch.mean(data_n).item(), 0, 7) | ||
self.assertAlmostEqual(torch.std(data_n).item(), 1, 7) | ||
|
||
data_dn = denormalize(data_n, mean, std) | ||
self.assertTrue(torch.allclose(data, data_dn)) | ||
|
||
def test_normalize_denormalize_amplitudes(self): | ||
amps = torch.exp(torch.arange(6, dtype=torch.float32)) | ||
log_amps = log_amplitudes(amps) | ||
min_amp = log_amps.min() | ||
max_amp = log_amps.max() | ||
|
||
n_amps = normalize_amp(amps, amp_min=min_amp, amp_max=max_amp) | ||
amps_ = denormalize_amp(n_amps, amp_min=min_amp, amp_max=max_amp) | ||
|
||
self.assertTrue(torch.allclose(amps, amps_)) | ||
|
||
def test_normalize_denormalize_phases(self): | ||
phases = torch.linspace(-np.pi, np.pi, 10) | ||
|
||
phases_n = normalize_phi(phases) | ||
phases_ = denormalize_phi(phases_n) | ||
|
||
self.assertTrue(torch.allclose(phases, phases_)) | ||
|
||
def test_normalize_denormalize_FC(self): | ||
img = psf_real(7, 27) | ||
rfft = torch.fft.rfftn(img) | ||
log_amps = log_amplitudes(rfft.abs()) | ||
min_amp = log_amps.min() | ||
max_amp = log_amps.max() | ||
|
||
amp_n, phi_n = normalize_FC(rfft, amp_min=min_amp, amp_max=max_amp) | ||
fc_n = torch.stack([amp_n, phi_n], -1) | ||
rfft_ = denormalize_FC(fc_n, amp_min=min_amp, amp_max=max_amp) | ||
|
||
self.assertTrue(torch.allclose(rfft, rfft_)) | ||
|
||
def test_convert2DFT(self): | ||
img = psf_real(7, 27) | ||
rfft = torch.fft.rfftn(img) | ||
log_amps = log_amplitudes(rfft.abs()) | ||
min_amp = log_amps.min() | ||
max_amp = log_amps.max() | ||
|
||
order = torch.from_numpy(np.random.permutation(27 * 14)) | ||
amp_n, phi_n = normalize_FC(rfft, amp_min=min_amp, amp_max=max_amp) | ||
fc_n = torch.stack([amp_n.flatten(), phi_n.flatten()], dim=-1)[order] | ||
|
||
dft = convert2DFT(fc_n.unsqueeze(0), amp_min=min_amp, amp_max=max_amp, dst_flatten_order=order, img_shape=27) | ||
img_ = torch.fft.irfftn(dft, s=(27, 27)) | ||
|
||
self.assertTrue(torch.allclose(img, img_)) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |