Skip to content

Commit

Permalink
Merge pull request #95 from ROCm/rocm-jaxlib-v0.4.30-qa-blas-and-dot-…
Browse files Browse the repository at this point in the history
…fixes

Avoid lazy init of blas handles and fix for non-canonical dots
  • Loading branch information
i-chaochen authored Jan 20, 2025
2 parents a4a8653 + ca02536 commit a3c22a0
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 24 deletions.
13 changes: 7 additions & 6 deletions xla/service/gpu/amdgpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,12 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
// tf2xla bridge, DepthwiseConvolutionConverter and GpuConvRewriter
// introduces reshapes and transposes that can be eliminated using
// AlgebraicSimplifier We run algsimp to a fixed point.
AlgebraicSimplifierOptions options =
AlgebraicSimplifierOptions algsimp_options =
GetAlgebraicSimplifierOptions(hlo_module->config());
options.set_enable_conv_operand_swap(false);
options.set_enable_unconditional_reduce_of_concat_replacement(false);
pipeline.AddPass<HloPassFix<GpuAlgebraicSimplifier>>(options, gpu_version);
algsimp_options.set_supports_non_canonical_dots(false);
algsimp_options.set_enable_conv_operand_swap(false);
algsimp_options.set_enable_unconditional_reduce_of_concat_replacement(false);
pipeline.AddPass<HloPassFix<GpuAlgebraicSimplifier>>(algsimp_options, gpu_version);

// tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and
// CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover
Expand All @@ -152,7 +153,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
ReshapeMoverOptions reshape_mover_options;
reshape_mover_options.reshape_of_1d_broadcast_is_cheap = true;
pipeline.AddPass<ReshapeMover>(reshape_mover_options);
pipeline.AddPass<GpuAlgebraicSimplifier>(options, gpu_version);
pipeline.AddPass<GpuAlgebraicSimplifier>(algsimp_options, gpu_version);
}();

// The reshapes and transposes can possibly be eliminated using
Expand All @@ -163,7 +164,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
[&, &pipeline = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
"simplify_after_conv_canonicalization")] {
pipeline.AddPass<ConvertMover>();
pipeline.AddPass<GpuAlgebraicSimplifier>(options, gpu_version);
pipeline.AddPass<GpuAlgebraicSimplifier>(algsimp_options, gpu_version);
}();

// GpuConvRewriter, GpuConvPaddingLegalization and
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/buffer_sharing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
bool is_reduction_emitter = analysis.GetEmitterFusionKind() ==
HloFusionAnalysis::EmitterFusionKind::kReduction;
const HloInstruction* reduction_hero =
is_reduction_emitter ? reduction_hero = analysis.FindHeroReduction()
is_reduction_emitter ? analysis.FindHeroReduction()
: nullptr;

// We need to make sure that the fusion parameter is accessed in the same
Expand Down
31 changes: 14 additions & 17 deletions xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,20 @@ absl::Status GpuExecutor::Init() {
return status;
}

return GpuDriver::GetGpuISAVersion(&version_, device_);
status = GpuDriver::GetGpuISAVersion(&version_, device_);
if (!status.ok()) {
return status;
}
// We initialize BLAS interfaces early here since otherwise it might create
// us problems during hipBlasLt initialization under graph capture.
// There is no real advantage of explicitly using 'lazy initialization' on
// ROCM platform because rocBLAS/hipBlasLt already use 'lazy initialization'
// internally
PluginRegistry* registry = PluginRegistry::Instance();
TF_ASSIGN_OR_RETURN(auto factory,
registry->GetFactory<PluginRegistry::BlasFactory>(rocm::kROCmPlatformId));
blas_.reset(factory(this));
return absl::OkStatus();
}

// Returns the path to the running executable.
Expand Down Expand Up @@ -681,22 +694,6 @@ absl::Status GpuExecutor::BlockHostUntilDone(Stream* stream) {
}

blas::BlasSupport* GpuExecutor::AsBlas() {
absl::MutexLock lock(&mu_);
if (blas_ != nullptr) {
return blas_.get();
}

PluginRegistry* registry = PluginRegistry::Instance();
absl::StatusOr<PluginRegistry::BlasFactory> status =
registry->GetFactory<PluginRegistry::BlasFactory>(rocm::kROCmPlatformId);
if (!status.ok()) {
LOG(ERROR) << "Unable to retrieve BLAS factory: "
<< status.status().message();
return nullptr;
}

auto blas = status.value()(this);
blas_.reset(blas);
return blas_.get();
}

