Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid lazy init of blas handles and fix for non-canonical dots #95

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions xla/service/gpu/amdgpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,12 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
// tf2xla bridge, DepthwiseConvolutionConverter and GpuConvRewriter
// introduces reshapes and transposes that can be eliminated using
// AlgebraicSimplifier We run algsimp to a fixed point.
AlgebraicSimplifierOptions options =
AlgebraicSimplifierOptions algsimp_options =
GetAlgebraicSimplifierOptions(hlo_module->config());
options.set_enable_conv_operand_swap(false);
options.set_enable_unconditional_reduce_of_concat_replacement(false);
pipeline.AddPass<HloPassFix<GpuAlgebraicSimplifier>>(options, gpu_version);
algsimp_options.set_supports_non_canonical_dots(false);
algsimp_options.set_enable_conv_operand_swap(false);
algsimp_options.set_enable_unconditional_reduce_of_concat_replacement(false);
pipeline.AddPass<HloPassFix<GpuAlgebraicSimplifier>>(algsimp_options, gpu_version);

// tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and
// CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover
Expand All @@ -152,7 +153,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
ReshapeMoverOptions reshape_mover_options;
reshape_mover_options.reshape_of_1d_broadcast_is_cheap = true;
pipeline.AddPass<ReshapeMover>(reshape_mover_options);
pipeline.AddPass<GpuAlgebraicSimplifier>(options, gpu_version);
pipeline.AddPass<GpuAlgebraicSimplifier>(algsimp_options, gpu_version);
}();

// The reshapes and transposes can possibly be eliminated using
Expand All @@ -163,7 +164,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
[&, &pipeline = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
"simplify_after_conv_canonicalization")] {
pipeline.AddPass<ConvertMover>();
pipeline.AddPass<GpuAlgebraicSimplifier>(options, gpu_version);
pipeline.AddPass<GpuAlgebraicSimplifier>(algsimp_options, gpu_version);
}();

// GpuConvRewriter, GpuConvPaddingLegalization and
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/buffer_sharing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
bool is_reduction_emitter = analysis.GetEmitterFusionKind() ==
HloFusionAnalysis::EmitterFusionKind::kReduction;
const HloInstruction* reduction_hero =
is_reduction_emitter ? reduction_hero = analysis.FindHeroReduction()
is_reduction_emitter ? analysis.FindHeroReduction()
: nullptr;

// We need to make sure that the fusion parameter is accessed in the same
Expand Down
31 changes: 14 additions & 17 deletions xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,20 @@ absl::Status GpuExecutor::Init() {
return status;
}

return GpuDriver::GetGpuISAVersion(&version_, device_);
status = GpuDriver::GetGpuISAVersion(&version_, device_);
if (!status.ok()) {
return status;
}
// We initialize BLAS interfaces early here since otherwise it might create
// us problems during hipBlasLt initialization under graph capture.
// There is no real advantage of explicitly using 'lazy initialization' on
// ROCM platform because rocBLAS/hipBlasLt already use 'lazy initialization'
// internally
PluginRegistry* registry = PluginRegistry::Instance();
TF_ASSIGN_OR_RETURN(auto factory,
registry->GetFactory<PluginRegistry::BlasFactory>(rocm::kROCmPlatformId));
blas_.reset(factory(this));
return absl::OkStatus();
}

// Returns the path to the running executable.
Expand Down Expand Up @@ -681,22 +694,6 @@ absl::Status GpuExecutor::BlockHostUntilDone(Stream* stream) {
}

blas::BlasSupport* GpuExecutor::AsBlas() {
absl::MutexLock lock(&mu_);
if (blas_ != nullptr) {
return blas_.get();
}

PluginRegistry* registry = PluginRegistry::Instance();
absl::StatusOr<PluginRegistry::BlasFactory> status =
registry->GetFactory<PluginRegistry::BlasFactory>(rocm::kROCmPlatformId);
if (!status.ok()) {
LOG(ERROR) << "Unable to retrieve BLAS factory: "
<< status.status().message();
return nullptr;
}

auto blas = status.value()(this);
blas_.reset(blas);
return blas_.get();
}

Expand Down
23 changes: 23 additions & 0 deletions xla/tests/matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,29 @@ class MatmulTestWithCublas : public HloTestBase,
const bool use_cublas_lt_{GetParam()};
};

