Skip to content

Commit

Permalink
Integrate Triton up to [515467a9](https://github.com/openai/triton/co…
Browse files Browse the repository at this point in the history
  • Loading branch information
gflegar authored and Google-ML-Automation committed Jan 24, 2025
1 parent 44a3cd4 commit 08c5584
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 18 deletions.
10 changes: 5 additions & 5 deletions third_party/triton/llvm_integration/cl718838076.patch
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
Utility.h defined a bunch of helper macros like `shl`, which (surprise!
surprise!) get wrongly expanded in LLVM's own headers.

--- a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp 2024-11-07 04:49:10.000000000 -0800
+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp 2025-01-23 08:33:43.000000000 -0800
@@ -1,10 +1,10 @@
@@ -1,8 +1,8 @@
-#include "PatternTritonGPUOpToLLVM.h"
-#include "TargetInfo.h"
-#include "Utility.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/PatternMatch.h"
+#include "PatternTritonGPUOpToLLVM.h"
+#include "TargetInfo.h"
+#include "Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

--- a/third_party/nvidia/BUILD 2025-01-23 03:06:03.000000000 -0800
+++ b/third_party/nvidia/BUILD 2025-01-23 08:24:44.000000000 -0800
@@ -84,6 +84,7 @@
@@ -83,6 +83,7 @@
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
+ "//:TritonAnalysis",
"//:TritonDialects",
+ "//:TritonGPUToLLVM",
],
)

Expand Down
2 changes: 1 addition & 1 deletion third_party/triton/llvm_integration/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ LLVM nor MLIR integrator, please do not add any patches to this list.
"""

llvm_patch_list = [
"//third_party/triton:llvm_integration/cl718257410.patch",
# "//third_party/triton:llvm_integration/cl718257410.patch", # upstream on next integrate
"//third_party/triton:llvm_integration/cl718838076.patch",
# Add new patches just above this line
]
4 changes: 2 additions & 2 deletions third_party/triton/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ load("//third_party/triton:xla_extensions/series.bzl", "extensions_files_patch_l
def repo():
"""Imports Triton."""

TRITON_COMMIT = "cl715271136"
TRITON_SHA256 = "f08445eac5df52173b50aebfb0a811b295287e2657f5ef73e778b3feface8d68"
TRITON_COMMIT = "cl718295721"
TRITON_SHA256 = "0006885c62fed68632a2a5b582ae17a9f591ded466fc6cd3dcd3b41d7f74322e"
tf_http_archive(
name = "triton",
sha256 = TRITON_SHA256,
Expand Down
1 change: 0 additions & 1 deletion xla/backends/gpu/codegen/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,6 @@ cc_library(
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@triton//:TritonAnalysis",
"@triton//:TritonDialects",
"@triton//:TritonGPUToLLVM",
"@triton//:TritonGPUTransforms",
Expand Down
2 changes: 2 additions & 0 deletions xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
}
pm->addPass(mlir::createTritonAMDGPUCanonicalizePointersPass());
pm->addPass(mlir::createCanonicalizerPass());
pm->addPass(mlir::createTritonAMDGPUConvertToBufferOpsPass(arch_name));
pm->addPass(mlir::createCanonicalizerPass());
pm->addPass(mlir::createCSEPass());
pm->addPass(mlir::createSymbolDCEPass());

Expand Down
2 changes: 1 addition & 1 deletion xla/backends/gpu/codegen/triton/xla_triton_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ SmallVector<unsigned> SparseDotMetaEncodingAttr::getThreadOrder() const {
SmallVector<unsigned> SparseDotMetaEncodingAttr::getSizePerThread() const {
return gpu::getSizePerThread(getParent());
}
std::optional<LinearLayout> SparseDotMetaEncodingAttr::toLinearLayout(
LinearLayout SparseDotMetaEncodingAttr::toLinearLayout(
ArrayRef<int64_t> shape) const {
return gpu::toLinearLayout(shape, getParent());
}
Expand Down
18 changes: 10 additions & 8 deletions xla/backends/gpu/codegen/triton/xla_triton_sparse_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,22 +460,24 @@ class SparseLocalLoadToLLVM

// Calculate number of tile repetitions.
Value tensor = op.getSrc();
auto shape = cast<triton::gpu::MemDescType>(tensor.getType()).getShape();
auto mem_desc = cast<triton::gpu::MemDescType>(tensor.getType());
auto shape = mem_desc.getShape();
int rep_m = shape[0] / shape_per_cta_tile[0];
int rep_k = shape[1] / shape_per_cta_tile[1];
CHECK_GT(rep_m, 0) << shape[0] << "/" << shape_per_cta_tile[0];
CHECK_GT(rep_k, 0) << shape[1] << "/" << shape_per_cta_tile[1];

// Load sparse metadata from shared memory.
auto elem_ty = getTypeConverter()->convertType(
cast<triton::gpu::MemDescType>(tensor.getType()).getElementType());
auto elem_ty = getTypeConverter()->convertType(mem_desc.getElementType());
auto s_mem_obj = LLVM::getSharedMemoryObjectFromStruct(
loc, adaptor.getSrc(), elem_ty, rewriter);
Value stride_m = s_mem_obj.strides[0];
Value stride_k = s_mem_obj.strides[1];
const SmallVector<Value> strides =
s_mem_obj.getStrides(mem_desc, loc, rewriter);
Value stride_m = strides[0];
Value stride_k = strides[1];
MLIRContext *ctx = tensor.getContext();
Type ptr_ty = ptr_ty(ctx, 3);
Value base = gep(ptr_ty, i16_ty, s_mem_obj.base, i32_val(0));
Value base = gep(ptr_ty, i16_ty, s_mem_obj.getBase(), i32_val(0));
SmallVector<Value> values;

for (int k = 0; k < rep_k; ++k) {
Expand Down Expand Up @@ -740,8 +742,8 @@ LogicalResult convertSparseWGMMA(SparseDotOp op, SparseDotOp::Adaptor adaptor,
int64_t swizzling =
getSwizzlingFromLayout(sharedLayout, shape[ord[0]] * byteSize);
Value baseDesc = createDescriptor(rewriter, loc, swizzling, shape[ord[1]]);
baseDesc =
add(baseDesc, lshr(ptrtoint(i64_ty, sharedObj.base), int_val(64, 4)));
baseDesc = add(baseDesc,
lshr(ptrtoint(i64_ty, sharedObj.getBase()), int_val(64, 4)));
return std::make_tuple(shape, ord, baseDesc);
};

Expand Down

0 comments on commit 08c5584

Please sign in to comment.