diff --git a/xla/service/gpu/fusions/concatenate_mlir_test.cc b/xla/service/gpu/fusions/concatenate_mlir_test.cc index c0637cbe12dc7..f7e9865135852 100644 --- a/xla/service/gpu/fusions/concatenate_mlir_test.cc +++ b/xla/service/gpu/fusions/concatenate_mlir_test.cc @@ -57,17 +57,17 @@ TEST_F(MlirConcatenateFusionTest, ThreadIdIndexing) { constexpr auto kIndexing = R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> - (bl_x * 128 + th_x) + (bl_x * 256 + th_x) domain: - th_x in [0, 127] + th_x in [0, 255] th_y in [0, 0] th_z in [0, 0] - bl_x in [0, 3] + bl_x in [0, 1] bl_y in [0, 0] bl_z in [0, 0] chunk_id in [0, 0] unroll_id in [0, 0] - bl_x * 128 + th_x in [0, 399] + bl_x * 256 + th_x in [0, 399] )"; auto thread_id_to_output_indexing_0 = fusion.ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); @@ -102,9 +102,9 @@ TEST_F(MlirConcatenateFusionTest, StandAloneConcatenate) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0)> - // CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0 + 200)> - // CHECK-DAG: #[[MAP_3:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0 + 600)> + // CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0, d1) -> (d1 * 256 + d0)> + // CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0, d1) -> (d1 * 256 + d0 + 200)> + // CHECK-DAG: #[[MAP_3:.*]] = affine_map<(d0, d1) -> (d1 * 256 + d0 + 600)> // CHECK-LABEL: fused_computation // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9]*]]: {{[^,]*}}, @@ -254,9 +254,9 @@ TEST_F(MlirConcatenateFusionTest, Vectorization) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: affine_map<(d0, d1) -> (d1 * 128 + d0)> - // CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0)> - // CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0 + 640002)> + // CHECK-DAG: affine_map<(d0, d1) -> (d1 * 256 + d0)> + // CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 512 + s0)> + // CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 512 + s0 + 640002)> // CHECK-LABEL: fused_computation // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index diff --git a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc index b68a95e9516bf..70c0460a83a8e 100644 --- a/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc +++ b/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc @@ -269,8 +269,10 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, OperandSubgraphWithTwoRoots) { // CHECK-SAME: , %[[ARG4:[^:]+]]: tensor<512x512xf32> // CHECK-DAG: %[[C_384:.*]] = arith.constant 384 // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 - // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x - // CHECK: %[[BLOCK_ID:.*]] = gpu.block_id x + // CHECK: %[[TID:.*]] = gpu.thread_id x + // CHECK: %[[BID:.*]] = gpu.block_id x + // CHECK: %[[BLOCK_ID:.*]] = xla_gpu.apply_indexing #map(%thread_id_x in [0, 255], %block_id_x in [0, 63]) + // CHECK: %[[THREAD_ID:.*]] = xla_gpu.apply_indexing #map1(%thread_id_x in [0, 255]) // CHECK: %[[I0:.*]] = xla_gpu.pure_call @dus_fusion_param_2_plus_one // CHECK: %[[I1:.*]] = xla_gpu.pure_call @dus_fusion_param_3_plus_one // CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]] diff --git a/xla/service/gpu/fusions/loop_mlir_test.cc b/xla/service/gpu/fusions/loop_mlir_test.cc index 08dcb4df490e5..e42b421ee2621 100644 --- a/xla/service/gpu/fusions/loop_mlir_test.cc +++ b/xla/service/gpu/fusions/loop_mlir_test.cc @@ -52,23 +52,22 @@ TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) { fusion.ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), - MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - (bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000, - ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200, - ((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id - ) - domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 1007] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 11] - unroll_id in [0, 3] - bl_x * 128 + chunk_id * 129024 + th_x in [0, 1499999] -)")); + MatchIndexingString(R"( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + (bl_x * 128 + chunk_id * 212992 + th_x) floordiv 15000, + ((bl_x * 128 + chunk_id * 212992 + th_x) floordiv 75) mod 200, + ((bl_x * 128 + chunk_id * 212992 + th_x) mod 75) * 4 + unroll_id) + domain: + th_x in [0, 127] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 1663] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 7] + unroll_id in [0, 3] + bl_x * 128 + chunk_id * 212992 + th_x in [0, 1499999] + )")); } TEST_F(MlirLoopFusionTest, ThreadId_IndexingNotUnrolled) { @@ -148,37 +147,35 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - (bl_x * 128 + th_x) floordiv 600, - ((bl_x * 128 + th_x) floordiv 30) mod 20, - (bl_x * 128 + th_x) mod 30 - ) - domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 46] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - bl_x * 128 + th_x in [0, 5999] + (bl_x * 256 + th_x) floordiv 600, + ((bl_x * 256 + th_x) floordiv 30) mod 20, + (bl_x * 256 + th_x) mod 30) + domain:th_x in [0, 255] + th_y in [0, 0] + th_z in [0, 0] + bl_x in [0, 23] + bl_y in [0, 0] + bl_z in [0, 0] + chunk_id in [0, 0] + unroll_id in [0, 0] + bl_x * 256 + th_x in [0, 5999] )")); - auto thread_id_to_input_indexing = fusion.ComputeThreadIdToInputIndexing( + auto thread_id_to_input_indexing = fusion.ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_input_indexing->ToString(thread_id_printer_), + EXPECT_THAT(thread_id_to_input_indexing->ToString(thread_id_printer_), MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> - (((bl_x * 128 + th_x) floordiv 30) mod 20) + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( + ((bl_x * 256 + th_x) floordiv 30) mod 20) domain: - th_x in [0, 127] + th_x in [0, 255] th_y in [0, 0] th_z in [0, 0] - bl_x in [0, 46] + bl_x in [0, 23] bl_y in [0, 0] bl_z in [0, 0] chunk_id in [0, 0] unroll_id in [0, 0] - bl_x * 128 + th_x in [0, 5999] + bl_x * 256 + th_x in [0, 5999] )")); } diff --git a/xla/service/gpu/fusions/mlir/BUILD b/xla/service/gpu/fusions/mlir/BUILD index 11f992b2ac798..5dd0216d2c315 100644 --- a/xla/service/gpu/fusions/mlir/BUILD +++ b/xla/service/gpu/fusions/mlir/BUILD @@ -334,6 +334,7 @@ cc_library( "@llvm-project//mlir:FuncToLLVM", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUToROCDLTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index 59471a3fb337e..648ab3fad3d3d 100644 --- a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -1177,20 +1177,13 @@ absl::StatusOr> HloToMlir( } // namespace bool IsHloOpSupported(const HloInstruction* instr, - se::CudaComputeCapability compute_capability) { + se::GpuComputeCapability compute_capability) { return !(kUnsupportedOps.contains(instr->opcode()) || IsUnsupportedGather(instr)); } bool IsHloConversionSupported(const HloComputation* computation, se::GpuComputeCapability compute_capability) { - if (!std::holds_alternative(compute_capability)) { - // ROCM is not tested. - return false; - } - auto cuda_compute_capability = - std::get(compute_capability); - return absl::c_all_of( computation->instructions(), [=](const HloInstruction* instr) { @@ -1199,7 +1192,7 @@ bool IsHloConversionSupported(const HloComputation* computation, return IsHloConversionSupported( called, compute_capability); }) && - IsHloOpSupported(instr, cuda_compute_capability); + IsHloOpSupported(instr, compute_capability); }) && (computation->IsFusionComputation() || (absl::c_all_of( @@ -1209,22 +1202,16 @@ bool IsHloConversionSupported(const HloComputation* computation, } bool IsHloConversionSupported(const HloFusionAdaptor& fusion, - se::GpuComputeCapability compute_capability) { - if (!std::holds_alternative(compute_capability)) { - // ROCM is not tested. - return false; - } - auto cuda_compute_capability = - std::get(compute_capability); - - return !HloAnyOf(fusion, [=](HloInstructionAdaptor instr) { - return !absl::c_all_of(instr.instruction().called_computations(), - [&](const HloComputation* called) { - return IsHloConversionSupported( - called, compute_capability); - }) || - !IsHloOpSupported(&instr.instruction(), cuda_compute_capability); - }); + se::GpuComputeCapability compute_capability) { + return !HloBfsFindIf( + fusion.GetRoots(), fusion, [=](HloInstructionAdaptor instr) { + return !absl::c_all_of(instr.instruction().called_computations(), + [&](const HloComputation* called) { + return IsHloConversionSupported( + called, compute_capability); + }) || + !IsHloOpSupported(&instr.instruction(), compute_capability); + }); } ValueRange ProvideParameter(const PartitionedComputation& computation, diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h index 1f52109e34c88..7f3623af4b7be 100644 --- a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h @@ -64,7 +64,7 @@ llvm::SmallVector ProvideParameterRange( // Checks whether the given HLO instruction can be converted to MLIR. bool IsHloOpSupported(const HloInstruction* instr, - se::CudaComputeCapability compute_capability); + se::GpuComputeCapability compute_capability); // Checks whether the given HLO computation is supported by the MLIR converter: // - all instructions in it are supported diff --git a/xla/service/gpu/fusions/mlir/lower_to_llvm.cc b/xla/service/gpu/fusions/mlir/lower_to_llvm.cc index 54ce7924f45ab..89962df2d019f 100644 --- a/xla/service/gpu/fusions/mlir/lower_to_llvm.cc +++ b/xla/service/gpu/fusions/mlir/lower_to_llvm.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include "llvm/Support/LogicalResult.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" diff --git a/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc index 18c3ab3c10f2f..01f31d16b4732 100644 --- a/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc +++ b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc @@ -42,7 +42,7 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -111,7 +111,7 @@ class MlirFusionEmitterTest : public HloTestBase { mlir::MLIRContext context_; stream_executor::DeviceDescription device_info_ = - TestGpuDeviceInfo::CudaOrRocmDeviceInfo(); + TestGpuDeviceInfo::TestCudaOrRocmDeviceInfo(); }; constexpr absl::string_view kModule = R"( diff --git a/xla/service/gpu/fusions/mlir_emitter_test_base.h b/xla/service/gpu/fusions/mlir_emitter_test_base.h index 351e4b4dae319..3bf372c368765 100644 --- a/xla/service/gpu/fusions/mlir_emitter_test_base.h +++ b/xla/service/gpu/fusions/mlir_emitter_test_base.h @@ -52,7 +52,7 @@ class MlirEmitterTestBaseImpl : public HloTestBase { std::string_view pattern); stream_executor::DeviceDescription device_info_ = - TestGpuDeviceInfo::CudaOrRocmDeviceInfo(); + TestGpuDeviceInfo::TestCudaOrRocmDeviceInfo(); mlir::MLIRContext mlir_context_; AffineMapPrinter thread_id_printer_; }; diff --git a/xla/service/gpu/fusions/scatter_mlir_test.cc b/xla/service/gpu/fusions/scatter_mlir_test.cc index 869d233500182..98df639096908 100644 --- a/xla/service/gpu/fusions/scatter_mlir_test.cc +++ b/xla/service/gpu/fusions/scatter_mlir_test.cc @@ -82,20 +82,19 @@ TEST_F(MlirScatterFusionTest, ThreadIdIndexing) { constexpr auto kUpdatesIndexing = R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - (bl_x * 128 + th_x) floordiv 200, - ((bl_x * 128 + th_x) floordiv 20) mod 10, - (bl_x * 128 + th_x) mod 20 - ) + (bl_x * 256 + th_x) floordiv 200, + ((bl_x * 256 + th_x) floordiv 20) mod 10, + (bl_x * 256 + th_x) mod 20) domain: - th_x in [0, 127] + th_x in [0, 255] th_y in [0, 0] th_z in [0, 0] - bl_x in [0, 65] + bl_x in [0, 32] bl_y in [0, 0] bl_z in [0, 0] chunk_id in [0, 0] unroll_id in [0, 0] - bl_x * 128 + th_x in [0, 8399] + bl_x * 256 + th_x in [0, 8399] )"; EXPECT_THAT( fusion @@ -123,19 +122,19 @@ TEST_F(MlirScatterFusionTest, ThreadIdIndexing) { MatchIndexingString(kUpdatesIndexing)); constexpr auto kIndicesIndexing = R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> - ((bl_x * 128 + th_x) floordiv 200, 0) + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> ( + (bl_x * 256 + th_x) floordiv 200, 0) domain: - th_x in [0, 127] + th_x in [0, 255] th_y in [0, 0] th_z in [0, 0] - bl_x in [0, 65] + bl_x in [0, 32] bl_y in [0, 0] bl_z in [0, 0] chunk_id in [0, 0] unroll_id in [0, 0] index_id in [0, 0] - bl_x * 128 + th_x in [0, 8399] + bl_x * 256 + th_x in [0, 8399] )"; EXPECT_THAT( fusion diff --git a/xla/service/gpu/fusions/transpose_mlir_test.cc b/xla/service/gpu/fusions/transpose_mlir_test.cc index 1861672a82279..7accd0ddaea90 100644 --- a/xla/service/gpu/fusions/transpose_mlir_test.cc +++ b/xla/service/gpu/fusions/transpose_mlir_test.cc @@ -666,7 +666,7 @@ TEST_F(MlirTransposeFusionTest, VectorizedTranspose021) { } )"; TF_EXPECT_OK(EmitAndCheckIR( - kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<1x64x65xbf16>")); + kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<1x32x33xbf16>")); } TEST_F(MlirTransposeFusionTest, VectorizedTranspose210) { @@ -683,7 +683,7 @@ TEST_F(MlirTransposeFusionTest, VectorizedTranspose210) { } )"; TF_EXPECT_OK(EmitAndCheckIR( - kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<64x1x65xbf16>")); + kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<32x1x33xbf16>")); } TEST_F(MlirTransposeFusionTest, PreferLargeVectorSize021) { @@ -700,7 +700,7 @@ TEST_F(MlirTransposeFusionTest, PreferLargeVectorSize021) { } )"; TF_EXPECT_OK(EmitAndCheckIR( - kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<1x128x129xi8>")); + kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<1x64x65xi8>")); } TEST_F(MlirTransposeFusionTest, PreferLargeVectorSize210) { @@ -717,7 +717,7 @@ TEST_F(MlirTransposeFusionTest, PreferLargeVectorSize210) { } )"; TF_EXPECT_OK(EmitAndCheckIR( - kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<128x1x129xi8>")); + kHloString, "// CHECK: xla_gpu.allocate_shared : tensor<64x1x65xi8>")); } } // namespace diff --git a/xla/service/gpu/gpu_device_info_for_tests.cc b/xla/service/gpu/gpu_device_info_for_tests.cc index 7e6a532d6cf3c..f4bedd79928f7 100644 --- a/xla/service/gpu/gpu_device_info_for_tests.cc +++ b/xla/service/gpu/gpu_device_info_for_tests.cc @@ -64,11 +64,11 @@ stream_executor::DeviceDescription TestGpuDeviceInfo::AMDMI210DeviceInfo() { return b.BuildObject(); } -stream_executor::DeviceDescription TestGpuDeviceInfo::CudaOrRocmDeviceInfo() { -#if defined(TENSORFLOW_USE_ROCM) - return AMDMI210DeviceInfo(); +stream_executor::DeviceDescription TestGpuDeviceInfo::TestCudaOrRocmDeviceInfo() { +#if !defined(TENSORFLOW_USE_ROCM) + return RTXA6000DeviceInfo() #else - return RTXA6000DeviceInfo(); + return AMDMI210DeviceInfo(); #endif // GOOGLE_CUDA } diff --git a/xla/service/gpu/gpu_device_info_for_tests.h b/xla/service/gpu/gpu_device_info_for_tests.h index 9085763fee43c..d43e88d4bbc3f 100644 --- a/xla/service/gpu/gpu_device_info_for_tests.h +++ b/xla/service/gpu/gpu_device_info_for_tests.h @@ -27,8 +27,7 @@ class TestGpuDeviceInfo { stream_executor::GpuComputeCapability cc = stream_executor::CudaComputeCapability(8, 9)); static stream_executor::DeviceDescription AMDMI210DeviceInfo(); - // Returns deafult RTXA6000 or AMDMI210 device info - static stream_executor::DeviceDescription CudaOrRocmDeviceInfo(); + static stream_executor::DeviceDescription TestCudaOrRocmDeviceInfo(); }; } // namespace gpu diff --git a/xla/service/gpu/tests/BUILD b/xla/service/gpu/tests/BUILD index abf7395bc024a..5ee93e95d3d06 100644 --- a/xla/service/gpu/tests/BUILD +++ b/xla/service/gpu/tests/BUILD @@ -224,6 +224,7 @@ xla_test( name = "gpu_hlo_runner_test", srcs = ["gpu_hlo_runner_test.cc"], backends = ["gpu"], + tags=["no_oss"], deps = [ ":gpu_codegen_test", "//xla:error_spec", diff --git a/xla/service/gpu/tests/gpu_hlo_runner_test.cc b/xla/service/gpu/tests/gpu_hlo_runner_test.cc index af28627561f95..bcc263d6f520a 100644 --- a/xla/service/gpu/tests/gpu_hlo_runner_test.cc +++ b/xla/service/gpu/tests/gpu_hlo_runner_test.cc @@ -41,7 +41,9 @@ class HloRunnerTest : public GpuCodegenTest {}; TEST_F(HloRunnerTest, RunSingle) { std::ifstream ifs("input.hlo"); - ASSERT_TRUE(ifs.good()); + if(ifs.fail()) { + GTEST_SKIP() << "No input HLO file provided!"; + } std::stringstream buffer; buffer << ifs.rdbuf(); diff --git a/xla/stream_executor/rocm/rocm_blas.cc b/xla/stream_executor/rocm/rocm_blas.cc index ae841e6740667..1c028176bd22a 100644 --- a/xla/stream_executor/rocm/rocm_blas.cc +++ b/xla/stream_executor/rocm/rocm_blas.cc @@ -374,7 +374,7 @@ absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, auto *workspace = GetWorkspace(); auto *wptr = workspace != nullptr ? workspace->opaque() : nullptr; size_t wsize = workspace != nullptr ? workspace->size() : 0; - ret = wrap::rocblas_set_workspace(blas_, wptr, wsize); + auto ret = wrap::rocblas_set_workspace(blas_, wptr, wsize); if (err_on_failure && ret != rocblas_status_success) { LOG(ERROR) << "failed to set workspace before " << FuncT::kName << ": " << ToString(ret);