diff --git a/xla/service/gpu/amdgpu_compiler.cc b/xla/service/gpu/amdgpu_compiler.cc index 04483ce86c78c..1b7128e421529 100644 --- a/xla/service/gpu/amdgpu_compiler.cc +++ b/xla/service/gpu/amdgpu_compiler.cc @@ -138,11 +138,12 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( // tf2xla bridge, DepthwiseConvolutionConverter and GpuConvRewriter // introduces reshapes and transposes that can be eliminated using // AlgebraicSimplifier We run algsimp to a fixed point. - AlgebraicSimplifierOptions options = + AlgebraicSimplifierOptions algsimp_options = GetAlgebraicSimplifierOptions(hlo_module->config()); - options.set_enable_conv_operand_swap(false); - options.set_enable_unconditional_reduce_of_concat_replacement(false); - pipeline.AddPass>(options, gpu_version); + algsimp_options.set_supports_non_canonical_dots(false); + algsimp_options.set_enable_conv_operand_swap(false); + algsimp_options.set_enable_unconditional_reduce_of_concat_replacement(false); + pipeline.AddPass>(algsimp_options, gpu_version); // tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and // CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover @@ -152,7 +153,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( ReshapeMoverOptions reshape_mover_options; reshape_mover_options.reshape_of_1d_broadcast_is_cheap = true; pipeline.AddPass(reshape_mover_options); - pipeline.AddPass(options, gpu_version); + pipeline.AddPass(algsimp_options, gpu_version); }(); // The reshapes and transposes can possibly be eliminated using @@ -163,7 +164,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( [&, &pipeline = pipeline.AddPass>( "simplify_after_conv_canonicalization")] { pipeline.AddPass(); - pipeline.AddPass(options, gpu_version); + pipeline.AddPass(algsimp_options, gpu_version); }(); // GpuConvRewriter, GpuConvPaddingLegalization and diff --git a/xla/service/gpu/buffer_sharing.cc b/xla/service/gpu/buffer_sharing.cc index 624d324e739e9..9c2931c2f7df8 100644 --- a/xla/service/gpu/buffer_sharing.cc +++ b/xla/service/gpu/buffer_sharing.cc @@ -83,7 +83,7 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user, bool is_reduction_emitter = analysis.GetEmitterFusionKind() == HloFusionAnalysis::EmitterFusionKind::kReduction; const HloInstruction* reduction_hero = - is_reduction_emitter ? reduction_hero = analysis.FindHeroReduction() + is_reduction_emitter ? analysis.FindHeroReduction() : nullptr; // We need to make sure that the fusion parameter is accessed in the same diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index dc9c30630d041..2fcc580996024 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -237,7 +237,20 @@ absl::Status GpuExecutor::Init() { return status; } - return GpuDriver::GetGpuISAVersion(&version_, device_); + status = GpuDriver::GetGpuISAVersion(&version_, device_); + if (!status.ok()) { + return status; + } + // We initialize BLAS interfaces early here since otherwise it might create + // us problems during hipBlasLt initialization under graph capture. + // There is no real advantage of explicitly using 'lazy initialization' on + // ROCM platform because rocBLAS/hipBlasLt already use 'lazy initialization' + // internally + PluginRegistry* registry = PluginRegistry::Instance(); + TF_ASSIGN_OR_RETURN(auto factory, + registry->GetFactory(rocm::kROCmPlatformId)); + blas_.reset(factory(this)); + return absl::OkStatus(); } // Returns the path to the running executable. @@ -681,22 +694,6 @@ absl::Status GpuExecutor::BlockHostUntilDone(Stream* stream) { } blas::BlasSupport* GpuExecutor::AsBlas() { - absl::MutexLock lock(&mu_); - if (blas_ != nullptr) { - return blas_.get(); - } - - PluginRegistry* registry = PluginRegistry::Instance(); - absl::StatusOr status = - registry->GetFactory(rocm::kROCmPlatformId); - if (!status.ok()) { - LOG(ERROR) << "Unable to retrieve BLAS factory: " - << status.status().message(); - return nullptr; - } - - auto blas = status.value()(this); - blas_.reset(blas); return blas_.get(); } diff --git a/xla/tests/matmul_test.cc b/xla/tests/matmul_test.cc index 668fa32425391..5544048e21105 100644 --- a/xla/tests/matmul_test.cc +++ b/xla/tests/matmul_test.cc @@ -51,6 +51,29 @@ class MatmulTestWithCublas : public HloTestBase, const bool use_cublas_lt_{GetParam()}; }; +TEST_P(MatmulTestWithCublas, GemmRewriter_NonCanonicalDots) { + const char* module_str = R"( + HloModule m + a { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT r = f32[] add(p0, p1) + } + test { + p0 = f32[32,8,5,6] parameter(0) + p1 = f32[8,32,6,7] parameter(1) + d = f32[32,8,5,7] dot(p0, p1), + lhs_batch_dims={0,1}, + rhs_batch_dims={1,0}, + rhs_contracting_dims={2}, + lhs_contracting_dims={3} + c = f32[] constant(0) + ROOT r = f32[8,5,7] reduce(d,c), dimensions={0}, to_apply=a + } + )"; + EXPECT_TRUE(RunAndCompare(module_str, ErrorSpec{1e-4, 1e-4})); +} + TEST_P(MatmulTestWithCublas, GemmRewriter_RegressionTestF64) { const char* module_str = R"( HloModule GeneralMatMulActivation.7, entry_computation_layout={(f64[2,2,2]{2,1,0}, f64[2,2,2]{2,1,0})->f64[2,2,2]{2,1,0}} diff --git a/xla/tools/multihost_hlo_runner/data/sharded_computation.hlo b/xla/tools/multihost_hlo_runner/data/sharded_computation.hlo new file mode 100644 index 0000000000000..9347b90d026d9 --- /dev/null +++ b/xla/tools/multihost_hlo_runner/data/sharded_computation.hlo @@ -0,0 +1,123 @@ +HloModule pjit_ref_func + +region_0.23 { + Arg_0.24 = f32[] parameter(0) + Arg_1.25 = f32[] parameter(1) + ROOT maximum.26 = f32[] maximum(Arg_0.24, Arg_1.25) +} + +region_1.35 { + Arg_0.36 = f32[] parameter(0) + Arg_1.37 = f32[] parameter(1) + ROOT add.38 = f32[] add(Arg_0.36, Arg_1.37) +} + +integer_pow.45 { + constant.47 = f32[] constant(1) + broadcast.48 = f32[32,12,512,1]{3,2,1,0} broadcast(constant.47), dimensions={} + Arg_0.46 = f32[32,12,512,1]{3,2,1,0} parameter(0) + multiply.49 = f32[32,12,512,1]{3,2,1,0} multiply(Arg_0.46, Arg_0.46) + ROOT divide.50 = f32[32,12,512,1]{3,2,1,0} divide(broadcast.48, multiply.49) +} + +region_2.54 { + Arg_0.55 = f32[] parameter(0) + Arg_1.56 = f32[] parameter(1) + ROOT add.57 = f32[] add(Arg_0.55, Arg_1.56) +} + +region_3.72 { + Arg_0.73 = f32[] parameter(0) + Arg_1.74 = f32[] parameter(1) + ROOT add.75 = f32[] add(Arg_0.73, Arg_1.74) +} + +region_4.83 { + Arg_0.84 = f32[] parameter(0) + Arg_1.85 = f32[] parameter(1) + ROOT add.86 = f32[] add(Arg_0.84, Arg_1.85) +} + +ENTRY main.107 { + Arg_0.1 = f16[32,512,3,12,64]{4,3,2,1,0} parameter(0), sharding={devices=[2,1,1,1,1]<=[2]} + slice.14 = f16[32,512,1,12,64]{4,3,2,1,0} slice(Arg_0.1), slice={[0:32], [0:512], [2:3], [0:12], [0:64]} + reshape.17 = f16[32,512,12,64]{3,2,1,0} reshape(slice.14) + convert.20 = f32[32,512,12,64]{3,2,1,0} convert(reshape.17) + slice.12 = f16[32,512,1,12,64]{4,3,2,1,0} slice(Arg_0.1), slice={[0:32], [0:512], [0:1], [0:12], [0:64]} + reshape.15 = f16[32,512,12,64]{3,2,1,0} reshape(slice.12) + convert.18 = f32[32,512,12,64]{3,2,1,0} convert(reshape.15) + constant.6 = f32[] constant(8) + broadcast.7 = f32[32,512,12,64]{3,2,1,0} broadcast(constant.6), dimensions={} + divide.21 = f32[32,512,12,64]{3,2,1,0} divide(convert.18, broadcast.7) + slice.13 = f16[32,512,1,12,64]{4,3,2,1,0} slice(Arg_0.1), slice={[0:32], [0:512], [1:2], [0:12], [0:64]} + reshape.16 = f16[32,512,12,64]{3,2,1,0} reshape(slice.13) + convert.19 = f32[32,512,12,64]{3,2,1,0} convert(reshape.16) + dot.22 = f32[32,12,512,512]{3,2,1,0} dot(divide.21, convert.19), lhs_batch_dims={0,2}, lhs_contracting_dims={3}, rhs_batch_dims={0,2}, rhs_contracting_dims={3} + constant.11 = f32[] constant(-inf) + reduce.27 = f32[32,12,512]{2,1,0} reduce(dot.22, constant.11), dimensions={3}, to_apply=region_0.23 + constant.4 = f32[] constant(-inf) + broadcast.5 = f32[32,12,512]{2,1,0} broadcast(constant.4), dimensions={} + maximum.28 = f32[32,12,512]{2,1,0} maximum(reduce.27, broadcast.5) + reshape.29 = f32[32,12,512,1]{3,2,1,0} reshape(maximum.28) + broadcast.30 = f32[32,12,512,1]{3,2,1,0} broadcast(reshape.29), dimensions={0,1,2,3} + reshape.31 = f32[32,12,512]{2,1,0} reshape(broadcast.30) + broadcast.32 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.31), dimensions={0,1,2} + subtract.33 = f32[32,12,512,512]{3,2,1,0} subtract(dot.22, broadcast.32) + exponential.34 = f32[32,12,512,512]{3,2,1,0} exponential(subtract.33) + constant.10 = f32[] constant(0) + reduce.39 = f32[32,12,512]{2,1,0} reduce(exponential.34, constant.10), dimensions={3}, to_apply=region_1.35 + reshape.40 = f32[32,12,512,1]{3,2,1,0} reshape(reduce.39) + broadcast.41 = f32[32,12,512,1]{3,2,1,0} broadcast(reshape.40), dimensions={0,1,2,3} + reshape.42 = f32[32,12,512]{2,1,0} reshape(broadcast.41) + broadcast.43 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.42), dimensions={0,1,2} + divide.44 = f32[32,12,512,512]{3,2,1,0} divide(exponential.34, broadcast.43) + dot.52 = f32[32,12,64,512]{3,2,1,0} dot(convert.20, divide.44), lhs_batch_dims={0,2}, lhs_contracting_dims={1}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} + transpose.53 = f32[32,512,12,64]{1,3,2,0} transpose(dot.52), dimensions={0,3,1,2} + reduce.58 = f32[] reduce(transpose.53, constant.10), dimensions={0,1,2,3}, to_apply=region_2.54 + constant.9 = f32[] constant(12582912) + divide.59 = f32[] divide(reduce.58, constant.9) + convert.60 = f16[] convert(divide.59) + reshape.104 = f16[] reshape(convert.60), sharding={replicated} + constant.2 = f32[] constant(7.94728621e-08) + broadcast.3 = f32[32,12,64,512]{3,2,1,0} broadcast(constant.2), dimensions={} + dot.62 = f32[32,12,64,512]{3,2,1,0} dot(broadcast.3, divide.44), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + transpose.63 = f32[32,512,12,64]{1,3,2,0} transpose(dot.62), dimensions={0,3,1,2} + convert.64 = f16[32,512,12,64]{1,3,2,0} convert(transpose.63) + reshape.65 = f16[32,512,1,12,64]{4,3,2,1,0} reshape(convert.64) + constant.8 = f16[] constant(0) + pad.66 = f16[32,512,3,12,64]{4,3,2,1,0} pad(reshape.65, constant.8), padding=0_0x0_0x2_0x0_0x0_0 + dot.61 = f32[32,12,512,512]{3,2,1,0} dot(broadcast.3, convert.20), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,2}, rhs_contracting_dims={3} + broadcast.79 = f32[32,12,512,1]{3,2,1,0} broadcast(reshape.40), dimensions={0,1,2,3} + reshape.80 = f32[32,12,512]{2,1,0} reshape(broadcast.79) + broadcast.81 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.80), dimensions={0,1,2} + divide.82 = f32[32,12,512,512]{3,2,1,0} divide(dot.61, broadcast.81) + call.51 = f32[32,12,512,1]{3,2,1,0} call(reshape.40), to_apply=integer_pow.45 + broadcast.67 = f32[32,12,512,1]{3,2,1,0} broadcast(call.51), dimensions={0,1,2,3} + reshape.68 = f32[32,12,512]{2,1,0} reshape(broadcast.67) + broadcast.69 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.68), dimensions={0,1,2} + multiply.70 = f32[32,12,512,512]{3,2,1,0} multiply(dot.61, broadcast.69) + multiply.71 = f32[32,12,512,512]{3,2,1,0} multiply(multiply.70, exponential.34) + reduce.76 = f32[32,12,512]{2,1,0} reduce(multiply.71, constant.10), dimensions={3}, to_apply=region_3.72 + reshape.77 = f32[32,12,512,1]{3,2,1,0} reshape(reduce.76) + negate.78 = f32[32,12,512,1]{3,2,1,0} negate(reshape.77) + reduce.87 = f32[32,12,512]{2,1,0} reduce(negate.78, constant.10), dimensions={3}, to_apply=region_4.83 + broadcast.88 = f32[32,12,512,512]{3,2,1,0} broadcast(reduce.87), dimensions={0,1,2} + add.89 = f32[32,12,512,512]{3,2,1,0} add(divide.82, broadcast.88) + multiply.90 = f32[32,12,512,512]{3,2,1,0} multiply(add.89, exponential.34) + dot.91 = f32[32,12,512,64]{3,2,1,0} dot(multiply.90, divide.21), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,2}, rhs_contracting_dims={1} + transpose.92 = f32[32,512,12,64]{3,1,2,0} transpose(dot.91), dimensions={0,2,1,3} + convert.95 = f16[32,512,12,64]{3,1,2,0} convert(transpose.92) + reshape.96 = f16[32,512,1,12,64]{4,3,2,1,0} reshape(convert.95) + pad.97 = f16[32,512,3,12,64]{4,3,2,1,0} pad(reshape.96, constant.8), padding=0_0x0_0x1_1x0_0x0_0 + add.98 = f16[32,512,3,12,64]{4,3,2,1,0} add(pad.66, pad.97) + dot.93 = f32[32,12,512,64]{3,2,1,0} dot(multiply.90, convert.19), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,2}, rhs_contracting_dims={1} + transpose.94 = f32[32,512,12,64]{3,1,2,0} transpose(dot.93), dimensions={0,2,1,3} + divide.99 = f32[32,512,12,64]{3,1,2,0} divide(transpose.94, broadcast.7) + convert.100 = f16[32,512,12,64]{3,1,2,0} convert(divide.99) + reshape.101 = f16[32,512,1,12,64]{4,3,2,1,0} reshape(convert.100) + pad.102 = f16[32,512,3,12,64]{4,3,2,1,0} pad(reshape.101, constant.8), padding=0_0x0_0x0_2x0_0x0_0 + add.103 = f16[32,512,3,12,64]{4,3,2,1,0} add(add.98, pad.102) + reshape.105 = f16[32,512,3,12,64]{4,3,2,1,0} reshape(add.103), sharding={devices=[2,1,1,1,1]<=[2]} + ROOT tuple.106 = (f16[], f16[32,512,3,12,64]{4,3,2,1,0}) tuple(reshape.104, reshape.105), sharding={{replicated}, {devices=[2,1,1,1,1]<=[2]}} +} // main.107 + diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc index 4018460bc90e3..8b9be29e5d6e6 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc @@ -177,6 +177,40 @@ TEST_F(FunctionalHloRunnerTest, UseUninitializedInputs) { InputFormat::kText)); } +// ROCM Error: +// E0000 00:00:1737155629.780742 137227 pjrt_stream_executor_client.cc:3045] Execution of replica 0 failed: INTERNAL: Failed to end stream capture: hipError_t(901) +TEST_F(FunctionalHloRunnerTest, ShardedComputationUnderStreamCapture) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetPjRtClient()); + + constexpr int kRequiredDeviceCount = 2; + const int kDeviceCount = client->device_count(); + if (kDeviceCount < kRequiredDeviceCount) { + GTEST_SKIP() << "Requires " << kRequiredDeviceCount + << " devices, but found only " << kDeviceCount; + return; + } + + // NOTE: debug_options sent to FunctionalHloRunner::LoadAndRunAndDump() get + // lost during the creating of XlaComputation from HloModuleProto in + // FunctionalHloRunner::Compile + xla::DebugOptions debug_options; + FunctionalHloRunner::PreprocessingOptions preproc_options; + FunctionalHloRunner::RawCompileOptions raw_compile_options; + raw_compile_options.spmd_mode = + FunctionalHloRunner::SpmdMode::kUseSpmdPartitioning; + raw_compile_options.num_replicas = 1; + raw_compile_options.num_partitions = 2; + FunctionalHloRunner::RunningOptions running_options; + running_options.module_argument_mode = + FunctionalHloRunner::ModuleArgumentMode::kUseRandomInputs; + + TF_EXPECT_OK(FunctionalHloRunner::LoadAndRunAndDump( + *client, debug_options, preproc_options, raw_compile_options, + running_options, {GetHloPath("sharded_computation.hlo")}, + InputFormat::kText)); +} + TEST_F(FunctionalHloRunnerTest, UseUninitializedInputsWithTupledArguments) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, GetPjRtClient());