Skip to content

Commit

Permalink
Make zero_start_index_M optional for dynamic BF16 Grouped Gemm (pytor…
Browse files Browse the repository at this point in the history
…ch#3553)

Summary:
Pull Request resolved: pytorch#3553

X-link: facebookresearch/FBGEMM#639

There is some value in being able to invoke a grouped gemm and directly return a single unified tensor rather than an array of tensors, especially if the goal is to concatenate the groups immediately after returning, as concat adds a copy.

This diff makes zero_start_index_M optional for the dynamic version of bf16 grouped gemm. When not provided, we allocate a single coalesced tensor for all outputs. While dynamic is a bit of a misnomer in this case, its the most convenient way to allow returning a tensor rather than a list.

Reviewed By: jasonjk-park, jianyuh

Differential Revision: D67884826

fbshipit-source-id: d2c0f83f7c55b2dc41000bb1f4dc27fbdf82515d
  • Loading branch information
jwfromm authored and facebook-github-bot committed Jan 8, 2025
1 parent e4880a0 commit 8b748c6
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,7 @@ at::Tensor get_grouped_kernel_args(

if (zero_start_index_M.has_value()) {
set_dynamic_kernel_args(
kernel_args,
A,
B,
output,
zero_start_index_M.value());
kernel_args, A, B, output, zero_start_index_M.value());
} else {
set_static_kernel_args(kernel_args, A, B, output);
}
Expand All @@ -294,23 +290,18 @@ std::vector<at::Tensor> bf16bf16bf16_grouped(
// Check that input datatypes are valid.
// First confirm that there are the same number of groups in all inputs.
TORCH_CHECK(
A.size() == B.size(),
"A and B must have the same number of groups.");
A.size() == B.size(), "A and B must have the same number of groups.");
int group_count = A.size();
// Iterate over inputs and check they are valid.
for (at::Tensor a : A) {
TORCH_CHECK(a.is_cuda() && a.is_contiguous());
TORCH_CHECK(a.dim() == 2, "Inputs must be 2D.");
TORCH_CHECK(
a.dtype() == at::kBFloat16,
"Inputs must be type bfloat16.");
TORCH_CHECK(a.dtype() == at::kBFloat16, "Inputs must be type bfloat16.");
}
for (at::Tensor b : B) {
TORCH_CHECK(b.is_cuda() && b.is_contiguous());
TORCH_CHECK(b.dim() == 2, "Inputs must be 2D.");
TORCH_CHECK(
b.dtype() == at::kBFloat16,
"Inputs must be type bfloat16.");
TORCH_CHECK(b.dtype() == at::kBFloat16, "Inputs must be type bfloat16.");
}

std::vector<at::Tensor> Y;
Expand Down Expand Up @@ -340,8 +331,7 @@ std::vector<at::Tensor> bf16bf16bf16_grouped(
}

// Prepare kernel arguments by copying them to the proper device location.
at::Tensor kernel_args = get_grouped_kernel_args(
A, B, std::nullopt, Y);
at::Tensor kernel_args = get_grouped_kernel_args(A, B, std::nullopt, Y);

// Perform shape lookup to find best kernel.
// We use the largest of each shape for heuristics.
Expand All @@ -353,50 +343,70 @@ std::vector<at::Tensor> bf16bf16bf16_grouped(
MaxN = max(MaxN, B[i].size(0));
MaxK = max(MaxK, A[i].size(1));
}
GroupedKernel selected_kernel =
grouped_heuristic_dispatch(MaxM, MaxN, MaxK);
GroupedKernel selected_kernel = grouped_heuristic_dispatch(MaxM, MaxN, MaxK);
return selected_kernel(A, B, kernel_args, Y);
}

at::Tensor bf16bf16bf16_grouped_dynamic(
at::TensorList A,
at::TensorList B,
at::Tensor zero_start_index_M) {
std::optional<at::Tensor> zero_start_index_M = std::nullopt) {
// Check that input datatypes are valid.
// First confirm that there are the same number of groups in all inputs.
TORCH_CHECK(
A.size() == B.size(),
"A and B must have the same number of groups.");
A.size() == B.size(), "A and B must have the same number of groups.");
int group_count = A.size();
// Iterate over inputs and check they are valid.
for (at::Tensor a : A) {
TORCH_CHECK(a.is_cuda() && a.is_contiguous());
TORCH_CHECK(a.dim() == 2, "Inputs must be 2D.");
TORCH_CHECK(
a.dtype() == at::kBFloat16,
"Inputs must be type bfloat16.");
TORCH_CHECK(a.dtype() == at::kBFloat16, "Inputs must be type bfloat16.");
}
for (at::Tensor b : B) {
TORCH_CHECK(b.is_cuda() && b.is_contiguous());
TORCH_CHECK(b.dim() == 2, "Inputs must be 2D.");
TORCH_CHECK(
b.dtype() == at::kBFloat16,
"Inputs must be type bfloat16.");
TORCH_CHECK(b.dtype() == at::kBFloat16, "Inputs must be type bfloat16.");
}

