Skip to content

Commit

Permalink
PR openxla#14605: [ROCm] Switch on Triton feature for ROCm.
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla#14605

Last in series of commits to switch on Triton in XLA for ROCm.

This is new version of:
openxla#13003

Changes in third_party/triton/temporary/amd_pr7.patch are already merged on:
triton-lang/triton#4238
Copybara import of the project:

--
c2ce7e0 by Zoran Jovanovic <[email protected]>:

[ROCm] Switch on Triton feature for ROCm.

--
563b303 by Zoran Jovanovic <[email protected]>:

[ROCm] Fixed an issue with test cases from ir_emitter_triton_test.cc

--
a4d2ad8 by Zoran Jovanovic <[email protected]>:

[ROCm] Fixed an issue with gpu_compiler_test.cc

--
a1b9260 by Zoran Jovanovic <[email protected]>:

[ROCm] Applied comments from code review.

--
c694a95 by Zoran Jovanovic <[email protected]>:

[ROCm] Fixed failed tests because of openxla@19c11ba

--
7359619 by Zoran Jovanovic <[email protected]>:

[ROCm] Fixed compilation issue with latest rebase.

--
82f58ce by Zoran Jovanovic <[email protected]>:

[ROCm] Skip SplitLHSInputOutputIsFused test in ir_emitter_triton_test.cc untill issue is fixed.

--
57e776b by Zoran Jovanovic <[email protected]>:

[ROCm] Triton related changes merged thus removed amd_pr7.patch

--
0d09d0e by Zoran Jovanovic <[email protected]>:

[ROCm] Applied comments from code review.

--
7b11147 by Zoran Jovanovic <[email protected]>:

[ROCm] Applied comments from code review.

--
9e7e0c7 by Zoran Jovanovic <[email protected]>:

[ROCm] Modified TestNoAutotuner test case.

Merging this change closes openxla#14605

COPYBARA_INTEGRATE_REVIEW=openxla#14605 from ROCm:rocm_triton_backend_8 9e7e0c7
PiperOrigin-RevId: 652449567
  • Loading branch information
zoranjovanovic-ns committed Nov 12, 2024
1 parent bf81e49 commit 1971ee5
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 29 deletions.
1 change: 1 addition & 0 deletions third_party/tsl/third_party/gpus/rocm_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,7 @@ def _create_local_rocm_repository(repository_ctx):
"-DTENSORFLOW_USE_ROCM=1",
"-D__HIP_PLATFORM_AMD__",
"-DEIGEN_USE_HIP",
"-DUSE_ROCM",
])

rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
Expand Down
4 changes: 3 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -585,12 +585,14 @@ cc_library(
"@triton//:TritonGPUToLLVM",
"@triton//:TritonToTritonGPU",
"@triton//:TritonGPUTransforms",
"@triton//:TritonLLVMIR",
]) + if_cuda_is_configured([
"@triton//third_party/nvidia:NVGPUToLLVM",
"@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM",
"@triton//:TritonLLVMIR",
]) + if_rocm_is_configured([
"@tsl//tsl/platform:rocm_rocdl_path",
"@triton//third_party/amd:TritonAMDGPUToLLVM",
"@triton//third_party/amd:TritonAMDGPUTransforms",
]),
)

Expand Down
7 changes: 4 additions & 3 deletions xla/service/gpu/amdgpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License.
#include "xla/service/gpu/cudnn_fused_conv_rewriter.h"
#include "xla/service/gpu/cusolver_rewriter.h"
#include "xla/service/gpu/gemm_algorithm_picker.h"
#include "xla/service/gpu/gpu_algebraic_simplifier.h"
#include "xla/service/gpu/gpu_compiler.h"
#include "xla/service/gpu/gpu_conv_padding_legalization.h"
#include "xla/service/gpu/gpu_conv_rewriter.h"
Expand Down Expand Up @@ -141,7 +142,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
GetAlgebraicSimplifierOptions(hlo_module->config());
options.set_enable_conv_operand_swap(false);
options.set_enable_unconditional_reduce_of_concat_replacement(false);
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
pipeline.AddPass<HloPassFix<GpuAlgebraicSimplifier>>(options, gpu_version);

// tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and
// CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover
Expand All @@ -151,7 +152,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
ReshapeMoverOptions reshape_mover_options;
reshape_mover_options.reshape_of_1d_broadcast_is_cheap = true;
pipeline.AddPass<ReshapeMover>(reshape_mover_options);
pipeline.AddPass<AlgebraicSimplifier>(options);
pipeline.AddPass<GpuAlgebraicSimplifier>(options, gpu_version);
}();

