Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rocm jaxlib v0.4.30 qa cleanup #35

Merged
merged 10 commits into from
Aug 29, 2024
35 changes: 35 additions & 0 deletions third_party/llvm/rocdl_shuffle_down.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
From a46b9e979ffa523bfed61487a2404e1f48140288 Mon Sep 17 00:00:00 2001
From: Dragan Mladjenovic <[email protected]>
Date: Fri, 29 Mar 2024 12:27:36 +0000
Subject: [PATCH] Support gpu::ShuffleMode::DOWN lowering

---
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e2cb3687d872..9317e30290c6 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -140,7 +140,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);

auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value width = adaptor.getWidth();
Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
@@ -151,6 +151,10 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
// TODO: Use ds_swizzle for XOR when step/offsets are constants for better
// perf.
switch (op.getMode()) {
+ case gpu::ShuffleMode::DOWN:
+ dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,
+ adaptor.getOffset());
+ break;
case gpu::ShuffleMode::XOR:
dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
adaptor.getOffset());
--
2.25.1

1 change: 1 addition & 0 deletions third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def repo(name):
"//third_party/llvm:mathextras.patch",
"//third_party/llvm:toolchains.patch",
"//third_party/llvm:zstd.patch",
"//third_party/llvm:rocdl_shuffle_down.patch",
],
link_files = {"//third_party/llvm:run_lit.sh": "mlir/run_lit.sh"},
)
35 changes: 35 additions & 0 deletions third_party/tsl/third_party/llvm/rocdl_shuffle_down.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
From a46b9e979ffa523bfed61487a2404e1f48140288 Mon Sep 17 00:00:00 2001
From: Dragan Mladjenovic <[email protected]>
Date: Fri, 29 Mar 2024 12:27:36 +0000
Subject: [PATCH] Support gpu::ShuffleMode::DOWN lowering

---
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e2cb3687d872..9317e30290c6 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -140,7 +140,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);

auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value width = adaptor.getWidth();
Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
Value add = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId, width);
@@ -151,6 +151,10 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
// TODO: Use ds_swizzle for XOR when step/offsets are constants for better
// perf.
switch (op.getMode()) {
+ case gpu::ShuffleMode::DOWN:
+ dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,
+ adaptor.getOffset());
+ break;
case gpu::ShuffleMode::XOR:
dstLane = rewriter.create<LLVM::XOrOp>(loc, int32Type, srcLaneId,
adaptor.getOffset());
--
2.25.1

1 change: 1 addition & 0 deletions third_party/tsl/third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def repo(name):
"//third_party/llvm:mathextras.patch",
"//third_party/llvm:toolchains.patch",
"//third_party/llvm:zstd.patch",
"//third_party/llvm:rocdl_shuffle_down.patch",
],
link_files = {"//third_party/llvm:run_lit.sh": "mlir/run_lit.sh"},
)
17 changes: 13 additions & 4 deletions xla/service/algorithm_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,17 @@ bool IsSupportedDotAlgorithmOnGpu(
std::get<se::CudaComputeCapability>(gpu_compute_capability)
.IsAtLeast(8, 9);

const bool is_rocm_mi100_and_above =
std::holds_alternative<se::RocmComputeCapability>(
gpu_compute_capability) &&
std::get<se::RocmComputeCapability>(gpu_compute_capability)
.gfx9_mi100_or_later();

switch (algorithm) {
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32:
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM:
// Other F8 types are actually not supported by NVIDIA GPUs.
return is_cuda_ge_ada &&
return (is_cuda_ge_ada || is_rocm_mi100_and_above) &&

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this is really correct. I guess FP8 support begins from MI300 arch? But I see that we have the same check also on upstream

(input_storage_type == F8E5M2 || input_storage_type == F8E4M3FN) &&
(output_storage_type == F8E5M2 ||
output_storage_type == F8E4M3FN || output_storage_type == F16 ||
Expand All @@ -168,14 +174,17 @@ bool IsSupportedDotAlgorithmOnGpu(
return input_storage_type == F16 &&
(output_storage_type == F16 || output_storage_type == F32);
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
return is_cuda_ge_ampere && input_storage_type == BF16 &&
return (is_cuda_ge_ampere || is_rocm_mi100_and_above) &&
input_storage_type == BF16 &&
(output_storage_type == BF16 || output_storage_type == F32);
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
return is_cuda_ge_ampere && input_storage_type == F32 &&
return (is_cuda_ge_ampere || is_rocm_mi100_and_above)
&& input_storage_type == F32 &&
output_storage_type == F32;
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
return is_cuda_ge_ampere && input_storage_type == F32 &&
return (is_cuda_ge_ampere || is_rocm_mi100_and_above) &&
input_storage_type == F32 &&
output_storage_type == F32;
case PrecisionConfig::ALG_DOT_F32_F32_F32:
return input_storage_type == F32 && output_storage_type == F32;
Expand Down
7 changes: 6 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ cc_library(
testonly = 1,
srcs = ["gpu_device_info_for_tests.cc"],
hdrs = ["gpu_device_info_for_tests.h"],
local_defines = if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]),
compatible_with = get_compatible_with_portable(),
deps = [
"//xla/stream_executor:device_description",
Expand Down Expand Up @@ -5680,6 +5681,8 @@ cc_library(
xla_test(
name = "dot_algorithm_support_test",
srcs = if_gpu_is_configured(["dot_algorithm_support_test.cc"]),
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) +
if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]),
backends = [
"gpu_v100",
"gpu_a100",
Expand All @@ -5698,7 +5701,9 @@ xla_test(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_googletest//:gtest",
],
] + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers",
]),
)

