diff --git a/third_party/triton/llvm_integration/cl718838076.patch b/third_party/triton/llvm_integration/cl718838076.patch index fad390e9f45ac5..16502e0ec5e2f7 100644 --- a/third_party/triton/llvm_integration/cl718838076.patch +++ b/third_party/triton/llvm_integration/cl718838076.patch @@ -1,29 +1,29 @@ +Utility.h defined a bunch of helper macros like `shl`, which (surprise! +surprise!) get wrongly expanded in LLVM's own headers. --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp 2024-11-07 04:49:10.000000000 -0800 +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp 2025-01-23 08:33:43.000000000 -0800 -@@ -1,10 +1,10 @@ +@@ -1,8 +1,8 @@ -#include "PatternTritonGPUOpToLLVM.h" -#include "TargetInfo.h" -#include "Utility.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" - #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/PatternMatch.h" +#include "PatternTritonGPUOpToLLVM.h" +#include "TargetInfo.h" +#include "Utility.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" - #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" --- a/third_party/nvidia/BUILD 2025-01-23 03:06:03.000000000 -0800 +++ b/third_party/nvidia/BUILD 2025-01-23 08:24:44.000000000 -0800 -@@ -84,6 +84,7 @@ +@@ -83,6 +83,7 @@ "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", -+ "//:TritonAnalysis", "//:TritonDialects", ++ "//:TritonGPUToLLVM", ], ) diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index 51c2a30f0e226e..5c74b63dced4c8 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -8,7 +8,7 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ - "//third_party/triton:llvm_integration/cl718257410.patch", + # "//third_party/triton:llvm_integration/cl718257410.patch", # upstream on next integrate "//third_party/triton:llvm_integration/cl718838076.patch", # Add new patches just above this line ] diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 562c1ca2e08305..8f1be8536deb2c 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl715271136" - TRITON_SHA256 = "f08445eac5df52173b50aebfb0a811b295287e2657f5ef73e778b3feface8d68" + TRITON_COMMIT = "cl718295721" + TRITON_SHA256 = "0006885c62fed68632a2a5b582ae17a9f591ded466fc6cd3dcd3b41d7f74322e" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/xla/backends/gpu/codegen/triton/BUILD b/xla/backends/gpu/codegen/triton/BUILD index 886ce3e0eed0dc..ae11a39dbeab0d 100644 --- a/xla/backends/gpu/codegen/triton/BUILD +++ b/xla/backends/gpu/codegen/triton/BUILD @@ -472,7 +472,6 @@ cc_library( "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", - "@triton//:TritonAnalysis", "@triton//:TritonDialects", "@triton//:TritonGPUToLLVM", "@triton//:TritonGPUTransforms", diff --git a/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc b/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc index 2b7e81d50d447d..55a0f8adfe0849 100644 --- a/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc +++ b/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc @@ -103,6 +103,8 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm, } pm->addPass(mlir::createTritonAMDGPUCanonicalizePointersPass()); pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createTritonAMDGPUConvertToBufferOpsPass(arch_name)); + pm->addPass(mlir::createCanonicalizerPass()); pm->addPass(mlir::createCSEPass()); pm->addPass(mlir::createSymbolDCEPass()); diff --git a/xla/backends/gpu/codegen/triton/xla_triton_attrs.cc b/xla/backends/gpu/codegen/triton/xla_triton_attrs.cc index e133a93ecbd052..66d658c7407d3b 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_attrs.cc +++ b/xla/backends/gpu/codegen/triton/xla_triton_attrs.cc @@ -69,7 +69,7 @@ SmallVector SparseDotMetaEncodingAttr::getThreadOrder() const { SmallVector SparseDotMetaEncodingAttr::getSizePerThread() const { return gpu::getSizePerThread(getParent()); } -std::optional SparseDotMetaEncodingAttr::toLinearLayout( +LinearLayout SparseDotMetaEncodingAttr::toLinearLayout( ArrayRef shape) const { return gpu::toLinearLayout(shape, getParent()); } diff --git a/xla/backends/gpu/codegen/triton/xla_triton_sparse_passes.cc b/xla/backends/gpu/codegen/triton/xla_triton_sparse_passes.cc index 46eddd850ddb8b..9aa1f43cc5958f 100644 --- a/xla/backends/gpu/codegen/triton/xla_triton_sparse_passes.cc +++ b/xla/backends/gpu/codegen/triton/xla_triton_sparse_passes.cc @@ -460,22 +460,24 @@ class SparseLocalLoadToLLVM // Calculate number of tile repetitions. Value tensor = op.getSrc(); - auto shape = cast(tensor.getType()).getShape(); + auto mem_desc = cast(tensor.getType()); + auto shape = mem_desc.getShape(); int rep_m = shape[0] / shape_per_cta_tile[0]; int rep_k = shape[1] / shape_per_cta_tile[1]; CHECK_GT(rep_m, 0) << shape[0] << "/" << shape_per_cta_tile[0]; CHECK_GT(rep_k, 0) << shape[1] << "/" << shape_per_cta_tile[1]; // Load sparse metadata from shared memory. - auto elem_ty = getTypeConverter()->convertType( - cast(tensor.getType()).getElementType()); + auto elem_ty = getTypeConverter()->convertType(mem_desc.getElementType()); auto s_mem_obj = LLVM::getSharedMemoryObjectFromStruct( loc, adaptor.getSrc(), elem_ty, rewriter); - Value stride_m = s_mem_obj.strides[0]; - Value stride_k = s_mem_obj.strides[1]; + const SmallVector strides = + s_mem_obj.getStrides(mem_desc, loc, rewriter); + Value stride_m = strides[0]; + Value stride_k = strides[1]; MLIRContext *ctx = tensor.getContext(); Type ptr_ty = ptr_ty(ctx, 3); - Value base = gep(ptr_ty, i16_ty, s_mem_obj.base, i32_val(0)); + Value base = gep(ptr_ty, i16_ty, s_mem_obj.getBase(), i32_val(0)); SmallVector values; for (int k = 0; k < rep_k; ++k) { @@ -740,8 +742,8 @@ LogicalResult convertSparseWGMMA(SparseDotOp op, SparseDotOp::Adaptor adaptor, int64_t swizzling = getSwizzlingFromLayout(sharedLayout, shape[ord[0]] * byteSize); Value baseDesc = createDescriptor(rewriter, loc, swizzling, shape[ord[1]]); - baseDesc = - add(baseDesc, lshr(ptrtoint(i64_ty, sharedObj.base), int_val(64, 4))); + baseDesc = add(baseDesc, + lshr(ptrtoint(i64_ty, sharedObj.getBase()), int_val(64, 4))); return std::make_tuple(shape, ord, baseDesc); };