// The reshapes and transposes can possibly be eliminated using
Expand All @@ -162,7 +163,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
[&, &pipeline = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
"simplify_after_conv_canonicalization")] {
pipeline.AddPass<ConvertMover>();
pipeline.AddPass<AlgebraicSimplifier>(options);
pipeline.AddPass<GpuAlgebraicSimplifier>(options, gpu_version);
}();

// GpuConvRewriter, GpuConvPaddingLegalization and
Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/fusions/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
triton_config.set_block_k(64);
triton_config.set_block_n(64);
triton_config.set_split_k(1);
triton_config.set_num_stages(1);
triton_config.set_num_warps(2);
triton_config.set_num_ctas(1);

block_level_parameters.num_ctas = 1;
block_level_parameters.num_stages = 1;
Expand Down
7 changes: 5 additions & 2 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1340,13 +1340,16 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
gpu_target_config.device_description.gpu_compute_capability();
pipeline.AddPass<AlgorithmChecker>(gpu_version);
const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(&gpu_version);
const auto* rocm_cc = std::get_if<se::RocmComputeCapability>(&gpu_version);

// Rewrite FP8 GEMMs ahead of Triton which currently lacks support for FP8
// and may rewrite quantized FP8 GEMMs as higher-precision GEMMs.
pipeline.AddPass<GemmRewriter>(gpu_version, GetToolkitVersion(),
/*f8_rewrite=*/true);
if (debug_options.xla_gpu_enable_triton_gemm() && cuda_cc != nullptr &&
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) {
if (debug_options.xla_gpu_enable_triton_gemm() &&
((cuda_cc != nullptr &&
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) ||
rocm_cc != nullptr)) {
pipeline.AddPass<GemvRewriter>();
pipeline.AddPass<GemmFusion>(gpu_version);
}
Expand Down
51 changes: 31 additions & 20 deletions xla/service/gpu/ir_emitter_triton_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ limitations under the License.
==============================================================================*/
// TODO(ROCm): Enable and include ROCm Triton passes when ROCm Triton is
// included in build.
// #include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h"
#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h"
#include "third_party/amd/include/TritonAMDGPUTransforms/Passes.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project
Expand All @@ -35,6 +36,10 @@ limitations under the License.
namespace xla {
namespace gpu {

// Value 0 for num_stages is used to represent AMD specific register
// file double buffering.
constexpr int kAmdDoubleBuffering = 0;

namespace ma = ::mlir::arith;
namespace mm = ::mlir::math;
namespace ml = ::mlir::LLVM;
Expand All @@ -55,9 +60,10 @@ absl::Status CreateTritonPipeline(
const int ccAsInt = 0;
// TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64.
const int threadsPerWarp = 32;
auto ccRocm = std::get<se::RocmComputeCapability>(cc);

// Based on make_ttir() in
// @triton//:third_party/nvidia/backend/compiler.py
// @triton//:third_party/amd/backend/compiler.py
pm.addPass(mlir::createInlinerPass());
pm.addPass(mt::createRewriteTensorPointerPass());
pm.addPass(mt::createCombineOpsPass());
Expand All @@ -68,46 +74,51 @@ absl::Status CreateTritonPipeline(
pm.addPass(mlir::createSymbolDCEPass());

// Based on make_ttgir() in
// @triton//:third_party/nvidia/backend/compiler.py
// @triton//:third_party/amd/backend/compiler.py
pm.addPass(mt::createConvertTritonToTritonGPUPass(
absl::StrFormat("cuda:%u", ccAsInt), block_level_parameters.num_warps,
threadsPerWarp, block_level_parameters.num_ctas));
absl::StrCat("hip:", ccRocm.gfx_version()),
block_level_parameters.num_warps, threadsPerWarp,
block_level_parameters.num_ctas));
pm.addPass(mt::gpu::createTritonGPUCoalesce());
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality());
pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul());
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
// TODO ROCm Check if we want to compare MI100 and greater
pm.addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass());
pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true}));
pm.addPass(mlir::createCSEPass());
pm.addPass(
mt::gpu::createTritonGPUPipeline({block_level_parameters.num_stages}));
pm.addPass(mt::gpu::createTritonGPUPrefetch());

