diff --git a/xla/service/collective_ops_utils.cc b/xla/service/collective_ops_utils.cc index db05bb80b0a392..2099f8003eef91 100644 --- a/xla/service/collective_ops_utils.cc +++ b/xla/service/collective_ops_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -834,6 +835,56 @@ bool IsBackwardCycle(const SourceTargetPairs& pairs) { return true; } +std::pair> GetCycleTypeAndIndices( + const SourceTargetPairs& pairs) { + std::set seen_replica_ids; + std::set> tentative_results; + // first figure out if we're dealing with a potential forward or backward + // cycle. + int forward_edge_counter = 0; + int backward_edge_counter = 0; + for (auto pair : pairs) { + pair.first < pair.second ? forward_edge_counter++ : backward_edge_counter++; + } + bool is_forward_cycle = forward_edge_counter > backward_edge_counter; + for (int64_t i = 0; i < pairs.size(); ++i) { + const SourceTargetPair& pair = pairs[i]; + if (is_forward_cycle) { + // check if the source of the current pair is smaller than the target + if (pair.first < pair.second) { + seen_replica_ids.insert(pair.first); + } else { + // the source of the current pair is larger than the target, so the + // current pair may be part of a cycle. We keep track of the target ID + // and the index of the pair in the original pairs array. + tentative_results.insert(std::make_pair(pair.second, i)); + } + } else { + // The backward cycle check uses similar logic but in reverse. + if (pair.first > pair.second) { + seen_replica_ids.insert(pair.second); + } else { + tentative_results.insert(std::make_pair(pair.first, i)); + } + } + } + std::set final_results; + // Iterate over the tentative results and only keep the indices that form an + // actual cycle. This is done by checking if the target replica ID of the + // pair is in the set of seen replica IDs. Note that the tentative results + // array will be fairly small in practice, so this is not adding too much to + // the runtime. + for (auto& [replica_id, index] : tentative_results) { + if (seen_replica_ids.contains(replica_id)) { + final_results.insert(index); + } + } + CycleType cycle_type = final_results.empty() ? CycleType::kUnknown + : is_forward_cycle ? CycleType::kForward + : CycleType::kBackward; + return std::make_pair(cycle_type, final_results); +} + bool IsExclusivelyCrossModule(absl::Span replica_groups, bool use_global_ids, bool has_channel_id, const DeviceAssignment& device_assignment) { diff --git a/xla/service/collective_ops_utils.h b/xla/service/collective_ops_utils.h index be41797c2c6c4c..1b9202891c1129 100644 --- a/xla/service/collective_ops_utils.h +++ b/xla/service/collective_ops_utils.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -43,6 +44,7 @@ limitations under the License. namespace xla { enum class ReductionKind { SUM, PRODUCT, MIN, MAX }; +enum class CycleType { kUnknown, kForward, kBackward }; constexpr absl::string_view ReductionKindToString( ReductionKind reduction_kind) { @@ -259,6 +261,13 @@ bool IsForwardCycle(const std::vector>& pairs); // pairs are ordered as they are generated by SPMD partitioning. bool IsBackwardCycle(const std::vector>& pairs); +// Returns the cycle type and indices of the vertices that form cycles. This +// function uses the assumption that, in practice, in forward cycles, most edges +// will have the target replica ID greater than the source replica ID except for +// the back edges that form cycles (similar logic applies to backward cycles). +std::pair> GetCycleTypeAndIndices( + const std::vector>& pairs); + // Key that identifies a particular Rendezvous object in our global hashtable. // This determines which calls to ExecuteOnStream communicate with each other. // The rules are as follows. diff --git a/xla/service/collective_ops_utils_test.cc b/xla/service/collective_ops_utils_test.cc index c2ed1e0bea0a27..84b9b513ab0460 100644 --- a/xla/service/collective_ops_utils_test.cc +++ b/xla/service/collective_ops_utils_test.cc @@ -180,6 +180,28 @@ TEST(CollectiveOpsUtilsTest, IsBackwardCycle) { << "Out of order pairs are not a cycle"; } +TEST(CollectiveOpsUtilsTest, GetForwardCycleIndices) { + auto res_one_cycle = GetCycleTypeAndIndices({{0, 1}, {1, 2}, {2, 3}, {3, 0}}); + EXPECT_EQ(res_one_cycle.first, CycleType::kForward); + EXPECT_THAT(res_one_cycle.second, testing::UnorderedElementsAreArray({3})); + auto res_two_cycles = + GetCycleTypeAndIndices({{0, 1}, {1, 2}, {2, 3}, {3, 0}, {4, 1}}); + EXPECT_EQ(res_two_cycles.first, CycleType::kForward); + EXPECT_THAT(res_two_cycles.second, + testing::UnorderedElementsAreArray({3, 4})); +} + +TEST(CollectiveOpsUtilsTest, GetBackwardCycleIndices) { + auto res_one_cycle = GetCycleTypeAndIndices({{0, 3}, {1, 0}, {2, 1}, {3, 2}}); + EXPECT_EQ(res_one_cycle.first, CycleType::kBackward); + EXPECT_THAT(res_one_cycle.second, testing::UnorderedElementsAreArray({0})); + auto res_two_cycles = + GetCycleTypeAndIndices({{0, 3}, {1, 4}, {2, 1}, {3, 2}, {4, 3}, {3, 0}}); + EXPECT_EQ(res_two_cycles.first, CycleType::kBackward); + EXPECT_THAT(res_two_cycles.second, + testing::UnorderedElementsAreArray({0, 1})); +} + TEST(IsExclusivelyCrossModuleTest, CrossReplicaNoChannelSet) { int64_t num_replicas = 4; int64_t num_partitions = 2; diff --git a/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc b/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc index e22b1487e8b6ef..502c91dd25f61c 100644 --- a/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc +++ b/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -53,36 +54,32 @@ namespace { using SourceTargetPair = std::pair; using SourceTargetPairs = std::vector; -enum class CycleType { kUnknown, kForward, kBackward }; - -// Returns true if the CollectivePermute instruction has a cycle in its -// source-target pairs and should be decomposed. -CycleType ShouldDecomposeWithCycleType( +// Returns the cycle type and indices of the vertices that form cycles. If the +// cycle type is kUnknown, the set of indices will be empty. +std::pair> GetCycleTypeAndIndicesArray( const HloCollectivePermuteInstruction& collective_permute, int64_t threshold_in_bytes) { if (collective_permute.operand_count() != 1) { - return CycleType::kUnknown; + return std::make_pair(CycleType::kUnknown, std::set{}); } // Skip the transformation if there is any context data. const Shape& result_shape = collective_permute.shape(); if (result_shape.IsTuple()) { - return CycleType::kUnknown; + return std::make_pair(CycleType::kUnknown, std::set{}); } CHECK(result_shape.IsArray()); if (ShapeUtil::ByteSizeOf(result_shape) < threshold_in_bytes) { - return CycleType::kUnknown; + return std::make_pair(CycleType::kUnknown, std::set{}); } const SourceTargetPairs& pairs = collective_permute.source_target_pairs(); if (pairs.size() == 1) { - return CycleType::kUnknown; + return std::make_pair(CycleType::kUnknown, std::set{}); } - return IsForwardCycle(pairs) ? CycleType::kForward - : IsBackwardCycle(pairs) ? CycleType::kBackward - : CycleType::kUnknown; + return GetCycleTypeAndIndices(pairs); } // Constructs the frontend attributes for the two decomposed CollectivePermute @@ -136,29 +133,26 @@ absl::Status GetFrontendAttributes(HloCollectivePermuteInstruction* cp, return absl::OkStatus(); } -// Decomposes a CollectivePermute instruction with a cycle in its source-target -// pairs into two CollectivePermute instructions. +// Decomposes a CollectivePermute instruction with cycles in its source-target +// pairs into cycle-free CollectivePermute instructions. absl::Status DecomposeCollectivePermuteCycle( HloCollectivePermuteInstruction* cp, HloComputation* computation, - HloModule* module, int64_t next_channel_id, CycleType cycle_type) { + HloModule* module, int64_t next_channel_id, CycleType cycle_type, + std::set indices_to_break_out) { const SourceTargetPairs& pairs = cp->source_target_pairs(); const OpMetadata& metadata = cp->metadata(); absl::string_view cp_name = cp->name(); int64_t num_pairs = pairs.size(); Shape shape = cp->shape(); HloInstruction* data = cp->mutable_operand(0); - - // A forward cycle has its backedge at the end as in - // {{0,1},{1,2},{2,3},{3,0}} while a backward cycle has its backedge at the - // beginning as in {{0,3},{1,0},{2,1},{3,2}}. - auto backedge_start = cycle_type == CycleType::kBackward - ? pairs.begin() - : pairs.begin() + num_pairs - 1; - auto other_edges_start = - cycle_type == CycleType::kBackward ? pairs.begin() + 1 : pairs.begin(); - SourceTargetPairs backedge(backedge_start, backedge_start + 1); - SourceTargetPairs other_edges(other_edges_start, - other_edges_start + num_pairs - 1); + SourceTargetPairs backedge, other_edges; + for (int i = 0; i < num_pairs; ++i) { + if (indices_to_break_out.contains(i)) { + backedge.push_back(pairs[i]); + } else { + other_edges.push_back(pairs[i]); + } + } xla::FrontendAttributes cp1_attr, cp2_attr; TF_RETURN_IF_ERROR(GetFrontendAttributes(cp, cycle_type, cp1_attr, cp2_attr)); @@ -241,15 +235,18 @@ absl::StatusOr CollectivePermuteCycleDecomposer::Run( continue; } auto collective_permute = Cast(hlo); - CycleType cycle_type = ShouldDecomposeWithCycleType(*collective_permute, - threshold_in_bytes_); + std::pair> cycle_type_and_indices = + GetCycleTypeAndIndicesArray(*collective_permute, threshold_in_bytes_); + CycleType cycle_type = cycle_type_and_indices.first; + std::set indices_to_break_out = cycle_type_and_indices.second; if (cycle_type != CycleType::kUnknown) { if (changed == false) { next_channel_id = hlo_query::NextChannelId(*module); changed = true; } TF_RETURN_IF_ERROR(DecomposeCollectivePermuteCycle( - collective_permute, comp, module, next_channel_id++, cycle_type)); + collective_permute, comp, module, next_channel_id++, cycle_type, + indices_to_break_out)); } } } diff --git a/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h b/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h index 7663d878745b2c..2cc00d550f46e9 100644 --- a/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h +++ b/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h @@ -27,11 +27,15 @@ limitations under the License. namespace xla { // CollectivePermuteCycleDecomposer is a pass that converts CollectivePermute -// instructions with all participants forming either a forward cycle (such as -// {{0,1},{1,2},{2,3},{3,0}) or a backward cycle (such as {{3,2},{2,1},{1,0}, -// {0,3}}) into two CollectivePermute instructions. We currently restrict -// this transformation to CollectivePermute using partition mode, with one -// input, without any context data. Here is an example. +// instructions with all participants forming EITHER multiple forward cycles +// (such as {{0,1},{1,2},{2,3},{3,0}}) OR multiple backward cycles (such as +// {{3,2},{2,1},{1,0}, {0,3}}) into two CollectivePermute instructions. The pass +// leads to undefined behavior if +// 1. A communication pattern contains both forward and backward cycles, or +// 2. if the communication pattern cannot be broken into two cycle-free +// sub-patterns (i.e. after the initial pass, we still have at least one +// cycle within one or more of the sub patterns). +// Here is an example. // // before transformation: // start = (, ) collective-permute(data), diff --git a/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc b/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc index 5803a34a7ae752..ae6b54dafcff90 100644 --- a/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc +++ b/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc @@ -178,6 +178,38 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleNoChannel) { )")); } +TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleMultipleCycles) { + // For a forward cycle, this checks: + // 1. Split collectives should not have channel-id + // 2. Split collectives are combined based on replica-id. + absl::string_view hlo = R"( + HloModule test + ENTRY test_computation { + p = u32[8,8] parameter(0) + ROOT start = u32[8,8] collective-permute(p), + source_target_pairs={{0,2},{2,4},{4,6},{6,0},{1,3},{3,5},{5,7},{7,1}} + } + )"; + + std::unique_ptr module = Transform(hlo); + EXPECT_TRUE(*RunFileCheck(module->ToString(PrintOptions()), R"( + // CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] { + // CHECK-DAG: %[[replica_id:.+]] = u32[] replica-id() + // CHECK-DAG: %[[c1:.+]] = u32[] constant(1) + // CHECK-DAG: %[[compare:.+]] = pred[] compare(%[[replica_id]], %[[c1]]), direction=EQ + // CHECK-DAG: %{{.+}} = u32[8,8] parameter(0) + + // CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs= + // CHECK-SAME{LITERAL}: {{6,0},{7,1}} + + // CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs= + // CHECK-SAME{LITERAL}: {{0,2},{2,4},{4,6},{1,3},{3,5},{5,7}} + + // CHECK-DAG: ROOT %{{.+}} = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]]) + // CHECK-DAG: } + )")); +} + TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) { absl::string_view hlo = R"( HloModule test