Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rocm jaxlib v0.4.31 qa misc backport #85

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions xla/service/gpu/fusions/concatenate_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,17 @@ TEST_F(MlirConcatenateFusionTest, ThreadIdIndexing) {

constexpr auto kIndexing = R"(
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] ->
(bl_x * 128 + th_x)
(bl_x * 256 + th_x)
domain:
th_x in [0, 127]
th_x in [0, 255]
th_y in [0, 0]
th_z in [0, 0]
bl_x in [0, 3]
bl_x in [0, 1]
bl_y in [0, 0]
bl_z in [0, 0]
chunk_id in [0, 0]
unroll_id in [0, 0]
bl_x * 128 + th_x in [0, 399]
bl_x * 256 + th_x in [0, 399]
)";
auto thread_id_to_output_indexing_0 = fusion.ComputeThreadIdToInputIndexing(
/*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
Expand Down Expand Up @@ -102,9 +102,9 @@ TEST_F(MlirConcatenateFusionTest, StandAloneConcatenate) {
}
)";
TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
// CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0)>
// CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0 + 200)>
// CHECK-DAG: #[[MAP_3:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0 + 600)>
// CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0, d1) -> (d1 * 256 + d0)>
// CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0, d1) -> (d1 * 256 + d0 + 200)>
// CHECK-DAG: #[[MAP_3:.*]] = affine_map<(d0, d1) -> (d1 * 256 + d0 + 600)>

// CHECK-LABEL: fused_computation
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9]*]]: {{[^,]*}},
Expand Down Expand Up @@ -254,9 +254,9 @@ TEST_F(MlirConcatenateFusionTest, Vectorization) {
}
)";
TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
// CHECK-DAG: affine_map<(d0, d1) -> (d1 * 128 + d0)>
// CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0)>
// CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0 + 640002)>
// CHECK-DAG: affine_map<(d0, d1) -> (d1 * 256 + d0)>
// CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 512 + s0)>
// CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 512 + s0 + 640002)>

// CHECK-LABEL: fused_computation
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,10 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, OperandSubgraphWithTwoRoots) {
// CHECK-SAME: , %[[ARG4:[^:]+]]: tensor<512x512xf32>
// CHECK-DAG: %[[C_384:.*]] = arith.constant 384
// CHECK-DAG: %[[C_0:.*]] = arith.constant 0
// CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x
// CHECK: %[[BLOCK_ID:.*]] = gpu.block_id x
// CHECK: %[[TID:.*]] = gpu.thread_id x
// CHECK: %[[BID:.*]] = gpu.block_id x
// CHECK: %[[BLOCK_ID:.*]] = xla_gpu.apply_indexing #map(%thread_id_x in [0, 255], %block_id_x in [0, 63])
// CHECK: %[[THREAD_ID:.*]] = xla_gpu.apply_indexing #map1(%thread_id_x in [0, 255])
// CHECK: %[[I0:.*]] = xla_gpu.pure_call @dus_fusion_param_2_plus_one
// CHECK: %[[I1:.*]] = xla_gpu.pure_call @dus_fusion_param_3_plus_one
// CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]]
Expand Down
73 changes: 35 additions & 38 deletions xla/service/gpu/fusions/loop_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,22 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) {
fusion.ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_);

EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_),
MatchIndexingString(R"(
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
(bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000,
((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200,
((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id
)
domain:
th_x in [0, 127]
th_y in [0, 0]
th_z in [0, 0]
bl_x in [0, 1007]
bl_y in [0, 0]
bl_z in [0, 0]
chunk_id in [0, 11]
unroll_id in [0, 3]
bl_x * 128 + chunk_id * 129024 + th_x in [0, 1499999]
)"));
MatchIndexingString(R"(
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
(bl_x * 128 + chunk_id * 212992 + th_x) floordiv 15000,
((bl_x * 128 + chunk_id * 212992 + th_x) floordiv 75) mod 200,
((bl_x * 128 + chunk_id * 212992 + th_x) mod 75) * 4 + unroll_id)
domain:
th_x in [0, 127]
th_y in [0, 0]
th_z in [0, 0]
bl_x in [0, 1663]
bl_y in [0, 0]
bl_z in [0, 0]
chunk_id in [0, 7]
unroll_id in [0, 3]
bl_x * 128 + chunk_id * 212992 + th_x in [0, 1499999]
)"));
}

TEST_F(MlirLoopFusionTest, ThreadId_IndexingNotUnrolled) {
Expand Down Expand Up @@ -148,37 +147,35 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) {
EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_),
MatchIndexingString(R"(
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
(bl_x * 128 + th_x) floordiv 600,
((bl_x * 128 + th_x) floordiv 30) mod 20,
(bl_x * 128 + th_x) mod 30
)
domain:
th_x in [0, 127]
th_y in [0, 0]
th_z in [0, 0]
bl_x in [0, 46]
bl_y in [0, 0]
bl_z in [0, 0]
chunk_id in [0, 0]
unroll_id in [0, 0]
bl_x * 128 + th_x in [0, 5999]
(bl_x * 256 + th_x) floordiv 600,
((bl_x * 256 + th_x) floordiv 30) mod 20,
(bl_x * 256 + th_x) mod 30)
domain:th_x in [0, 255]
th_y in [0, 0]
th_z in [0, 0]
bl_x in [0, 23]
bl_y in [0, 0]
bl_z in [0, 0]
chunk_id in [0, 0]
unroll_id in [0, 0]
bl_x * 256 + th_x in [0, 5999]
)"));
auto thread_id_to_input_indexing = fusion.ComputeThreadIdToInputIndexing(
auto thread_id_to_input_indexing = fusion.ComputeThreadIdToInputIndexing(
/*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_);
EXPECT_THAT(thread_id_to_input_indexing->ToString(thread_id_printer_),
EXPECT_THAT(thread_id_to_input_indexing->ToString(thread_id_printer_),
MatchIndexingString(R"(
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] ->
(((bl_x * 128 + th_x) floordiv 30) mod 20)
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
((bl_x * 256 + th_x) floordiv 30) mod 20)
domain:
th_x in [0, 127]
th_x in [0, 255]
th_y in [0, 0]
th_z in [0, 0]
bl_x in [0, 46]
bl_x in [0, 23]
bl_y in [0, 0]
bl_z in [0, 0]
chunk_id in [0, 0]
unroll_id in [0, 0]
bl_x * 128 + th_x in [0, 5999]
bl_x * 256 + th_x in [0, 5999]
)"));
}

Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/fusions/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ cc_library(
"@llvm-project//mlir:FuncToLLVM",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUToNVVMTransforms",
"@llvm-project//mlir:GPUToROCDLTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:LLVMDialect",
Expand Down
37 changes: 12 additions & 25 deletions xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1177,20 +1177,13 @@ absl::StatusOr<SmallVector<Value, 1>> HloToMlir(
} // namespace

bool IsHloOpSupported(const HloInstruction* instr,
se::CudaComputeCapability compute_capability) {
se::GpuComputeCapability compute_capability) {
return !(kUnsupportedOps.contains(instr->opcode()) ||
IsUnsupportedGather(instr));
}

bool IsHloConversionSupported(const HloComputation* computation,
se::GpuComputeCapability compute_capability) {
if (!std::holds_alternative<se::CudaComputeCapability>(compute_capability)) {
// ROCM is not tested.
return false;
}
auto cuda_compute_capability =
std::get<se::CudaComputeCapability>(compute_capability);

return absl::c_all_of(
computation->instructions(),
[=](const HloInstruction* instr) {
Expand All @@ -1199,7 +1192,7 @@ bool IsHloConversionSupported(const HloComputation* computation,
return IsHloConversionSupported(
called, compute_capability);
}) &&
IsHloOpSupported(instr, cuda_compute_capability);
IsHloOpSupported(instr, compute_capability);
}) &&
(computation->IsFusionComputation() ||
(absl::c_all_of(
Expand All @@ -1209,22 +1202,16 @@ bool IsHloConversionSupported(const HloComputation* computation,
}

bool IsHloConversionSupported(const HloFusionAdaptor& fusion,
se::GpuComputeCapability compute_capability) {
if (!std::holds_alternative<se::CudaComputeCapability>(compute_capability)) {
// ROCM is not tested.
return false;
}
auto cuda_compute_capability =
std::get<se::CudaComputeCapability>(compute_capability);

return !HloAnyOf(fusion, [=](HloInstructionAdaptor instr) {
return !absl::c_all_of(instr.instruction().called_computations(),
[&](const HloComputation* called) {
return IsHloConversionSupported(
called, compute_capability);
}) ||
!IsHloOpSupported(&instr.instruction(), cuda_compute_capability);
});
se::GpuComputeCapability compute_capability) {
return !HloBfsFindIf(
fusion.GetRoots(), fusion, [=](HloInstructionAdaptor instr) {
return !absl::c_all_of(instr.instruction().called_computations(),
[&](const HloComputation* called) {
return IsHloConversionSupported(
called, compute_capability);
}) ||
!IsHloOpSupported(&instr.instruction(), compute_capability);
});
}

ValueRange ProvideParameter(const PartitionedComputation& computation,
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ llvm::SmallVector<mlir::Value, 2> ProvideParameterRange(

// Checks whether the given HLO instruction can be converted to MLIR.
bool IsHloOpSupported(const HloInstruction* instr,
se::CudaComputeCapability compute_capability);
se::GpuComputeCapability compute_capability);

// Checks whether the given HLO computation is supported by the MLIR converter:
// - all instructions in it are supported
Expand Down
1 change: 0 additions & 1 deletion xla/service/gpu/fusions/mlir/lower_to_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#include <memory>
#include <utility>

#include "llvm/Support/LogicalResult.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ limitations under the License.
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
Expand Down Expand Up @@ -111,7 +111,7 @@ class MlirFusionEmitterTest : public HloTestBase {

mlir::MLIRContext context_;
stream_executor::DeviceDescription device_info_ =
TestGpuDeviceInfo::CudaOrRocmDeviceInfo();
TestGpuDeviceInfo::TestCudaOrRocmDeviceInfo();
};

constexpr absl::string_view kModule = R"(
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/fusions/mlir_emitter_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class MlirEmitterTestBaseImpl : public HloTestBase {
std::string_view pattern);

stream_executor::DeviceDescription device_info_ =
TestGpuDeviceInfo::CudaOrRocmDeviceInfo();
TestGpuDeviceInfo::TestCudaOrRocmDeviceInfo();
mlir::MLIRContext mlir_context_;
AffineMapPrinter thread_id_printer_;
};
Expand Down
23 changes: 11 additions & 12 deletions xla/service/gpu/fusions/scatter_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,19 @@ TEST_F(MlirScatterFusionTest, ThreadIdIndexing) {

constexpr auto kUpdatesIndexing = R"(
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (
(bl_x * 128 + th_x) floordiv 200,
((bl_x * 128 + th_x) floordiv 20) mod 10,
(bl_x * 128 + th_x) mod 20
)
(bl_x * 256 + th_x) floordiv 200,
((bl_x * 256 + th_x) floordiv 20) mod 10,
(bl_x * 256 + th_x) mod 20)
domain:
th_x in [0, 127]
th_x in [0, 255]
th_y in [0, 0]
th_z in [0, 0]
bl_x in [0, 65]
bl_x in [0, 32]
bl_y in [0, 0]
bl_z in [0, 0]
chunk_id in [0, 0]
unroll_id in [0, 0]
bl_x * 128 + th_x in [0, 8399]
bl_x * 256 + th_x in [0, 8399]
)";
EXPECT_THAT(
fusion
Expand Down Expand Up @@ -123,19 +122,19 @@ TEST_F(MlirScatterFusionTest, ThreadIdIndexing) {
MatchIndexingString(kUpdatesIndexing));

constexpr auto kIndicesIndexing = R"(
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] ->
((bl_x * 128 + th_x) floordiv 200, 0)
(th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> (
(bl_x * 256 + th_x) floordiv 200, 0)
domain:
th_x in [0, 127]
th_x in [0, 255]
th_y in [0, 0]
th_z in [0, 0]
bl_x in [0, 65]
bl_x in [0, 32]
bl_y in [0, 0]
bl_z in [0, 0]
chunk_id in [0, 0]
unroll_id in [0, 0]
index_id in [0, 0]
bl_x * 128 + th_x in [0, 8399]
bl_x * 256 + th_x in [0, 8399]
)";
EXPECT_THAT(
fusion
Expand Down
8 changes: 4 additions & 4 deletions xla/service/gpu/fusions/transpose_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ TEST_F(MlirTransposeFusionTest, VectorizedTranspose021) {
}
)";
TF_EXPECT_OK(EmitAndCheckIR(
kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<1x64x65xbf16>"));
kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<1x32x33xbf16>"));
}

TEST_F(MlirTransposeFusionTest, VectorizedTranspose210) {
Expand All @@ -683,7 +683,7 @@ TEST_F(MlirTransposeFusionTest, VectorizedTranspose210) {
}
)";
TF_EXPECT_OK(EmitAndCheckIR(
kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<64x1x65xbf16>"));
kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<32x1x33xbf16>"));
}

TEST_F(MlirTransposeFusionTest, PreferLargeVectorSize021) {
Expand All @@ -700,7 +700,7 @@ TEST_F(MlirTransposeFusionTest, PreferLargeVectorSize021) {
}
)";
TF_EXPECT_OK(EmitAndCheckIR(
kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<1x128x129xi8>"));
kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<1x64x65xi8>"));
}

TEST_F(MlirTransposeFusionTest, PreferLargeVectorSize210) {
Expand All @@ -717,7 +717,7 @@ TEST_F(MlirTransposeFusionTest, PreferLargeVectorSize210) {
}
)";
TF_EXPECT_OK(EmitAndCheckIR(
kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<128x1x129xi8>"));
kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<64x1x65xi8>"));
}

} // namespace
Expand Down
8 changes: 4 additions & 4 deletions xla/service/gpu/gpu_device_info_for_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ stream_executor::DeviceDescription TestGpuDeviceInfo::AMDMI210DeviceInfo() {
return b.BuildObject();
}

stream_executor::DeviceDescription TestGpuDeviceInfo::CudaOrRocmDeviceInfo() {
#if defined(TENSORFLOW_USE_ROCM)
return AMDMI210DeviceInfo();
stream_executor::DeviceDescription TestGpuDeviceInfo::TestCudaOrRocmDeviceInfo() {
#if !defined(TENSORFLOW_USE_ROCM)
return RTXA6000DeviceInfo()
#else
return RTXA6000DeviceInfo();
return AMDMI210DeviceInfo();
#endif // GOOGLE_CUDA
}

Expand Down
3 changes: 1 addition & 2 deletions xla/service/gpu/gpu_device_info_for_tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ class TestGpuDeviceInfo {
stream_executor::GpuComputeCapability cc =
stream_executor::CudaComputeCapability(8, 9));
static stream_executor::DeviceDescription AMDMI210DeviceInfo();
// Returns deafult RTXA6000 or AMDMI210 device info
static stream_executor::DeviceDescription CudaOrRocmDeviceInfo();
static stream_executor::DeviceDescription TestCudaOrRocmDeviceInfo();
};

} // namespace gpu
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ xla_test(
name = "gpu_hlo_runner_test",
srcs = ["gpu_hlo_runner_test.cc"],
backends = ["gpu"],
tags=["no_oss"],
deps = [
":gpu_codegen_test",
"//xla:error_spec",
Expand Down
Loading
Loading