Skip to content

Commit

Permalink
[XLA:GPU] fix collective permute cycle decomposer and make it work fo…
Browse files Browse the repository at this point in the history
…r multiple cycles

PiperOrigin-RevId: 717367512
  • Loading branch information
Google-ML-Automation committed Jan 23, 2025
1 parent e9a1a97 commit 3d6322e
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 35 deletions.
51 changes: 51 additions & 0 deletions xla/service/collective_ops_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <cstdint>
#include <optional>
#include <set>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -834,6 +835,56 @@ bool IsBackwardCycle(const SourceTargetPairs& pairs) {
return true;
}

std::pair<CycleType, std::set<int>> GetCycleTypeAndIndices(
const SourceTargetPairs& pairs) {
std::set<int> seen_replica_ids;
std::set<std::pair<int64_t, int64_t>> tentative_results;
// first figure out if we're dealing with a potential forward or backward
// cycle.
int forward_edge_counter = 0;
int backward_edge_counter = 0;
for (auto pair : pairs) {
pair.first < pair.second ? forward_edge_counter++ : backward_edge_counter++;
}
bool is_forward_cycle = forward_edge_counter > backward_edge_counter;
for (int64_t i = 0; i < pairs.size(); ++i) {
const SourceTargetPair& pair = pairs[i];
if (is_forward_cycle) {
// check if the source of the current pair is smaller than the target
if (pair.first < pair.second) {
seen_replica_ids.insert(pair.first);
} else {
// the source of the current pair is larger than the target, so the
// current pair may be part of a cycle. We keep track of the target ID
// and the index of the pair in the original pairs array.
tentative_results.insert(std::make_pair(pair.second, i));
}
} else {
// The backward cycle check uses similar logic but in reverse.
if (pair.first > pair.second) {
seen_replica_ids.insert(pair.second);
} else {
tentative_results.insert(std::make_pair(pair.first, i));
}
}
}
std::set<int> final_results;
// Iterate over the tentative results and only keep the indices that form an
// actual cycle. This is done by checking if the target replica ID of the
// pair is in the set of seen replica IDs. Note that the tentative results
// array will be fairly small in practice, so this is not adding too much to
// the runtime.
for (auto& [replica_id, index] : tentative_results) {
if (seen_replica_ids.contains(replica_id)) {
final_results.insert(index);
}
}
CycleType cycle_type = final_results.empty() ? CycleType::kUnknown
: is_forward_cycle ? CycleType::kForward
: CycleType::kBackward;
return std::make_pair(cycle_type, final_results);
}

