Skip to content

Commit

Permalink
[XLA] Add ragged-all-to-all support to latency hiding scheduler.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718991621
  • Loading branch information
Google-ML-Automation committed Jan 23, 2025
1 parent 427aaaa commit 290a673
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 0 deletions.
16 changes: 16 additions & 0 deletions xla/service/latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ bool AsyncTracker::IsSupportedAsyncDone(const HloInstruction& hlo) const {
}
switch (op.inner) {
case HloOpcode::kAllToAll:
case HloOpcode::kRaggedAllToAll:
case HloOpcode::kAllGather:
case HloOpcode::kAllReduce:
case HloOpcode::kCollectiveBroadcast:
Expand Down Expand Up @@ -243,6 +244,7 @@ bool AsyncTracker::IsSupportedAsyncStart(const HloInstruction& hlo) const {
}
switch (op.inner) {
case HloOpcode::kAllToAll:
case HloOpcode::kRaggedAllToAll:
case HloOpcode::kAllGather:
case HloOpcode::kAllReduce:
case HloOpcode::kCollectiveBroadcast:
Expand All @@ -268,6 +270,8 @@ ResourcesVector AsyncTracker::GetResourcesFromInstructionImpl(
return ResourceType::kAllGather;
case HloOpcode::kAllToAll:
return ResourceType::kAllToAll;
case HloOpcode::kRaggedAllToAll:
return ResourceType::kRaggedAllToAll;
case HloOpcode::kCollectiveBroadcast:
return ResourceType::kCollectiveBroadcast;
case HloOpcode::kCollectivePermute:
Expand Down Expand Up @@ -420,6 +424,8 @@ void AsyncTracker::SetConcurrentResourceLimits(
config_.copy_overlap_limit;
max_concurrent_resource[ResourceTypeToIndex(ResourceType::kAllToAll)] =
config_.all_to_all_overlap_limit;
max_concurrent_resource[ResourceTypeToIndex(ResourceType::kRaggedAllToAll)] =
config_.ragged_all_to_all_overlap_limit;
max_concurrent_resource[ResourceTypeToIndex(ResourceType::kAllGather)] =
config_.all_gather_overlap_limit;
max_concurrent_resource[ResourceTypeToIndex(ResourceType::kAllReduce)] =
Expand Down Expand Up @@ -453,6 +459,8 @@ absl::string_view AsyncTracker::GetResourceName(int64_t resource_type) const {
return "kNoResource";
case ResourceTypeToIndex(ResourceType::kAllToAll):
return "kAllToAll";
case ResourceTypeToIndex(ResourceType::kRaggedAllToAll):
return "kRaggedAllToAll";
case ResourceTypeToIndex(ResourceType::kAllGather):
return "kAllGather";
case ResourceTypeToIndex(ResourceType::kAllReduce):
Expand Down Expand Up @@ -2499,6 +2507,7 @@ LatencyHidingScheduler::LatencyHidingStatistics(
kAllReduce,
kCollectivePermute,
kAllToAll,
kRaggedAllToAll,
kReduceScatter,
kSend,
kRecv,
Expand All @@ -2516,6 +2525,8 @@ LatencyHidingScheduler::LatencyHidingStatistics(
return AsyncKind::kCollectivePermute;
case HloOpcode::kAllToAll:
return AsyncKind::kAllToAll;
case HloOpcode::kRaggedAllToAll:
return AsyncKind::kRaggedAllToAll;
case HloOpcode::kReduceScatter:
return AsyncKind::kReduceScatter;
case HloOpcode::kSend:
Expand Down Expand Up @@ -2614,6 +2625,8 @@ LatencyHidingScheduler::LatencyHidingStatistics(
wasted_time_per_collective[AsyncKind::kCollectivePermute],
/*all_to_all_wasted_cycles=*/
wasted_time_per_collective[AsyncKind::kAllToAll],
/*ragged_all_to_all_wasted_cycles=*/
wasted_time_per_collective[AsyncKind::kRaggedAllToAll],
/*reduce_scatter_wasted_cycles=*/
wasted_time_per_collective[AsyncKind::kReduceScatter],
/*send_wasted_cycles=*/wasted_time_per_collective[AsyncKind::kSend],
Expand All @@ -2640,6 +2653,7 @@ std::string LatencyHidingScheduler::SchedulerStatisticsString(
sched_stats.collective_broadcast_wasted_cycles +
sched_stats.collective_permute_wasted_cycles +
sched_stats.all_to_all_wasted_cycles +
sched_stats.ragged_all_to_all_wasted_cycles +
sched_stats.reduce_scatter_wasted_cycles +
sched_stats.send_wasted_cycles +
sched_stats.recv_wasted_cycles,
Expand All @@ -2654,6 +2668,8 @@ std::string LatencyHidingScheduler::SchedulerStatisticsString(
sched_stats.collective_permute_wasted_cycles, "\n");
absl::StrAppend(&result, "Wasted cycles for all-to-all: ",
sched_stats.all_to_all_wasted_cycles, "\n");
absl::StrAppend(&result, "Wasted cycles for ragged-all-to-all: ",
sched_stats.ragged_all_to_all_wasted_cycles, "\n");
absl::StrAppend(&result, "Wasted cycles for reduce-scatter: ",
sched_stats.reduce_scatter_wasted_cycles, "\n");
absl::StrAppend(&result,
Expand Down
3 changes: 3 additions & 0 deletions xla/service/latency_hiding_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ enum class ResourceType {
kRecvHost,
kCollectiveBroadcast,
kNumResources,
kRaggedAllToAll,
kTargetDefinedResourceTypeBegin,
};

Expand Down Expand Up @@ -126,6 +127,7 @@ struct SchedulerConfig {
int64_t collective_broadcast_overlap_limit = 1;
int64_t collective_permute_overlap_limit = 1;
int64_t all_to_all_overlap_limit = 1;
int64_t ragged_all_to_all_overlap_limit = 1;
int64_t all_gather_overlap_limit = 1;
int64_t all_reduce_overlap_limit = 1;
int64_t reduce_scatter_overlap_limit = 1;
Expand Down Expand Up @@ -1116,6 +1118,7 @@ class LatencyHidingScheduler : public HloModulePass {
double collective_broadcast_wasted_cycles = 0;
double collective_permute_wasted_cycles = 0;
double all_to_all_wasted_cycles = 0;
double ragged_all_to_all_wasted_cycles = 0;
double reduce_scatter_wasted_cycles = 0;
double send_wasted_cycles = 0;
double recv_wasted_cycles = 0;
Expand Down
56 changes: 56 additions & 0 deletions xla/service/latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3914,4 +3914,60 @@ TEST_F(LatencyHidingSchedulerTest, CrossComputationAnnotation) {
GetIndex(loop_instruction_sequence, "cpd1"));
}

TEST_F(LatencyHidingSchedulerTest, RaggedAllToAll) {
constexpr absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
async_computation {
input = f32[8,128,1024]{2,1,0:T(8,128)} parameter(0)
output = f32[8,128,1024]{2,1,0:T(8,128)} parameter(1)
input_offsets = s32[8]{0} parameter(2)
send_sizes = s32[8]{0} parameter(3)
output_offsets = s32[8]{0} parameter(4)
recv_sizes = s32[8]{0} parameter(5)
ROOT ra2a = f32[8,128,1024]{2,1,0:T(8,128)} ragged-all-to-all(input, output,input_offsets, send_sizes, output_offsets, recv_sizes), replica_groups={{0,1,2,3,4,5,6,7}}
}
ENTRY RA2A {
p0 = f32[8,128,1024]{2,1,0:T(8,128)} parameter(0)
c0 = f32[] constant(0)
output = f32[8,128,1024]{2,1,0:T(8,128)} broadcast(c0), dimensions={}
p1 = s32[8]{0} parameter(1)
p2 = s32[8]{0} parameter(2)
p3 = s32[8]{0} parameter(3)
p4 = s32[8]{0} parameter(4)
p5 = f32[1024, 1024]{1,0:T(8,128)} parameter(5)
input = f32[8,128,1024]{2,1,0:T(8,128)} copy(p0)
input_offsets = s32[8]{0} copy(p1)
send_sizes = s32[8]{0} copy(p2)
output_offsets = s32[8]{0} copy(p3)
recv_sizes = s32[8]{0} copy(p4)
ra2a-start = ((f32[8,128,1024]{2,1,0:T(8,128)}, f32[8,128,1024]{2,1,0:T(8,128)}, s32[8]{0}, s32[8]{0}, s32[8]{0}, s32[8]{0}),
f32[8,128,1024]{2,1,0:T(8,128)}, u32[]{:S(2)}, u32[]{:S(2)}) async-start(input, output, input_offsets, send_sizes, output_offsets, recv_sizes), calls=async_computation
ra2a-done = f32[8,128,1024]{2,1,0:T(8,128)} async-done(ra2a-start), calls=async_computation
d = f32[1024,1024]{1,0:T(8,128)} dot(p5, p5), lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT tuple = (f32[8,128,1024]{2,1,0:T(8,128)}, f32[1024,1024]{1,0:T(8,128)}) tuple(ra2a-done, d)
})";

TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string));
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config).ok());
EXPECT_TRUE(hlo_module->has_entry_computation());

std::vector<HloInstruction*> new_instruction_sequence =
module_schedule.sequence(hlo_module->entry_computation()).instructions();
if (VLOG_IS_ON(1)) {
for (auto* new_i : new_instruction_sequence) {
VLOG(1) << new_i->ToString();
}
}

EXPECT_LT(GetIndex(new_instruction_sequence, "ra2a-start"),
GetIndex(new_instruction_sequence, "d"));
EXPECT_LT(GetIndex(new_instruction_sequence, "d"),
GetIndex(new_instruction_sequence, "ra2a-done"));
}

} // namespace xla

0 comments on commit 290a673

Please sign in to comment.