From 51863619943ad55332aed2469fe8833ae63aeb22 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Thu, 9 Jan 2025 08:26:57 -0800 Subject: [PATCH] [ET-VK] Fixing conv2d dw incorrect output when stride != dilation issue. This diff moves current implementation of conv2d dw as a special case when stride equals dilation in the Vulkan backend of Executorch, since that's the only time this kind of caching is possible. If stride does not equal dilation the old implementation is used. Differential Revision: [D67908916](https://our.internmc.facebook.com/intern/diff/D67908916/) [ghstack-poisoned] --- .../graph/ops/glsl/conv2d_dw_output_tile.glsl | 41 ++++++++++++++++++- .../graph/ops/glsl/conv2d_dw_output_tile.yaml | 13 ++++++ .../runtime/graph/ops/impl/Convolution.cpp | 41 ++++++++++++++----- backends/vulkan/test/op_tests/cases.py | 33 +++++++++++++++ 4 files changed, 116 insertions(+), 12 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index 32d0229d96..984ebf4b2d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -14,6 +14,8 @@ #define TILE_SIZE ${TILE_SIZE} +#define STRIDE_EQ_DILATION ${STRIDE_EQ_DILATION} + #define BATCH_SIZE_X ${BATCH_SIZE_X} #define BATCH_SIZE_Y ${BATCH_SIZE_Y} @@ -36,12 +38,12 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require - /* * Computes a depthwise convolution. Each shader invocation calculates the * output at a single output location. */ + +#if STRIDE_EQ_DILATION void main() { // x and y are divided by batch size to determine 3d position // since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z @@ -119,3 +121,38 @@ void main() { } } } + +#else +void main() { + const ivec3 pos = idx_to_ipos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits.y); + + if (any(greaterThanEqual(pos, out_limits))) { + return; + } + + // Compute the index of the top-left element of the overlay region. Negative + // indices indicate that the top-left element is in a region added by padding. + const ivec2 ipos = pos.xy * stride - padding; + + // Compute the start and end of the input indices to load. Padding is assumed + // to be constant 0 padding, so any reads from the padding region is skipped. + const ivec2 start = ipos; + const ivec2 end = ipos + overlay_region.xy; + + VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0); + int kx = 0; + for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) { + for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) { + // The weight kernel was rearranged such that every NxN filter is + // flattened to fit in one row. Each filter was then stacked on top of + // each other vertically. + const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0); + sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum); + kx++; + } + } + + imageStore(t_out, pos, op(sum, out_min, out_max)); +} + +#endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml index 9cf6c22c6c..d3672f5ec2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml @@ -12,6 +12,7 @@ conv2d_dw_output_tile: TILE_SIZE: 3 BATCH_SIZE_X: 4 BATCH_SIZE_Y: 2 + STRIDE_EQ_DILATION: 0 generate_variant_forall: DTYPE: - VALUE: half @@ -25,3 +26,15 @@ conv2d_dw_output_tile: - NAME: conv2d_dw_output_tile_5x5_clamp OPERATOR: clamp(X, A, B) TILE_SIZE: 5 + - NAME: conv2d_dw_sed_output_tile_3x3 + STRIDE_EQ_DILATION: 1 + - NAME: conv2d_dw_sed_output_tile_3x3_clamp + OPERATOR: clamp(X, A, B) + STRIDE_EQ_DILATION: 1 + - NAME: conv2d_dw_sed_output_tile_5x5 + TILE_SIZE: 5 + STRIDE_EQ_DILATION: 1 + - NAME: conv2d_dw_sed_output_tile_5x5_clamp + OPERATOR: clamp(X, A, B) + TILE_SIZE: 5 + STRIDE_EQ_DILATION: 1 diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 64c145fb7e..9e64184bf4 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -126,13 +126,17 @@ vkapi::ShaderInfo get_conv2d_shader( const bool prepack_weights, const Conv2dMethod method, const ValueRef weight, - const bool clamp_out = false) { + const bool clamp_out = false, + const bool stride_equals_dilation = false) { std::string kernel_name; kernel_name.reserve(kShaderNameReserve); switch (method) { case Conv2dMethod::Depthwise: kernel_name = "conv2d_dw"; if (!prepack_weights) { + if (stride_equals_dilation) { + kernel_name += "_sed"; + } const auto& weight_sizes = graph.get_tref(weight)->sizes; if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) { kernel_name += "_output_tile_3x3"; @@ -286,22 +290,33 @@ Conv2dMethod get_conv2d_method( return Conv2dMethod::SlidingWindow; } +utils::uvec2 get_conv2d_dw_dispatch_divisor(const std::vector& weight_sizes) { + if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) { + return {4u, 2u}; + } + if (weight_sizes.at(2) == 5 && weight_sizes.at(3) == 5) { + return {4u, 2u}; + } + return {4u, 2u}; +} + utils::uvec3 create_conv2d_global_wg_size( ComputeGraph& graph, const Conv2dMethod method, - const ValueRef out) { + const ValueRef out, + const ValueRef weight_data, + const bool stride_equals_dilation) { if (method == Conv2dMethod::Pointwise) { const utils::uvec3 image_extents = graph.logical_limits_of(out); return { utils::div_up(image_extents[0u], 2u), utils::div_up(image_extents[1u], 2u), image_extents[2u]}; - } else if (method == Conv2dMethod::Depthwise) { - const utils::uvec3 image_extents = graph.logical_limits_of(out); - return { - utils::div_up(image_extents[0u], 4u), - utils::div_up(image_extents[1u], 2u), - image_extents[2u]}; + } else if (method == Conv2dMethod::Depthwise && stride_equals_dilation) { + const utils::uvec3 image_extents = graph.create_global_wg_size(out); + const utils::uvec2 div = + get_conv2d_dw_dispatch_divisor(graph.get_tref(weight_data)->sizes); + return {utils::div_up(image_extents[0], div[0]), utils::div_up(image_extents[1], div[1]), image_extents[2]}; } else { return graph.create_global_wg_size(out); } @@ -364,6 +379,10 @@ void add_conv2d_node( Conv2dParams extra_params = create_conv2d_params(graph, weight_data, kernel_params, transposed_val); + const bool stride_equals_dilation = + (kernel_params.stride[0] == kernel_params.dilation[0] && + kernel_params.stride[1] == kernel_params.dilation[1]); + OutputParams out_params = {out_min_val, out_max_val}; check_conv2d_params(kernel_params, transposed_val); @@ -374,9 +393,11 @@ void add_conv2d_node( /*prepack_weights = */ false, method, weight_data, - clamp_out); + clamp_out, + stride_equals_dilation); - utils::uvec3 wg_size = create_conv2d_global_wg_size(graph, method, out); + utils::uvec3 wg_size = create_conv2d_global_wg_size( + graph, method, out, weight_data, stride_equals_dilation); if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) { wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1}; diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 85732d7701..d32fa71573 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -348,6 +348,39 @@ def get_conv_inputs(): [0, 0], 1, ), + ( + (1, 4, 234, 234), + (4, 1, 3, 3), + (4,), + [2, 1], + [1, 1], + [1, 1], + False, + [0, 0], + 4, + ), + ( + (1, 4, 234, 234), + (4, 1, 3, 3), + (4,), + [1, 2], + [1, 1], + [1, 1], + False, + [0, 0], + 4, + ), + ( + (1, 4, 234, 234), + (4, 1, 3, 3), + (4,), + [2, 2], + [1, 1], + [1, 1], + False, + [0, 0], + 4, + ), ] ) return test_suite