Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vulkan: improve im2col performance #11778

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 23 additions & 34 deletions ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
Original file line number Diff line number Diff line change
Expand Up @@ -40,48 +40,37 @@ void main() {
const uint batch = gl_GlobalInvocationID.z / p.IC;
const uint ic = gl_GlobalInvocationID.z % p.IC;

A_TYPE values[NUM_ITER];
uint offset_dst[NUM_ITER];
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
values[idx] = A_TYPE(0);
}

[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {

const uint i = gidx * NUM_ITER + idx;

const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
const uint kx = i / ksize;
const uint kd = kx * ksize;
const uint ky = (i - kd) / p.OW;
const uint ix = i % p.OW;

const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
const int oh_s1 = int(oh) * p.s1;
const uint base_linear_idx = gidx * NUM_ITER;
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);

offset_dst[idx] =
((batch * p.OH + oh) * p.OW + ix) * p.CHW +
(ic * (p.KW * p.KH) + ky * p.KW + kx);
uint current_kx = base_linear_idx / ksize;
uint rem = base_linear_idx - (current_kx * ksize); // equivalent to init_val % ksize
uint current_ky = rem / p.OW;
uint current_ix = rem % p.OW;

if (i >= p.pelements) {
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
uint linear_idx = base_linear_idx + idx;
if (linear_idx >= p.pelements) {
continue;
}

if (iih < p.IH && iiw < p.IW) {
const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
values[idx] = data_a[offset_src + iih * p.IW + iiw];
}
}
int iiw = int(current_ix) * p.s0 + int(current_kx) * p.d0 - p.p0;
int iih = oh_s1 + int(current_ky) * p.d1 - p.p1;
uint dst_offset = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;

[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
bool valid = (iih >= 0 && iih < int(p.IH) && iiw >= 0 && iiw < int(p.IW));
data_d[dst_offset] = D_TYPE(valid ? data_a[src_base + uint(iih) * p.IW + uint(iiw)] : 0);

const uint i = gidx * NUM_ITER + idx;

if (i >= p.pelements) {
continue;
if (++current_ix == p.OW) {
current_ix = 0;
if (++current_ky == (ksize / p.OW)) {
current_ky = 0;
current_kx++;
}
}

data_d[offset_dst[idx]] = D_TYPE(values[idx]);
}

}
Loading