Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA] Add ragged-all-to-all support to latency hiding scheduler. #21681

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions xla/service/latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ bool AsyncTracker::IsSupportedAsyncDone(const HloInstruction& hlo) const {
}
switch (op.inner) {
case HloOpcode::kAllToAll:
case HloOpcode::kRaggedAllToAll:
case HloOpcode::kAllGather:
case HloOpcode::kAllReduce:
case HloOpcode::kCollectiveBroadcast:
Expand Down Expand Up @@ -243,6 +244,7 @@ bool AsyncTracker::IsSupportedAsyncStart(const HloInstruction& hlo) const {
}
switch (op.inner) {
case HloOpcode::kAllToAll:
case HloOpcode::kRaggedAllToAll:
case HloOpcode::kAllGather:
case HloOpcode::kAllReduce:
case HloOpcode::kCollectiveBroadcast:
Expand All @@ -268,6 +270,8 @@ ResourcesVector AsyncTracker::GetResourcesFromInstructionImpl(
return ResourceType::kAllGather;
case HloOpcode::kAllToAll:
return ResourceType::kAllToAll;
case HloOpcode::kRaggedAllToAll:
return ResourceType::kRaggedAllToAll;
case HloOpcode::kCollectiveBroadcast:
return ResourceType::kCollectiveBroadcast;
case HloOpcode::kCollectivePermute:
Expand Down Expand Up @@ -420,6 +424,8 @@ void AsyncTracker::SetConcurrentResourceLimits(
config_.copy_overlap_limit;
max_concurrent_resource[ResourceTypeToIndex(ResourceType::kAllToAll)] =
config_.all_to_all_overlap_limit;
max_concurrent_resource[ResourceTypeToIndex(ResourceType::kRaggedAllToAll)] =
config_.ragged_all_to_all_overlap_limit;
max_concurrent_resource[ResourceTypeToIndex(ResourceType::kAllGather)] =
config_.all_gather_overlap_limit;
max_concurrent_resource[ResourceTypeToIndex(ResourceType::kAllReduce)] =
Expand Down Expand Up @@ -453,6 +459,8 @@ absl::string_view AsyncTracker::GetResourceName(int64_t resource_type) const {
return "kNoResource";
case ResourceTypeToIndex(ResourceType::kAllToAll):
return "kAllToAll";
case ResourceTypeToIndex(ResourceType::kRaggedAllToAll):
return "kRaggedAllToAll";
case ResourceTypeToIndex(ResourceType::kAllGather):
return "kAllGather";
case ResourceTypeToIndex(ResourceType::kAllReduce):
Expand Down Expand Up @@ -2499,6 +2507,7 @@ LatencyHidingScheduler::LatencyHidingStatistics(
kAllReduce,
kCollectivePermute,
kAllToAll,
kRaggedAllToAll,
kReduceScatter,
kSend,
kRecv,
Expand All @@ -2516,6 +2525,8 @@ LatencyHidingScheduler::LatencyHidingStatistics(
return AsyncKind::kCollectivePermute;
case HloOpcode::kAllToAll:
return AsyncKind::kAllToAll;
case HloOpcode::kRaggedAllToAll:
return AsyncKind::kRaggedAllToAll;
case HloOpcode::kReduceScatter:
return AsyncKind::kReduceScatter;
case HloOpcode::kSend:
Expand Down Expand Up @@ -2614,6 +2625,8 @@ LatencyHidingScheduler::LatencyHidingStatistics(
wasted_time_per_collective[AsyncKind::kCollectivePermute],
/*all_to_all_wasted_cycles=*/
wasted_time_per_collective[AsyncKind::kAllToAll],
/*ragged_all_to_all_wasted_cycles=*/
wasted_time_per_collective[AsyncKind::kRaggedAllToAll],
/*reduce_scatter_wasted_cycles=*/
wasted_time_per_collective[AsyncKind::kReduceScatter],
/*send_wasted_cycles=*/wasted_time_per_collective[AsyncKind::kSend],
Expand All @@ -2640,6 +2653,7 @@ std::string LatencyHidingScheduler::SchedulerStatisticsString(
sched_stats.collective_broadcast_wasted_cycles +
sched_stats.collective_permute_wasted_cycles +
sched_stats.all_to_all_wasted_cycles +
sched_stats.ragged_all_to_all_wasted_cycles +
sched_stats.reduce_scatter_wasted_cycles +
sched_stats.send_wasted_cycles +
sched_stats.recv_wasted_cycles,
Expand All @@ -2654,6 +2668,8 @@ std::string LatencyHidingScheduler::SchedulerStatisticsString(
sched_stats.collective_permute_wasted_cycles, "\n");
absl::StrAppend(&result, "Wasted cycles for all-to-all: ",
sched_stats.all_to_all_wasted_cycles, "\n");
absl::StrAppend(&result, "Wasted cycles for ragged-all-to-all: ",
sched_stats.ragged_all_to_all_wasted_cycles, "\n");
absl::StrAppend(&result, "Wasted cycles for reduce-scatter: ",
sched_stats.reduce_scatter_wasted_cycles, "\n");
absl::StrAppend(&result,
Expand Down
3 changes: 3 additions & 0 deletions xla/service/latency_hiding_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ enum class ResourceType {
kRecvHost,
kCollectiveBroadcast,
kNumResources,
kRaggedAllToAll,
kTargetDefinedResourceTypeBegin,
};

Expand Down Expand Up @@ -126,6 +127,7 @@ struct SchedulerConfig {
int64_t collective_broadcast_overlap_limit = 1;
int64_t collective_permute_overlap_limit = 1;
int64_t all_to_all_overlap_limit = 1;
int64_t ragged_all_to_all_overlap_limit = 1;
int64_t all_gather_overlap_limit = 1;
int64_t all_reduce_overlap_limit = 1;
int64_t reduce_scatter_overlap_limit = 1;
Expand Down Expand Up @@ -1116,6 +1118,7 @@ class LatencyHidingScheduler : public HloModulePass {
double collective_broadcast_wasted_cycles = 0;
double collective_permute_wasted_cycles = 0;
double all_to_all_wasted_cycles = 0;
double ragged_all_to_all_wasted_cycles = 0;
double reduce_scatter_wasted_cycles = 0;
double send_wasted_cycles = 0;
double recv_wasted_cycles = 0;
Expand Down
56 changes: 56 additions & 0 deletions xla/service/latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3914,4 +3914,60 @@ TEST_F(LatencyHidingSchedulerTest, CrossComputationAnnotation) {
GetIndex(loop_instruction_sequence, "cpd1"));
}

TEST_F(LatencyHidingSchedulerTest, RaggedAllToAll) {
constexpr absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
async_computation {
input = f32[8,128,1024]{2,1,0:T(8,128)} parameter(0)
output = f32[8,128,1024]{2,1,0:T(8,128)} parameter(1)
input_offsets = s32[8]{0} parameter(2)
send_sizes = s32[8]{0} parameter(3)
output_offsets = s32[8]{0} parameter(4)
recv_sizes = s32[8]{0} parameter(5)
ROOT ra2a = f32[8,128,1024]{2,1,0:T(8,128)} ragged-all-to-all(input, output,input_offsets, send_sizes, output_offsets, recv_sizes), replica_groups={{0,1,2,3,4,5,6,7}}
}
ENTRY RA2A {
p0 = f32[8,128,1024]{2,1,0:T(8,128)} parameter(0)
c0 = f32[] constant(0)
output = f32[8,128,1024]{2,1,0:T(8,128)} broadcast(c0), dimensions={}
p1 = s32[8]{0} parameter(1)
p2 = s32[8]{0} parameter(2)
p3 = s32[8]{0} parameter(3)
p4 = s32[8]{0} parameter(4)
p5 = f32[1024, 1024]{1,0:T(8,128)} parameter(5)
input = f32[8,128,1024]{2,1,0:T(8,128)} copy(p0)
input_offsets = s32[8]{0} copy(p1)
send_sizes = s32[8]{0} copy(p2)
output_offsets = s32[8]{0} copy(p3)
recv_sizes = s32[8]{0} copy(p4)
ra2a-start = ((f32[8,128,1024]{2,1,0:T(8,128)}, f32[8,128,1024]{2,1,0:T(8,128)}, s32[8]{0}, s32[8]{0}, s32[8]{0}, s32[8]{0}),
f32[8,128,1024]{2,1,0:T(8,128)}, u32[]{:S(2)}, u32[]{:S(2)}) async-start(input, output, input_offsets, send_sizes, output_offsets, recv_sizes), calls=async_computation
ra2a-done = f32[8,128,1024]{2,1,0:T(8,128)} async-done(ra2a-start), calls=async_computation
d = f32[1024,1024]{1,0:T(8,128)} dot(p5, p5), lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT tuple = (f32[8,128,1024]{2,1,0:T(8,128)}, f32[1024,1024]{1,0:T(8,128)}) tuple(ra2a-done, d)
})";

TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string));
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config).ok());
EXPECT_TRUE(hlo_module->has_entry_computation());

std::vector<HloInstruction*> new_instruction_sequence =
module_schedule.sequence(hlo_module->entry_computation()).instructions();
if (VLOG_IS_ON(1)) {
for (auto* new_i : new_instruction_sequence) {
VLOG(1) << new_i->ToString();
}
}

EXPECT_LT(GetIndex(new_instruction_sequence, "ra2a-start"),
GetIndex(new_instruction_sequence, "d"));
EXPECT_LT(GetIndex(new_instruction_sequence, "d"),
GetIndex(new_instruction_sequence, "ra2a-done"));
}

} // namespace xla
Loading