TEST_P(MatmulTestWithCublas, GemmRewriter_NonCanonicalDots) {
const char* module_str = R"(
HloModule m
a {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
ROOT r = f32[] add(p0, p1)
}
test {
p0 = f32[32,8,5,6] parameter(0)
p1 = f32[8,32,6,7] parameter(1)
d = f32[32,8,5,7] dot(p0, p1),
lhs_batch_dims={0,1},
rhs_batch_dims={1,0},
rhs_contracting_dims={2},
lhs_contracting_dims={3}
c = f32[] constant(0)
ROOT r = f32[8,5,7] reduce(d,c), dimensions={0}, to_apply=a
}
)";
EXPECT_TRUE(RunAndCompare(module_str, ErrorSpec{1e-4, 1e-4}));
}

TEST_P(MatmulTestWithCublas, GemmRewriter_RegressionTestF64) {
const char* module_str = R"(
HloModule GeneralMatMulActivation.7, entry_computation_layout={(f64[2,2,2]{2,1,0}, f64[2,2,2]{2,1,0})->f64[2,2,2]{2,1,0}}
Expand Down
123 changes: 123 additions & 0 deletions xla/tools/multihost_hlo_runner/data/sharded_computation.hlo
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
HloModule pjit_ref_func

region_0.23 {
Arg_0.24 = f32[] parameter(0)
Arg_1.25 = f32[] parameter(1)
ROOT maximum.26 = f32[] maximum(Arg_0.24, Arg_1.25)
}

region_1.35 {
Arg_0.36 = f32[] parameter(0)
Arg_1.37 = f32[] parameter(1)
ROOT add.38 = f32[] add(Arg_0.36, Arg_1.37)
}

integer_pow.45 {
constant.47 = f32[] constant(1)
broadcast.48 = f32[32,12,512,1]{3,2,1,0} broadcast(constant.47), dimensions={}
Arg_0.46 = f32[32,12,512,1]{3,2,1,0} parameter(0)
multiply.49 = f32[32,12,512,1]{3,2,1,0} multiply(Arg_0.46, Arg_0.46)
ROOT divide.50 = f32[32,12,512,1]{3,2,1,0} divide(broadcast.48, multiply.49)
}

region_2.54 {
Arg_0.55 = f32[] parameter(0)
Arg_1.56 = f32[] parameter(1)
ROOT add.57 = f32[] add(Arg_0.55, Arg_1.56)
}

region_3.72 {
Arg_0.73 = f32[] parameter(0)
Arg_1.74 = f32[] parameter(1)
ROOT add.75 = f32[] add(Arg_0.73, Arg_1.74)
}

region_4.83 {
Arg_0.84 = f32[] parameter(0)
Arg_1.85 = f32[] parameter(1)
ROOT add.86 = f32[] add(Arg_0.84, Arg_1.85)
}

ENTRY main.107 {
Arg_0.1 = f16[32,512,3,12,64]{4,3,2,1,0} parameter(0), sharding={devices=[2,1,1,1,1]<=[2]}
slice.14 = f16[32,512,1,12,64]{4,3,2,1,0} slice(Arg_0.1), slice={[0:32], [0:512], [2:3], [0:12], [0:64]}
reshape.17 = f16[32,512,12,64]{3,2,1,0} reshape(slice.14)
convert.20 = f32[32,512,12,64]{3,2,1,0} convert(reshape.17)
slice.12 = f16[32,512,1,12,64]{4,3,2,1,0} slice(Arg_0.1), slice={[0:32], [0:512], [0:1], [0:12], [0:64]}
reshape.15 = f16[32,512,12,64]{3,2,1,0} reshape(slice.12)
convert.18 = f32[32,512,12,64]{3,2,1,0} convert(reshape.15)
constant.6 = f32[] constant(8)
broadcast.7 = f32[32,512,12,64]{3,2,1,0} broadcast(constant.6), dimensions={}
divide.21 = f32[32,512,12,64]{3,2,1,0} divide(convert.18, broadcast.7)
slice.13 = f16[32,512,1,12,64]{4,3,2,1,0} slice(Arg_0.1), slice={[0:32], [0:512], [1:2], [0:12], [0:64]}
reshape.16 = f16[32,512,12,64]{3,2,1,0} reshape(slice.13)
convert.19 = f32[32,512,12,64]{3,2,1,0} convert(reshape.16)
dot.22 = f32[32,12,512,512]{3,2,1,0} dot(divide.21, convert.19), lhs_batch_dims={0,2}, lhs_contracting_dims={3}, rhs_batch_dims={0,2}, rhs_contracting_dims={3}
constant.11 = f32[] constant(-inf)
reduce.27 = f32[32,12,512]{2,1,0} reduce(dot.22, constant.11), dimensions={3}, to_apply=region_0.23
constant.4 = f32[] constant(-inf)
broadcast.5 = f32[32,12,512]{2,1,0} broadcast(constant.4), dimensions={}
maximum.28 = f32[32,12,512]{2,1,0} maximum(reduce.27, broadcast.5)
reshape.29 = f32[32,12,512,1]{3,2,1,0} reshape(maximum.28)
broadcast.30 = f32[32,12,512,1]{3,2,1,0} broadcast(reshape.29), dimensions={0,1,2,3}
reshape.31 = f32[32,12,512]{2,1,0} reshape(broadcast.30)
broadcast.32 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.31), dimensions={0,1,2}
subtract.33 = f32[32,12,512,512]{3,2,1,0} subtract(dot.22, broadcast.32)
exponential.34 = f32[32,12,512,512]{3,2,1,0} exponential(subtract.33)
constant.10 = f32[] constant(0)
reduce.39 = f32[32,12,512]{2,1,0} reduce(exponential.34, constant.10), dimensions={3}, to_apply=region_1.35
reshape.40 = f32[32,12,512,1]{3,2,1,0} reshape(reduce.39)
broadcast.41 = f32[32,12,512,1]{3,2,1,0} broadcast(reshape.40), dimensions={0,1,2,3}
reshape.42 = f32[32,12,512]{2,1,0} reshape(broadcast.41)
broadcast.43 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.42), dimensions={0,1,2}
divide.44 = f32[32,12,512,512]{3,2,1,0} divide(exponential.34, broadcast.43)
dot.52 = f32[32,12,64,512]{3,2,1,0} dot(convert.20, divide.44), lhs_batch_dims={0,2}, lhs_contracting_dims={1}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}
transpose.53 = f32[32,512,12,64]{1,3,2,0} transpose(dot.52), dimensions={0,3,1,2}
reduce.58 = f32[] reduce(transpose.53, constant.10), dimensions={0,1,2,3}, to_apply=region_2.54
constant.9 = f32[] constant(12582912)
divide.59 = f32[] divide(reduce.58, constant.9)
convert.60 = f16[] convert(divide.59)
reshape.104 = f16[] reshape(convert.60), sharding={replicated}
constant.2 = f32[] constant(7.94728621e-08)
broadcast.3 = f32[32,12,64,512]{3,2,1,0} broadcast(constant.2), dimensions={}
dot.62 = f32[32,12,64,512]{3,2,1,0} dot(broadcast.3, divide.44), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}
transpose.63 = f32[32,512,12,64]{1,3,2,0} transpose(dot.62), dimensions={0,3,1,2}
convert.64 = f16[32,512,12,64]{1,3,2,0} convert(transpose.63)
reshape.65 = f16[32,512,1,12,64]{4,3,2,1,0} reshape(convert.64)
constant.8 = f16[] constant(0)
pad.66 = f16[32,512,3,12,64]{4,3,2,1,0} pad(reshape.65, constant.8), padding=0_0x0_0x2_0x0_0x0_0
dot.61 = f32[32,12,512,512]{3,2,1,0} dot(broadcast.3, convert.20), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,2}, rhs_contracting_dims={3}
broadcast.79 = f32[32,12,512,1]{3,2,1,0} broadcast(reshape.40), dimensions={0,1,2,3}
reshape.80 = f32[32,12,512]{2,1,0} reshape(broadcast.79)
broadcast.81 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.80), dimensions={0,1,2}
divide.82 = f32[32,12,512,512]{3,2,1,0} divide(dot.61, broadcast.81)
call.51 = f32[32,12,512,1]{3,2,1,0} call(reshape.40), to_apply=integer_pow.45
broadcast.67 = f32[32,12,512,1]{3,2,1,0} broadcast(call.51), dimensions={0,1,2,3}
reshape.68 = f32[32,12,512]{2,1,0} reshape(broadcast.67)
broadcast.69 = f32[32,12,512,512]{3,2,1,0} broadcast(reshape.68), dimensions={0,1,2}
multiply.70 = f32[32,12,512,512]{3,2,1,0} multiply(dot.61, broadcast.69)
multiply.71 = f32[32,12,512,512]{3,2,1,0} multiply(multiply.70, exponential.34)
reduce.76 = f32[32,12,512]{2,1,0} reduce(multiply.71, constant.10), dimensions={3}, to_apply=region_3.72
reshape.77 = f32[32,12,512,1]{3,2,1,0} reshape(reduce.76)
negate.78 = f32[32,12,512,1]{3,2,1,0} negate(reshape.77)
reduce.87 = f32[32,12,512]{2,1,0} reduce(negate.78, constant.10), dimensions={3}, to_apply=region_4.83
broadcast.88 = f32[32,12,512,512]{3,2,1,0} broadcast(reduce.87), dimensions={0,1,2}
add.89 = f32[32,12,512,512]{3,2,1,0} add(divide.82, broadcast.88)
multiply.90 = f32[32,12,512,512]{3,2,1,0} multiply(add.89, exponential.34)
dot.91 = f32[32,12,512,64]{3,2,1,0} dot(multiply.90, divide.21), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,2}, rhs_contracting_dims={1}
transpose.92 = f32[32,512,12,64]{3,1,2,0} transpose(dot.91), dimensions={0,2,1,3}
convert.95 = f16[32,512,12,64]{3,1,2,0} convert(transpose.92)
reshape.96 = f16[32,512,1,12,64]{4,3,2,1,0} reshape(convert.95)
pad.97 = f16[32,512,3,12,64]{4,3,2,1,0} pad(reshape.96, constant.8), padding=0_0x0_0x1_1x0_0x0_0
add.98 = f16[32,512,3,12,64]{4,3,2,1,0} add(pad.66, pad.97)
dot.93 = f32[32,12,512,64]{3,2,1,0} dot(multiply.90, convert.19), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,2}, rhs_contracting_dims={1}
transpose.94 = f32[32,512,12,64]{3,1,2,0} transpose(dot.93), dimensions={0,2,1,3}
divide.99 = f32[32,512,12,64]{3,1,2,0} divide(transpose.94, broadcast.7)
convert.100 = f16[32,512,12,64]{3,1,2,0} convert(divide.99)
reshape.101 = f16[32,512,1,12,64]{4,3,2,1,0} reshape(convert.100)
pad.102 = f16[32,512,3,12,64]{4,3,2,1,0} pad(reshape.101, constant.8), padding=0_0x0_0x0_2x0_0x0_0
add.103 = f16[32,512,3,12,64]{4,3,2,1,0} add(add.98, pad.102)
reshape.105 = f16[32,512,3,12,64]{4,3,2,1,0} reshape(add.103), sharding={devices=[2,1,1,1,1]<=[2]}
ROOT tuple.106 = (f16[], f16[32,512,3,12,64]{4,3,2,1,0}) tuple(reshape.104, reshape.105), sharding={{replicated}, {devices=[2,1,1,1,1]<=[2]}}
} // main.107

