From 97fec20575c0ccd521402419d6e5d2f7bbe90a9b Mon Sep 17 00:00:00 2001 From: Antony Chan Date: Sun, 7 Apr 2024 21:08:11 -0700 Subject: [PATCH] Check for FFT2 image size When the input image dimensions does not match that of the Halide-accelerated FFT2 input dimensions, print a warning message. Sometimes, it is desirable to apply circular boundary conditions to fill the missing input pixels. --- proximal/halide/halide.py | 9 +++++++++ proximal/halide/interface/fft2_r2c.cpp | 5 +++++ proximal/halide/interface/ifft2_c2r.cpp | 5 +++++ proximal/halide/meson.build | 2 ++ 4 files changed, 21 insertions(+) diff --git a/proximal/halide/halide.py b/proximal/halide/halide.py index 94c2cb0..f2395c2 100644 --- a/proximal/halide/halide.py +++ b/proximal/halide/halide.py @@ -93,6 +93,15 @@ def run(self, *args): launch = importlib.import_module( 'proximal.halide.build.{}'.format(self.module_name)) + + if self.module_name[:4] == 'fft2': + expected_shape = (launch.htarget, launch.wtarget) + if np.any(expected_shape != self.target_shape): + print('Warning: FFT2 shape mismatch. Expected {expected_shape}, found {self.target_shape}. Please recompile.') + + if np.any(expected_shape != args[0].shape): + print('Warning: Input image shape mismatch for FFT2. Expected {expected_shape}, found {self.args[0].shape}. Applying circular boundary condition.') + error = launch.run(*args) if error != 0: diff --git a/proximal/halide/interface/fft2_r2c.cpp b/proximal/halide/interface/fft2_r2c.cpp index 46c52dd..fd74048 100644 --- a/proximal/halide/interface/fft2_r2c.cpp +++ b/proximal/halide/interface/fft2_r2c.cpp @@ -3,6 +3,9 @@ namespace proximal { +constexpr int32_t wtarget{CONFIG_FFT_WIDTH}; +constexpr int32_t htarget{CONFIG_FFT_HEIGHT}; + int fft2_r2c_glue(const array_float_t input, const int xshift, const int yshift, array_cxfloat_t output) { @@ -16,4 +19,6 @@ int fft2_r2c_glue(const array_float_t input, const int xshift, PYBIND11_MODULE(fft2_r2c, m) { m.def("run", &proximal::fft2_r2c_glue, "Apply 2D adjoint convolution"); + m.attr("wtarget") = pybind11::int_(proximal::wtarget); + m.attr("htarget") = pybind11::int_(proximal::htarget); } \ No newline at end of file diff --git a/proximal/halide/interface/ifft2_c2r.cpp b/proximal/halide/interface/ifft2_c2r.cpp index eae3893..c284aba 100644 --- a/proximal/halide/interface/ifft2_c2r.cpp +++ b/proximal/halide/interface/ifft2_c2r.cpp @@ -3,6 +3,9 @@ namespace proximal { +constexpr int32_t wtarget{CONFIG_FFT_WIDTH}; +constexpr int32_t htarget{CONFIG_FFT_HEIGHT}; + int ifft2_c2r_glue(const array_cxfloat_t input, array_float_t output) { auto input_buf = getHalideComplexBuffer<4>(input); @@ -15,4 +18,6 @@ int ifft2_c2r_glue(const array_cxfloat_t input, array_float_t output) { PYBIND11_MODULE(ifft2_c2r, m) { m.def("run", &proximal::ifft2_c2r_glue, "Apply 2D ifft"); + m.attr("wtarget") = pybind11::int_(proximal::wtarget); + m.attr("htarget") = pybind11::int_(proximal::htarget); } \ No newline at end of file diff --git a/proximal/halide/meson.build b/proximal/halide/meson.build index 7a29520..0be2eca 100644 --- a/proximal/halide/meson.build +++ b/proximal/halide/meson.build @@ -209,6 +209,8 @@ foreach p : pipeline_name ], cpp_args: [ '-fvisibility=hidden', + '-DCONFIG_FFT_WIDTH=@0@'.format(get_option('wtarget')), + '-DCONFIG_FFT_HEIGHT=@0@'.format(get_option('htarget')), ], link_with: p['link_with'], dependencies: [