Skip to content

Commit

Permalink
Merge pull request #454 from DiamondLightSource/improve-paganin-memor…
Browse files Browse the repository at this point in the history
…y-estimation

Improve paganin filter method's memory estimation
  • Loading branch information
yousefmoazzam authored Sep 26, 2024
2 parents 351e028 + b32a2c8 commit 0852f3b
Showing 1 changed file with 97 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from typing import Tuple
import numpy as np

from httomo.cufft import CufftType, cufft_estimate_2d

__all__ = [
"_calc_memory_bytes_paganin_filter_savu",
"_calc_memory_bytes_paganin_filter_tomopy",
Expand All @@ -37,20 +39,55 @@ def _calc_memory_bytes_paganin_filter_savu(
) -> Tuple[int, int]:
pad_x = kwargs["pad_x"]
pad_y = kwargs["pad_y"]
input_size = np.prod(non_slice_dims_shape) * dtype.itemsize
in_slice_size = (
(non_slice_dims_shape[0] + 2 * pad_y)
* (non_slice_dims_shape[1] + 2 * pad_x)
* dtype.itemsize

# Input (unpadded)
unpadded_in_slice_size = np.prod(non_slice_dims_shape) * dtype.itemsize

# Padded input
padded_non_slice_dims_shape = (
non_slice_dims_shape[0] + 2 * pad_y,
non_slice_dims_shape[1] + 2 * pad_x,
)
# FFT needs complex inputs, so copy to complex happens first
complex_slice = in_slice_size / dtype.itemsize * np.complex64().nbytes
fftplan_slice = complex_slice
filter_size = complex_slice
res_slice = np.prod(non_slice_dims_shape) * np.float32().nbytes
tot_memory_bytes = (
input_size + in_slice_size + complex_slice + fftplan_slice + res_slice
padded_in_slice_size = (
padded_non_slice_dims_shape[0] * padded_non_slice_dims_shape[1] * dtype.itemsize
)

# Padded input cast to `complex64`
complex_slice = padded_in_slice_size / dtype.itemsize * np.complex64().nbytes

# Plan size for 2D FFT
fftplan_slice_size = cufft_estimate_2d(
nx=padded_non_slice_dims_shape[1],
ny=padded_non_slice_dims_shape[0],
fft_type=CufftType.CUFFT_C2C,
)

# Shape of 2D filter is the same as the padded `complex64` slice shape, so the size will be
# the same
filter_size = complex_slice

# Size of cropped/unpadded + cast to float32 result of 2D IFFT
cropped_float32_res_slice = np.prod(non_slice_dims_shape) * np.float32().nbytes

# If the FFT plan size is negligible for some reason, this changes where the peak GPU
# memory usage occurs. Hence, the if/else branching below for calculating the total bytes.
NEGLIGIBLE_FFT_PLAN_SIZE = 16
if fftplan_slice_size < NEGLIGIBLE_FFT_PLAN_SIZE:
tot_memory_bytes = int(
unpadded_in_slice_size + padded_in_slice_size + complex_slice
)
else:
tot_memory_bytes = int(
unpadded_in_slice_size
+ padded_in_slice_size
+ complex_slice
# The padded float32 array is deallocated when a copy is made when casting to complex64
# and the variable `padded_tomo` is reassigned to the complex64 version
- padded_in_slice_size
+ fftplan_slice_size
+ cropped_float32_res_slice
)

return (tot_memory_bytes, filter_size)


Expand All @@ -61,11 +98,14 @@ def _calc_memory_bytes_paganin_filter_tomopy(
) -> Tuple[int, int]:
from httomolibgpu.prep.phase import _shift_bit_length

# Input (unpadded)
unpadded_in_slice_size = np.prod(non_slice_dims_shape) * dtype.itemsize

# estimate padding size here based on non_slice dimensions
pad_tup = []
for index, element in enumerate(non_slice_dims_shape):
diff = _shift_bit_length(element + 1) - element
if element % 2 == 0:
for dim_len in non_slice_dims_shape:
diff = _shift_bit_length(dim_len + 1) - dim_len
if dim_len % 2 == 0:
pad_width = diff // 2
pad_width = (pad_width, pad_width)
else:
Expand All @@ -75,34 +115,55 @@ def _calc_memory_bytes_paganin_filter_tomopy(
pad_width = (left_pad, right_pad)
pad_tup.append(pad_width)

input_size = np.prod(non_slice_dims_shape) * dtype.itemsize

in_slice_size = (
# Padded input
padded_in_slice_size = (
(non_slice_dims_shape[0] + pad_tup[0][0] + pad_tup[0][1])
* (non_slice_dims_shape[1] + pad_tup[1][0] + pad_tup[1][1])
* dtype.itemsize
)
out_slice_size = (
(non_slice_dims_shape[0] + pad_tup[0][0] + pad_tup[0][1])
* (non_slice_dims_shape[1] + pad_tup[1][0] + pad_tup[1][1])
* dtype.itemsize

# Padded input cast to `complex64`
complex_slice = padded_in_slice_size / dtype.itemsize * np.complex64().nbytes

# Plan size for 2D FFT
ny = non_slice_dims_shape[0] + pad_tup[0][0] + pad_tup[0][1]
nx = non_slice_dims_shape[1] + pad_tup[1][0] + pad_tup[1][1]
fftplan_slice_size = cufft_estimate_2d(
nx=nx,
ny=ny,
fft_type=CufftType.CUFFT_C2C,
)

# FFT needs complex inputs, so copy to complex happens first
complex_slice = in_slice_size / dtype.itemsize * np.complex64().nbytes
fftplan_slice = complex_slice
grid_size = np.prod(non_slice_dims_shape) * np.float32().nbytes
# Size of "reciprocal grid" generated, based on padded projections shape
grid_size = np.prod((ny, nx)) * np.float32().nbytes
filter_size = grid_size
res_slice = grid_size

tot_memory_bytes = int(
input_size
+ in_slice_size
+ out_slice_size
+ 2 * complex_slice
+ 0.5 * fftplan_slice
+ 2 * res_slice
)

# Size of cropped/unpadded + cast to float32 result of 2D IFFT
cropped_float32_res_slice = np.prod(non_slice_dims_shape) * np.float32().nbytes

# Size of negative log of cropped float32 result of 2D IFFT
negative_log_slice = cropped_float32_res_slice

# If the FFT plan size is negligible for some reason, this changes where the peak GPU
# memory usage occurs. Hence, the if/else branching below for calculating the total bytes.
NEGLIGIBLE_FFT_PLAN_SIZE = 16
if fftplan_slice_size < NEGLIGIBLE_FFT_PLAN_SIZE:
tot_memory_bytes = int(
unpadded_in_slice_size + padded_in_slice_size + complex_slice
)
else:
tot_memory_bytes = int(
unpadded_in_slice_size
+ padded_in_slice_size
+ complex_slice
# The padded float32 array is deallocated when a copy is made when casting to complex64
# and the variable `padded_tomo` is reassigned to the complex64 version
- padded_in_slice_size
+ fftplan_slice_size
+ cropped_float32_res_slice
+ negative_log_slice
)

subtract_bytes = int(filter_size + grid_size)

return (tot_memory_bytes, subtract_bytes)

0 comments on commit 0852f3b

Please sign in to comment.