diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index f4eaf93779d3a..4f7e308db31e6 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -5823,6 +5823,7 @@ cc_library( hdrs = ["stream_attribute_annotator.h"], deps = [ ":backend_configs_cc", + ":gpu_fusible", "//xla:comparison_util", "//xla:status", "//xla:statusor", diff --git a/xla/service/gpu/stream_attribute_annotator.cc b/xla/service/gpu/stream_attribute_annotator.cc index 0bfa2cef837e2..0b8d00984df7a 100644 --- a/xla/service/gpu/stream_attribute_annotator.cc +++ b/xla/service/gpu/stream_attribute_annotator.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/statusor.h" #include "xla/util.h" @@ -102,6 +103,39 @@ absl::StatusOr AnnotateStreamAttributesForCopyStart( return true; } +absl::StatusOr WrapIntoFusionAndAnnotateStreamAttributes( + HloInstruction* instruction, int64_t channel_id, + GpuBackendConfig& instr_gpu_config) { + auto* computation = instruction->parent(); + auto* module = computation->parent(); + auto* fusion_instruction = + computation->AddInstruction(HloInstruction::CreateFusion( + instruction->shape(), ChooseFusionKind(*instruction, *instruction), + instruction)); + const absl::string_view wrapped_opcode = + HloOpcodeString(instruction->opcode()); + module->SetAndUniquifyInstrName(fusion_instruction, + absl::StrCat("wrapped_", wrapped_opcode)); + module->SetAndUniquifyComputationName( + fusion_instruction->fused_instructions_computation(), + absl::StrCat("wrapped_", wrapped_opcode, "_computation")); + if (module->has_schedule()) { + module->schedule().replace_instruction(computation, instruction, + fusion_instruction); + } + TF_RETURN_IF_ERROR(fusion_instruction->CopyAllControlDepsFrom(instruction)); + TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(fusion_instruction)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); + + instr_gpu_config.set_operation_queue_id(channel_id); + TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(instr_gpu_config)); + VLOG(3) << "Add async stream " << channel_id << " and wrapped instruction " + << instruction->ToString(); + VLOG(3) << " Fusion wrapper: " << fusion_instruction->ToString(); + return true; +} + absl::StatusOr AnnotateStreamAttributesForUsers( HloInstruction* instr, GpuBackendConfig& instr_gpu_config) { bool changed = false; @@ -140,7 +174,8 @@ absl::StatusOr StreamAttributeAnnotator::Run( 5, "StreamAttributeAnnotator::Run(), before:\n" + module->ToString()); bool changed = false; int64_t channel_id = hlo_query::NextChannelId(*module); - for (const HloComputation* comp : module->computations(execution_threads)) { + for (const HloComputation* comp : + module->MakeComputationPostOrder(execution_threads)) { for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { auto instr_gpu_config = instr->backend_config(); if (!instr_gpu_config.ok()) { @@ -160,6 +195,14 @@ absl::StatusOr StreamAttributeAnnotator::Run( instr, channel_id, instr_gpu_config.value())); changed |= comp_result; continue; + } else if (comp->IsAsyncComputation() && + (instr->opcode() == HloOpcode::kDynamicSlice || + instr->opcode() == HloOpcode::kDynamicUpdateSlice)) { + TF_ASSIGN_OR_RETURN(bool comp_result, + WrapIntoFusionAndAnnotateStreamAttributes( + instr, channel_id, instr_gpu_config.value())); + changed |= comp_result; + continue; } TF_ASSIGN_OR_RETURN( diff --git a/xla/service/gpu/stream_attribute_annotator_test.cc b/xla/service/gpu/stream_attribute_annotator_test.cc index 2861f9a82a7ef..17d9b2f1e212d 100644 --- a/xla/service/gpu/stream_attribute_annotator_test.cc +++ b/xla/service/gpu/stream_attribute_annotator_test.cc @@ -208,5 +208,80 @@ TEST_F(StreamAttributeAnnotatorTest, CopyStartIsAnnotated) { EXPECT_EQ(gpu_config.operation_queue_id(), 1); } } + +TEST_F(StreamAttributeAnnotatorTest, DynamicUpdateSliceWrappedAndAnnotated) { + constexpr absl::string_view kHloString = R"( + HloModule ModuleWithAsyncDynamicUpdateSlice + + ENTRY entry (param_0: f32[256,128,128], param_1: f32[1,128,128]) -> f32[256,128,128] { + param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0) + param_1 = f32[1,128,128]{2,1,0} parameter(1) + izero = s32[] constant(0) + dynamic-update-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, f32[1,128,128]{2,1,0}, s32[], s32[], s32[]), f32[256,128,128]{2,1,0:S(5)}, u32[]) + dynamic-update-slice-start(param_0, param_1, izero, izero, izero) + ROOT dynamic-update-slice-done.2 = f32[256,128,128]{2,1,0:S(5)} + dynamic-update-slice-done(dynamic-update-slice-start.2) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + StreamAttributeAnnotator().Run(module.get())); + EXPECT_TRUE(changed); + + // Check that the dynamic-update-slice instruction is wrapped in a fusion + // and the fusion is annotated with the correct operation_queue_id. + const HloInstruction* dus = + FindInstruction(module.get(), HloOpcode::kDynamicUpdateSlice); + const HloComputation* computation = dus->parent(); + EXPECT_TRUE(computation->IsFusionComputation()); + const HloInstruction* fusion = computation->FusionInstruction(); + EXPECT_EQ(fusion->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(fusion->parent()->IsAsyncComputation()); + + EXPECT_TRUE(fusion->has_backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + fusion->backend_config()); + EXPECT_EQ(gpu_config.operation_queue_id(), 1); +} + +TEST_F(StreamAttributeAnnotatorTest, DynamicSliceWrappedAndAnnotated) { + constexpr absl::string_view kHloString = R"( + HloModule ModuleWithAsyncDynamicSlice + + ENTRY entry (param_0: f32[256,128,128]) -> f32[1,128,128] { + param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0) + izero = s32[] constant(0) + dynamic-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, s32[], s32[], s32[]), f32[1,128,128]{2,1,0}, u32[]) + dynamic-slice-start(param_0, izero, izero, izero), dynamic_slice_sizes={1,128,128} + ROOT dynamic-slice-done.2 = f32[1,128,128]{2,1,0} + dynamic-slice-done(dynamic-slice-start.2) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + StreamAttributeAnnotator().Run(module.get())); + EXPECT_TRUE(changed); + + // Check that the dynamic-slice instruction is wrapped in a fusion + // and the fusion is annotated with the correct operation_queue_id. + const HloInstruction* ds = + FindInstruction(module.get(), HloOpcode::kDynamicSlice); + const HloComputation* computation = ds->parent(); + EXPECT_TRUE(computation->IsFusionComputation()); + const HloInstruction* fusion = computation->FusionInstruction(); + EXPECT_EQ(fusion->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(fusion->parent()->IsAsyncComputation()); + + EXPECT_TRUE(fusion->has_backend_config()); + TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, + fusion->backend_config()); + EXPECT_EQ(gpu_config.operation_queue_id(), 1); +} } // namespace } // namespace xla::gpu