From 610b3ac3cddcca15d3066df3aeb3261fa8748628 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 27 Jan 2025 15:04:44 +0100 Subject: [PATCH 1/6] ggml : x2 speed for WASM by optimizing SIMD --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 881 ++++++++++++++++++++++------ 1 file changed, 710 insertions(+), 171 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 88303ff0e61c7..7f11b3567a230 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -747,7 +747,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); } } -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ for (int i = 0; i < nb; i++) { v128_t srcv [8]; v128_t asrcv[8]; @@ -1037,7 +1037,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); } -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ for (int i = 0; i < nb; i++) { v128_t srcv [8]; v128_t asrcv[8]; @@ -1653,7 +1653,105 @@ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -1 //===================================== Q8_K ============================================== void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { +#ifdef __wasm_simd128__ + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + block_q8_K * restrict yc = y; // Cast to proper type + + for (int i = 0; i < nb; i++) { + const float * x_block = x + i * QK_K; + v128_t amax_vec = wasm_f32x4_splat(0.0f); + v128_t max_vec = wasm_f32x4_splat(0.0f); + + // Vectorized max abs value search + for (int j = 0; j < QK_K; j += 4) { + v128_t x_vec = wasm_v128_load(x_block + j); + v128_t abs_x = wasm_f32x4_abs(x_vec); + v128_t mask = wasm_f32x4_gt(abs_x, amax_vec); + amax_vec = wasm_v128_bitselect(abs_x, amax_vec, mask); + max_vec = wasm_v128_bitselect(x_vec, max_vec, mask); + } + + // Manual unroll for lane extraction + float amax = wasm_f32x4_extract_lane(amax_vec, 0); + float max_val = wasm_f32x4_extract_lane(max_vec, 0); + #define UPDATE_MAX(lane) \ + { \ + float a = wasm_f32x4_extract_lane(amax_vec, lane); \ + if (a > amax) { \ + amax = a; \ + max_val = wasm_f32x4_extract_lane(max_vec, lane); \ + } \ + } + UPDATE_MAX(1) + UPDATE_MAX(2) + UPDATE_MAX(3) + #undef UPDATE_MAX + + if (amax == 0.0f) { + yc[i].d = 0.0f; + const v128_t zero = wasm_i8x16_splat(0); + for (int j = 0; j < QK_K; j += 16) { + wasm_v128_store(yc[i].qs + j, zero); + } + memset(yc[i].bsums, 0, QK_K/16 * sizeof(int)); + continue; + } + + const float iscale = -127.0f / max_val; + const v128_t scale_vec = wasm_f32x4_splat(iscale); + + // Process 16 elements per iteration + for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) { + // Load and quantize 16 floats + v128_t x0 = wasm_v128_load(x_block + j); + v128_t x1 = wasm_v128_load(x_block + j + 4); + v128_t x2 = wasm_v128_load(x_block + j + 8); + v128_t x3 = wasm_v128_load(x_block + j + 12); + + v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec)); + v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec)); + v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec)); + v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec)); + + // Convert to i32 with saturation + v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0); + v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1); + v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2); + v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3); + + // Pack into 16 i8 values + v128_t i8 = wasm_i8x16_narrow_i16x8( + wasm_i16x8_narrow_i32x4( + wasm_i32x4_min(wasm_i32x4_max(i0, wasm_i32x4_splat(-127)), wasm_i32x4_splat(127)), + wasm_i32x4_min(wasm_i32x4_max(i1, wasm_i32x4_splat(-127)), wasm_i32x4_splat(127)) + ), + wasm_i16x8_narrow_i32x4( + wasm_i32x4_min(wasm_i32x4_max(i2, wasm_i32x4_splat(-127)), wasm_i32x4_splat(127)), + wasm_i32x4_min(wasm_i32x4_max(i3, wasm_i32x4_splat(-127)), wasm_i32x4_splat(127)) + ) + ); + wasm_v128_store(yc[i].qs + j, i8); + + // Calculate bsums using SIMD + v128_t sum16 = wasm_i16x8_add( + wasm_i16x8_extend_low_i8x16(i8), + wasm_i16x8_extend_high_i8x16(i8) + ); + v128_t sum32 = wasm_i32x4_add( + wasm_i32x4_extend_low_i16x8(sum16), + wasm_i32x4_extend_high_i16x8(sum16) + ); + sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1)); + sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2)); + yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0); + } + + yc[i].d = 1.0f / iscale; + } +#else quantize_row_q8_K_ref(x, y, k); +#endif } //===================================== Dot products ================================= @@ -2011,6 +2109,94 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + + const v128_t m4b = wasm_i8x16_splat(0x0F); + const v128_t s8b = wasm_i8x16_splat(0x8); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // Load and process x0 + v128_t v0_0 = wasm_v128_load(x0->qs); + v128_t v0_0l = wasm_v128_and(v0_0, m4b); + v128_t v0_0h = wasm_u8x16_shr(v0_0, 4); + v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b); + v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b); + + // Load y0 vectors + v128_t y0_l = wasm_v128_load(y0->qs); + v128_t y0_h = wasm_v128_load(y0->qs + 16); + + // Extend to i16x8 and compute dot products + v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls); + v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls); + v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs); + v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs); + + v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l); + v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l); + v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h); + v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h); + + v128_t dp0 = wasm_i32x4_add( + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx0l, dy0ll), + wasm_i32x4_dot_i16x8(dx0h, dy0lh) + ), + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx0hl, dy0hl), + wasm_i32x4_dot_i16x8(dx0hh, dy0hh) + ) + ); + + // Load and process x1 + v128_t v0_1 = wasm_v128_load(x1->qs); + v128_t v0_1l = wasm_v128_and(v0_1, m4b); + v128_t v0_1h = wasm_u8x16_shr(v0_1, 4); + v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b); + v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b); + + // Load y1 vectors + v128_t y1_l = wasm_v128_load(y1->qs); + v128_t y1_h = wasm_v128_load(y1->qs + 16); + + // Extend to i16x8 and compute dot products + v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls); + v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls); + v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs); + v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs); + + v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l); + v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l); + v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h); + v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h); + + v128_t dp1 = wasm_i32x4_add( + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx1l, dy1ll), + wasm_i32x4_dot_i16x8(dx1h, dy1lh) + ), + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx1hl, dy1hl), + wasm_i32x4_dot_i16x8(dx1hh, dy1hh) + ) + ); + + // Accumulate results with scaling + float scale0 = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d); + float scale1 = GGML_FP16_TO_FP32(x1->d) * GGML_FP16_TO_FP32(y1->d); + + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0))); + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2696,10 +2882,10 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ v128_t sumv = wasm_f32x4_splat(0.0f); - uint32_t qh; + uint32_t qh_; uint64_t tmp[4]; // TODO: check if unrolling this is better @@ -2710,12 +2896,12 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r const v128_t m4b = wasm_i8x16_splat(0x0F); // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); + memcpy(&qh_, x0->qh, sizeof(qh_)); - tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_1[(qh >> 24) ]; + tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF]; + tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF]; + tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF]; + tmp[3] = table_b2b_1[(qh_ >> 24) ]; const v128_t qhl = wasm_v128_load(tmp + 0); const v128_t qhh = wasm_v128_load(tmp + 2); @@ -3057,12 +3243,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; -#elif defined(__wasm_simd128__) +#elif defined __wasm_simd128__ v128_t sumv = wasm_f32x4_splat(0.0f); float summs = 0.0f; - uint32_t qh; + uint32_t qh_; uint64_t tmp[4]; // TODO: check if unrolling this is better @@ -3075,12 +3261,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r const v128_t m4b = wasm_i8x16_splat(0x0F); // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); + memcpy(&qh_, x0->qh, sizeof(qh_)); - tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_0[(qh >> 24) ]; + tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF]; + tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF]; + tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF]; + tmp[3] = table_b2b_0[(qh_ >> 24) ]; const v128_t qhl = wasm_v128_load(tmp + 0); const v128_t qhh = wasm_v128_load(tmp + 2); @@ -3573,6 +3759,45 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + + for (; ib < nb; ++ib) { + const block_q8_0 * restrict x0 = &x[ib]; + const block_q8_0 * restrict y0 = &y[ib]; + + const v128_t x0_0 = wasm_v128_load(x0->qs); + const v128_t x0_1 = wasm_v128_load(x0->qs + 16); + const v128_t y0_0 = wasm_v128_load(y0->qs); + const v128_t y0_1 = wasm_v128_load(y0->qs + 16); + + // Extend 8-bit to 16-bit + const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0); + const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0); + const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1); + const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1); + + const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0); + const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0); + const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1); + const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1); + + // Compute dot products + const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l); + const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h); + const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l); + const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h); + + // Sum all dot products + const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1)); + + // Convert to float and accumulate + const float scale = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d); + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -4447,6 +4672,106 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc); +#elif defined __wasm_simd128__ + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + // Vectorized summs calculation + v128_t summs_vec = wasm_i32x4_splat(0); + { + v128_t sc_vec = wasm_v128_load(sc); + v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4); + + v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper); + v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper); + + v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]); + v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]); + + summs_vec = wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1), + wasm_i32x4_dot_i16x8(sc_high, bsums2)), + summs_vec + ); + + summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1)); + summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2)); + } + int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0); + + // Vectorized isum calculation + int32_t isum = 0; + const uint8_t * sc_ptr = sc; + const int k_iters = QK_K/128; + + for (int k = 0; k < k_iters; ++k) { + v128_t isum_vec = wasm_i32x4_splat(0); + int shift = 0; + + for (int j = 0; j < 4; ++j) { + const int d0 = (sc_ptr[0] & 0xF); + const int d1 = (sc_ptr[1] & 0xF); + sc_ptr += 2; + + // Process first 16 elements + v128_t q2_0 = wasm_v128_load(q2); + v128_t q8_0 = wasm_v128_load(q8); + v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift); + v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03)); + + // Process next 16 elements + v128_t q2_1 = wasm_v128_load(q2 + 16); + v128_t q8_1 = wasm_v128_load(q8 + 16); + v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift); + v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03)); + + // Calculate dot products + v128_t p0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q8_0), + wasm_i16x8_extend_low_i8x16(q2_bits_0) + ); + v128_t p1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q8_0), + wasm_i16x8_extend_high_i8x16(q2_bits_0) + ); + v128_t p2 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q8_1), + wasm_i16x8_extend_low_i8x16(q2_bits_1) + ); + v128_t p3 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q8_1), + wasm_i16x8_extend_high_i8x16(q2_bits_1) + ); + + // Accumulate scaled results + v128_t scaled = wasm_i32x4_add( + wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)), + wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1)) + ); + + isum_vec = wasm_i32x4_add(isum_vec, scaled); + q8 += 32; + shift += 2; + } + q2 += 32; + + // Horizontal sum of isum_vec + isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1)); + isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2)); + isum += wasm_i32x4_extract_lane(isum_vec, 0); + } + + const float dall = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf += dall * isum - dmin * summs; + } + + *s = sumf; + #elif defined __riscv_v_intrinsic float sumf = 0; @@ -5129,6 +5454,94 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc); +#elif defined __wasm_simd128__ + int8_t aux8[QK_K]; + float sums[8] = {0}; + uint32_t auxs[4]; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + // Process blocks with SIMD + int8_t * a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int shift = 0; shift <= 6; shift += 2) { + v128_t v_m = wasm_i8x16_splat(m); + for (int l = 0; l < 32; l += 16) { + v128_t v_q3 = wasm_v128_load(q3 + l); + v128_t v_shift = wasm_i8x16_shr(v_q3, shift); + v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03)); + + v128_t v_hm = wasm_v128_load(hm + l); + v128_t v_mask = wasm_v128_and(v_hm, v_m); + v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0)); + + v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask))); + wasm_v128_store(a + l, v_low2); + } + a += 32; + m <<= 1; + } + q3 += 32; + } + + // Extract scales + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + const int8_t * scales = (const int8_t *)auxs; + + // SIMD dot product with register accumulators + v128_t v_acc0 = wasm_i32x4_splat(0); + v128_t v_acc1 = wasm_i32x4_splat(0); + a = aux8; + for (int j = 0; j < QK_K/16; ++j) { + const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32); + + // Process 16 elements per iteration + for (int k = 0; k < 2; ++k) { + const v128_t v_q8 = wasm_i16x8_load8x8(q8); + const v128_t v_a = wasm_i16x8_load8x8(a); + + v128_t v_prod = wasm_i16x8_mul(v_q8, v_a); + v_prod = wasm_i16x8_mul(v_prod, v_scale); + + v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod)); + v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod)); + + q8 += 8; + a += 8; + } + } + + // Accumulate results + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const v128_t v_d = wasm_f32x4_splat(d); + v128_t v_sum = wasm_f32x4_add( + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d), + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d) + ); + + // Accumulate into sums vector + wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum)); + } + + // Horizontal sum + v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4)); + sumf = wasm_f32x4_extract_lane(v_sum, 0) + + wasm_f32x4_extract_lane(v_sum, 1) + + wasm_f32x4_extract_lane(v_sum, 2) + + wasm_f32x4_extract_lane(v_sum, 3); + + *s = sumf; + #elif defined __riscv_v_intrinsic uint32_t aux[3]; @@ -5573,88 +5986,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r uint32_t utmp[4]; -#ifdef __ARM_FEATURE_SVE - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); - - memcpy(utmp, x[i].scales, K_SCALE_SIZE); - - uint32x2_t mins8 = { 0 }; - mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); - mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); - - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[0] &= kmask1; - - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - sumf -= dmin * vaddvq_s32(prod); - - const uint8_t * scales = (const uint8_t *)utmp; - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const int vector_length = ggml_cpu_get_sve_cnt()*8; - const svuint8_t m4b = svdup_n_u8(0xf); - const svint32_t mzero = svdup_n_s32(0); - svint32_t sumi1 = svdup_n_s32(0); - svint32_t sumi1_1 = svdup_n_s32(0); - svint32_t sumi1_2 = svdup_n_s32(0); - svint32_t sumi2 = svdup_n_s32(0); - svint32_t sumi2_1 = svdup_n_s32(0); - svint32_t sumi2_2 = svdup_n_s32(0); - switch (vector_length) { - case 128: - { - for (int j = 0; j < QK_K/64; ++j) { - svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b)); - svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; - sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); - q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b)); - q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; - sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); - - q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4)); - q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; - sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); - q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4)); - q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; - sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); - q4 += 32; - } - sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2); - sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2); - sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2))); - } break; - case 256: - case 512: - { - for (int j = 0; j < QK_K/64; ++j) { - const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32; - svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b)); - svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; - sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); - - q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4)); - q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; - sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); - } - sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2))); - } break; - default: - assert(false && "Unsupported vector length"); - break; - } - } - *s = sumf; -#elif __ARM_NEON +#ifdef __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); const int32x4_t mzero = vdupq_n_s32(0); @@ -5717,6 +6049,107 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf; +#elif defined __wasm_simd128__ + const uint8_t * scales = (const uint8_t*)&utmp[0]; + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Process scales and mins + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + // Sum mins * q8sums + int32_t sumi = 0; + const int16_t * restrict q8sums = y[i].bsums; + const uint8_t * m = (const uint8_t *)&utmp[2]; + for (int j = 0; j < 16; j += 2) { + sumi += (q8sums[j] + q8sums[j+1]) * m[j/2]; + } + sumf -= dmin * sumi; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + // Load 64 4-bit weights (32 bytes) + const v128_t q4x0 = wasm_v128_load(q4); + const v128_t q4x1 = wasm_v128_load(q4 + 16); + q4 += 32; + + // Split into low/high nibbles + const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F)); + const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4); + const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F)); + const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4); + + // Load 64 8-bit values (64 bytes) + const v128_t q8x0 = wasm_v128_load(q8); + const v128_t q8x1 = wasm_v128_load(q8 + 16); + const v128_t q8x2 = wasm_v128_load(q8 + 32); + const v128_t q8x3 = wasm_v128_load(q8 + 48); + q8 += 64; + + // Low nibble products + v128_t vacc1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4l0), + wasm_i16x8_extend_low_i8x16(q8x0) + ); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4l0), + wasm_i16x8_extend_high_i8x16(q8x0) + )); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4l1), + wasm_i16x8_extend_low_i8x16(q8x1) + )); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4l1), + wasm_i16x8_extend_high_i8x16(q8x1) + )); + + // High nibble products + v128_t vacc2 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4h0), + wasm_i16x8_extend_low_i8x16(q8x2) + ); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4h0), + wasm_i16x8_extend_high_i8x16(q8x2) + )); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4h1), + wasm_i16x8_extend_low_i8x16(q8x3) + )); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4h1), + wasm_i16x8_extend_high_i8x16(q8x3) + )); + + // Accumulate scaled results + int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) + + wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3); + sumi1 += vacc1_sum * scales[2*j]; + + int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) + + wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3); + sumi2 += vacc2_sum * scales[2*j+1]; + } + + sumf += d * (sumi1 + sumi2); + } + + *s = sumf; + #elif defined __AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); @@ -6469,6 +6902,118 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc) + summs; +#elif defined __wasm_simd128__ + //const uint8_t * scales = (const uint8_t*)&utmp[0]; + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + // Process scales and mins + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + // Sum mins * q8sums + int32_t sumi_mins = 0; + const int16_t * restrict q8sums = y[i].bsums; + const uint8_t * m = (const uint8_t *)&utmp[2]; + for (int j = 0; j < 16; j += 2) { + sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2]; + } + sumf -= dmin * sumi_mins; // Correct subtraction + + v128_t qh0 = wasm_v128_load(qh); + v128_t qh1 = wasm_v128_load(qh + 16); + const uint8_t * sc = (const uint8_t *)utmp; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + const int shift = j * 2; + v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift); + v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift); + + v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4); + v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3); + v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4); + v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3); + + v128_t q5_0 = wasm_v128_load(q5); + v128_t q5_1 = wasm_v128_load(q5 + 16); + q5 += 32; + + v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0); + v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0); + v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1); + v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1); + + v128_t q8_0 = wasm_v128_load(q8); + v128_t q8_1 = wasm_v128_load(q8 + 16); + v128_t q8_2 = wasm_v128_load(q8 + 32); + v128_t q8_3 = wasm_v128_load(q8 + 48); + q8 += 64; + + // Process low quants + v128_t pl0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5l_0), + wasm_i16x8_extend_low_i8x16(q8_0) + ); + pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5l_0), + wasm_i16x8_extend_high_i8x16(q8_0) + )); + v128_t pl1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5l_1), + wasm_i16x8_extend_low_i8x16(q8_1) + ); + pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5l_1), + wasm_i16x8_extend_high_i8x16(q8_1) + )); + v128_t sum_low = wasm_i32x4_add(pl0, pl1); + + // Process high quants + v128_t ph0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5h_0), + wasm_i16x8_extend_low_i8x16(q8_2) + ); + ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5h_0), + wasm_i16x8_extend_high_i8x16(q8_2) + )); + v128_t ph1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5h_1), + wasm_i16x8_extend_low_i8x16(q8_3) + ); + ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5h_1), + wasm_i16x8_extend_high_i8x16(q8_3) + )); + v128_t sum_high = wasm_i32x4_add(ph0, ph1); + + // Accumulate with scale factors + int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) + + wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3); + int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) + + wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3); + + sumi += sl * sc[2*j] + sh * sc[2*j+1]; + } + + sumf += d * sumi; + } + + *s = sumf; + #elif defined __riscv_v_intrinsic const uint8_t * scales = (const uint8_t*)&utmp[0]; @@ -7132,89 +7677,83 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = hsum_float_8(acc); -#elif defined __riscv_v_intrinsic +#elif defined __wasm_simd128__ + int8_t aux8[QK_K] __attribute__((aligned(16))); + int32_t aux32[8] __attribute__((aligned(16))) = {0}; + float sums[8] __attribute__((aligned(16))) = {0}; - float sumf = 0; for (int i = 0; i < nb; ++i) { - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * restrict q6 = x[i].ql; + // Unpack 6-bit quantized data into aux8 (unchanged) + const uint8_t * restrict q4 = x[i].ql; const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - size_t vl; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - - int sum_t = 0; - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - vl = 32; - - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); - - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); - - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); - - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); - - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); - - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); - - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - - vl = 16; - - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + int8_t * a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } - q6 += 64; qh += 32; q8 += 128; is=8; + const int8_t * restrict a_ptr = aux8; + const int8_t * restrict q8 = y[i].qs; + v128_t acc0 = wasm_i32x4_splat(0); + v128_t acc1 = wasm_i32x4_splat(0); + for (int j = 0; j < QK_K/16; ++j) { + const int scale = x[i].scales[j]; + const v128_t vscale = wasm_i32x4_splat(scale); + + // Load 16 elements from a and q8 + const v128_t a_vec = wasm_v128_load(a_ptr); + const v128_t q8_vec = wasm_v128_load(q8); + + // Process low 8 elements + v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec); + v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec); + v128_t prod_low = wasm_i16x8_mul(a_low, q8_low); + v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low); + v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low); + + // Process high 8 elements + v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec); + v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec); + v128_t prod_high = wasm_i16x8_mul(a_high, q8_high); + v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high); + v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high); + + // Scale and accumulate + prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale); + prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale); + prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale); + prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale); + + acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo)); + acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi)); + + a_ptr += 16; + q8 += 16; } - sumf += d * sum_t; + // Store accumulated results + wasm_v128_store(&aux32[0], acc0); + wasm_v128_store(&aux32[4], acc1); + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) { + sums[l] += d * aux32[l]; + } } + // Sum final results + float sumf = 0; + for (int l = 0; l < 8; ++l) { + sumf += sums[l]; + } *s = sumf; #elif defined(__POWER9_VECTOR__) From e5aeb423a5ebd8b391d2ef9415f1f0efa350dffc Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 27 Jan 2025 15:24:46 +0100 Subject: [PATCH 2/6] fix bad merging --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 168 +++++++++++++++++++++++++++- 1 file changed, 167 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 7f11b3567a230..2e5659cf3475d 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -5986,7 +5986,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r uint32_t utmp[4]; -#ifdef __ARM_NEON +#ifdef __ARM_FEATURE_SVE + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, K_SCALE_SIZE); + + uint32x2_t mins8 = { 0 }; + mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); + mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const int vector_length = ggml_cpu_get_sve_cnt()*8; + const svuint8_t m4b = svdup_n_u8(0xf); + const svint32_t mzero = svdup_n_s32(0); + svint32_t sumi1 = svdup_n_s32(0); + svint32_t sumi1_1 = svdup_n_s32(0); + svint32_t sumi1_2 = svdup_n_s32(0); + svint32_t sumi2 = svdup_n_s32(0); + svint32_t sumi2_1 = svdup_n_s32(0); + svint32_t sumi2_2 = svdup_n_s32(0); + switch (vector_length) { + case 128: + { + for (int j = 0; j < QK_K/64; ++j) { + svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b)); + svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + q4 += 32; + } + sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2); + sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2); + sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2))); + } break; + case 256: + case 512: + { + for (int j = 0; j < QK_K/64; ++j) { + const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32; + svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b)); + svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; + sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4)); + q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; + sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + } + sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2))); + } break; + default: + assert(false && "Unsupported vector length"); + break; + } + } + *s = sumf; +#elif defined __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); const int32x4_t mzero = vdupq_n_s32(0); @@ -7756,6 +7837,91 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r } *s = sumf; +#elif defined __riscv_v_intrinsic + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + size_t vl; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + int sum_t = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + vl = 32; + + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q6 += 64; qh += 32; q8 += 128; is=8; + + } + + sumf += d * sum_t; + + } + + *s = sumf; + #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); const vector int v0 = vec_splats((int32_t)0); From 226d59270bfeff5d4b327e96d333aaf5f91d871e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 27 Jan 2025 15:26:01 +0100 Subject: [PATCH 3/6] rm trailing spaces --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 32 ++++++++++++++--------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 2e5659cf3475d..e22c3dbec18b3 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -2190,7 +2190,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r // Accumulate results with scaling float scale0 = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d); float scale1 = GGML_FP16_TO_FP32(x1->d) * GGML_FP16_TO_FP32(y1->d); - + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0))); sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1))); } @@ -4685,19 +4685,19 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r { v128_t sc_vec = wasm_v128_load(sc); v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4); - + v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper); v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper); - + v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]); v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]); - + summs_vec = wasm_i32x4_add( wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1), wasm_i32x4_dot_i16x8(sc_high, bsums2)), summs_vec ); - + summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1)); summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2)); } @@ -4707,11 +4707,11 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r int32_t isum = 0; const uint8_t * sc_ptr = sc; const int k_iters = QK_K/128; - + for (int k = 0; k < k_iters; ++k) { v128_t isum_vec = wasm_i32x4_splat(0); int shift = 0; - + for (int j = 0; j < 4; ++j) { const int d0 = (sc_ptr[0] & 0xF); const int d1 = (sc_ptr[1] & 0xF); @@ -4722,7 +4722,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r v128_t q8_0 = wasm_v128_load(q8); v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift); v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03)); - + // Process next 16 elements v128_t q2_1 = wasm_v128_load(q2 + 16); v128_t q8_1 = wasm_v128_load(q8 + 16); @@ -4752,7 +4752,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)), wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1)) ); - + isum_vec = wasm_i32x4_add(isum_vec, scaled); q8 += 32; shift += 2; @@ -5464,7 +5464,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r const uint8_t * restrict q3 = x[i].qs; const uint8_t * restrict hm = x[i].hmask; const int8_t * restrict q8 = y[i].qs; - + // Process blocks with SIMD int8_t * a = aux8; uint8_t m = 1; @@ -5475,11 +5475,11 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r v128_t v_q3 = wasm_v128_load(q3 + l); v128_t v_shift = wasm_i8x16_shr(v_q3, shift); v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03)); - + v128_t v_hm = wasm_v128_load(hm + l); v128_t v_mask = wasm_v128_and(v_hm, v_m); v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0)); - + v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask))); wasm_v128_store(a + l, v_low2); } @@ -5509,13 +5509,13 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r for (int k = 0; k < 2; ++k) { const v128_t v_q8 = wasm_i16x8_load8x8(q8); const v128_t v_a = wasm_i16x8_load8x8(a); - + v128_t v_prod = wasm_i16x8_mul(v_q8, v_a); v_prod = wasm_i16x8_mul(v_prod, v_scale); - + v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod)); v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod)); - + q8 += 8; a += 8; } @@ -5528,7 +5528,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d), wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d) ); - + // Accumulate into sums vector wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum)); } From 9517aee23c9b2a0506b524813bbab04e47632063 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 27 Jan 2025 16:36:20 +0100 Subject: [PATCH 4/6] rm redundant clamp --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index e22c3dbec18b3..dafa80f2cb6a3 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -1722,14 +1722,8 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { // Pack into 16 i8 values v128_t i8 = wasm_i8x16_narrow_i16x8( - wasm_i16x8_narrow_i32x4( - wasm_i32x4_min(wasm_i32x4_max(i0, wasm_i32x4_splat(-127)), wasm_i32x4_splat(127)), - wasm_i32x4_min(wasm_i32x4_max(i1, wasm_i32x4_splat(-127)), wasm_i32x4_splat(127)) - ), - wasm_i16x8_narrow_i32x4( - wasm_i32x4_min(wasm_i32x4_max(i2, wasm_i32x4_splat(-127)), wasm_i32x4_splat(127)), - wasm_i32x4_min(wasm_i32x4_max(i3, wasm_i32x4_splat(-127)), wasm_i32x4_splat(127)) - ) + wasm_i16x8_narrow_i32x4(i0, i1), + wasm_i16x8_narrow_i32x4(i2, i3) ); wasm_v128_store(yc[i].qs + j, i8); From 10dacabbcd4f53546fafe3f76b8fa0687fe6c883 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 2 Feb 2025 12:04:55 +0100 Subject: [PATCH 5/6] better quantize_row_q8_K Co-authored-by: camel-cdr --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 39 +++++++++++------------------ 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index dafa80f2cb6a3..bef43a29eb0a1 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -1660,33 +1660,22 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { for (int i = 0; i < nb; i++) { const float * x_block = x + i * QK_K; - v128_t amax_vec = wasm_f32x4_splat(0.0f); - v128_t max_vec = wasm_f32x4_splat(0.0f); - // Vectorized max abs value search - for (int j = 0; j < QK_K; j += 4) { + v128_t min_vec = wasm_v128_load(x_block); + v128_t max_vec = min_vec; + + for (int j = 4; j < QK_K; j += 4) { v128_t x_vec = wasm_v128_load(x_block + j); - v128_t abs_x = wasm_f32x4_abs(x_vec); - v128_t mask = wasm_f32x4_gt(abs_x, amax_vec); - amax_vec = wasm_v128_bitselect(abs_x, amax_vec, mask); - max_vec = wasm_v128_bitselect(x_vec, max_vec, mask); - } - - // Manual unroll for lane extraction - float amax = wasm_f32x4_extract_lane(amax_vec, 0); - float max_val = wasm_f32x4_extract_lane(max_vec, 0); - #define UPDATE_MAX(lane) \ - { \ - float a = wasm_f32x4_extract_lane(amax_vec, lane); \ - if (a > amax) { \ - amax = a; \ - max_val = wasm_f32x4_extract_lane(max_vec, lane); \ - } \ - } - UPDATE_MAX(1) - UPDATE_MAX(2) - UPDATE_MAX(3) - #undef UPDATE_MAX + max_vec = wasm_f32x4_pmax(max_vec, x_vec); + min_vec = wasm_f32x4_pmin(min_vec, x_vec); + } + max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1)); + max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2)); + min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1)); + min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2)); + float max = wasm_f32x4_extract_lane(max_vec, 0); + float min = wasm_f32x4_extract_lane(min_vec, 0); + float amax = -min > max ? min : max; if (amax == 0.0f) { yc[i].d = 0.0f; From 2ab608b1526772d4586951a1934bf40372358ad3 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 8 Feb 2025 22:19:44 +0100 Subject: [PATCH 6/6] remove memset that causes buffer overflow Co-authored-by: camel-cdr --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index f859469cf4c41..1b4bd66e80c68 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -1674,11 +1674,10 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { for (int j = 0; j < QK_K; j += 16) { wasm_v128_store(yc[i].qs + j, zero); } - memset(yc[i].bsums, 0, QK_K/16 * sizeof(int)); continue; } - const float iscale = -127.0f / max_val; + const float iscale = -127.0f / amax; const v128_t scale_vec = wasm_f32x4_splat(iscale); // Process 16 elements per iteration