From 5ea8773d00586b7430341799f857dca4fa761b99 Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Wed, 15 Nov 2023 14:54:40 -0700 Subject: [PATCH] Fix phase ramp --- scico/linop/_circconv.py | 12 ++++++++++-- scico/test/linop/test_circconv.py | 3 ++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/scico/linop/_circconv.py b/scico/linop/_circconv.py index dfbd0bb86..d6817e421 100644 --- a/scico/linop/_circconv.py +++ b/scico/linop/_circconv.py @@ -120,7 +120,6 @@ def __init__( output_dtype = snp.dtype(input_dtype) # cannot infer from h_dft because it is complex else: fft_shape = input_shape[-self.ndims :] - pad = () fft_axes = list(range(h.ndim - self.ndims, h.ndim)) self.h_dft = snp.fft.fftn(h, s=fft_shape, axes=fft_axes) output_dtype = result_type(h.dtype, input_dtype) @@ -140,7 +139,16 @@ def __init__( offset = -snp.array(self.h_center) shifts: Tuple[np.ndarray, ...] = np.ix_( *tuple( - np.exp(-1j * k * 2 * np.pi * np.fft.fftfreq(s)) # type: ignore + np.select( + # see "Closed Form Variable Fractional Time Delay Using FFT" or + # "Comments on 'Sinc Interpolation of Discrete Periodic Signals'" + [np.arange(s) < s / 2, np.arange(s) == s / 2, np.arange(s) > s / 2], + [ + np.exp(-1j * k * 2 * np.pi * np.arange(s) / s), + np.cos(k * np.pi), + np.exp(1j * k * 2 * np.pi * (s - np.arange(s)) / s), + ], # type: ignore + ) for k, s in zip(offset, input_shape[-self.ndims :]) ) ) diff --git a/scico/test/linop/test_circconv.py b/scico/test/linop/test_circconv.py index 3a775b917..aaa3122ec 100644 --- a/scico/test/linop/test_circconv.py +++ b/scico/test/linop/test_circconv.py @@ -160,10 +160,11 @@ def test_center(self, center): def test_fractional_center(self): """A fractional center should keep outputs real""" - x, _ = uniform(minval=-1, maxval=1, shape=(3, 4), key=self.key) + x, _ = uniform(minval=-1, maxval=1, shape=(4, 5), key=self.key) h, _ = uniform(minval=-1, maxval=1, shape=(2, 2), key=self.key) A = CircularConvolve(h=h, input_shape=x.shape, h_center=[0.1, 2.7]) + # taken from CircularConvolve._eval x_dft = snp.fft.fftn(x, axes=A.x_fft_axes) hx = snp.fft.ifftn( A.h_dft * x_dft,