diff --git a/proximal/lin_ops/conv.py b/proximal/lin_ops/conv.py index 433d3d6..eab5c9c 100644 --- a/proximal/lin_ops/conv.py +++ b/proximal/lin_ops/conv.py @@ -29,15 +29,21 @@ def init_kernel(self): self.kernel = np.stack((self.kernel,) * arg.shape[2], axis=-1) # Halide kernel - if self.implementation == Impl['halide'] and \ - (len(arg.shape) == 2 or (len(arg.shape) == 3 and arg.dims == 2)): + if self.implementation == Impl["halide"] and ( + len(arg.shape) == 2 or (len(arg.shape) == 3 and arg.dims == 2) + ): self.kernel = np.asfortranarray(self.kernel.astype(np.float32)) - # Halide FFT (pack into diag) - # TODO: FIX IMREAL LATER + + # Fourier-transformed real-valued signal has Hermitian symmetry. + # Only the right-half plane is needed by Halide-generated + # algorithms as the input. + # + # TODO: Move this logic to proximal.utils hsize = arg.shape if len(arg.shape) == 3 else arg.shape + (1,) output_fft_tmp = np.zeros((int((hsize[0] + 1) / 2) + 1, hsize[1], hsize[2]), dtype=np.complex64, order='F') + # Halide FFT (pack into diag) Halide('fft2_r2c', target_shape=hsize[:2]).fft2_r2c(self.kernel, self.kernel.shape[1] // 2, self.kernel.shape[0] // 2, output_fft_tmp) self.forward_kernel[:] = 0. @@ -56,8 +62,9 @@ def forward(self, inputs, outputs): Reads from inputs and writes to outputs. """ self.init_kernel() - if self.implementation == Impl['halide'] and \ - (len(self.shape) == 2 or (len(self.shape) == 3 and self.dims == 2)): + if self.implementation == Impl["halide"] and ( + len(self.shape) == 2 or (len(self.shape) == 3 and self.dims == 2) + ): # Halide implementation Halide('A_conv').A_conv(inputs[0], self.kernel, self.tmpout) # Call diff --git a/proximal/prox_fns/sum_squares.py b/proximal/prox_fns/sum_squares.py index 7924a4d..dec9869 100644 --- a/proximal/prox_fns/sum_squares.py +++ b/proximal/prox_fns/sum_squares.py @@ -115,10 +115,16 @@ def __init__(self, self.freq_shape = self.K.orig_end.variables()[0].shape self.freq_diag = np.reshape(self.freq_diag, self.freq_shape) - if implem == Impl['halide'] and \ - (len(self.freq_shape) == 2 or (len(self.freq_shape) == 2 and - self.freq_dims == 2)): - # TODO: FIX REAL TO IMAG + if implem == Impl["halide"] and ( + len(self.freq_shape) == 2 + or (len(self.freq_shape) == 3 and self.freq_dims == 2) + ): + + # Fourier-transformed real-valued signal has Hermitian symmetry. + # Only the right-half plane is needed by Halide-generated + # algorithms as the input. + # + # TODO: Move this logic to proximal.utils hsize = self.freq_shape if len( self.freq_shape) == 3 else (self.freq_shape[0], self.freq_shape[1], 1) @@ -212,9 +218,10 @@ def solve(self, b: memoized_expr, rho=None, v=None, lin_solver="lsqr", hash=None # KtK operator is diagonal in frequency domain. elif self.freq_diag is not None: # Frequency inversion - if self.implementation == Impl['halide'] and \ - (len(self.freq_shape) == 2 or - (len(self.freq_shape) == 2 and self.freq_dims == 2)): + if self.implementation == Impl["halide"] and ( + len(self.freq_shape) == 2 + or (len(self.freq_shape) == 3 and self.freq_dims == 2) + ): ftmp_halide_out = np.empty(self.freq_shape, dtype=np.float32, order='F')