diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip index 54a60a009..a00564b90 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip @@ -276,11 +276,7 @@ at::Tensor get_grouped_kernel_args( if (zero_start_index_M.has_value()) { set_dynamic_kernel_args( - kernel_args, - A, - B, - output, - zero_start_index_M.value()); + kernel_args, A, B, output, zero_start_index_M.value()); } else { set_static_kernel_args(kernel_args, A, B, output); } @@ -294,23 +290,18 @@ std::vector bf16bf16bf16_grouped( // Check that input datatypes are valid. // First confirm that there are the same number of groups in all inputs. TORCH_CHECK( - A.size() == B.size(), - "A and B must have the same number of groups."); + A.size() == B.size(), "A and B must have the same number of groups."); int group_count = A.size(); // Iterate over inputs and check they are valid. for (at::Tensor a : A) { TORCH_CHECK(a.is_cuda() && a.is_contiguous()); TORCH_CHECK(a.dim() == 2, "Inputs must be 2D."); - TORCH_CHECK( - a.dtype() == at::kBFloat16, - "Inputs must be type bfloat16."); + TORCH_CHECK(a.dtype() == at::kBFloat16, "Inputs must be type bfloat16."); } for (at::Tensor b : B) { TORCH_CHECK(b.is_cuda() && b.is_contiguous()); TORCH_CHECK(b.dim() == 2, "Inputs must be 2D."); - TORCH_CHECK( - b.dtype() == at::kBFloat16, - "Inputs must be type bfloat16."); + TORCH_CHECK(b.dtype() == at::kBFloat16, "Inputs must be type bfloat16."); } std::vector Y; @@ -340,8 +331,7 @@ std::vector bf16bf16bf16_grouped( } // Prepare kernel arguments by copying them to the proper device location. - at::Tensor kernel_args = get_grouped_kernel_args( - A, B, std::nullopt, Y); + at::Tensor kernel_args = get_grouped_kernel_args(A, B, std::nullopt, Y); // Perform shape lookup to find best kernel. // We use the largest of each shape for heuristics. @@ -353,50 +343,70 @@ std::vector bf16bf16bf16_grouped( MaxN = max(MaxN, B[i].size(0)); MaxK = max(MaxK, A[i].size(1)); } - GroupedKernel selected_kernel = - grouped_heuristic_dispatch(MaxM, MaxN, MaxK); + GroupedKernel selected_kernel = grouped_heuristic_dispatch(MaxM, MaxN, MaxK); return selected_kernel(A, B, kernel_args, Y); } at::Tensor bf16bf16bf16_grouped_dynamic( at::TensorList A, at::TensorList B, - at::Tensor zero_start_index_M) { + std::optional zero_start_index_M = std::nullopt) { // Check that input datatypes are valid. // First confirm that there are the same number of groups in all inputs. TORCH_CHECK( - A.size() == B.size(), - "A and B must have the same number of groups."); + A.size() == B.size(), "A and B must have the same number of groups."); int group_count = A.size(); // Iterate over inputs and check they are valid. for (at::Tensor a : A) { TORCH_CHECK(a.is_cuda() && a.is_contiguous()); TORCH_CHECK(a.dim() == 2, "Inputs must be 2D."); - TORCH_CHECK( - a.dtype() == at::kBFloat16, - "Inputs must be type bfloat16."); + TORCH_CHECK(a.dtype() == at::kBFloat16, "Inputs must be type bfloat16."); } for (at::Tensor b : B) { TORCH_CHECK(b.is_cuda() && b.is_contiguous()); TORCH_CHECK(b.dim() == 2, "Inputs must be 2D."); - TORCH_CHECK( - b.dtype() == at::kBFloat16, - "Inputs must be type bfloat16."); + TORCH_CHECK(b.dtype() == at::kBFloat16, "Inputs must be type bfloat16."); } std::vector Y; - int M = A[0].size(0); + at::Tensor Y_full; int N = B[0].size(0); - // Fill output with zeros to simplify integration. This prevents nans from - // showing up in the tensor. - at::Tensor Y_full = - at::zeros({group_count, M, N}, A[0].options().dtype(at::kBFloat16)); - // Split the output into groups. - Y = at::unbind(Y_full, 0); + int K = A[0].size(1); + + if (zero_start_index_M.has_value()) { + int M = A[0].size(0); + // Fill output with zeros to simplify integration. This prevents nans from + // showing up in the tensor. + Y_full = + at::zeros({group_count, M, N}, A[0].options().dtype(at::kBFloat16)); + // Split the output into groups. + Y = at::unbind(Y_full, 0); + } else { + // If not provided, we try to allocate a single blob that can store each + // group. + int total_M = 0; + std::vector group_sizes = {}; + for (int i = 0; i < group_count; i++) { + TORCH_CHECK( + A[i].size(1) == K && B[i].size(0) == N, + "Dynamic grouped gemm requires fixed N and K."); + int group_M = A[i].size(0); + total_M += group_M; + group_sizes.push_back(group_M); + } + // Allocate a contiguous array for all groups. + Y_full = at::empty({total_M, N}, A[0].options().dtype(at::kBFloat16)); + // Split the full array into appropriate groups. + // We do this with narrow to make sure there are no extra copies. + int offset = 0; + for (int size : group_sizes) { + Y.push_back(Y_full.narrow(0, offset, size)); + offset += size; + } + } // Prepare kernel arguments by copying them to the proper device location. - at::Tensor kernel_args = get_grouped_kernel_args( - A, B, zero_start_index_M, Y); + at::Tensor kernel_args = get_grouped_kernel_args(A, B, zero_start_index_M, Y); // Perform shape lookup to find best kernel. // We use the largest of each shape for heuristics. @@ -408,8 +418,7 @@ at::Tensor bf16bf16bf16_grouped_dynamic( MaxN = max(MaxN, B[i].size(0)); MaxK = max(MaxK, A[i].size(1)); } - GroupedKernel selected_kernel = - grouped_heuristic_dispatch(MaxM, MaxN, MaxK); + GroupedKernel selected_kernel = grouped_heuristic_dispatch(MaxM, MaxN, MaxK); // Run kernel to populate output. selected_kernel(A, B, kernel_args, Y); // Return coalesced view of output tensor. diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu index ab38fb548..5e051cd73 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu @@ -499,20 +499,47 @@ std::vector bf16bf16bf16_grouped( at::Tensor bf16bf16bf16_grouped_dynamic( at::TensorList x_group, // BF16 at::TensorList w_group, // BF16 - at::Tensor zero_start_index_M) { - std::vector output_tensor; + std::optional zero_start_index_M = std::nullopt) { + std::vector output_groups; + at::Tensor output_full; int problem_count = x_group.size(); - int M = x_group[0].size(0); int N = w_group[0].size(0); - // Fill output with zeros to simplify integration. This prevents nans from - // showing up in the tensor. - at::Tensor output_full = at::zeros( - {problem_count, M, N}, x_group[0].options().dtype(at::kBFloat16)); - // Split the output into groups. - output_tensor = at::unbind(output_full, 0); + int K = x_group[0].size(1); + if (zero_start_index_M.has_value()) { + int M = x_group[0].size(0); + // Fill output with zeros to simplify integration. This prevents nans from + // showing up in the tensor. + output_full = at::zeros( + {problem_count, M, N}, x_group[0].options().dtype(at::kBFloat16)); + // Split the output into groups. + output_groups = at::unbind(output_full, 0); + } else { + // If not provided, we try to allocate a single blob that can store each + // group. + int total_M = 0; + std::vector group_sizes = {}; + for (int i = 0; i < problem_count; i++) { + TORCH_CHECK( + x_group[i].size(1) == K && w_group[i].size(0) == N, + "Dynamic grouped gemm requires fixed N and K."); + int group_M = x_group[i].size(0); + total_M += group_M; + group_sizes.push_back(group_M); + } + // Allocate a contiguous array for all groups. + output_full = + at::empty({total_M, N}, x_group[0].options().dtype(at::kBFloat16)); + // Split the full array into appropriate groups. + // We do this with narrow to make sure there are no extra copies. + int offset = 0; + for (int size : group_sizes) { + output_groups.push_back(output_full.narrow(0, offset, size)); + offset += size; + } + } // Run kernel to populate output tensor. dispatch_bf16_grouped_kernel( - x_group, w_group, output_tensor, zero_start_index_M); + x_group, w_group, output_groups, zero_start_index_M); // Return coalesced view of output. return output_full; } @@ -530,7 +557,7 @@ std::vector bf16bf16bf16_grouped( at::Tensor bf16bf16bf16_grouped_dynamic( at::TensorList /* x_group */, // BF16 at::TensorList /* w_group */, // BF16 - at::Tensor /* zero_start_index_M */) { + std::optional /* zero_start_index_M */) { throw std::runtime_error( "CUDA version is older than 12.0"); // requires CUDA>=12 } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 9d24ecb9d..34ea4ec42 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -68,7 +68,7 @@ std::vector bf16bf16bf16_grouped( at::Tensor bf16bf16bf16_grouped_dynamic( at::TensorList X, at::TensorList W, - at::Tensor zero_start_index_M); + std::optional zero_start_index_M = std::nullopt); at::Tensor f8f8bf16_rowwise( at::Tensor XQ, at::Tensor WQ, @@ -209,7 +209,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "bf16bf16bf16_grouped(Tensor[] X, Tensor[] W, Tensor[](a!)? output=None) -> Tensor[]"); m.def( - "bf16bf16bf16_grouped_dynamic(Tensor[] X, Tensor[] W, Tensor zero_start_index_M) -> Tensor"); + "bf16bf16bf16_grouped_dynamic(Tensor[] X, Tensor[] W, Tensor? zero_start_index_M=None) -> Tensor"); m.def( "f8f8bf16_blockwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, int block_m=128, int block_n=128, int block_k=128) -> Tensor"); m.def( @@ -506,7 +506,7 @@ std::vector bf16bf16bf16_grouped_meta( at::Tensor bf16bf16bf16_grouped_dynamic_meta( at::TensorList X, at::TensorList W, - at::Tensor /* zero_start_index_M = std::nullopt */) { + std::optional /* zero_start_index_M = std::nullopt */) { int G = X.size(); int M = X[0].size(0); int N = W[0].size(0); diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index e2126b5a0..ac21c5706 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -734,8 +734,9 @@ def fp8_loopover_bmm( M=st.sampled_from([2048, 3584]), N=st.sampled_from([1024, 6144]), K=st.sampled_from([512, 3584]), - use_cudagraph=st.sampled_from([True, False]), - use_padding_zeros=st.sampled_from([True, False]), + use_cudagraph=st.booleans(), + use_padding_zeros=st.booleans(), + use_dynamic=st.booleans(), ) def test_fp8_grouped_gemm( self, @@ -745,6 +746,7 @@ def test_fp8_grouped_gemm( K: int, use_cudagraph: bool, use_padding_zeros: bool, + use_dynamic: bool, ) -> None: ms = ( torch.randint( @@ -755,8 +757,8 @@ def test_fp8_grouped_gemm( ) * 64 ) - # When using padding, Ns and Ks should be fixed. - if use_padding_zeros: + # When using padding or the dynamic kernel, Ns and Ks should be fixed. + if use_padding_zeros or use_dynamic: ns = [N] * G ks = [K] * G # Otherwise, any value is supported. @@ -828,13 +830,13 @@ def test_fp8_grouped_gemm( # BF16 grouped gemm kernel bf16_args = ( - [x_group, w_group, zero_start_index_M] - if use_padding_zeros + [x_group, w_group, zero_start_index_M if use_padding_zeros else None] + if use_dynamic else [x_group, w_group] ) bf16_op = ( torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic - if use_padding_zeros + if use_dynamic else torch.ops.fbgemm.bf16bf16bf16_grouped ) if use_cudagraph: @@ -850,7 +852,10 @@ def test_fp8_grouped_gemm( # View output as list if needed. if not isinstance(y_bf16_group, (tuple, list)): - y_bf16_group = torch.unbind(y_bf16_group) + if y_bf16_group.ndim == 2: + y_bf16_group = torch.split(y_bf16_group, tuple(ms.tolist()), dim=0) + else: + y_bf16_group = torch.unbind(y_bf16_group) # BF16 loopover gemm reference y_group_ref = []