34 changes: 34 additions & 0 deletions xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,40 @@ TEST_F(FunctionalHloRunnerTest, UseUninitializedInputs) {
InputFormat::kText));
}

// ROCM Error:
// E0000 00:00:1737155629.780742 137227 pjrt_stream_executor_client.cc:3045] Execution of replica 0 failed: INTERNAL: Failed to end stream capture: hipError_t(901)
TEST_F(FunctionalHloRunnerTest, ShardedComputationUnderStreamCapture) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
GetPjRtClient());

constexpr int kRequiredDeviceCount = 2;
const int kDeviceCount = client->device_count();
if (kDeviceCount < kRequiredDeviceCount) {
GTEST_SKIP() << "Requires " << kRequiredDeviceCount
<< " devices, but found only " << kDeviceCount;
return;
}

// NOTE: debug_options sent to FunctionalHloRunner::LoadAndRunAndDump() get
// lost during the creating of XlaComputation from HloModuleProto in
// FunctionalHloRunner::Compile
xla::DebugOptions debug_options;
FunctionalHloRunner::PreprocessingOptions preproc_options;
FunctionalHloRunner::RawCompileOptions raw_compile_options;
raw_compile_options.spmd_mode =
FunctionalHloRunner::SpmdMode::kUseSpmdPartitioning;
raw_compile_options.num_replicas = 1;
raw_compile_options.num_partitions = 2;
FunctionalHloRunner::RunningOptions running_options;
running_options.module_argument_mode =
FunctionalHloRunner::ModuleArgumentMode::kUseRandomInputs;

TF_EXPECT_OK(FunctionalHloRunner::LoadAndRunAndDump(
*client, debug_options, preproc_options, raw_compile_options,
running_options, {GetHloPath("sharded_computation.hlo")},
InputFormat::kText));
}

TEST_F(FunctionalHloRunnerTest, UseUninitializedInputsWithTupledArguments) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
GetPjRtClient());
Expand Down
Loading