Expand Down
23 changes: 23 additions & 0 deletions xla/tests/matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,29 @@ class MatmulTestWithCublas : public HloTestBase,
const bool use_cublas_lt_{GetParam()};
};

TEST_P(MatmulTestWithCublas, GemmRewriter_NonCanonicalDots) {
const char* module_str = R"(
HloModule m
a {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
ROOT r = f32[] add(p0, p1)
}
test {
p0 = f32[32,8,5,6] parameter(0)
p1 = f32[8,32,6,7] parameter(1)
d = f32[32,8,5,7] dot(p0, p1),
lhs_batch_dims={0,1},
rhs_batch_dims={1,0},
rhs_contracting_dims={2},
lhs_contracting_dims={3}
c = f32[] constant(0)
ROOT r = f32[8,5,7] reduce(d,c), dimensions={0}, to_apply=a
}
)";
EXPECT_TRUE(RunAndCompare(module_str, ErrorSpec{1e-4, 1e-4}));
}

TEST_P(MatmulTestWithCublas, GemmRewriter_RegressionTestF64) {
const char* module_str = R"(
HloModule GeneralMatMulActivation.7, entry_computation_layout={(f64[2,2,2]{2,1,0}, f64[2,2,2]{2,1,0})->f64[2,2,2]{2,1,0}}
Expand Down
123 changes: 123 additions & 0 deletions xla/tools/multihost_hlo_runner/data/sharded_computation.hlo
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
HloModule pjit_ref_func

region_0.23 {
Arg_0.24 = f32[] parameter(0)
Arg_1.25 = f32[] parameter(1)
ROOT maximum.26 = f32[] maximum(Arg_0.24, Arg_1.25)
}

region_1.35 {
Arg_0.36 = f32[] parameter(0)
Arg_1.37 = f32[] parameter(1)
ROOT add.38 = f32[] add(Arg_0.36, Arg_1.37)
}

integer_pow.45 {
constant.47 = f32[] constant(1)
broadcast.48 = f32[32,12,512,1]{3,2,1,0} broadcast(constant.47), dimensions={}
Arg_0.46 = f32[32,12,512,1]{3,2,1,0} parameter(0)
multiply.49 = f32[32,12,512,1]{3,2,1,0} multiply(Arg_0.46, Arg_0.46)
ROOT divide.50 = f32[32,12,512,1]{3,2,1,0} divide(broadcast.48, multiply.49)
}

region_2.54 {
Arg_0.55 = f32[] parameter(0)
Arg_1.56 = f32[] parameter(1)
ROOT add.57 = f32[] add(Arg_0.55, Arg_1.56)
}

region_3.72 {
Arg_0.73 = f32[] parameter(0)
Arg_1.74 = f32[] parameter(1)
ROOT add.75 = f32[] add(Arg_0.73, Arg_1.74)
}

region_4.83 {
Arg_0.84 = f32[] parameter(0)
Arg_1.85 = f32[] parameter(1)
ROOT add.86 = f32[] add(Arg_0.84, Arg_1.85)
}

ENTRY main.107 {
Arg_0.1 = f16[32,512,3,12,64]{4,3,2,1,0} parameter(0), sharding={devices=[2,1,1,1,1]<=[2]}
slice.14 = f16[32,512,1,12,64]{4,3,2,1,0} slice(Arg_0.1), slice={[0:32], [0:512], [2:3], [0:12], [0:64]}
reshape.17 = f16[32,512,12,64]{3,2,1,0} reshape(slice.14)
convert.20 = f32[32,512,12,64]{3,2,1,0} convert(reshape.17)
slice.12 = f16[32,512,1,12,64]{4,3,2,1,0} slice(Arg_0.1), slice={[0:32], [0:512], [0:1], [0:12], [0:64]}
reshape.15 = f16[32,512,12,64]{3,2,1,0} reshape(slice.12)
convert.18 = f32[32,512,12,64]{3,2,1,0} convert(reshape.15)
constant.6 = f32[] constant(8)
broadcast.7 = f32[32,512,12,64]{3,2,1,0} broadcast(constant.6), dimensions={}
divide.21 = f32[32,512,12,64]{3,2,1,0} divide(convert.18, broadcast.7)
slice.13 = f16[32,512,1,12,64]{4,3,2,1,0} slice(Arg_0.1), slice={[0:32], [0:512], [1:2], [0:12], [0:64]}
reshape.16 = f16[32,512,12,64]{3,2,1,0} reshape(slice.13)
convert.19 = f32[32,512,12,64]{3,2,1,0} convert(reshape.16)
dot.22 = f32[32,12,512,512]{3,2,1,0} dot(divide.21, convert.19), lhs_batch_dims={0,2}, lhs_contracting_dims={3}, rhs_batch_dims={0,2}, rhs_contracting_dims={3}
constant.11 = f32[] constant(-inf)
reduce.27 = f32[32,12,512]{2,1,0} reduce(dot.22, constant.11), dimensions={3}, to_apply=region_0.23
constant.4 = f32[] constant(-inf)
broadcast.5 = f32[32,12,512]{2,1,0} broadcast(constant.4), dimensions={}
maximum.28 = f32[32,12,512]{2,1,0} maximum(reduce.27, broadcast.5)
reshape.29 = f32[32,12,512,1]{3,2,1,0} reshape(maximum.28)
broadcast.30 = f32[32,12,512,1]{3,2,1,0} broadcast(reshape.29), dimensions={0,1,2,3}
reshape.31 = f32[32,12,512]{2,1,0} reshape(broadcast.30)
broadcast.32 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.31), dimensions={0,1,2}
subtract.33 = f32[32,12,512,512]{3,2,1,0} subtract(dot.22, broadcast.32)
exponential.34 = f32[32,12,512,512]{3,2,1,0} exponential(subtract.33)
constant.10 = f32[] constant(0)
reduce.39 = f32[32,12,512]{2,1,0} reduce(exponential.34, constant.10), dimensions={3}, to_apply=region_1.35
reshape.40 = f32[32,12,512,1]{3,2,1,0} reshape(reduce.39)
broadcast.41 = f32[32,12,512,1]{3,2,1,0} broadcast(reshape.40), dimensions={0,1,2,3}
reshape.42 = f32[32,12,512]{2,1,0} reshape(broadcast.41)
broadcast.43 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.42), dimensions={0,1,2}
divide.44 = f32[32,12,512,512]{3,2,1,0} divide(exponential.34, broadcast.43)
dot.52 = f32[32,12,64,512]{3,2,1,0} dot(convert.20, divide.44), lhs_batch_dims={0,2}, lhs_contracting_dims={1}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
transpose.53 = f32[32,512,12,64]{1,3,2,0} transpose(dot.52), dimensions={0,3,1,2}
reduce.58 = f32[] reduce(transpose.53, constant.10), dimensions={0,1,2,3}, to_apply=region_2.54
constant.9 = f32[] constant(12582912)
divide.59 = f32[] divide(reduce.58, constant.9)
convert.60 = f16[] convert(divide.59)
reshape.104 = f16[] reshape(convert.60), sharding={replicated}
constant.2 = f32[] constant(7.94728621e-08)
broadcast.3 = f32[32,12,64,512]{3,2,1,0} broadcast(constant.2), dimensions={}
dot.62 = f32[32,12,64,512]{3,2,1,0} dot(broadcast.3, divide.44), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
transpose.63 = f32[32,512,12,64]{1,3,2,0} transpose(dot.62), dimensions={0,3,1,2}
convert.64 = f16[32,512,12,64]{1,3,2,0} convert(transpose.63)
reshape.65 = f16[32,512,1,12,64]{4,3,2,1,0} reshape(convert.64)
constant.8 = f16[] constant(0)
pad.66 = f16[32,512,3,12,64]{4,3,2,1,0} pad(reshape.65, constant.8), padding=0_0x0_0x2_0x0_0x0_0
dot.61 = f32[32,12,512,512]{3,2,1,0} dot(broadcast.3, convert.20), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,2}, rhs_contracting_dims={3}
broadcast.79 = f32[32,12,512,1]{3,2,1,0} broadcast(reshape.40), dimensions={0,1,2,3}
reshape.80 = f32[32,12,512]{2,1,0} reshape(broadcast.79)
broadcast.81 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.80), dimensions={0,1,2}
divide.82 = f32[32,12,512,512]{3,2,1,0} divide(dot.61, broadcast.81)
call.51 = f32[32,12,512,1]{3,2,1,0} call(reshape.40), to_apply=integer_pow.45
broadcast.67 = f32[32,12,512,1]{3,2,1,0} broadcast(call.51), dimensions={0,1,2,3}
reshape.68 = f32[32,12,512]{2,1,0} reshape(broadcast.67)
broadcast.69 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.68), dimensions={0,1,2}
multiply.70 = f32[32,12,512,512]{3,2,1,0} multiply(dot.61, broadcast.69)
multiply.71 = f32[32,12,512,512]{3,2,1,0} multiply(multiply.70, exponential.34)
reduce.76 = f32[32,12,512]{2,1,0} reduce(multiply.71, constant.10), dimensions={3}, to_apply=region_3.72
reshape.77 = f32[32,12,512,1]{3,2,1,0} reshape(reduce.76)
negate.78 = f32[32,12,512,1]{3,2,1,0} negate(reshape.77)
reduce.87 = f32[32,12,512]{2,1,0} reduce(negate.78, constant.10), dimensions={3}, to_apply=region_4.83
broadcast.88 = f32[32,12,512,512]{3,2,1,0} broadcast(reduce.87), dimensions={0,1,2}
add.89 = f32[32,12,512,512]{3,2,1,0} add(divide.82, broadcast.88)
multiply.90 = f32[32,12,512,512]{3,2,1,0} multiply(add.89, exponential.34)
dot.91 = f32[32,12,512,64]{3,2,1,0} dot(multiply.90, divide.21), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,2}, rhs_contracting_dims={1}
transpose.92 = f32[32,512,12,64]{3,1,2,0} transpose(dot.91), dimensions={0,2,1,3}
convert.95 = f16[32,512,12,64]{3,1,2,0} convert(transpose.92)
reshape.96 = f16[32,512,1,12,64]{4,3,2,1,0} reshape(convert.95)
pad.97 = f16[32,512,3,12,64]{4,3,2,1,0} pad(reshape.96, constant.8), padding=0_0x0_0x1_1x0_0x0_0
add.98 = f16[32,512,3,12,64]{4,3,2,1,0} add(pad.66, pad.97)
dot.93 = f32[32,12,512,64]{3,2,1,0} dot(multiply.90, convert.19), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,2}, rhs_contracting_dims={1}
transpose.94 = f32[32,512,12,64]{3,1,2,0} transpose(dot.93), dimensions={0,2,1,3}
divide.99 = f32[32,512,12,64]{3,1,2,0} divide(transpose.94, broadcast.7)
convert.100 = f16[32,512,12,64]{3,1,2,0} convert(divide.99)
reshape.101 = f16[32,512,1,12,64]{4,3,2,1,0} reshape(convert.100)
pad.102 = f16[32,512,3,12,64]{4,3,2,1,0} pad(reshape.101, constant.8), padding=0_0x0_0x0_2x0_0x0_0
add.103 = f16[32,512,3,12,64]{4,3,2,1,0} add(add.98, pad.102)
reshape.105 = f16[32,512,3,12,64]{4,3,2,1,0} reshape(add.103), sharding={devices=[2,1,1,1,1]<=[2]}
ROOT tuple.106 = (f16[], f16[32,512,3,12,64]{4,3,2,1,0}) tuple(reshape.104, reshape.105), sharding={{replicated}, {devices=[2,1,1,1,1]<=[2]}}
} // main.107

