diff --git a/proximal/halide/halide.py b/proximal/halide/halide.py index 94c2cb0..f2395c2 100644 --- a/proximal/halide/halide.py +++ b/proximal/halide/halide.py @@ -93,6 +93,15 @@ def run(self, *args): launch = importlib.import_module( 'proximal.halide.build.{}'.format(self.module_name)) + + if self.module_name[:4] == 'fft2': + expected_shape = (launch.htarget, launch.wtarget) + if np.any(expected_shape != self.target_shape): + print('Warning: FFT2 shape mismatch. Expected {expected_shape}, found {self.target_shape}. Please recompile.') + + if np.any(expected_shape != args[0].shape): + print('Warning: Input image shape mismatch for FFT2. Expected {expected_shape}, found {self.args[0].shape}. Applying circular boundary condition.') + error = launch.run(*args) if error != 0: diff --git a/proximal/halide/interface/fft2_r2c.cpp b/proximal/halide/interface/fft2_r2c.cpp index 46c52dd..fd74048 100644 --- a/proximal/halide/interface/fft2_r2c.cpp +++ b/proximal/halide/interface/fft2_r2c.cpp @@ -3,6 +3,9 @@ namespace proximal { +constexpr int32_t wtarget{CONFIG_FFT_WIDTH}; +constexpr int32_t htarget{CONFIG_FFT_HEIGHT}; + int fft2_r2c_glue(const array_float_t input, const int xshift, const int yshift, array_cxfloat_t output) { @@ -16,4 +19,6 @@ int fft2_r2c_glue(const array_float_t input, const int xshift, PYBIND11_MODULE(fft2_r2c, m) { m.def("run", &proximal::fft2_r2c_glue, "Apply 2D adjoint convolution"); + m.attr("wtarget") = pybind11::int_(proximal::wtarget); + m.attr("htarget") = pybind11::int_(proximal::htarget); } \ No newline at end of file diff --git a/proximal/halide/interface/ifft2_c2r.cpp b/proximal/halide/interface/ifft2_c2r.cpp index eae3893..c284aba 100644 --- a/proximal/halide/interface/ifft2_c2r.cpp +++ b/proximal/halide/interface/ifft2_c2r.cpp @@ -3,6 +3,9 @@ namespace proximal { +constexpr int32_t wtarget{CONFIG_FFT_WIDTH}; +constexpr int32_t htarget{CONFIG_FFT_HEIGHT}; + int ifft2_c2r_glue(const array_cxfloat_t input, array_float_t output) { auto input_buf = getHalideComplexBuffer<4>(input); @@ -15,4 +18,6 @@ int ifft2_c2r_glue(const array_cxfloat_t input, array_float_t output) { PYBIND11_MODULE(ifft2_c2r, m) { m.def("run", &proximal::ifft2_c2r_glue, "Apply 2D ifft"); + m.attr("wtarget") = pybind11::int_(proximal::wtarget); + m.attr("htarget") = pybind11::int_(proximal::htarget); } \ No newline at end of file diff --git a/proximal/halide/meson.build b/proximal/halide/meson.build index 7a29520..0be2eca 100644 --- a/proximal/halide/meson.build +++ b/proximal/halide/meson.build @@ -209,6 +209,8 @@ foreach p : pipeline_name ], cpp_args: [ '-fvisibility=hidden', + '-DCONFIG_FFT_WIDTH=@0@'.format(get_option('wtarget')), + '-DCONFIG_FFT_HEIGHT=@0@'.format(get_option('htarget')), ], link_with: p['link_with'], dependencies: [