Skip to content

Commit

Permalink
[XLA] Fix scheduling annotations to avoid creating invalid overlap of…
Browse files Browse the repository at this point in the history
… instructions

PiperOrigin-RevId: 723565430
  • Loading branch information
Marcello Maggioni authored and Google-ML-Automation committed Feb 5, 2025
1 parent 2f16279 commit c0e905e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 2 deletions.
16 changes: 14 additions & 2 deletions xla/service/latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1567,18 +1567,29 @@ class AnnotationReadySetLt {
}
};
absl::StatusOr<HloGraphNode*> FindAndExtractBestAnnotatedNode(
DefaultSchedulerCore::ReadyQueueSet& annotation_ready) {
DefaultSchedulerCore::SchedulingState& sched_state,
DefaultSchedulerCore::OverlapLimitRule
scheduling_instruction_crosses_overlap_limit) {
using ScheduleCandidate = DefaultSchedulerCore::ScheduleCandidate;
using CandidateResult = DefaultSchedulerCore::CandidateResult;
AnnotationReadySetLt ready_lt;
// Construct a schedule candidate for caching.
ScheduleCandidate ready_chosen;
auto& annotation_ready = sched_state.annotation_ready;
auto chosen_it = annotation_ready.end();
// Try to pick nodes from the ready set first as are the ones that cause the
// most latency hiding.
for (auto ready_node_it = annotation_ready.begin(),
e = annotation_ready.end();
ready_node_it != e; ++ready_node_it) {
// If this node would cause the max_concurrent_resource count to go beyond
// the limit do not schedule it and pass to the next node.
if (scheduling_instruction_crosses_overlap_limit(sched_state,
*ready_node_it)) {
VLOG(2) << "Annotation instructions crosses overlap limit:"
<< (*ready_node_it)->GetInstr().name();
continue;
}
ScheduleCandidate ready_candidate;
ready_candidate.node = *ready_node_it;
if (ready_chosen.node == nullptr) {
Expand Down Expand Up @@ -1654,7 +1665,8 @@ absl::Status DefaultSchedulerCore::ScheduleAnnotation(
// Find the best annotated node to schedule.
TF_ASSIGN_OR_RETURN(
HloGraphNode * node,
FindAndExtractBestAnnotatedNode(sched_state->annotation_ready));
FindAndExtractBestAnnotatedNode(
*sched_state, scheduling_instruction_crosses_overlap_limit_));

TF_RET_CHECK(node != nullptr)
<< "Couldn't find an annotated node to schedule.";
Expand Down
66 changes: 66 additions & 0 deletions xla/service/latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3970,4 +3970,70 @@ TEST_F(LatencyHidingSchedulerTest, RaggedAllToAll) {
GetIndex(new_instruction_sequence, "ra2a-done"));
}

TEST_F(LatencyHidingSchedulerTest, InvalidAnnotationOverlap) {
constexpr absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
while_cond {
param = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) parameter(0)
ROOT gte = pred[] get-tuple-element(param), index=2
}
while_body {
param = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) parameter(0)
gte0 = f32[16,64,256]{2,1,0} get-tuple-element(param), index=0
gte1 = f32[16,64,256]{2,1,0} get-tuple-element(param), index=1
gte2 = pred[] get-tuple-element(param), index=2
cps1 = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, u32[], u32[]) collective-permute-start(gte1), source_target_pairs={{0,1},{1,2},{2,3},{3,0}}, frontend_attributes={_scheduling_group_id="1"}
cpd1 = f32[16,64,256]{2,1,0} collective-permute-done(cps1), frontend_attributes={_scheduling_group_id="1"}
c1 = f32[16,256,256]{2,1,0} convolution(gte0, gte0), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="1"}
slice = f32[16,64,256]{2,1,0} slice(c1), slice={[0:16], [0:64], [0:256]}
add = f32[16,64,256]{2,1,0} add(gte0, slice)
ROOT tuple = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) tuple(add, cpd1, gte2)
}
ENTRY entry {
p0 = f32[256,1024]{1,0} parameter(0)
p1 = f32[16,64,256]{2,1,0} parameter(1)
p2 = f32[16,64,256]{2,1,0} parameter(2)
p3 = pred[] parameter(3)
c0 = f32[16,256,256]{2,1,0} convolution(p1, p2), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="1"}
ags0 = (f32[256,1024]{1,0}, f32[1024,1024]{1,0}) all-gather-start(p0), replica_groups={{0,1,2,3}}, dimensions={0}, frontend_attributes={_scheduling_group_id="1"}
ags1 = (f32[256,1024]{1,0}, f32[1024,1024]{1,0}) all-gather-start(p0), replica_groups={{0,1,2,3}}, dimensions={0}, frontend_attributes={_scheduling_group_id="1"}
tuple = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) tuple(p1, p2, p3)
while = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) while(tuple), condition=while_cond, body=while_body
agd1 = f32[1024,1024]{1,0} all-gather-done(ags1), frontend_attributes={_scheduling_group_id="1"}
agd0 = f32[1024,1024]{1,0} all-gather-done(ags0), frontend_attributes={_scheduling_group_id="1"}
gte = f32[16,64,256]{2,1,0} get-tuple-element(while), index=0
ROOT tuple1 = (f32[16,64,256]{2,1,0}, f32[16,256,256]{2,1,0}, f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) tuple(gte, c0, agd0, agd1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string));
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
sched_config.all_gather_overlap_limit = 1;
EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config,
std::make_unique<TestLatencyEstimator>())
.ok());
EXPECT_TRUE(hlo_module->has_entry_computation());

std::vector<HloInstruction*> new_instruction_sequence =
module_schedule.sequence(hlo_module->entry_computation()).instructions();
if (VLOG_IS_ON(1)) {
for (auto* new_i : new_instruction_sequence) {
VLOG(1) << new_i->ToString();
}
}

EXPECT_LT(GetIndex(new_instruction_sequence, "ags0"),
GetIndex(new_instruction_sequence, "c0"));
EXPECT_LT(GetIndex(new_instruction_sequence, "c0"),
GetIndex(new_instruction_sequence, "agd0"));
EXPECT_TRUE((GetIndex(new_instruction_sequence, "ags0") >
GetIndex(new_instruction_sequence, "agd1")) ||
(GetIndex(new_instruction_sequence, "ags1") >
GetIndex(new_instruction_sequence, "agd0")));
}

} // namespace xla

0 comments on commit c0e905e

Please sign in to comment.