From aac6ceab23b4e6c4061a7c2b68ec64054ead0498 Mon Sep 17 00:00:00 2001 From: Zoran Jovanovic Date: Tue, 12 Nov 2024 15:17:28 +0000 Subject: [PATCH] Fixed test issues. --- third_party/triton/temporary/amd_pr7.patch | 27 ++++++++++++++++++++++ third_party/triton/temporary/series.bzl | 4 +++- xla/debug_options_flags.cc | 2 +- xla/service/gpu/gemm_fusion.cc | 8 ++++--- xla/service/gpu/ir_emitter_triton_test.cc | 7 ++++++ 5 files changed, 43 insertions(+), 5 deletions(-) create mode 100644 third_party/triton/temporary/amd_pr7.patch diff --git a/third_party/triton/temporary/amd_pr7.patch b/third_party/triton/temporary/amd_pr7.patch new file mode 100644 index 0000000000000..0d704da4f21bb --- /dev/null +++ b/third_party/triton/temporary/amd_pr7.patch @@ -0,0 +1,27 @@ +diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +index b0976f8..bcdc5c7 100644 +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +@@ -956,6 +956,22 @@ struct FpToFpOpConversion + for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) { + inVals.push_back(operands[i][0]); + } ++ ++ bool isSrcFP16 = srcElementType.isF16(); ++ bool isSrcBF16 = srcElementType.isBF16(); ++ ++ if ((isSrcFP16 || isSrcBF16) ++ && isDstFP32) { ++ SmallVector outVals; ++ for (Value &v : inVals) { ++ if(isSrcFP16) ++ outVals.push_back(convertFp16ToFp32(loc, rewriter, v)); ++ else ++ outVals.push_back(convertBf16ToFp32(loc, rewriter, v)); ++ } ++ return outVals; ++ } ++ + if (useFP16IntermediateSrc) + for (Value &v : inVals) + v = cvtFp32ToFp16(loc, rewriter, v, \ No newline at end of file diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 53a5059fe432d..e8ac2dd33f159 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -5,4 +5,6 @@ These are created temporarily and should be moved to the first copybara workflow internal patch during the next triton integration process. """ -temporary_patch_list = [] +temporary_patch_list = [ + "//third_party/triton/temporary:amd_pr7.patch", +] diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index 3b103bf8edd02..8e79b75113f2f 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -180,7 +180,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_partitioning_algorithm( DebugOptions::PARTITIONING_ALGORITHM_NOOP); - opts.set_xla_gpu_enable_triton_gemm(true); + opts.set_xla_gpu_enable_triton_gemm(false); opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true); opts.set_xla_gpu_triton_gemm_any(false); opts.set_xla_gpu_enable_triton_softmax_fusion(false); diff --git a/xla/service/gpu/gemm_fusion.cc b/xla/service/gpu/gemm_fusion.cc index af5da900c99bf..b926a3eee2861 100644 --- a/xla/service/gpu/gemm_fusion.cc +++ b/xla/service/gpu/gemm_fusion.cc @@ -801,10 +801,12 @@ absl::StatusOr GemmFusion::Run( const absl::flat_hash_set& execution_threads) { auto cuda_compute_capability = std::get_if(&gpu_version_); - if (!cuda_compute_capability) { + auto rocm_compute_capability = + std::get_if(&gpu_version_); + if (!cuda_compute_capability && !rocm_compute_capability) { return absl::FailedPreconditionError( - "Triton support is only enabled for CUDA GPUs."); - } else if (!cuda_compute_capability->IsAtLeastAmpere()) { + "Triton support is only enabled for CUDA and ROCM GPUs."); + } else if (cuda_compute_capability && !cuda_compute_capability->IsAtLeastAmpere()) { return absl::FailedPreconditionError( absl::StrCat("Triton support is only enabled for Ampere GPUs (compute ", "capability 8.0) and up, but got compute capability ", diff --git a/xla/service/gpu/ir_emitter_triton_test.cc b/xla/service/gpu/ir_emitter_triton_test.cc index d3298763ee99a..060069c5ebcec 100644 --- a/xla/service/gpu/ir_emitter_triton_test.cc +++ b/xla/service/gpu/ir_emitter_triton_test.cc @@ -82,6 +82,7 @@ class TritonGemmTest : public TritonTest { debug_options.set_xla_gpu_enable_split_k_autotuning(false); // Always rewrite Gemms with Triton regardless of size. debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); + debug_options.set_xla_gpu_enable_triton_gemm(true); return debug_options; } @@ -3313,6 +3314,9 @@ ENTRY e { TEST_F(TritonGemmTestAny, LowerDotWithLhsWithoutNonContractingDimThroughTriton) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "Not enough memory to allocate on ROCM."; + } const std::string hlo_text = R"( HloModule t @@ -3335,6 +3339,9 @@ ENTRY e { TEST_F(TritonGemmTestAny, LowerDotWithRhsWithoutNonContractingDimThroughTriton) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "Not enough memory to allocate on ROCM."; + } const std::string hlo_text = R"( HloModule t