Skip to content

Commit

Permalink
Check for FFT2 image size
Browse files Browse the repository at this point in the history
When the input image dimensions does not match that of the
Halide-accelerated FFT2 input dimensions, print a warning message.
Sometimes, it is desirable to apply circular boundary conditions to fill
the missing input pixels.
  • Loading branch information
antonysigma committed Jul 18, 2024
1 parent f76a9d0 commit 97fec20
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 0 deletions.
9 changes: 9 additions & 0 deletions proximal/halide/halide.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ def run(self, *args):

launch = importlib.import_module(
'proximal.halide.build.{}'.format(self.module_name))

if self.module_name[:4] == 'fft2':
expected_shape = (launch.htarget, launch.wtarget)
if np.any(expected_shape != self.target_shape):
print('Warning: FFT2 shape mismatch. Expected {expected_shape}, found {self.target_shape}. Please recompile.')

if np.any(expected_shape != args[0].shape):
print('Warning: Input image shape mismatch for FFT2. Expected {expected_shape}, found {self.args[0].shape}. Applying circular boundary condition.')

error = launch.run(*args)

if error != 0:
Expand Down
5 changes: 5 additions & 0 deletions proximal/halide/interface/fft2_r2c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

namespace proximal {

constexpr int32_t wtarget{CONFIG_FFT_WIDTH};
constexpr int32_t htarget{CONFIG_FFT_HEIGHT};

int fft2_r2c_glue(const array_float_t input, const int xshift,
const int yshift, array_cxfloat_t output) {

Expand All @@ -16,4 +19,6 @@ int fft2_r2c_glue(const array_float_t input, const int xshift,

PYBIND11_MODULE(fft2_r2c, m) {
m.def("run", &proximal::fft2_r2c_glue, "Apply 2D adjoint convolution");
m.attr("wtarget") = pybind11::int_(proximal::wtarget);
m.attr("htarget") = pybind11::int_(proximal::htarget);
}
5 changes: 5 additions & 0 deletions proximal/halide/interface/ifft2_c2r.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

namespace proximal {

constexpr int32_t wtarget{CONFIG_FFT_WIDTH};
constexpr int32_t htarget{CONFIG_FFT_HEIGHT};

int ifft2_c2r_glue(const array_cxfloat_t input, array_float_t output) {

auto input_buf = getHalideComplexBuffer<4>(input);
Expand All @@ -15,4 +18,6 @@ int ifft2_c2r_glue(const array_cxfloat_t input, array_float_t output) {

PYBIND11_MODULE(ifft2_c2r, m) {
m.def("run", &proximal::ifft2_c2r_glue, "Apply 2D ifft");
m.attr("wtarget") = pybind11::int_(proximal::wtarget);
m.attr("htarget") = pybind11::int_(proximal::htarget);
}
2 changes: 2 additions & 0 deletions proximal/halide/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ foreach p : pipeline_name
],
cpp_args: [
'-fvisibility=hidden',
'-DCONFIG_FFT_WIDTH=@0@'.format(get_option('wtarget')),
'-DCONFIG_FFT_HEIGHT=@0@'.format(get_option('htarget')),
],
link_with: p['link_with'],
dependencies: [
Expand Down

0 comments on commit 97fec20

Please sign in to comment.