Skip to content

Commit

Permalink
PR #21845: [ROCM] Add missing triton MLIR int4 -> int8 rewrite pass f…
Browse files Browse the repository at this point in the history
…or ROCM

Imported from GitHub PR #21845

```
TritonTest.DotWithInt4WeightsOnLhsFusedWithMultiplyByChannelScales
TritonTest.NonstandardLayoutInt4
TritonTest.DotWithI4WeightsOnLhsWithBitcastTo3dTensor
TritonTest.DotWithI4WeightsOnLhsWithNonStandardLayoutAndMultplyInEpilogue
TritonTest.LHSWithMinorDimEqualTo1
TritonTest.RHSWithMinorDimEqualTo1
TritonTest.LHSNonMinorContractingDim
TritonTest.LHSNonMinorContractingDimWithBatchDim0
TritonTest.LHSMinorContractingDim
TritonTest.ConvertPlusNegate
TritonTest.LHSMinorContractingDimWithBatchDim0
TritonTest.RHSTestWithNotMinorContractingDim
TritonTest.RHSTestWithMinorContractingDim
TritonTest.RHSTestWithMinorContractingDimWithBatchDim
TritonTest.RHSTestWithNotMinorContractingDimWithBatchDim0
ParametrizedTritonTest.Int4WeightsOnTheLhs
ParametrizedTritonTest.Int4WeightsOnTheLhsWithBatchDim
ParametrizedTritonTest.Int4WeightsOnTheRhs
```
Tests above are failing on ROCm side after int4 rewriting was moved from legacy matmul emitter to MLIR pass. This MLIR pass is now missing in ROCm triton pipeline and I'm adding it in the place.

@xla-rotation: would you please take a look?
Copybara import of the project:

--
75e78ad by Jian Li <[email protected]>:

[ROCM] Add missing triton MLIR int4 -> int8 rewrite pass for ROCM

Merging this change closes #21845

COPYBARA_INTEGRATE_REVIEW=#21845 from ROCm:ci_fix_rocm_triton_test 75e78ad
PiperOrigin-RevId: 720233927
  • Loading branch information
amd-jianli12 authored and Google-ML-Automation committed Jan 27, 2025
1 parent c85aa3b commit cb0060b
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "xla/backends/gpu/codegen/triton/xla_triton_passes.h"
#include "xla/service/gpu/llvm_gpu_backend/amdgpu_backend.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/hlo_module_config.h"
Expand All @@ -47,6 +48,7 @@ namespace ma = ::mlir::arith;
namespace mm = ::mlir::math;
namespace ml = ::mlir::LLVM;
namespace mt = ::mlir::triton;
namespace mt_xla = ::mlir::triton::xla;

using ::llvm::SmallVector;
using mlir::ArrayRef;
Expand All @@ -64,6 +66,10 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
const int threadsPerWarp = 32;
auto cc = se::RocmComputeCapability(std::move(arch_name));

if (is_xla_fusion) {
pm->addPass(mt_xla::CreateInt4ToPackedInt4RewritePass());
}

// Based on make_ttir() in
// @triton//:third_party/amd/backend/compiler.py
pm->addPass(mlir::createInlinerPass());
Expand Down

0 comments on commit cb0060b

Please sign in to comment.