Skip to content

Commit

Permalink
fix(library): Propagate upstream Marlin kernel fix
Browse files Browse the repository at this point in the history
Increase shared mem. size

Fix shared mem. size, re-activate test

Remove debugging-related stuff
  • Loading branch information
ahadnagy committed Jan 6, 2025
1 parent 66dca87 commit ecd85fc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
22 changes: 12 additions & 10 deletions optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,10 @@ __global__ void Marlin(
int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_s = sh_b + (stages * b_sh_stage);
int4* sh_red = sh_s + (stages * s_sh_stage);
// ADDED: shared memory storage for scaled zero points
int4* sh_sz = sh_s + (stages * s_sh_stage);
int4* sh_sz = sh_red + (stages * s_sh_stage);

// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2];
Expand Down Expand Up @@ -499,21 +501,21 @@ __global__ void Marlin(
for (int j = 0; j < 4 * 2; j++) {
int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) {
float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
float* c_rd = reinterpret_cast<float*>(&sh_red[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
#pragma unroll
for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k];
}
sh[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
sh_red[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
}
}
__syncthreads();
}
if (red_idx == 0) {
#pragma unroll
for (int i = 0; i < 4 * 2; i++) {
float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
float* c_rd = reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += c_rd[j];
Expand Down Expand Up @@ -548,7 +550,7 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(
&sh[c_sh_wr + c_sh_wr_delta * i],
&sh_red[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m
);
Expand All @@ -561,7 +563,7 @@ __global__ void Marlin(
for (int i = 0; i < thread_m_blocks * 4; i++) {
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
if (!first) {
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float(
Expand Down Expand Up @@ -605,7 +607,7 @@ __global__ void Marlin(
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
if (group_blocks == -1) // for per-column quantization we finally apply the scale here
res = __hmul2(res, s[0]);
((half2*) sh)[idx] = res;
((half2*) sh_red)[idx] = res;
};
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
Expand All @@ -626,7 +628,7 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {
if (c_gl_wr < c_gl_wr_end) {
C[c_gl_wr] = sh[c_sh_rd];
C[c_gl_wr] = sh_red[c_sh_rd];
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
Expand Down Expand Up @@ -726,7 +728,7 @@ __global__ void Marlin(
// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles.
const int THREADS = 256;
const int STAGES = 4; // 4 pipeline stages fit into shared memory
const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
const int SHARED_MEM = 164 * 1000; // max shared memory on compute capability 8.0

// ADDED: add scaled zero pointer
#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,14 @@ def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, in_features,
)


@pytest.mark.xfail(reason="Bug in Marlin kernel", strict=False)
#Tests previous Marlin kernel bug: https://github.com/huggingface/optimum-quanto/issues/332
@pytest.mark.skipif(
not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8,
reason="CUDA >= sm80 not available",
)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("tokens", [48, 64])
# @pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384])
@pytest.mark.parametrize("in_features", [4096, 16384])
@pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384])
@pytest.mark.parametrize("out_features", [2048, 4096])
def test_marlin_int4_weight_qbits_tensor_linear_failing(batch_size, tokens, in_features, out_features):
dtype = torch.float16
Expand Down

0 comments on commit ecd85fc

Please sign in to comment.