34 changes: 34 additions & 0 deletions xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,40 @@ TEST_F(FunctionalHloRunnerTest, UseUninitializedInputs) {
InputFormat::kText));
}

// ROCM Error:
// E0000 00:00:1737155629.780742 137227 pjrt_stream_executor_client.cc:3045] Execution of replica 0 failed: INTERNAL: Failed to end stream capture: hipError_t(901)
TEST_F(FunctionalHloRunnerTest, ShardedComputationUnderStreamCapture) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
GetPjRtClient());

constexpr int kRequiredDeviceCount = 2;
const int kDeviceCount = client->device_count();
if (kDeviceCount < kRequiredDeviceCount) {
GTEST_SKIP() << "Requires " << kRequiredDeviceCount
<< " devices, but found only " << kDeviceCount;
return;
}

// NOTE: debug_options sent to FunctionalHloRunner::LoadAndRunAndDump() get
// lost during the creating of XlaComputation from HloModuleProto in
// FunctionalHloRunner::Compile
xla::DebugOptions debug_options;
FunctionalHloRunner::PreprocessingOptions preproc_options;
FunctionalHloRunner::RawCompileOptions raw_compile_options;
raw_compile_options.spmd_mode =
FunctionalHloRunner::SpmdMode::kUseSpmdPartitioning;
raw_compile_options.num_replicas = 1;
raw_compile_options.num_partitions = 2;
FunctionalHloRunner::RunningOptions running_options;
running_options.module_argument_mode =
FunctionalHloRunner::ModuleArgumentMode::kUseRandomInputs;

TF_EXPECT_OK(FunctionalHloRunner::LoadAndRunAndDump(
*client, debug_options, preproc_options, raw_compile_options,
running_options, {GetHloPath("sharded_computation.hlo")},
InputFormat::kText));
}

TEST_F(FunctionalHloRunnerTest, UseUninitializedInputsWithTupledArguments) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
GetPjRtClient());
Expand Down

0 comments on commit a3c22a0

Please sign in to comment.