std::vector<at::Tensor> Y;
int M = A[0].size(0);
at::Tensor Y_full;
int N = B[0].size(0);
// Fill output with zeros to simplify integration. This prevents nans from
// showing up in the tensor.
at::Tensor Y_full =
at::zeros({group_count, M, N}, A[0].options().dtype(at::kBFloat16));
// Split the output into groups.
Y = at::unbind(Y_full, 0);
int K = A[0].size(1);

if (zero_start_index_M.has_value()) {
int M = A[0].size(0);
// Fill output with zeros to simplify integration. This prevents nans from
// showing up in the tensor.
Y_full =
at::zeros({group_count, M, N}, A[0].options().dtype(at::kBFloat16));
// Split the output into groups.
Y = at::unbind(Y_full, 0);
} else {
// If not provided, we try to allocate a single blob that can store each
// group.
int total_M = 0;
std::vector<int> group_sizes = {};
for (int i = 0; i < group_count; i++) {
TORCH_CHECK(
A[i].size(1) == K && B[i].size(0) == N,
"Dynamic grouped gemm requires fixed N and K.");
int group_M = A[i].size(0);
total_M += group_M;
group_sizes.push_back(group_M);
}
// Allocate a contiguous array for all groups.
Y_full = at::empty({total_M, N}, A[0].options().dtype(at::kBFloat16));
// Split the full array into appropriate groups.
// We do this with narrow to make sure there are no extra copies.
int offset = 0;
for (int size : group_sizes) {
Y.push_back(Y_full.narrow(0, offset, size));
offset += size;
}
}

// Prepare kernel arguments by copying them to the proper device location.
at::Tensor kernel_args = get_grouped_kernel_args(
A, B, zero_start_index_M, Y);
at::Tensor kernel_args = get_grouped_kernel_args(A, B, zero_start_index_M, Y);

