From 5c8740ce3083a739aa36839ceffef5de2c4beb42 Mon Sep 17 00:00:00 2001 From: paugier Date: Wed, 17 Apr 2024 14:41:52 +0200 Subject: [PATCH] Correl: fix _like_fftshift --- src/fluidimage/calcul/correl.py | 36 ++++++++++++++++++---------- src/fluidimage/calcul/test_correl.py | 10 ++++++-- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/src/fluidimage/calcul/correl.py b/src/fluidimage/calcul/correl.py index be8f4e2f..762683e1 100644 --- a/src/fluidimage/calcul/correl.py +++ b/src/fluidimage/calcul/correl.py @@ -597,25 +597,37 @@ def _like_fftshift(arr: A2dC): """ n0, n1 = arr.shape - assert n0 % 2 == 0 - assert n1 % 2 == 0 - arr = np.ascontiguousarray(arr[::-1, ::-1]) tmp = np.empty_like(arr) - for i0 in range(n0): - for i1 in range(n1 // 2): - tmp[i0, n1 // 2 + i1] = arr[i0, i1] - tmp[i0, i1] = arr[i0, n1 // 2 + i1] + if n1 % 2 == 0: + for i0 in range(n0): + for i1 in range(n1 // 2): + tmp[i0, n1 // 2 + i1] = arr[i0, i1] + tmp[i0, i1] = arr[i0, n1 // 2 + i1] + else: + for i0 in range(n0): + for i1 in range(n1 // 2 + 1): + tmp[i0, n1 // 2 + i1] = arr[i0, i1] + + for i1 in range(n1 // 2): + tmp[i0, i1] = arr[i0, n1 // 2 + 1 + i1] arr_1d_view = arr.ravel() tmp_1d_view = tmp.ravel() - n_half = n0 * n1 // 2 - for idx in range(n_half): - arr_1d_view[idx + n_half] = tmp_1d_view[idx] - arr_1d_view[idx] = tmp_1d_view[idx + n_half] - + if n0 % 2 == 0: + n_half = (n0 // 2) * n1 + for idx in range(n_half): + arr_1d_view[idx + n_half] = tmp_1d_view[idx] + arr_1d_view[idx] = tmp_1d_view[idx + n_half] + else: + n_half_a = (n0 // 2 + 1) * n1 + n_half_b = (n0 // 2) * n1 + for idx in range(n_half_a): + arr_1d_view[idx + n_half_b] = tmp_1d_view[idx] + for idx in range(n_half_b): + arr_1d_view[idx] = tmp_1d_view[idx + n_half_a] return arr diff --git a/src/fluidimage/calcul/test_correl.py b/src/fluidimage/calcul/test_correl.py index a6072577..3d368ada 100644 --- a/src/fluidimage/calcul/test_correl.py +++ b/src/fluidimage/calcul/test_correl.py @@ -242,8 +242,7 @@ def _test2(self, cls=cls, k=k): exec("TestCorrel2.test_correl_images_diff_sizes_" + k + " = _test2") -def test_like_fftshift(): - n0, n1 = 24, 32 +def _test_like_fftshift(n0, n1): correl = np.reshape(np.arange(n0 * n1, dtype=np.float32), (n0, n1)) assert np.allclose( _like_fftshift(correl), @@ -251,5 +250,12 @@ def test_like_fftshift(): ) +def test_like_fftshift(): + _test_like_fftshift(24, 32) + _test_like_fftshift(21, 32) + _test_like_fftshift(12, 13) + _test_like_fftshift(7, 9) + + if __name__ == "__main__": unittest.main()