Skip to content

Commit

Permalink
PR openxla#11353: Implement async dynamic slicing
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla#11353

This implements async dynamic-slice and dynamic-update-slice for host memory offloading on GPU.

Since the emitter does not understand dynamic slicing instructions in async computation, we wrap them in a fusion node and mark them for execution on a different stream. This is all we need to execute the offloading of slices asynchronously.
Copybara import of the project:

--
905db6d by Jaroslav Sevcik <[email protected]>:

Wrap async dynamic slicing into fusion

Merging this change closes openxla#11353

COPYBARA_INTEGRATE_REVIEW=openxla#11353 from jaro-sevcik:host-offloading-async-fusion-dynamic-slicing 905db6d
PiperOrigin-RevId: 625428675
  • Loading branch information
jaro-sevcik authored and copybara-github committed Apr 16, 2024
1 parent d5c5e4d commit 1f99384
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 1 deletion.
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5823,6 +5823,7 @@ cc_library(
hdrs = ["stream_attribute_annotator.h"],
deps = [
":backend_configs_cc",
":gpu_fusible",
"//xla:comparison_util",
"//xla:status",
"//xla:statusor",
Expand Down
45 changes: 44 additions & 1 deletion xla/service/gpu/stream_attribute_annotator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/gpu_fusible.h"
#include "xla/service/gpu/runtime/thunk.h"
#include "xla/statusor.h"
#include "xla/util.h"
Expand Down Expand Up @@ -102,6 +103,39 @@ absl::StatusOr<bool> AnnotateStreamAttributesForCopyStart(
return true;
}

absl::StatusOr<bool> WrapIntoFusionAndAnnotateStreamAttributes(
HloInstruction* instruction, int64_t channel_id,
GpuBackendConfig& instr_gpu_config) {
auto* computation = instruction->parent();
auto* module = computation->parent();
auto* fusion_instruction =
computation->AddInstruction(HloInstruction::CreateFusion(
instruction->shape(), ChooseFusionKind(*instruction, *instruction),
instruction));
const absl::string_view wrapped_opcode =
HloOpcodeString(instruction->opcode());
module->SetAndUniquifyInstrName(fusion_instruction,
absl::StrCat("wrapped_", wrapped_opcode));
module->SetAndUniquifyComputationName(
fusion_instruction->fused_instructions_computation(),
absl::StrCat("wrapped_", wrapped_opcode, "_computation"));
if (module->has_schedule()) {
module->schedule().replace_instruction(computation, instruction,
fusion_instruction);
}
TF_RETURN_IF_ERROR(fusion_instruction->CopyAllControlDepsFrom(instruction));
TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction));
TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));

instr_gpu_config.set_operation_queue_id(channel_id);
TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(instr_gpu_config));
VLOG(3) << "Add async stream " << channel_id << " and wrapped instruction "
<< instruction->ToString();
VLOG(3) << " Fusion wrapper: " << fusion_instruction->ToString();
return true;
}