cc_library(
Expand Down
10 changes: 10 additions & 0 deletions xla/service/gpu/buffer_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,16 @@ absl::StatusOr<bool> BufferComparator::CompareEqual(
stream, current, expected, "fp8_e5m2_comparison",
buffer_comparator::fp8_e5m2_comparison());
#endif // GOOGLE_CUDA
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
case xla::F8E4M3FNUZ:
return CompareEqualParameterized<tsl::float8_e4m3fnuz, float>(
stream, current, expected, "fp8_e4m3fnuz_comparison",
buffer_comparator::fp8_e4m3fnuz_comparison());
case xla::F8E5M2FNUZ:
return CompareEqualParameterized<tsl::float8_e5m2fnuz, float>(
stream, current, expected, "fp8_e5m2fnuz_comparison",
buffer_comparator::fp8_e5m2fnuz_comparison());
#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
case xla::F16:
return CompareEqualParameterized<Eigen::half, float>(
stream, current, expected, "fp16_comparison",
Expand Down
61 changes: 61 additions & 0 deletions xla/service/gpu/buffer_comparator.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ using bfloat16 = __nv_bfloat16;
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp16.h>

#include "rocm/rocm_config.h"
#if TF_ROCM_VERSION >= 60200
#include <hip/hip_fp8.h>
#endif // TF_ROCM_VERSION >= 60200

using bfloat16 = hip_bfloat16;
#define BF16_TO_F32 float

Expand Down Expand Up @@ -97,6 +102,52 @@ __global__ void xla_fp8_e5m2_comparison(__nv_fp8_storage_t* buffer_a,
}
#endif // GOOGLE_CUDA

#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
__global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a,
__hip_fp8_storage_t* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
__hip_fp8_e4m3_fnuz elem_a_fp8, elem_b_fp8;
elem_a_fp8.__x = buffer_a[idx];
elem_b_fp8.__x = buffer_b[idx];
float elem_a = static_cast<float>(elem_a_fp8);
float elem_b = static_cast<float>(elem_b_fp8);
elem_a = Canonicalize(elem_a);
elem_b = Canonicalize(elem_b);
if (isnan(elem_a) && isnan(elem_b)) return;

float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);

if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
}

__global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a,
__hip_fp8_storage_t* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
__hip_fp8_e5m2_fnuz elem_a_fp8, elem_b_fp8;
elem_a_fp8.__x = buffer_a[idx];
elem_b_fp8.__x = buffer_b[idx];
float elem_a = static_cast<float>(elem_a_fp8);
float elem_b = static_cast<float>(elem_b_fp8);
elem_a = Canonicalize(elem_a);
elem_b = Canonicalize(elem_b);
if (isnan(elem_a) && isnan(elem_b)) return;

float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);

if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
}
#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200

__global__ void xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
Expand Down Expand Up @@ -206,6 +257,16 @@ void* fp8_e5m2_comparison() {
}
#endif

#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
void* fp8_e4m3fnuz_comparison() {
return reinterpret_cast<void*>(&xla_fp8_e4m3fnuz_comparison);
}

void* fp8_e5m2fnuz_comparison() {
return reinterpret_cast<void*>(&xla_fp8_e5m2fnuz_comparison);
}
#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200

void* fp16_comparison() {
return reinterpret_cast<void*>(&xla_fp16_comparison);
}
Expand Down
8 changes: 8 additions & 0 deletions xla/service/gpu/buffer_comparator.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ limitations under the License.
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/stream_executor.h"

#if TENSORFLOW_USE_ROCM
#include "rocm/rocm_config.h"
#endif

namespace xla::gpu {

// A device-side comparator that compares buffers.
Expand Down Expand Up @@ -76,6 +80,10 @@ namespace buffer_comparator {
// Returns a pointer to CUDA C++ device function implementing comparison.
void* fp8_e4m3fn_comparison();
void* fp8_e5m2_comparison();
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
void* fp8_e4m3fnuz_comparison();
void* fp8_e5m2fnuz_comparison();
#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
void* fp16_comparison();
void* bf16_comparison();
void* fp32_comparison();
Expand Down
Loading
Loading