bool IsExclusivelyCrossModule(absl::Span<const ReplicaGroup> replica_groups,
bool use_global_ids, bool has_channel_id,
const DeviceAssignment& device_assignment) {
Expand Down
9 changes: 9 additions & 0 deletions xla/service/collective_ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include <cstdint>
#include <optional>
#include <set>
#include <string>
#include <utility>
#include <vector>
Expand All @@ -43,6 +44,7 @@ limitations under the License.
namespace xla {

enum class ReductionKind { SUM, PRODUCT, MIN, MAX };
enum class CycleType { kUnknown, kForward, kBackward };

constexpr absl::string_view ReductionKindToString(
ReductionKind reduction_kind) {
Expand Down Expand Up @@ -259,6 +261,13 @@ bool IsForwardCycle(const std::vector<std::pair<int64_t, int64_t>>& pairs);
// pairs are ordered as they are generated by SPMD partitioning.
bool IsBackwardCycle(const std::vector<std::pair<int64_t, int64_t>>& pairs);

// Returns the cycle type and indices of the vertices that form cycles. This
// function uses the assumption that, in practice, in forward cycles, most edges
// will have the target replica ID greater than the source replica ID except for
// the back edges that form cycles (similar logic applies to backward cycles).
std::pair<CycleType, std::set<int>> GetCycleTypeAndIndices(
const std::vector<std::pair<int64_t, int64_t>>& pairs);

// Key that identifies a particular Rendezvous object in our global hashtable.
// This determines which calls to ExecuteOnStream communicate with each other.
// The rules are as follows.
Expand Down
22 changes: 22 additions & 0 deletions xla/service/collective_ops_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,28 @@ TEST(CollectiveOpsUtilsTest, IsBackwardCycle) {
<< "Out of order pairs are not a cycle";
}

TEST(CollectiveOpsUtilsTest, GetForwardCycleIndices) {
auto res_one_cycle = GetCycleTypeAndIndices({{0, 1}, {1, 2}, {2, 3}, {3, 0}});
EXPECT_EQ(res_one_cycle.first, CycleType::kForward);
EXPECT_THAT(res_one_cycle.second, testing::UnorderedElementsAreArray({3}));
auto res_two_cycles =
GetCycleTypeAndIndices({{0, 1}, {1, 2}, {2, 3}, {3, 0}, {4, 1}});
EXPECT_EQ(res_two_cycles.first, CycleType::kForward);
EXPECT_THAT(res_two_cycles.second,
testing::UnorderedElementsAreArray({3, 4}));
}

TEST(CollectiveOpsUtilsTest, GetBackwardCycleIndices) {
auto res_one_cycle = GetCycleTypeAndIndices({{0, 3}, {1, 0}, {2, 1}, {3, 2}});
EXPECT_EQ(res_one_cycle.first, CycleType::kBackward);
EXPECT_THAT(res_one_cycle.second, testing::UnorderedElementsAreArray({0}));
auto res_two_cycles =
GetCycleTypeAndIndices({{0, 3}, {1, 4}, {2, 1}, {3, 2}, {4, 3}, {3, 0}});
EXPECT_EQ(res_two_cycles.first, CycleType::kBackward);
EXPECT_THAT(res_two_cycles.second,
testing::UnorderedElementsAreArray({0, 1}));
}

TEST(IsExclusivelyCrossModuleTest, CrossReplicaNoChannelSet) {
int64_t num_replicas = 4;
int64_t num_partitions = 2;
Expand Down
57 changes: 27 additions & 30 deletions xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <cstdint>
#include <optional>
#include <set>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -53,36 +54,32 @@ namespace {
using SourceTargetPair = std::pair<int64_t, int64_t>;
using SourceTargetPairs = std::vector<SourceTargetPair>;

enum class CycleType { kUnknown, kForward, kBackward };

// Returns true if the CollectivePermute instruction has a cycle in its
// source-target pairs and should be decomposed.
CycleType ShouldDecomposeWithCycleType(
// Returns the cycle type and indices of the vertices that form cycles. If the
// cycle type is kUnknown, the set of indices will be empty.
std::pair<CycleType, std::set<int>> GetCycleTypeAndIndicesArray(
const HloCollectivePermuteInstruction& collective_permute,
int64_t threshold_in_bytes) {
if (collective_permute.operand_count() != 1) {
return CycleType::kUnknown;
return std::make_pair(CycleType::kUnknown, std::set<int>{});
}

// Skip the transformation if there is any context data.
const Shape& result_shape = collective_permute.shape();
if (result_shape.IsTuple()) {
return CycleType::kUnknown;
return std::make_pair(CycleType::kUnknown, std::set<int>{});
}

CHECK(result_shape.IsArray());
if (ShapeUtil::ByteSizeOf(result_shape) < threshold_in_bytes) {
return CycleType::kUnknown;
return std::make_pair(CycleType::kUnknown, std::set<int>{});
}

const SourceTargetPairs& pairs = collective_permute.source_target_pairs();
if (pairs.size() == 1) {
return CycleType::kUnknown;
return std::make_pair(CycleType::kUnknown, std::set<int>{});
}

return IsForwardCycle(pairs) ? CycleType::kForward
: IsBackwardCycle(pairs) ? CycleType::kBackward
: CycleType::kUnknown;
return GetCycleTypeAndIndices(pairs);
}

// Constructs the frontend attributes for the two decomposed CollectivePermute
Expand Down Expand Up @@ -136,29 +133,26 @@ absl::Status GetFrontendAttributes(HloCollectivePermuteInstruction* cp,
return absl::OkStatus();
}

// Decomposes a CollectivePermute instruction with a cycle in its source-target
// pairs into two CollectivePermute instructions.
// Decomposes a CollectivePermute instruction with cycles in its source-target
// pairs into cycle-free CollectivePermute instructions.
absl::Status DecomposeCollectivePermuteCycle(
HloCollectivePermuteInstruction* cp, HloComputation* computation,
HloModule* module, int64_t next_channel_id, CycleType cycle_type) {
HloModule* module, int64_t next_channel_id, CycleType cycle_type,
std::set<int> indices_to_break_out) {
const SourceTargetPairs& pairs = cp->source_target_pairs();
const OpMetadata& metadata = cp->metadata();
absl::string_view cp_name = cp->name();
int64_t num_pairs = pairs.size();
Shape shape = cp->shape();
HloInstruction* data = cp->mutable_operand(0);

// A forward cycle has its backedge at the end as in
// {{0,1},{1,2},{2,3},{3,0}} while a backward cycle has its backedge at the
// beginning as in {{0,3},{1,0},{2,1},{3,2}}.
auto backedge_start = cycle_type == CycleType::kBackward
? pairs.begin()
: pairs.begin() + num_pairs - 1;
auto other_edges_start =
cycle_type == CycleType::kBackward ? pairs.begin() + 1 : pairs.begin();
SourceTargetPairs backedge(backedge_start, backedge_start + 1);
SourceTargetPairs other_edges(other_edges_start,
other_edges_start + num_pairs - 1);
SourceTargetPairs backedge, other_edges;
for (int i = 0; i < num_pairs; ++i) {
if (indices_to_break_out.contains(i)) {
backedge.push_back(pairs[i]);
} else {
other_edges.push_back(pairs[i]);
}
}

xla::FrontendAttributes cp1_attr, cp2_attr;
TF_RETURN_IF_ERROR(GetFrontendAttributes(cp, cycle_type, cp1_attr, cp2_attr));
Expand Down Expand Up @@ -241,15 +235,18 @@ absl::StatusOr<bool> CollectivePermuteCycleDecomposer::Run(
continue;
}
auto collective_permute = Cast<HloCollectivePermuteInstruction>(hlo);
CycleType cycle_type = ShouldDecomposeWithCycleType(*collective_permute,
threshold_in_bytes_);
std::pair<CycleType, std::set<int>> cycle_type_and_indices =
GetCycleTypeAndIndicesArray(*collective_permute, threshold_in_bytes_);
CycleType cycle_type = cycle_type_and_indices.first;
std::set<int> indices_to_break_out = cycle_type_and_indices.second;
if (cycle_type != CycleType::kUnknown) {
if (changed == false) {
next_channel_id = hlo_query::NextChannelId(*module);
changed = true;
}
TF_RETURN_IF_ERROR(DecomposeCollectivePermuteCycle(
collective_permute, comp, module, next_channel_id++, cycle_type));
collective_permute, comp, module, next_channel_id++, cycle_type,
indices_to_break_out));
}
}
}
Expand Down
14 changes: 9 additions & 5 deletions xla/service/gpu/transforms/collective_permute_cycle_decomposer.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ limitations under the License.
namespace xla {

// CollectivePermuteCycleDecomposer is a pass that converts CollectivePermute
// instructions with all participants forming either a forward cycle (such as
// {{0,1},{1,2},{2,3},{3,0}) or a backward cycle (such as {{3,2},{2,1},{1,0},
// {0,3}}) into two CollectivePermute instructions. We currently restrict
// this transformation to CollectivePermute using partition mode, with one
// input, without any context data. Here is an example.
// instructions with all participants forming EITHER multiple forward cycles
// (such as {{0,1},{1,2},{2,3},{3,0}}) OR multiple backward cycles (such as
// {{3,2},{2,1},{1,0}, {0,3}}) into two CollectivePermute instructions. The pass
// leads to undefined behavior if
// 1. A communication pattern contains both forward and backward cycles, or
// 2. if the communication pattern cannot be broken into two cycle-free
// sub-patterns (i.e. after the initial pass, we still have at least one
// cycle within one or more of the sub patterns).
// Here is an example.
//
// before transformation:
// start = (<rt>, <rt>) collective-permute(data),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,38 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleNoChannel) {
)"));
}

TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleMultipleCycles) {
// For a forward cycle, this checks:
// 1. Split collectives should not have channel-id
// 2. Split collectives are combined based on replica-id.
absl::string_view hlo = R"(
HloModule test
ENTRY test_computation {
p = u32[8,8] parameter(0)
ROOT start = u32[8,8] collective-permute(p),
source_target_pairs={{0,2},{2,4},{4,6},{6,0},{1,3},{3,5},{5,7},{7,1}}
}
)";

std::unique_ptr<HloModule> module = Transform(hlo);
EXPECT_TRUE(*RunFileCheck(module->ToString(PrintOptions()), R"(
// CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] {
// CHECK-DAG: %[[replica_id:.+]] = u32[] replica-id()
// CHECK-DAG: %[[c1:.+]] = u32[] constant(1)
// CHECK-DAG: %[[compare:.+]] = pred[] compare(%[[replica_id]], %[[c1]]), direction=EQ
// CHECK-DAG: %{{.+}} = u32[8,8] parameter(0)
// CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs=
// CHECK-SAME{LITERAL}: {{6,0},{7,1}}
// CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs=
// CHECK-SAME{LITERAL}: {{0,2},{2,4},{4,6},{1,3},{3,5},{5,7}}
// CHECK-DAG: ROOT %{{.+}} = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]])
// CHECK-DAG: }
)"));
}

TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) {
absl::string_view hlo = R"(
HloModule test
Expand Down

0 comments on commit 3d6322e

Please sign in to comment.