absl::StatusOr<bool> AnnotateStreamAttributesForUsers(
HloInstruction* instr, GpuBackendConfig& instr_gpu_config) {
bool changed = false;
Expand Down Expand Up @@ -140,7 +174,8 @@ absl::StatusOr<bool> StreamAttributeAnnotator::Run(
5, "StreamAttributeAnnotator::Run(), before:\n" + module->ToString());
bool changed = false;
int64_t channel_id = hlo_query::NextChannelId(*module);
for (const HloComputation* comp : module->computations(execution_threads)) {
for (const HloComputation* comp :
module->MakeComputationPostOrder(execution_threads)) {
for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
auto instr_gpu_config = instr->backend_config<GpuBackendConfig>();
if (!instr_gpu_config.ok()) {
Expand All @@ -160,6 +195,14 @@ absl::StatusOr<bool> StreamAttributeAnnotator::Run(
instr, channel_id, instr_gpu_config.value()));
changed |= comp_result;
continue;
} else if (comp->IsAsyncComputation() &&
(instr->opcode() == HloOpcode::kDynamicSlice ||
instr->opcode() == HloOpcode::kDynamicUpdateSlice)) {
TF_ASSIGN_OR_RETURN(bool comp_result,
WrapIntoFusionAndAnnotateStreamAttributes(
instr, channel_id, instr_gpu_config.value()));
changed |= comp_result;
continue;
}

TF_ASSIGN_OR_RETURN(
Expand Down
75 changes: 75 additions & 0 deletions xla/service/gpu/stream_attribute_annotator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,5 +208,80 @@ TEST_F(StreamAttributeAnnotatorTest, CopyStartIsAnnotated) {
EXPECT_EQ(gpu_config.operation_queue_id(), 1);
}
}

TEST_F(StreamAttributeAnnotatorTest, DynamicUpdateSliceWrappedAndAnnotated) {
constexpr absl::string_view kHloString = R"(
HloModule ModuleWithAsyncDynamicUpdateSlice
ENTRY entry (param_0: f32[256,128,128], param_1: f32[1,128,128]) -> f32[256,128,128] {
param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0)
param_1 = f32[1,128,128]{2,1,0} parameter(1)
izero = s32[] constant(0)
dynamic-update-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, f32[1,128,128]{2,1,0}, s32[], s32[], s32[]), f32[256,128,128]{2,1,0:S(5)}, u32[])
dynamic-update-slice-start(param_0, param_1, izero, izero, izero)
ROOT dynamic-update-slice-done.2 = f32[256,128,128]{2,1,0:S(5)}
dynamic-update-slice-done(dynamic-update-slice-start.2)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloString));

TF_ASSERT_OK_AND_ASSIGN(bool changed,
StreamAttributeAnnotator().Run(module.get()));
EXPECT_TRUE(changed);

// Check that the dynamic-update-slice instruction is wrapped in a fusion
// and the fusion is annotated with the correct operation_queue_id.
const HloInstruction* dus =
FindInstruction(module.get(), HloOpcode::kDynamicUpdateSlice);
const HloComputation* computation = dus->parent();
EXPECT_TRUE(computation->IsFusionComputation());
const HloInstruction* fusion = computation->FusionInstruction();
EXPECT_EQ(fusion->opcode(), HloOpcode::kFusion);
EXPECT_TRUE(fusion->parent()->IsAsyncComputation());

EXPECT_TRUE(fusion->has_backend_config());
TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
fusion->backend_config<GpuBackendConfig>());
EXPECT_EQ(gpu_config.operation_queue_id(), 1);
}

TEST_F(StreamAttributeAnnotatorTest, DynamicSliceWrappedAndAnnotated) {
constexpr absl::string_view kHloString = R"(
HloModule ModuleWithAsyncDynamicSlice
ENTRY entry (param_0: f32[256,128,128]) -> f32[1,128,128] {
param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0)
izero = s32[] constant(0)
dynamic-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, s32[], s32[], s32[]), f32[1,128,128]{2,1,0}, u32[])
dynamic-slice-start(param_0, izero, izero, izero), dynamic_slice_sizes={1,128,128}
ROOT dynamic-slice-done.2 = f32[1,128,128]{2,1,0}
dynamic-slice-done(dynamic-slice-start.2)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloString));

TF_ASSERT_OK_AND_ASSIGN(bool changed,
StreamAttributeAnnotator().Run(module.get()));
EXPECT_TRUE(changed);

// Check that the dynamic-slice instruction is wrapped in a fusion
// and the fusion is annotated with the correct operation_queue_id.
const HloInstruction* ds =
FindInstruction(module.get(), HloOpcode::kDynamicSlice);
const HloComputation* computation = ds->parent();
EXPECT_TRUE(computation->IsFusionComputation());
const HloInstruction* fusion = computation->FusionInstruction();
EXPECT_EQ(fusion->opcode(), HloOpcode::kFusion);
EXPECT_TRUE(fusion->parent()->IsAsyncComputation());

EXPECT_TRUE(fusion->has_backend_config());
TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config,
fusion->backend_config<GpuBackendConfig>());
EXPECT_EQ(gpu_config.operation_queue_id(), 1);
}
} // namespace
} // namespace xla::gpu

0 comments on commit 1f99384

Please sign in to comment.