Skip to content

Commit

Permalink
[ET-VK] Changing all conv 2d pw ints from uint16 to int since it slig…
Browse files Browse the repository at this point in the history
…htly improves perf. (#7566)

* [ET-VK] Adding a common utility function to calculate 3d output position based on unique index.

Pull Request resolved: #7522

This diff adds an indexing utils header file used in Vulkan backend of Executorch. The header file includes functions for converting a global index to u16 indices based on input sizes.
ghstack-source-id: 260707858
@exported-using-ghexport

Differential Revision: [D67821941](https://our.internmc.facebook.com/intern/diff/D67821941/)

* [ET-VK] Adding batch processing in x axis to conv2d dw shader by caching input texel for reuse.

Pull Request resolved: #7526

This diff adds batch processing in the x axis to the conv2d dw shader by reusing input texel overlapping between consecutive tiles. The changes include modifying the glsl code for the conv2d dw output tile, adding a new parameter to the yaml file, and modifying the Convolution.cpp file to use the new parameter.
ghstack-source-id: 260707856

Differential Revision: [D67868671](https://our.internmc.facebook.com/intern/diff/D67868671/)

* [ET-VK] Changing all conv 2d pw ints from uint16 to int since it slightly improves perf.

Pull Request resolved: #7545

This diff changes all integers in conv 2d pw op shader from uint16 to int in the Vulkan backend of Executorch. The change is made to improve performance since the shader does not appear to be register bound.
ghstack-source-id: 260707857

Differential Revision: [D67906023](https://our.internmc.facebook.com/intern/diff/D67906023/)

---------

Co-authored-by: Vivek Trivedi <[email protected]>
  • Loading branch information
pytorchbot and trivedivivek authored Jan 9, 2025
1 parent c7098ca commit b63d9fa
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 20 deletions.
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
* output at a single output location.
*/
void main() {
const ivec3 pos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits.y);
const ivec3 pos = idx_to_ipos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits.y);

if (any(greaterThanEqual(pos, out_limits))) {
return;
Expand Down
36 changes: 17 additions & 19 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require

// shared memory to hold calculated positions, this would reduce register usage thus improving performance.
shared u16vec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z * TILE_SIZE * TILE_SIZE];
shared ivec2 pos_shared[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z * TILE_SIZE * TILE_SIZE];

/*
* Computes a 2D pointwise convolution of an NxN output tile. Calculating an
Expand All @@ -46,18 +44,18 @@ void main() {
const ivec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE;
const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;

const u16vec3 gpos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits_scaled.x, out_limits_scaled.y);
const ivec3 gpos = idx_to_ipos_x_wise(gl_GlobalInvocationID.x, out_limits_scaled.x, out_limits_scaled.y);

// Output position for TILE_SIZE = 2
// +--------+--------+
// | pos[0] | pos[1] |
// +--------+--------+
// | pos[2] | pos[3] |
// +--------+--------+
u16vec2 pos[TILE_SIZE * TILE_SIZE];
ivec2 pos[TILE_SIZE * TILE_SIZE];
for (int y = 0, i = 0; y < TILE_SIZE; ++y) {
for (int x = 0; x < TILE_SIZE; ++x) {
pos[i] = u16vec2(
pos[i] = ivec2(
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y);
pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
i++;
Expand All @@ -66,38 +64,38 @@ void main() {

// If the top left position is out of bounds, then this invocation will have
// no work to do.
if (any(greaterThanEqual(u16vec3(pos[0], gpos.z), out_limits))) {
if (any(greaterThanEqual(ivec3(pos[0], gpos.z), out_limits))) {
return;
}

// Compute the index of the input texture that needs to be loaded for each
// output position. Note that negative indices can be produced indicating that
// the top-left element is in a region added by padding.
u16vec2 ipos[TILE_SIZE * TILE_SIZE];
ivec2 ipos[TILE_SIZE * TILE_SIZE];
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
ipos[i] = pos[i] * u16vec2(stride) - u16vec2(padding);
ipos[i] = pos[i] * stride - padding;
}

vec4 sum[TILE_SIZE * TILE_SIZE];
sum[0] = texelFetch(t_bias, u16vec2(gpos.z, 0), 0);
sum[0] = texelFetch(t_bias, ivec2(gpos.z, 0), 0);
for (int i = 1; i < TILE_SIZE * TILE_SIZE; ++i) {
sum[i] = sum[0];
}

int z4 = 0;
// Since the kernel is 1x1, we only have to loop over the depth dimension.
for (uint16_t z = uint16_t(0); z < uint16_t(in_group_size); z += uint16_t(4), ++z4) {
for (int z = 0; z < in_group_size; z += 4, ++z4) {
// During prepacking, the weight tensor has been permuted so that the
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
// the z-axis.
const vec4 ktex_0 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(0, 0));
const vec4 ktex_1 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(1, 0));
const vec4 ktex_2 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(2, 0));
const vec4 ktex_3 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(3, 0));
const vec4 ktex_0 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(0, 0));
const vec4 ktex_1 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(1, 0));
const vec4 ktex_2 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(2, 0));
const vec4 ktex_3 = texelFetchOffset(t_kernel, ivec2(z, gpos.z), 0, ivec2(3, 0));

#pragma unroll
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
const vec4 in_tex = texelFetch(t_in, u16vec3(ipos[i], z4), 0);
const vec4 in_tex = texelFetch(t_in, ivec3(ipos[i], z4), 0);
// For 2x2 tile size algorithm works as follows.
// To explain the calculations below, the contents of one in_tex and the
// group of 4 texels loaded from t_kernel are shown:
Expand Down Expand Up @@ -139,9 +137,9 @@ void main() {
}

for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
const u16vec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
if (all(lessThan(u16vec3(pos, gpos.z), out_limits))) {
imageStore(t_out, u16vec3(pos, gpos.z), op(sum[i], out_min, out_max));
const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
if (all(lessThan(ivec3(pos, gpos.z), out_limits))) {
imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max));
}
}
}

0 comments on commit b63d9fa

Please sign in to comment.