From ecd85fcbf7cfd72a08e5368306f6424a3424628c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81kos=20Hadnagy?= Date: Mon, 6 Jan 2025 23:39:07 +0100 Subject: [PATCH] fix(library): Propagate upstream Marlin kernel fix Increase shared mem. size Fix shared mem. size, re-activate test Remove debugging-related stuff --- .../cuda/marlin/marlin_cuda_kernel.cu | 22 ++++++++++--------- .../test_marlin_int4_weight_qbits_tensor.py | 5 ++--- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu b/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu index d34f3340..b18b0469 100644 --- a/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu +++ b/optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu @@ -374,8 +374,10 @@ __global__ void Marlin( int4* sh_a = sh; int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_s = sh_b + (stages * b_sh_stage); + int4* sh_red = sh_s + (stages * s_sh_stage); // ADDED: shared memory storage for scaled zero points - int4* sh_sz = sh_s + (stages * s_sh_stage); + int4* sh_sz = sh_red + (stages * s_sh_stage); + // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2]; @@ -499,13 +501,13 @@ __global__ void Marlin( for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); #pragma unroll for (int k = 0; k < 4; k++) reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } - sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + sh_red[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); @@ -513,7 +515,7 @@ __global__ void Marlin( if (red_idx == 0) { #pragma unroll for (int i = 0; i < 4 * 2; i++) { - float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); #pragma unroll for (int j = 0; j < 4; j++) reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; @@ -548,7 +550,7 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], + &sh_red[c_sh_wr + c_sh_wr_delta * i], &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m ); @@ -561,7 +563,7 @@ __global__ void Marlin( for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; #pragma unroll for (int j = 0; j < 2 * 4; j++) { reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float( @@ -605,7 +607,7 @@ __global__ void Marlin( half2 res = __halves2half2(__float2half(c0), __float2half(c1)); if (group_blocks == -1) // for per-column quantization we finally apply the scale here res = __hmul2(res, s[0]); - ((half2*) sh)[idx] = res; + ((half2*) sh_red)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll @@ -626,7 +628,7 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; + C[c_gl_wr] = sh_red[c_sh_rd]; c_gl_wr += c_gl_wr_delta; c_sh_rd += c_sh_rd_delta; } @@ -726,7 +728,7 @@ __global__ void Marlin( // latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. const int THREADS = 256; const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) +const int SHARED_MEM = 164 * 1000; // max shared memory on compute capability 8.0 // ADDED: add scaled zero pointer #define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ diff --git a/test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py b/test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py index a44db5b2..c1a8d93d 100644 --- a/test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py +++ b/test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py @@ -131,15 +131,14 @@ def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, in_features, ) -@pytest.mark.xfail(reason="Bug in Marlin kernel", strict=False) +#Tests previous Marlin kernel bug: https://github.com/huggingface/optimum-quanto/issues/332 @pytest.mark.skipif( not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8, reason="CUDA >= sm80 not available", ) @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("tokens", [48, 64]) -# @pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384]) -@pytest.mark.parametrize("in_features", [4096, 16384]) +@pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384]) @pytest.mark.parametrize("out_features", [2048, 4096]) def test_marlin_int4_weight_qbits_tensor_linear_failing(batch_size, tokens, in_features, out_features): dtype = torch.float16