From ad28415ecf9cc7dd86f1b6112203a0b7f94c2bef Mon Sep 17 00:00:00 2001 From: Pavel Emeliyanenko Date: Mon, 20 Jan 2025 16:57:32 +0000 Subject: [PATCH 1/4] Avoid lazy init of blas handles and fix for non-canonical dots --- xla/service/gpu/amdgpu_compiler.cc | 13 +- xla/service/gpu/buffer_sharing.cc | 2 +- xla/stream_executor/rocm/rocm_executor.cc | 31 ++--- xla/tests/matmul_test.cc | 23 ++++ .../data/sharded_computation.hlo | 123 ++++++++++++++++++ .../functional_hlo_runner_test.cc | 34 +++++ 6 files changed, 202 insertions(+), 24 deletions(-) create mode 100644 xla/tools/multihost_hlo_runner/data/sharded_computation.hlo diff --git a/xla/service/gpu/amdgpu_compiler.cc b/xla/service/gpu/amdgpu_compiler.cc index 04483ce86c78c..1b7128e421529 100644 --- a/xla/service/gpu/amdgpu_compiler.cc +++ b/xla/service/gpu/amdgpu_compiler.cc @@ -138,11 +138,12 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( // tf2xla bridge, DepthwiseConvolutionConverter and GpuConvRewriter // introduces reshapes and transposes that can be eliminated using // AlgebraicSimplifier We run algsimp to a fixed point. - AlgebraicSimplifierOptions options = + AlgebraicSimplifierOptions algsimp_options = GetAlgebraicSimplifierOptions(hlo_module->config()); - options.set_enable_conv_operand_swap(false); - options.set_enable_unconditional_reduce_of_concat_replacement(false); - pipeline.AddPass>(options, gpu_version); + algsimp_options.set_supports_non_canonical_dots(false); + algsimp_options.set_enable_conv_operand_swap(false); + algsimp_options.set_enable_unconditional_reduce_of_concat_replacement(false); + pipeline.AddPass>(algsimp_options, gpu_version); // tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and // CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover @@ -152,7 +153,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( ReshapeMoverOptions reshape_mover_options; reshape_mover_options.reshape_of_1d_broadcast_is_cheap = true; pipeline.AddPass(reshape_mover_options); - pipeline.AddPass(options, gpu_version); + pipeline.AddPass(algsimp_options, gpu_version); }(); // The reshapes and transposes can possibly be eliminated using @@ -163,7 +164,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( [&, &pipeline = pipeline.AddPass>( "simplify_after_conv_canonicalization")] { pipeline.AddPass(); - pipeline.AddPass(options, gpu_version); + pipeline.AddPass(algsimp_options, gpu_version); }(); // GpuConvRewriter, GpuConvPaddingLegalization and diff --git a/xla/service/gpu/buffer_sharing.cc b/xla/service/gpu/buffer_sharing.cc index 624d324e739e9..9c2931c2f7df8 100644 --- a/xla/service/gpu/buffer_sharing.cc +++ b/xla/service/gpu/buffer_sharing.cc @@ -83,7 +83,7 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user, bool is_reduction_emitter = analysis.GetEmitterFusionKind() == HloFusionAnalysis::EmitterFusionKind::kReduction; const HloInstruction* reduction_hero = - is_reduction_emitter ? reduction_hero = analysis.FindHeroReduction() + is_reduction_emitter ? analysis.FindHeroReduction() : nullptr; // We need to make sure that the fusion parameter is accessed in the same diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index dc9c30630d041..2fcc580996024 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -237,7 +237,20 @@ absl::Status GpuExecutor::Init() { return status; } - return GpuDriver::GetGpuISAVersion(&version_, device_); + status = GpuDriver::GetGpuISAVersion(&version_, device_); + if (!status.ok()) { + return status; + } + // We initialize BLAS interfaces early here since otherwise it might create + // us problems during hipBlasLt initialization under graph capture. + // There is no real advantage of explicitly using 'lazy initialization' on + // ROCM platform because rocBLAS/hipBlasLt already use 'lazy initialization' + // internally + PluginRegistry* registry = PluginRegistry::Instance(); + TF_ASSIGN_OR_RETURN(auto factory, + registry->GetFactory(rocm::kROCmPlatformId)); + blas_.reset(factory(this)); + return absl::OkStatus(); } // Returns the path to the running executable. @@ -681,22 +694,6 @@ absl::Status GpuExecutor::BlockHostUntilDone(Stream* stream) { } blas::BlasSupport* GpuExecutor::AsBlas() { - absl::MutexLock lock(&mu_); - if (blas_ != nullptr) { - return blas_.get(); - } - - PluginRegistry* registry = PluginRegistry::Instance(); - absl::StatusOr status = - registry->GetFactory(rocm::kROCmPlatformId); - if (!status.ok()) { - LOG(ERROR) << "Unable to retrieve BLAS factory: " - << status.status().message(); - return nullptr; - } - - auto blas = status.value()(this); - blas_.reset(blas); return blas_.get(); } diff --git a/xla/tests/matmul_test.cc b/xla/tests/matmul_test.cc index 668fa32425391..5544048e21105 100644 --- a/xla/tests/matmul_test.cc +++ b/xla/tests/matmul_test.cc @@ -51,6 +51,29 @@ class MatmulTestWithCublas : public HloTestBase, const bool use_cublas_lt_{GetParam()}; }; +TEST_P(MatmulTestWithCublas, GemmRewriter_NonCanonicalDots) { + const char* module_str = R"( + HloModule m + a { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT r = f32[] add(p0, p1) + } + test { + p0 = f32[32,8,5,6] parameter(0) + p1 = f32[8,32,6,7] parameter(1) + d = f32[32,8,5,7] dot(p0, p1), + lhs_batch_dims={0,1}, + rhs_batch_dims={1,0}, + rhs_contracting_dims={2}, + lhs_contracting_dims={3} + c = f32[] constant(0) + ROOT r = f32[8,5,7] reduce(d,c), dimensions={0}, to_apply=a + } + )"; + EXPECT_TRUE(RunAndCompare(module_str, ErrorSpec{1e-4, 1e-4})); +} + TEST_P(MatmulTestWithCublas, GemmRewriter_RegressionTestF64) { const char* module_str = R"( HloModule GeneralMatMulActivation.7, entry_computation_layout={(f64[2,2,2]{2,1,0}, f64[2,2,2]{2,1,0})->f64[2,2,2]{2,1,0}} diff --git a/xla/tools/multihost_hlo_runner/data/sharded_computation.hlo b/xla/tools/multihost_hlo_runner/data/sharded_computation.hlo new file mode 100644 index 0000000000000..9347b90d026d9 --- /dev/null +++ b/xla/tools/multihost_hlo_runner/data/sharded_computation.hlo @@ -0,0 +1,123 @@ +HloModule pjit_ref_func + +region_0.23 { + Arg_0.24 = f32[] parameter(0) + Arg_1.25 = f32[] parameter(1) + ROOT maximum.26 = f32[] maximum(Arg_0.24, Arg_1.25) +} + +region_1.35 { + Arg_0.36 = f32[] parameter(0) + Arg_1.37 = f32[] parameter(1) + ROOT add.38 = f32[] add(Arg_0.36, Arg_1.37) +} + +integer_pow.45 { + constant.47 = f32[] constant(1) + broadcast.48 = f32[32,12,512,1]{3,2,1,0} broadcast(constant.47), dimensions={} + Arg_0.46 = f32[32,12,512,1]{3,2,1,0} parameter(0) + multiply.49 = f32[32,12,512,1]{3,2,1,0} multiply(Arg_0.46, Arg_0.46) + ROOT divide.50 = f32[32,12,512,1]{3,2,1,0} divide(broadcast.48, multiply.49) +} + +region_2.54 { + Arg_0.55 = f32[] parameter(0) + Arg_1.56 = f32[] parameter(1) + ROOT add.57 = f32[] add(Arg_0.55, Arg_1.56) +} + +region_3.72 { + Arg_0.73 = f32[] parameter(0) + Arg_1.74 = f32[] parameter(1) + ROOT add.75 = f32[] add(Arg_0.73, Arg_1.74) +} + +region_4.83 { + Arg_0.84 = f32[] parameter(0) + Arg_1.85 = f32[] parameter(1) + ROOT add.86 = f32[] add(Arg_0.84, Arg_1.85) +} + +ENTRY main.107 { + Arg_0.1 = f16[32,512,3,12,64]{4,3,2,1,0} parameter(0), sharding={devices=[2,1,1,1,1]<=[2]} + slice.14 = f16[32,512,1,12,64]{4,3,2,1,0} slice(Arg_0.1), slice={[0:32], [0:512], [2:3], [0:12], [0:64]} + reshape.17 = f16[32,512,12,64]{3,2,1,0} reshape(slice.14) + convert.20 = f32[32,512,12,64]{3,2,1,0} convert(reshape.17) + slice.12 = f16[32,512,1,12,64]{4,3,2,1,0} slice(Arg_0.1), slice={[0:32], [0:512], [0:1], [0:12], [0:64]} + reshape.15 = f16[32,512,12,64]{3,2,1,0} reshape(slice.12) + convert.18 = f32[32,512,12,64]{3,2,1,0} convert(reshape.15) + constant.6 = f32[] constant(8) + broadcast.7 = f32[32,512,12,64]{3,2,1,0} broadcast(constant.6), dimensions={} + divide.21 = f32[32,512,12,64]{3,2,1,0} divide(convert.18, broadcast.7) + slice.13 = f16[32,512,1,12,64]{4,3,2,1,0} slice(Arg_0.1), slice={[0:32], [0:512], [1:2], [0:12], [0:64]} + reshape.16 = f16[32,512,12,64]{3,2,1,0} reshape(slice.13) + convert.19 = f32[32,512,12,64]{3,2,1,0} convert(reshape.16) + dot.22 = f32[32,12,512,512]{3,2,1,0} dot(divide.21, convert.19), lhs_batch_dims={0,2}, lhs_contracting_dims={3}, rhs_batch_dims={0,2}, rhs_contracting_dims={3} + constant.11 = f32[] constant(-inf) + reduce.27 = f32[32,12,512]{2,1,0} reduce(dot.22, constant.11), dimensions={3}, to_apply=region_0.23 + constant.4 = f32[] constant(-inf) + broadcast.5 = f32[32,12,512]{2,1,0} broadcast(constant.4), dimensions={} + maximum.28 = f32[32,12,512]{2,1,0} maximum(reduce.27, broadcast.5) + reshape.29 = f32[32,12,512,1]{3,2,1,0} reshape(maximum.28) + broadcast.30 = f32[32,12,512,1]{3,2,1,0} broadcast(reshape.29), dimensions={0,1,2,3} + reshape.31 = f32[32,12,512]{2,1,0} reshape(broadcast.30) + broadcast.32 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.31), dimensions={0,1,2} + subtract.33 = f32[32,12,512,512]{3,2,1,0} subtract(dot.22, broadcast.32) + exponential.34 = f32[32,12,512,512]{3,2,1,0} exponential(subtract.33) + constant.10 = f32[] constant(0) + reduce.39 = f32[32,12,512]{2,1,0} reduce(exponential.34, constant.10), dimensions={3}, to_apply=region_1.35 + reshape.40 = f32[32,12,512,1]{3,2,1,0} reshape(reduce.39) + broadcast.41 = f32[32,12,512,1]{3,2,1,0} broadcast(reshape.40), dimensions={0,1,2,3} + reshape.42 = f32[32,12,512]{2,1,0} reshape(broadcast.41) + broadcast.43 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.42), dimensions={0,1,2} + divide.44 = f32[32,12,512,512]{3,2,1,0} divide(exponential.34, broadcast.43) + dot.52 = f32[32,12,64,512]{3,2,1,0} dot(convert.20, divide.44), lhs_batch_dims={0,2}, lhs_contracting_dims={1}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + transpose.53 = f32[32,512,12,64]{1,3,2,0} transpose(dot.52), dimensions={0,3,1,2} + reduce.58 = f32[] reduce(transpose.53, constant.10), dimensions={0,1,2,3}, to_apply=region_2.54 + constant.9 = f32[] constant(12582912) + divide.59 = f32[] divide(reduce.58, constant.9) + convert.60 = f16[] convert(divide.59) + reshape.104 = f16[] reshape(convert.60), sharding={replicated} + constant.2 = f32[] constant(7.94728621e-08) + broadcast.3 = f32[32,12,64,512]{3,2,1,0} broadcast(constant.2), dimensions={} + dot.62 = f32[32,12,64,512]{3,2,1,0} dot(broadcast.3, divide.44), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.63 = f32[32,512,12,64]{1,3,2,0} transpose(dot.62), dimensions={0,3,1,2} + convert.64 = f16[32,512,12,64]{1,3,2,0} convert(transpose.63) + reshape.65 = f16[32,512,1,12,64]{4,3,2,1,0} reshape(convert.64) + constant.8 = f16[] constant(0) + pad.66 = f16[32,512,3,12,64]{4,3,2,1,0} pad(reshape.65, constant.8), padding=0_0x0_0x2_0x0_0x0_0 + dot.61 = f32[32,12,512,512]{3,2,1,0} dot(broadcast.3, convert.20), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,2}, rhs_contracting_dims={3} + broadcast.79 = f32[32,12,512,1]{3,2,1,0} broadcast(reshape.40), dimensions={0,1,2,3} + reshape.80 = f32[32,12,512]{2,1,0} reshape(broadcast.79) + broadcast.81 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.80), dimensions={0,1,2} + divide.82 = f32[32,12,512,512]{3,2,1,0} divide(dot.61, broadcast.81) + call.51 = f32[32,12,512,1]{3,2,1,0} call(reshape.40), to_apply=integer_pow.45 + broadcast.67 = f32[32,12,512,1]{3,2,1,0} broadcast(call.51), dimensions={0,1,2,3} + reshape.68 = f32[32,12,512]{2,1,0} reshape(broadcast.67) + broadcast.69 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.68), dimensions={0,1,2} + multiply.70 = f32[32,12,512,512]{3,2,1,0} multiply(dot.61, broadcast.69) + multiply.71 = f32[32,12,512,512]{3,2,1,0} multiply(multiply.70, exponential.34) + reduce.76 = f32[32,12,512]{2,1,0} reduce(multiply.71, constant.10), dimensions={3}, to_apply=region_3.72 + reshape.77 = f32[32,12,512,1]{3,2,1,0} reshape(reduce.76) + negate.78 = f32[32,12,512,1]{3,2,1,0} negate(reshape.77) + reduce.87 = f32[32,12,512]{2,1,0} reduce(negate.78, constant.10), dimensions={3}, to_apply=region_4.83 + broadcast.88 = f32[32,12,512,512]{3,2,1,0} broadcast(reduce.87), dimensions={0,1,2} + add.89 = f32[32,12,512,512]{3,2,1,0} add(divide.82, broadcast.88) + multiply.90 = f32[32,12,512,512]{3,2,1,0} multiply(add.89, exponential.34) + dot.91 = f32[32,12,512,64]{3,2,1,0} dot(multiply.90, divide.21), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,2}, rhs_contracting_dims={1} + transpose.92 = f32[32,512,12,64]{3,1,2,0} transpose(dot.91), dimensions={0,2,1,3} + convert.95 = f16[32,512,12,64]{3,1,2,0} convert(transpose.92) + reshape.96 = f16[32,512,1,12,64]{4,3,2,1,0} reshape(convert.95) + pad.97 = f16[32,512,3,12,64]{4,3,2,1,0} pad(reshape.96, constant.8), padding=0_0x0_0x1_1x0_0x0_0 + add.98 = f16[32,512,3,12,64]{4,3,2,1,0} add(pad.66, pad.97) + dot.93 = f32[32,12,512,64]{3,2,1,0} dot(multiply.90, convert.19), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,2}, rhs_contracting_dims={1} + transpose.94 = f32[32,512,12,64]{3,1,2,0} transpose(dot.93), dimensions={0,2,1,3} + divide.99 = f32[32,512,12,64]{3,1,2,0} divide(transpose.94, broadcast.7) + convert.100 = f16[32,512,12,64]{3,1,2,0} convert(divide.99) + reshape.101 = f16[32,512,1,12,64]{4,3,2,1,0} reshape(convert.100) + pad.102 = f16[32,512,3,12,64]{4,3,2,1,0} pad(reshape.101, constant.8), padding=0_0x0_0x0_2x0_0x0_0 + add.103 = f16[32,512,3,12,64]{4,3,2,1,0} add(add.98, pad.102) + reshape.105 = f16[32,512,3,12,64]{4,3,2,1,0} reshape(add.103), sharding={devices=[2,1,1,1,1]<=[2]} + ROOT tuple.106 = (f16[], f16[32,512,3,12,64]{4,3,2,1,0}) tuple(reshape.104, reshape.105), sharding={{replicated}, {devices=[2,1,1,1,1]<=[2]}} +} // main.107 + diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc index 4018460bc90e3..8b9be29e5d6e6 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc @@ -177,6 +177,40 @@ TEST_F(FunctionalHloRunnerTest, UseUninitializedInputs) { InputFormat::kText)); } +// ROCM Error: +// E0000 00:00:1737155629.780742 137227 pjrt_stream_executor_client.cc:3045] Execution of replica 0 failed: INTERNAL: Failed to end stream capture: hipError_t(901) +TEST_F(FunctionalHloRunnerTest, ShardedComputationUnderStreamCapture) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetPjRtClient()); + + constexpr int kRequiredDeviceCount = 2; + const int kDeviceCount = client->device_count(); + if (kDeviceCount < kRequiredDeviceCount) { + GTEST_SKIP() << "Requires " << kRequiredDeviceCount + << " devices, but found only " << kDeviceCount; + return; + } + + // NOTE: debug_options sent to FunctionalHloRunner::LoadAndRunAndDump() get + // lost during the creating of XlaComputation from HloModuleProto in + // FunctionalHloRunner::Compile + xla::DebugOptions debug_options; + FunctionalHloRunner::PreprocessingOptions preproc_options; + FunctionalHloRunner::RawCompileOptions raw_compile_options; + raw_compile_options.spmd_mode = + FunctionalHloRunner::SpmdMode::kUseSpmdPartitioning; + raw_compile_options.num_replicas = 1; + raw_compile_options.num_partitions = 2; + FunctionalHloRunner::RunningOptions running_options; + running_options.module_argument_mode = + FunctionalHloRunner::ModuleArgumentMode::kUseRandomInputs; + + TF_EXPECT_OK(FunctionalHloRunner::LoadAndRunAndDump( + *client, debug_options, preproc_options, raw_compile_options, + running_options, {GetHloPath("sharded_computation.hlo")}, + InputFormat::kText)); +} + TEST_F(FunctionalHloRunnerTest, UseUninitializedInputsWithTupledArguments) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, GetPjRtClient()); From 70a94c66ca973cc70940018d3a6291d186c1e857 Mon Sep 17 00:00:00 2001 From: pemeliya <141146080+pemeliya@users.noreply.github.com> Date: Tue, 1 Oct 2024 12:32:45 -0700 Subject: [PATCH 2/4] PR #16520: [ROCM] ResetStream function for GemmAlgorithmPicker (BlasSupport interface) --- xla/service/gpu/gemm_algorithm_picker.cc | 8 +---- xla/service/gpu/gemm_algorithm_picker_test.cc | 22 +++++++++++-- xla/stream_executor/BUILD | 1 + xla/stream_executor/blas.h | 10 +++--- xla/stream_executor/rocm/BUILD | 1 + xla/stream_executor/rocm/rocm_blas.cc | 33 +++++++++++-------- xla/stream_executor/rocm/rocm_blas.h | 2 +- 7 files changed, 49 insertions(+), 28 deletions(-) diff --git a/xla/service/gpu/gemm_algorithm_picker.cc b/xla/service/gpu/gemm_algorithm_picker.cc index df9b16c3d04a1..a68a3ec1e25f5 100644 --- a/xla/service/gpu/gemm_algorithm_picker.cc +++ b/xla/service/gpu/gemm_algorithm_picker.cc @@ -98,13 +98,7 @@ class GemmAutotuner { explicit GemmAutotuner(const AutotuneConfig& autotune_config) : autotune_config_(autotune_config) {} - ~GemmAutotuner() { - if (stream_ != nullptr) { - if (auto blas = stream_->parent()->AsBlas()) blas->ResetStream(); - } - } - - const AutotuneConfig& config() { return autotune_config_; } + const AutotuneConfig& config() const { return autotune_config_; } size_t num_algorithms_left() const { return num_algorithms_left_; } diff --git a/xla/service/gpu/gemm_algorithm_picker_test.cc b/xla/service/gpu/gemm_algorithm_picker_test.cc index 92e9cf95825b1..7542e6679a49b 100644 --- a/xla/service/gpu/gemm_algorithm_picker_test.cc +++ b/xla/service/gpu/gemm_algorithm_picker_test.cc @@ -56,11 +56,11 @@ class GemmAlgorithmPickerTest : public HloTestBase, se::StreamExecutor *stream_exec() { return backend().default_stream_executor(); } - const se::DeviceDescription& gpu_device_desc() { + const se::DeviceDescription& device_desc() { return stream_exec()->GetDeviceDescription(); } const se::GpuComputeCapability& gpu_comp() { - return gpu_device_desc().gpu_compute_capability(); + return device_desc().gpu_compute_capability(); } void SetUp() override { @@ -82,6 +82,15 @@ class GemmAlgorithmPickerTest : public HloTestBase, } }; +TEST_P(GemmAlgorithmPickerTest, BlasGetVersion) { + auto* blas = backend().default_stream_executor()->AsBlas(); + ASSERT_TRUE(blas != nullptr); + std::string version; + ASSERT_TRUE(blas->GetVersion(&version).ok()); + VLOG(0) << "Blas version: " << version; + ASSERT_TRUE(!version.empty()); +} + TEST_P(GemmAlgorithmPickerTest, SkipAlgorithmsWithAccuracyCheck) { constexpr absl::string_view kHlo = R"( HloModule module @@ -117,6 +126,15 @@ TF_ASSERT_OK_AND_ASSIGN(auto module, if(num_left1 < 2) { GTEST_SKIP() << "Too few algorithms left after the first step"; } + + // Test that the function to get current stream value works fine: + auto* blas = stream_exec()->AsBlas(); + ASSERT_TRUE(blas != nullptr); + TF_ASSERT_OK_AND_ASSIGN(bool is_main_stream, blas->IsMainStreamSet()); + // ROCM only: CUDA blas API does not reset stream after each blas call. + if (std::holds_alternative(gpu_comp())) { + ASSERT_TRUE(is_main_stream); + } } // Clear cache before the second run! diff --git a/xla/stream_executor/BUILD b/xla/stream_executor/BUILD index e52d5f842ad8a..d55ed1f73e656 100644 --- a/xla/stream_executor/BUILD +++ b/xla/stream_executor/BUILD @@ -364,6 +364,7 @@ cc_library( "//xla/stream_executor/platform", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", diff --git a/xla/stream_executor/blas.h b/xla/stream_executor/blas.h index 26443fd3bbe45..2135569936230 100644 --- a/xla/stream_executor/blas.h +++ b/xla/stream_executor/blas.h @@ -29,7 +29,7 @@ limitations under the License. #include #include -#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/stream_executor/data_type.h" #include "xla/stream_executor/device_memory.h" @@ -221,9 +221,10 @@ class BlasSupport { virtual ~BlasSupport() {} virtual gpu::BlasLt *GetBlasLt() = 0; - // resets the underlying blas stream to its default value - virtual bool ResetStream() = 0; +// For tests only: sets *is_main_stream to true if the underlying Blas library + // has stream 0 set as its current stream. + virtual absl::StatusOr IsMainStreamSet() const = 0; // Performs a BLAS y <- ax+y operation. virtual bool DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha, const DeviceMemory &x, int incx, @@ -233,7 +234,6 @@ class BlasSupport { virtual bool DoBlasCopy(Stream *stream, uint64_t elem_count, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) = 0; - // Computes the product of a vector by a scalar: x <- a*x. virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, DeviceMemory *x, int incx) = 0; @@ -750,13 +750,13 @@ class BlasSupport { // Macro used to quickly declare overrides for abstract virtuals in the // BlasSupport base class. #define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \ + absl::StatusOr IsMainStreamSet() const override; \ bool DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha, \ const DeviceMemory &x, int incx, \ DeviceMemory *y, int incy) override; \ bool DoBlasCopy(Stream *stream, uint64_t elem_count, \ const DeviceMemory &x, int incx, \ DeviceMemory *y, int incy) override; \ - bool ResetStream() override; \ bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, \ DeviceMemory *x, int incx) override; \ bool DoBlasScal(Stream *stream, uint64_t elem_count, double alpha, \ diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index 4bb4aa3ce5fed..761c9b32d6c31 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -304,6 +304,7 @@ cc_library( "//xla/stream_executor/platform:dso_loader", "//xla/tsl/util:determinism_hdr_lib", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", diff --git a/xla/stream_executor/rocm/rocm_blas.cc b/xla/stream_executor/rocm/rocm_blas.cc index a796c44f484fb..2ce66c7c3b8cb 100644 --- a/xla/stream_executor/rocm/rocm_blas.cc +++ b/xla/stream_executor/rocm/rocm_blas.cc @@ -22,8 +22,10 @@ limitations under the License. #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive #include "rocm/rocm_config.h" @@ -149,18 +151,10 @@ ROCMBlas::~ROCMBlas() { } } -bool ROCMBlas::ResetStream() { - absl::MutexLock lock{&mu_}; - return SetStream(nullptr); -} - bool ROCMBlas::SetStream(Stream *stream) { CHECK(blas_ != nullptr); - gpu::ScopedActivateExecutorContext sac{parent_}; - - GpuStreamHandle handle = (stream != nullptr) ? AsGpuStreamValue(stream) : 0; - - if (auto ret = wrap::rocblas_set_stream(blas_, handle); + auto handle = (stream != nullptr) ? AsGpuStreamValue(stream) : nullptr; + if (auto ret = wrap::rocblas_set_stream(blas_, handle); ret != rocblas_status_success) { LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret); return false; @@ -168,6 +162,17 @@ bool ROCMBlas::SetStream(Stream *stream) { return true; } +absl::StatusOr ROCMBlas::IsMainStreamSet() const { + absl::MutexLock lock{&mu_}; + CHECK(blas_ != nullptr); + GpuStreamHandle handle{}; + if (auto ret = wrap::rocblas_get_stream(blas_, &handle); + ret != rocblas_status_success) { + return absl::InternalError("failed to get the current stream value"); + } + return (handle == nullptr); +} + namespace { // Helper functions transforming blas arguments into rocBLAS arguments. @@ -343,12 +348,12 @@ absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, absl::MutexLock lock{&mu_}; CHECK(blas_ != nullptr); + gpu::ScopedActivateExecutorContext sac{parent_}; if (!SetStream(stream)) { return absl::InternalError("Setting stream failed"); } - gpu::ScopedActivateExecutorContext sac{parent_}; - + rocblas_status ret; // set the atomics mode, leaving default to library bool allow_atomics = !OpDeterminismRequired(); if (!allow_atomics) { @@ -371,7 +376,9 @@ absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, } #endif - auto ret = rocblas_func(blas_, std::forward(args)...); + ret = rocblas_func(blas_, std::forward(args)...); + SetStream(nullptr); // Resetting stream after the function call + if (ret != rocblas_status_success) { auto err_str = absl::StrFormat("%s failed with: %s", FuncT::kName, ToString(ret)); diff --git a/xla/stream_executor/rocm/rocm_blas.h b/xla/stream_executor/rocm/rocm_blas.h index c9d8fc33e0c20..80d6dff953cca 100644 --- a/xla/stream_executor/rocm/rocm_blas.h +++ b/xla/stream_executor/rocm/rocm_blas.h @@ -185,7 +185,7 @@ class ROCMBlas : public blas::BlasSupport { ScratchAllocator *scratch_allocator); // mutex that guards the rocBLAS handle for this device. - absl::Mutex mu_; + mutable absl::Mutex mu_; // GpuExecutor which instantiated this ROCMBlas. // Immutable post-initialization. From c758bffa980b4e8a4a25b5f6cd6e7fd160854838 Mon Sep 17 00:00:00 2001 From: Pavel Emeliyanenko Date: Mon, 20 Jan 2025 17:28:03 +0000 Subject: [PATCH 3/4] test fix --- xla/service/gpu/gemm_algorithm_picker_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xla/service/gpu/gemm_algorithm_picker_test.cc b/xla/service/gpu/gemm_algorithm_picker_test.cc index 7542e6679a49b..cbf8ce1b5f408 100644 --- a/xla/service/gpu/gemm_algorithm_picker_test.cc +++ b/xla/service/gpu/gemm_algorithm_picker_test.cc @@ -272,7 +272,7 @@ ENTRY main { changed = false; DevicelessConfig deviceless_config{ - gpu_device_desc().model_str(), gpu_comp()}; + device_desc().model_str(), gpu_comp()}; AutotuneConfig deviceless_cfg{deviceless_config, opts}; TF_ASSERT_OK_AND_ASSIGN( changed, From 1e29fe904e8ba8935a3beca186c0441b5172a6e5 Mon Sep 17 00:00:00 2001 From: Pavel Emeliyanenko Date: Mon, 20 Jan 2025 17:57:59 +0000 Subject: [PATCH 4/4] added missing hlo file --- xla/tools/multihost_hlo_runner/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/xla/tools/multihost_hlo_runner/BUILD b/xla/tools/multihost_hlo_runner/BUILD index d3610ab406b1d..3ae75616a0f88 100644 --- a/xla/tools/multihost_hlo_runner/BUILD +++ b/xla/tools/multihost_hlo_runner/BUILD @@ -173,6 +173,7 @@ xla_test( "data/sharded_2_devices.hlo", "data/single_device.hlo", "data/single_device_tupled.hlo", + "data/sharded_computation.hlo", ], tags = ["nomac"], deps = [