// Perform shape lookup to find best kernel.
// We use the largest of each shape for heuristics.
Expand All @@ -408,8 +418,7 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
MaxN = max(MaxN, B[i].size(0));
MaxK = max(MaxK, A[i].size(1));
}
GroupedKernel selected_kernel =
grouped_heuristic_dispatch(MaxM, MaxN, MaxK);
GroupedKernel selected_kernel = grouped_heuristic_dispatch(MaxM, MaxN, MaxK);
// Run kernel to populate output.
selected_kernel(A, B, kernel_args, Y);
// Return coalesced view of output tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,20 +499,47 @@ std::vector<at::Tensor> bf16bf16bf16_grouped(
at::Tensor bf16bf16bf16_grouped_dynamic(
at::TensorList x_group, // BF16
at::TensorList w_group, // BF16
at::Tensor zero_start_index_M) {
std::vector<at::Tensor> output_tensor;
std::optional<at::Tensor> zero_start_index_M = std::nullopt) {
std::vector<at::Tensor> output_groups;
at::Tensor output_full;
int problem_count = x_group.size();
int M = x_group[0].size(0);
int N = w_group[0].size(0);
// Fill output with zeros to simplify integration. This prevents nans from
// showing up in the tensor.
at::Tensor output_full = at::zeros(
{problem_count, M, N}, x_group[0].options().dtype(at::kBFloat16));
// Split the output into groups.
output_tensor = at::unbind(output_full, 0);
int K = x_group[0].size(1);
if (zero_start_index_M.has_value()) {
int M = x_group[0].size(0);
// Fill output with zeros to simplify integration. This prevents nans from
// showing up in the tensor.
output_full = at::zeros(
{problem_count, M, N}, x_group[0].options().dtype(at::kBFloat16));
// Split the output into groups.
output_groups = at::unbind(output_full, 0);
} else {
// If not provided, we try to allocate a single blob that can store each
// group.
int total_M = 0;
std::vector<int> group_sizes = {};
for (int i = 0; i < problem_count; i++) {
TORCH_CHECK(
x_group[i].size(1) == K && w_group[i].size(0) == N,
"Dynamic grouped gemm requires fixed N and K.");
int group_M = x_group[i].size(0);
total_M += group_M;
group_sizes.push_back(group_M);
}
// Allocate a contiguous array for all groups.
output_full =
at::empty({total_M, N}, x_group[0].options().dtype(at::kBFloat16));
// Split the full array into appropriate groups.
// We do this with narrow to make sure there are no extra copies.
int offset = 0;
for (int size : group_sizes) {
output_groups.push_back(output_full.narrow(0, offset, size));
offset += size;
}
}
// Run kernel to populate output tensor.
dispatch_bf16_grouped_kernel(
x_group, w_group, output_tensor, zero_start_index_M);
x_group, w_group, output_groups, zero_start_index_M);
// Return coalesced view of output.
return output_full;
}
Expand All @@ -530,7 +557,7 @@ std::vector<at::Tensor> bf16bf16bf16_grouped(
at::Tensor bf16bf16bf16_grouped_dynamic(
at::TensorList /* x_group */, // BF16
at::TensorList /* w_group */, // BF16
at::Tensor /* zero_start_index_M */) {
std::optional<at::Tensor> /* zero_start_index_M */) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}
Expand Down
6 changes: 3 additions & 3 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ std::vector<at::Tensor> bf16bf16bf16_grouped(
at::Tensor bf16bf16bf16_grouped_dynamic(
at::TensorList X,
at::TensorList W,
at::Tensor zero_start_index_M);
std::optional<at::Tensor> zero_start_index_M = std::nullopt);
at::Tensor f8f8bf16_rowwise(
at::Tensor XQ,
at::Tensor WQ,
Expand Down Expand Up @@ -209,7 +209,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"bf16bf16bf16_grouped(Tensor[] X, Tensor[] W, Tensor[](a!)? output=None) -> Tensor[]");
m.def(
"bf16bf16bf16_grouped_dynamic(Tensor[] X, Tensor[] W, Tensor zero_start_index_M) -> Tensor");
"bf16bf16bf16_grouped_dynamic(Tensor[] X, Tensor[] W, Tensor? zero_start_index_M=None) -> Tensor");
m.def(
"f8f8bf16_blockwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, int block_m=128, int block_n=128, int block_k=128) -> Tensor");
m.def(
Expand Down Expand Up @@ -506,7 +506,7 @@ std::vector<at::Tensor> bf16bf16bf16_grouped_meta(
at::Tensor bf16bf16bf16_grouped_dynamic_meta(
at::TensorList X,
at::TensorList W,
at::Tensor /* zero_start_index_M = std::nullopt */) {
std::optional<at::Tensor> /* zero_start_index_M = std::nullopt */) {
int G = X.size();
int M = X[0].size(0);
int N = W[0].size(0);
Expand Down
21 changes: 13 additions & 8 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,8 +734,9 @@ def fp8_loopover_bmm(
M=st.sampled_from([2048, 3584]),
N=st.sampled_from([1024, 6144]),
K=st.sampled_from([512, 3584]),
use_cudagraph=st.sampled_from([True, False]),
use_padding_zeros=st.sampled_from([True, False]),
use_cudagraph=st.booleans(),
use_padding_zeros=st.booleans(),
use_dynamic=st.booleans(),
)
def test_fp8_grouped_gemm(
self,
Expand All @@ -745,6 +746,7 @@ def test_fp8_grouped_gemm(
K: int,
use_cudagraph: bool,
use_padding_zeros: bool,
use_dynamic: bool,
) -> None:
ms = (
torch.randint(
Expand All @@ -755,8 +757,8 @@ def test_fp8_grouped_gemm(
)
* 64
)
# When using padding, Ns and Ks should be fixed.
if use_padding_zeros:
# When using padding or the dynamic kernel, Ns and Ks should be fixed.
if use_padding_zeros or use_dynamic:
ns = [N] * G
ks = [K] * G
# Otherwise, any value is supported.
Expand Down Expand Up @@ -828,13 +830,13 @@ def test_fp8_grouped_gemm(

# BF16 grouped gemm kernel
bf16_args = (
[x_group, w_group, zero_start_index_M]
if use_padding_zeros
[x_group, w_group, zero_start_index_M if use_padding_zeros else None]
if use_dynamic
else [x_group, w_group]
)
bf16_op = (
torch.ops.fbgemm.bf16bf16bf16_grouped_dynamic
if use_padding_zeros
if use_dynamic
else torch.ops.fbgemm.bf16bf16bf16_grouped
)
if use_cudagraph:
Expand All @@ -850,7 +852,10 @@ def test_fp8_grouped_gemm(

# View output as list if needed.
if not isinstance(y_bf16_group, (tuple, list)):
y_bf16_group = torch.unbind(y_bf16_group)
if y_bf16_group.ndim == 2:
y_bf16_group = torch.split(y_bf16_group, tuple(ms.tolist()), dim=0)
else:
y_bf16_group = torch.unbind(y_bf16_group)

# BF16 loopover gemm reference
y_group_ref = []
Expand Down

0 comments on commit 8b748c6

Please sign in to comment.