From 9bfe284b29fdebe32b93d1ce00e607e068abc74c Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Mon, 6 Jan 2025 10:58:51 -0800 Subject: [PATCH] [ET-VK] Adding batch processing in x axis to conv2d dw shader by caching input texel for reuse. This diff adds batch processing in the x axis to the conv2d dw shader by reusing input texel overlapping between consecutive tiles. The changes include modifying the glsl code for the conv2d dw output tile, adding a new parameter to the yaml file, and modifying the Convolution.cpp file to use the new parameter. Differential Revision: [D67868671](https://our.internmc.facebook.com/intern/diff/D67868671/) [ghstack-poisoned] --- .../graph/ops/glsl/conv2d_dw_output_tile.glsl | 42 ++++++++++++------- .../graph/ops/glsl/conv2d_dw_output_tile.yaml | 1 + .../runtime/graph/ops/impl/Convolution.cpp | 2 +- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index 20fb9374be..a8f4c940d4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -14,6 +14,8 @@ #define TILE_SIZE ${TILE_SIZE} +#define BATCH_SIZE_X ${BATCH_SIZE_X} + #define BATCH_SIZE_Y ${BATCH_SIZE_Y} #define op(X, A, B) ${OPERATOR} @@ -41,13 +43,15 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; * output at a single output location. */ void main() { + // x divided up by batch size is used to determine 3d position // y divided up by batch size is used to determine 3d position // since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z - const int out_limits_y_scaled = (out_limits.y + BATCH_SIZE_Y - 1) / BATCH_SIZE_Y; + const ivec2 out_limits_xy_scaled = ivec2(out_limits.xy + ivec2(BATCH_SIZE_X, BATCH_SIZE_Y) - 1) / ivec2(BATCH_SIZE_X, BATCH_SIZE_Y); - u16vec3 pos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits_y_scaled); + u16vec3 pos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits_xy_scaled.x, out_limits_xy_scaled.y); - // scale pos.y by batch size, because that's the top pixel to be processed + // scale pos.xy by batch sizes, because that's the top pixel to be processed + pos.x *= uint16_t(BATCH_SIZE_X); pos.y *= uint16_t(BATCH_SIZE_Y); // do not process if top pixel does not fit within the output range @@ -65,30 +69,34 @@ void main() { const u16vec2 end = ipos + u16vec2(overlay_region.xy); // sum outputs - VEC4_T sum[BATCH_SIZE_Y]; + VEC4_T sum[BATCH_SIZE_Y][BATCH_SIZE_X]; - sum[0] = texelFetch(t_bias, u16vec2(pos.z, 0), 0); - for (int i = 1; i < BATCH_SIZE_Y; i++) { - sum[i] = sum[0]; + sum[0][0] = texelFetch(t_bias, u16vec2(pos.z, 0), 0); + for (int y = 0; y < BATCH_SIZE_Y; y++) { + for (int x = 0; x < BATCH_SIZE_X; x++) { + sum[y][x] = sum[0][0]; + } } // array to store input texels - VEC4_T in_texels[TILE_SIZE]; + VEC4_T in_texels[TILE_SIZE + BATCH_SIZE_X - 1]; // array to store kernel data of previous y VEC4_T prev_kernel_line[TILE_SIZE]; uint16_t kx = uint16_t(0); for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE + BATCH_SIZE_Y - 1); y += uint16_t(dilation.y), i++) { - for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++) { + for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE + BATCH_SIZE_X - 1); x += uint16_t(dilation.x), j++) { in_texels[int(j)] = texelFetch(t_in, u16vec3(x, y, pos.z), 0); } // from 2nd iteration onwards accumulate dot product in 2nd sum // based on kernel line data fetched in previous iteration and input texel from this iteration if (i > uint16_t(0)) { - for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++) { - sum[1] = fma(in_texels[int(j)], prev_kernel_line[int(j)], sum[1]); + for (uint16_t s = uint16_t(0); s < uint16_t(BATCH_SIZE_X); s++) { + for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++) { + sum[1][int(s)] = fma(in_texels[int(j+s)], prev_kernel_line[int(j)], sum[1][int(s)]); + } } } @@ -96,15 +104,19 @@ void main() { if (i < uint16_t(TILE_SIZE)) { for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++, kx++) { prev_kernel_line[int(j)] = texelFetch(t_kernel, u16vec2(kx, pos.z), 0); - sum[0] = fma(in_texels[int(j)], prev_kernel_line[int(j)], sum[0]); + for (uint16_t s = uint16_t(0); s < uint16_t(BATCH_SIZE_X); s++) { + sum[0][int(s)] = fma(in_texels[int(j+s)], prev_kernel_line[int(j)], sum[0][int(s)]); + } } } } for (int i = 0; i < BATCH_SIZE_Y; i++) { - if (any(greaterThanEqual(u16vec3(pos.x, pos.y + i, pos.z), out_limits))) { - continue; + for (int j = 0; j < BATCH_SIZE_X; j++) { + if (any(greaterThanEqual(u16vec3(pos.x + j, pos.y + i, pos.z), out_limits))) { + continue; + } + imageStore(t_out, u16vec3(pos.x + j, pos.y + i, pos.z), op(sum[i][j], out_min, out_max)); } - imageStore(t_out, u16vec3(pos.x, pos.y + i, pos.z), op(sum[i], out_min, out_max)); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml index bb197c2c18..9cf6c22c6c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml @@ -10,6 +10,7 @@ conv2d_dw_output_tile: NDIM: 3 DTYPE: float TILE_SIZE: 3 + BATCH_SIZE_X: 4 BATCH_SIZE_Y: 2 generate_variant_forall: DTYPE: diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 3519635ac7..64c145fb7e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -299,7 +299,7 @@ utils::uvec3 create_conv2d_global_wg_size( } else if (method == Conv2dMethod::Depthwise) { const utils::uvec3 image_extents = graph.logical_limits_of(out); return { - utils::div_up(image_extents[0u], 1u), + utils::div_up(image_extents[0u], 4u), utils::div_up(image_extents[1u], 2u), image_extents[2u]}; } else {