// TODO ROCm Check if we want to compare MI100 and greater
if (block_level_parameters.num_stages == kAmdDoubleBuffering &&
ccRocm.has_amd_matrix_core()) {
pm.addPass(mlir::createTritonAMDGPUStreamPipelinePass());
pm.addPass(mlir::createCanonicalizerPass());
}
pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true}));
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication());
pm.addPass(mt::gpu::createTritonGPUReorderInstructions());
if (block_level_parameters.num_stages != kAmdDoubleBuffering) {
pm.addPass(mt::gpu::createTritonGPUReorderInstructions());
}
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mlir::createCanonicalizerPass());

// Based on make_llir() in
// @triton//:third_party/nvidia/backend/compiler.py
// pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass());
// @triton//:third_party/amd/backend/compiler.py
pm.addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass(
ccRocm.gfx_version()));
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(mt::gpu::createAllocateSharedMemoryPass());
// pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass());
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(
mt::createConvertTritonAMDGPUToLLVMPass(ccRocm.gfx_version(), true));
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
// Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass.
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertControlFlowToLLVMPass());

pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mt::createConvertBuiltinFuncToLLVMPass());
// There is no clusters in ROCm for now.
out_cluster_info.clusterDimX = 1;
out_cluster_info.clusterDimY = 1;
Expand Down
16 changes: 16 additions & 0 deletions xla/service/gpu/ir_emitter_triton_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2414,6 +2414,9 @@ ENTRY e {

TEST_F(TritonGemmTestAny,
DoNotFuseConcatenationOfSplitNonContractingDimension) {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "Not using autotuner on ROCM yet.";
}
if (SkipBF16Tests()) {
GTEST_SKIP() << "BF16 not supported.";
}
Expand Down Expand Up @@ -3235,6 +3238,10 @@ TEST_F(TritonGemmLevel2Test, SplitLHSInputOutputIsFused) {
if (SkipBF16Tests()) {
GTEST_SKIP() << "BF16 not supported.";
}
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "Skipped until corresponding issue on ROCm is fixed.";
}

const std::string kHloText = R"(
ENTRY e {
p0t = (s8[5,18,20,150]) parameter(0)
Expand Down Expand Up @@ -3565,6 +3572,9 @@ ENTRY e {
}

TEST_F(CompareTest, UsingOptinSharedMemoryOnAmpereProducesSameResult) {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "No Optin Shared Memory on AMD.";
}
const se::DeviceDescription dev_info =
backend().default_stream_executor()->GetDeviceDescription();
constexpr int kBytesOfSharedMemoryTested = 64 * 1024;
Expand Down Expand Up @@ -5011,6 +5021,9 @@ CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16>
}

TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmEndToEnd) {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X6 not supported on ROCM.";
}
const char* kHloText = R"(
HloModule t
Expand Down Expand Up @@ -5347,6 +5360,9 @@ CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16>
}

TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmEndToEnd) {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X3 not supported on ROCM.";
}
const char* kHloText = R"(
HloModule t
Expand Down
6 changes: 3 additions & 3 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1588,8 +1588,8 @@ absl::Status IrEmitterUnnested::EmitTopKCustomCall(

absl::Status IrEmitterUnnested::EmitTritonCustomCall(
const HloCustomCallInstruction* instr) {
#if !GOOGLE_CUDA
return absl::UnimplementedError("Triton support requires CUDA");
#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM
return absl::UnimplementedError("Triton support requires CUDA or ROCm");
#else
auto generate = [this, &instr]() -> absl::StatusOr<KernelReuseCache::Entry> {
mlir::MLIRContext& mlir_context = *ir_emitter_context_->mlir_context();
Expand Down Expand Up @@ -1617,7 +1617,7 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall(
TF_ASSIGN_OR_RETURN(
auto result,
CompileTritonToLLVM(hlo_module->config(), hlo_module->name(),
ir_emitter_context_->cuda_compute_capability(),
ir_emitter_context_->gpu_compute_capability(),
ir_emitter_context_->gpu_device_info(),
block_level_parameters, triton_module.get(),
ir_emitter_context_->llvm_module(), mlir_context));
Expand Down
5 changes: 5 additions & 0 deletions xla/stream_executor/device_description.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ class RocmComputeCapability {

bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); }

bool has_amd_matrix_core() const {
return (gfx9_mi100_or_later() || gfx_version().find("gfx11") ||
gfx_version().find("gfx12"));
}

bool has_fp16_atomics_support() const {
// TODO(rocm): Check. This should be the same as has_fast_fp16_support().
return gfx9_mi200_or_later();
Expand Down

0 comments on commit 1971ee5

Please sign in to comment.