From 5f7326b0c14046c139e9bc7c86449ebc018e02d2 Mon Sep 17 00:00:00 2001 From: Daniele <57776841+daniandtheweb@users.noreply.github.com> Date: Sun, 9 Feb 2025 23:30:13 +0000 Subject: [PATCH 1/5] vulkan: improve im2col performance --- .../ggml-vulkan/vulkan-shaders/im2col.comp | 48 ++++++------------- 1 file changed, 15 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 122b1e93fb496..302e61cb3a06e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -37,51 +37,33 @@ void main() { const uint gidx = gl_GlobalInvocationID.x; const uint oh = gl_GlobalInvocationID.y; - const uint batch = gl_GlobalInvocationID.z / p.IC; - const uint ic = gl_GlobalInvocationID.z % p.IC; + const uint batch_ic = gl_GlobalInvocationID.z; - A_TYPE values[NUM_ITER]; - uint offset_dst[NUM_ITER]; - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - values[idx] = A_TYPE(0); - } + const uint batch = batch_ic / p.IC; + const uint ic = batch_ic % p.IC; + + const uint ksize = p.OW * ((p.KH > 1) ? p.KW : 1); + const uint src_base = ic * p.offset_delta + batch * p.batch_offset; + const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * p.KW * p.KH; [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { const uint i = gidx * NUM_ITER + idx; + if (i >= p.pelements) continue; - const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); const uint kx = i / ksize; - const uint kd = kx * ksize; - const uint ky = (i - kd) / p.OW; + const uint ky = (i % ksize) / p.OW; const uint ix = i % p.OW; - const uint iiw = ix * p.s0 + kx * p.d0 - p.p0; - const uint iih = oh * p.s1 + ky * p.d1 - p.p1; - - offset_dst[idx] = - ((batch * p.OH + oh) * p.OW + ix) * p.CHW + - (ic * (p.KW * p.KH) + ky * p.KW + kx); - - if (i >= p.pelements) { - continue; - } + const int iiw = int(ix * uint(p.s0)) + int(kx * uint(p.d0)) - p.p0; + const int iih = int(oh * uint(p.s1)) + int(ky * uint(p.d1)) - p.p1; - if (iih < p.IH && iiw < p.IW) { - const uint offset_src = ic * p.offset_delta + batch * p.batch_offset; - values[idx] = data_a[offset_src + iih * p.IW + iiw]; - } - } - - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - - const uint i = gidx * NUM_ITER + idx; + const uint dst_offset = dst_base + ix * p.CHW + ky * p.KW + kx; - if (i >= p.pelements) { - continue; - } + const bool valid = iih >= 0 && iih < int(p.IH) && iiw >= 0 && iiw < int(p.IW); + const uint src_offset = src_base + uint(iih) * p.IW + uint(iiw); - data_d[offset_dst[idx]] = D_TYPE(values[idx]); + data_d[dst_offset] = D_TYPE(valid ? data_a[src_offset] : 0.0); } } From 8446b617c5b52005778a63f554c8743777f3879d Mon Sep 17 00:00:00 2001 From: Daniele <57776841+daniandtheweb@users.noreply.github.com> Date: Sun, 9 Feb 2025 23:49:17 +0000 Subject: [PATCH 2/5] Fixed if formatting --- ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 302e61cb3a06e..569c75ad55bb5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -49,7 +49,9 @@ void main() { [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { const uint i = gidx * NUM_ITER + idx; - if (i >= p.pelements) continue; + if (i >= p.pelements) { + continue; + } const uint kx = i / ksize; const uint ky = (i % ksize) / p.OW; From c982f009ebe93f5c1973effb25fd1ecb2db1cad4 Mon Sep 17 00:00:00 2001 From: Daniele <57776841+daniandtheweb@users.noreply.github.com> Date: Mon, 10 Feb 2025 22:14:07 +0000 Subject: [PATCH 3/5] im2col test 2 --- .../ggml-vulkan/vulkan-shaders/im2col.comp | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 569c75ad55bb5..6715901215f25 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -37,35 +37,37 @@ void main() { const uint gidx = gl_GlobalInvocationID.x; const uint oh = gl_GlobalInvocationID.y; - const uint batch_ic = gl_GlobalInvocationID.z; + const uint batch = gl_GlobalInvocationID.z / p.IC; + const uint ic = gl_GlobalInvocationID.z % p.IC; - const uint batch = batch_ic / p.IC; - const uint ic = batch_ic % p.IC; - - const uint ksize = p.OW * ((p.KH > 1) ? p.KW : 1); + const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); const uint src_base = ic * p.offset_delta + batch * p.batch_offset; - const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * p.KW * p.KH; + const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH); + const int oh_s1 = int(oh) * p.s1; + const uint base_idx = gidx * NUM_ITER; [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - const uint i = gidx * NUM_ITER + idx; + uint i = base_idx + idx; + if (i >= p.pelements) { - continue; + break; } - const uint kx = i / ksize; - const uint ky = (i % ksize) / p.OW; - const uint ix = i % p.OW; + const uint kx = i / ksize; + const uint rem = i % ksize; + const uint ky = rem / p.OW; + const uint ix = rem % p.OW; - const int iiw = int(ix * uint(p.s0)) + int(kx * uint(p.d0)) - p.p0; - const int iih = int(oh * uint(p.s1)) + int(ky * uint(p.d1)) - p.p1; + int iiw = int(ix) * p.s0 + int(kx) * p.d0 - p.p0; + int iih = oh_s1 + int(ky) * p.d1 - p.p1; const uint dst_offset = dst_base + ix * p.CHW + ky * p.KW + kx; - const bool valid = iih >= 0 && iih < int(p.IH) && iiw >= 0 && iiw < int(p.IW); - const uint src_offset = src_base + uint(iih) * p.IW + uint(iiw); - - data_d[dst_offset] = D_TYPE(valid ? data_a[src_offset] : 0.0); + data_d[dst_offset] = D_TYPE((iih >= 0 && iih < int(p.IH) && + iiw >= 0 && iiw < int(p.IW)) + ? data_a[src_base + uint(iih) * p.IW + uint(iiw)] + : 0); } } From a6b70d4b1abe647af00758278f83925518edd2d3 Mon Sep 17 00:00:00 2001 From: Daniele <57776841+daniandtheweb@users.noreply.github.com> Date: Mon, 10 Feb 2025 23:01:09 +0000 Subject: [PATCH 4/5] convert int to constant int --- ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 6715901215f25..8b81107ca3929 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -48,7 +48,7 @@ void main() { [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - uint i = base_idx + idx; + const uint i = base_idx + idx; if (i >= p.pelements) { break; @@ -59,8 +59,8 @@ void main() { const uint ky = rem / p.OW; const uint ix = rem % p.OW; - int iiw = int(ix) * p.s0 + int(kx) * p.d0 - p.p0; - int iih = oh_s1 + int(ky) * p.d1 - p.p1; + const int iiw = int(ix) * p.s0 + int(kx) * p.d0 - p.p0; + const int iih = oh_s1 + int(ky) * p.d1 - p.p1; const uint dst_offset = dst_base + ix * p.CHW + ky * p.KW + kx; From e2c2a1a23040b5549678ea29ab4af6d00bb77830 Mon Sep 17 00:00:00 2001 From: Daniele <57776841+daniandtheweb@users.noreply.github.com> Date: Wed, 12 Feb 2025 04:27:37 +0100 Subject: [PATCH 5/5] im2col test 3 --- .../ggml-vulkan/vulkan-shaders/im2col.comp | 41 ++++++++++--------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 8b81107ca3929..4d0b2601b02d7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -41,33 +41,36 @@ void main() { const uint ic = gl_GlobalInvocationID.z % p.IC; const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); + const int oh_s1 = int(oh) * p.s1; + const uint base_linear_idx = gidx * NUM_ITER; const uint src_base = ic * p.offset_delta + batch * p.batch_offset; const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH); - const int oh_s1 = int(oh) * p.s1; - const uint base_idx = gidx * NUM_ITER; - - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - const uint i = base_idx + idx; + uint current_kx = base_linear_idx / ksize; + uint rem = base_linear_idx - (current_kx * ksize); // equivalent to init_val % ksize + uint current_ky = rem / p.OW; + uint current_ix = rem % p.OW; - if (i >= p.pelements) { - break; + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + uint linear_idx = base_linear_idx + idx; + if (linear_idx >= p.pelements) { + continue; } - const uint kx = i / ksize; - const uint rem = i % ksize; - const uint ky = rem / p.OW; - const uint ix = rem % p.OW; - - const int iiw = int(ix) * p.s0 + int(kx) * p.d0 - p.p0; - const int iih = oh_s1 + int(ky) * p.d1 - p.p1; + int iiw = int(current_ix) * p.s0 + int(current_kx) * p.d0 - p.p0; + int iih = oh_s1 + int(current_ky) * p.d1 - p.p1; + uint dst_offset = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx; - const uint dst_offset = dst_base + ix * p.CHW + ky * p.KW + kx; + bool valid = (iih >= 0 && iih < int(p.IH) && iiw >= 0 && iiw < int(p.IW)); + data_d[dst_offset] = D_TYPE(valid ? data_a[src_base + uint(iih) * p.IW + uint(iiw)] : 0); - data_d[dst_offset] = D_TYPE((iih >= 0 && iih < int(p.IH) && - iiw >= 0 && iiw < int(p.IW)) - ? data_a[src_base + uint(iih) * p.IW + uint(iiw)] - : 0); + if (++current_ix == p.OW) { + current_ix = 0; + if (++current_ky == (ksize / p.OW)) { + current_ky = 0; + current_kx++; + } + } } }