Skip to content

Commit

Permalink
[ET-VK] Fixing conv2d dw incorrect output when stride != dilation issue.
Browse files Browse the repository at this point in the history
This diff moves current implementation of conv2d dw as a special case when stride equals dilation in the Vulkan backend of Executorch, since that's the only time this kind of caching is possible.

If stride does not equal dilation the old implementation is used.

Differential Revision: [D67908916](https://our.internmc.facebook.com/intern/diff/D67908916/)

[ghstack-poisoned]
  • Loading branch information
trivedivivek committed Jan 9, 2025
1 parent 218850f commit 5186361
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 12 deletions.
41 changes: 39 additions & 2 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#define TILE_SIZE ${TILE_SIZE}

#define STRIDE_EQ_DILATION ${STRIDE_EQ_DILATION}

#define BATCH_SIZE_X ${BATCH_SIZE_X}

#define BATCH_SIZE_Y ${BATCH_SIZE_Y}
Expand All @@ -36,12 +38,12 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require

/*
* Computes a depthwise convolution. Each shader invocation calculates the
* output at a single output location.
*/

#if STRIDE_EQ_DILATION
void main() {
// x and y are divided by batch size to determine 3d position
// since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
Expand Down Expand Up @@ -119,3 +121,38 @@ void main() {
}
}
}

#else
void main() {
const ivec3 pos = idx_to_ipos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits.y);

if (any(greaterThanEqual(pos, out_limits))) {
return;
}

// Compute the index of the top-left element of the overlay region. Negative
// indices indicate that the top-left element is in a region added by padding.
const ivec2 ipos = pos.xy * stride - padding;

// Compute the start and end of the input indices to load. Padding is assumed
// to be constant 0 padding, so any reads from the padding region is skipped.
const ivec2 start = ipos;
const ivec2 end = ipos + overlay_region.xy;

VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
int kx = 0;
for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) {
for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) {
// The weight kernel was rearranged such that every NxN filter is
// flattened to fit in one row. Each filter was then stacked on top of
// each other vertically.
const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0);
sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum);
kx++;
}
}

imageStore(t_out, pos, op(sum, out_min, out_max));
}

#endif
13 changes: 13 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ conv2d_dw_output_tile:
TILE_SIZE: 3
BATCH_SIZE_X: 4
BATCH_SIZE_Y: 2
STRIDE_EQ_DILATION: 0
generate_variant_forall:
DTYPE:
- VALUE: half
Expand All @@ -25,3 +26,15 @@ conv2d_dw_output_tile:
- NAME: conv2d_dw_output_tile_5x5_clamp
OPERATOR: clamp(X, A, B)
TILE_SIZE: 5
- NAME: conv2d_dw_sed_output_tile_3x3
STRIDE_EQ_DILATION: 1
- NAME: conv2d_dw_sed_output_tile_3x3_clamp
OPERATOR: clamp(X, A, B)
STRIDE_EQ_DILATION: 1
- NAME: conv2d_dw_sed_output_tile_5x5
TILE_SIZE: 5
STRIDE_EQ_DILATION: 1
- NAME: conv2d_dw_sed_output_tile_5x5_clamp
OPERATOR: clamp(X, A, B)
TILE_SIZE: 5
STRIDE_EQ_DILATION: 1
41 changes: 31 additions & 10 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,17 @@ vkapi::ShaderInfo get_conv2d_shader(
const bool prepack_weights,
const Conv2dMethod method,
const ValueRef weight,
const bool clamp_out = false) {
const bool clamp_out = false,
const bool stride_equals_dilation = false) {
std::string kernel_name;
kernel_name.reserve(kShaderNameReserve);
switch (method) {
case Conv2dMethod::Depthwise:
kernel_name = "conv2d_dw";
if (!prepack_weights) {
if (stride_equals_dilation) {
kernel_name += "_sed";
}
const auto& weight_sizes = graph.get_tref(weight)->sizes;
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
kernel_name += "_output_tile_3x3";
Expand Down Expand Up @@ -286,22 +290,33 @@ Conv2dMethod get_conv2d_method(
return Conv2dMethod::SlidingWindow;
}

utils::uvec2 get_conv2d_dw_dispatch_divisor(const std::vector<int64_t>& weight_sizes) {
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
return {4u, 2u};
}
if (weight_sizes.at(2) == 5 && weight_sizes.at(3) == 5) {
return {4u, 2u};
}
return {4u, 2u};
}

utils::uvec3 create_conv2d_global_wg_size(
ComputeGraph& graph,
const Conv2dMethod method,
const ValueRef out) {
const ValueRef out,
const ValueRef weight_data,
const bool stride_equals_dilation) {
if (method == Conv2dMethod::Pointwise) {
const utils::uvec3 image_extents = graph.logical_limits_of(out);
return {
utils::div_up(image_extents[0u], 2u),
utils::div_up(image_extents[1u], 2u),
image_extents[2u]};
} else if (method == Conv2dMethod::Depthwise) {
const utils::uvec3 image_extents = graph.logical_limits_of(out);
return {
utils::div_up(image_extents[0u], 4u),
utils::div_up(image_extents[1u], 2u),
image_extents[2u]};
} else if (method == Conv2dMethod::Depthwise && stride_equals_dilation) {
const utils::uvec3 image_extents = graph.create_global_wg_size(out);
const utils::uvec2 div =
get_conv2d_dw_dispatch_divisor(graph.get_tref(weight_data)->sizes);
return {utils::div_up(image_extents[0], div[0]), utils::div_up(image_extents[1], div[1]), image_extents[2]};
} else {
return graph.create_global_wg_size(out);
}
Expand Down Expand Up @@ -364,6 +379,10 @@ void add_conv2d_node(
Conv2dParams extra_params =
create_conv2d_params(graph, weight_data, kernel_params, transposed_val);

const bool stride_equals_dilation =
(kernel_params.stride[0] == kernel_params.dilation[0] &&
kernel_params.stride[1] == kernel_params.dilation[1]);

OutputParams out_params = {out_min_val, out_max_val};

check_conv2d_params(kernel_params, transposed_val);
Expand All @@ -374,9 +393,11 @@ void add_conv2d_node(
/*prepack_weights = */ false,
method,
weight_data,
clamp_out);
clamp_out,
stride_equals_dilation);

utils::uvec3 wg_size = create_conv2d_global_wg_size(graph, method, out);
utils::uvec3 wg_size = create_conv2d_global_wg_size(
graph, method, out, weight_data, stride_equals_dilation);

if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};
Expand Down
33 changes: 33 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,39 @@ def get_conv_inputs():
[0, 0],
1,
),
(
(1, 4, 234, 234),
(4, 1, 3, 3),
(4,),
[2, 1],
[1, 1],
[1, 1],
False,
[0, 0],
4,
),
(
(1, 4, 234, 234),
(4, 1, 3, 3),
(4,),
[1, 2],
[1, 1],
[1, 1],
False,
[0, 0],
4,
),
(
(1, 4, 234, 234),
(4, 1, 3, 3),
(4,),
[2, 2],
[1, 1],
[1, 1],
False,
[0, 0],
4,
),
]
)
return test_suite
Expand Down

0 comments on commit 5186361

Please sign in to comment.