diff --git a/xla/service/latency_hiding_scheduler.cc b/xla/service/latency_hiding_scheduler.cc index f85342a913fc8..f9af62dcaac72 100644 --- a/xla/service/latency_hiding_scheduler.cc +++ b/xla/service/latency_hiding_scheduler.cc @@ -214,6 +214,7 @@ bool AsyncTracker::IsSupportedAsyncDone(const HloInstruction& hlo) const { } switch (op.inner) { case HloOpcode::kAllToAll: + case HloOpcode::kRaggedAllToAll: case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kCollectiveBroadcast: @@ -243,6 +244,7 @@ bool AsyncTracker::IsSupportedAsyncStart(const HloInstruction& hlo) const { } switch (op.inner) { case HloOpcode::kAllToAll: + case HloOpcode::kRaggedAllToAll: case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kCollectiveBroadcast: @@ -268,6 +270,8 @@ ResourcesVector AsyncTracker::GetResourcesFromInstructionImpl( return ResourceType::kAllGather; case HloOpcode::kAllToAll: return ResourceType::kAllToAll; + case HloOpcode::kRaggedAllToAll: + return ResourceType::kRaggedAllToAll; case HloOpcode::kCollectiveBroadcast: return ResourceType::kCollectiveBroadcast; case HloOpcode::kCollectivePermute: @@ -420,6 +424,8 @@ void AsyncTracker::SetConcurrentResourceLimits( config_.copy_overlap_limit; max_concurrent_resource[ResourceTypeToIndex(ResourceType::kAllToAll)] = config_.all_to_all_overlap_limit; + max_concurrent_resource[ResourceTypeToIndex(ResourceType::kRaggedAllToAll)] = + config_.ragged_all_to_all_overlap_limit; max_concurrent_resource[ResourceTypeToIndex(ResourceType::kAllGather)] = config_.all_gather_overlap_limit; max_concurrent_resource[ResourceTypeToIndex(ResourceType::kAllReduce)] = @@ -453,6 +459,8 @@ absl::string_view AsyncTracker::GetResourceName(int64_t resource_type) const { return "kNoResource"; case ResourceTypeToIndex(ResourceType::kAllToAll): return "kAllToAll"; + case ResourceTypeToIndex(ResourceType::kRaggedAllToAll): + return "kRaggedAllToAll"; case ResourceTypeToIndex(ResourceType::kAllGather): return "kAllGather"; case ResourceTypeToIndex(ResourceType::kAllReduce): @@ -2499,6 +2507,7 @@ LatencyHidingScheduler::LatencyHidingStatistics( kAllReduce, kCollectivePermute, kAllToAll, + kRaggedAllToAll, kReduceScatter, kSend, kRecv, @@ -2516,6 +2525,8 @@ LatencyHidingScheduler::LatencyHidingStatistics( return AsyncKind::kCollectivePermute; case HloOpcode::kAllToAll: return AsyncKind::kAllToAll; + case HloOpcode::kRaggedAllToAll: + return AsyncKind::kRaggedAllToAll; case HloOpcode::kReduceScatter: return AsyncKind::kReduceScatter; case HloOpcode::kSend: @@ -2614,6 +2625,8 @@ LatencyHidingScheduler::LatencyHidingStatistics( wasted_time_per_collective[AsyncKind::kCollectivePermute], /*all_to_all_wasted_cycles=*/ wasted_time_per_collective[AsyncKind::kAllToAll], + /*ragged_all_to_all_wasted_cycles=*/ + wasted_time_per_collective[AsyncKind::kRaggedAllToAll], /*reduce_scatter_wasted_cycles=*/ wasted_time_per_collective[AsyncKind::kReduceScatter], /*send_wasted_cycles=*/wasted_time_per_collective[AsyncKind::kSend], @@ -2640,6 +2653,7 @@ std::string LatencyHidingScheduler::SchedulerStatisticsString( sched_stats.collective_broadcast_wasted_cycles + sched_stats.collective_permute_wasted_cycles + sched_stats.all_to_all_wasted_cycles + + sched_stats.ragged_all_to_all_wasted_cycles + sched_stats.reduce_scatter_wasted_cycles + sched_stats.send_wasted_cycles + sched_stats.recv_wasted_cycles, @@ -2654,6 +2668,8 @@ std::string LatencyHidingScheduler::SchedulerStatisticsString( sched_stats.collective_permute_wasted_cycles, "\n"); absl::StrAppend(&result, "Wasted cycles for all-to-all: ", sched_stats.all_to_all_wasted_cycles, "\n"); + absl::StrAppend(&result, "Wasted cycles for ragged-all-to-all: ", + sched_stats.ragged_all_to_all_wasted_cycles, "\n"); absl::StrAppend(&result, "Wasted cycles for reduce-scatter: ", sched_stats.reduce_scatter_wasted_cycles, "\n"); absl::StrAppend(&result, diff --git a/xla/service/latency_hiding_scheduler.h b/xla/service/latency_hiding_scheduler.h index 8d18bea0b2aad..6fb08d52ba403 100644 --- a/xla/service/latency_hiding_scheduler.h +++ b/xla/service/latency_hiding_scheduler.h @@ -83,6 +83,7 @@ enum class ResourceType { kRecvHost, kCollectiveBroadcast, kNumResources, + kRaggedAllToAll, kTargetDefinedResourceTypeBegin, }; @@ -126,6 +127,7 @@ struct SchedulerConfig { int64_t collective_broadcast_overlap_limit = 1; int64_t collective_permute_overlap_limit = 1; int64_t all_to_all_overlap_limit = 1; + int64_t ragged_all_to_all_overlap_limit = 1; int64_t all_gather_overlap_limit = 1; int64_t all_reduce_overlap_limit = 1; int64_t reduce_scatter_overlap_limit = 1; @@ -1116,6 +1118,7 @@ class LatencyHidingScheduler : public HloModulePass { double collective_broadcast_wasted_cycles = 0; double collective_permute_wasted_cycles = 0; double all_to_all_wasted_cycles = 0; + double ragged_all_to_all_wasted_cycles = 0; double reduce_scatter_wasted_cycles = 0; double send_wasted_cycles = 0; double recv_wasted_cycles = 0; diff --git a/xla/service/latency_hiding_scheduler_test.cc b/xla/service/latency_hiding_scheduler_test.cc index 76e4e92382d5d..01ed048fff2aa 100644 --- a/xla/service/latency_hiding_scheduler_test.cc +++ b/xla/service/latency_hiding_scheduler_test.cc @@ -3914,4 +3914,60 @@ TEST_F(LatencyHidingSchedulerTest, CrossComputationAnnotation) { GetIndex(loop_instruction_sequence, "cpd1")); } +TEST_F(LatencyHidingSchedulerTest, RaggedAllToAll) { + constexpr absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + async_computation { + input = f32[8,128,1024]{2,1,0:T(8,128)} parameter(0) + output = f32[8,128,1024]{2,1,0:T(8,128)} parameter(1) + input_offsets = s32[8]{0} parameter(2) + send_sizes = s32[8]{0} parameter(3) + output_offsets = s32[8]{0} parameter(4) + recv_sizes = s32[8]{0} parameter(5) + ROOT ra2a = f32[8,128,1024]{2,1,0:T(8,128)} ragged-all-to-all(input, output,input_offsets, send_sizes, output_offsets, recv_sizes), replica_groups={{0,1,2,3,4,5,6,7}} + } + + ENTRY RA2A { + p0 = f32[8,128,1024]{2,1,0:T(8,128)} parameter(0) + c0 = f32[] constant(0) + output = f32[8,128,1024]{2,1,0:T(8,128)} broadcast(c0), dimensions={} + p1 = s32[8]{0} parameter(1) + p2 = s32[8]{0} parameter(2) + p3 = s32[8]{0} parameter(3) + p4 = s32[8]{0} parameter(4) + p5 = f32[1024, 1024]{1,0:T(8,128)} parameter(5) + input = f32[8,128,1024]{2,1,0:T(8,128)} copy(p0) + input_offsets = s32[8]{0} copy(p1) + send_sizes = s32[8]{0} copy(p2) + output_offsets = s32[8]{0} copy(p3) + recv_sizes = s32[8]{0} copy(p4) + ra2a-start = ((f32[8,128,1024]{2,1,0:T(8,128)}, f32[8,128,1024]{2,1,0:T(8,128)}, s32[8]{0}, s32[8]{0}, s32[8]{0}, s32[8]{0}), + f32[8,128,1024]{2,1,0:T(8,128)}, u32[]{:S(2)}, u32[]{:S(2)}) async-start(input, output, input_offsets, send_sizes, output_offsets, recv_sizes), calls=async_computation + ra2a-done = f32[8,128,1024]{2,1,0:T(8,128)} async-done(ra2a-start), calls=async_computation + d = f32[1024,1024]{1,0:T(8,128)} dot(p5, p5), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT tuple = (f32[8,128,1024]{2,1,0:T(8,128)}, f32[1024,1024]{1,0:T(8,128)}) tuple(ra2a-done, d) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string)); + HloSchedule& module_schedule = hlo_module->schedule(); + EXPECT_TRUE(hlo_module->has_entry_computation()); + auto sched_config = GetDefaultSchedConfig(); + EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config).ok()); + EXPECT_TRUE(hlo_module->has_entry_computation()); + + std::vector new_instruction_sequence = + module_schedule.sequence(hlo_module->entry_computation()).instructions(); + if (VLOG_IS_ON(1)) { + for (auto* new_i : new_instruction_sequence) { + VLOG(1) << new_i->ToString(); + } + } + + EXPECT_LT(GetIndex(new_instruction_sequence, "ra2a-start"), + GetIndex(new_instruction_sequence, "d")); + EXPECT_LT(GetIndex(new_instruction_sequence, "d"), + GetIndex(new_instruction_sequence, "ra2a-done")); +} + } // namespace xla