Skip to content

Commit

Permalink
Support expanding ragged all-to-all dims similar to all-to-alls.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715831514
  • Loading branch information
Google-ML-Automation committed Jan 15, 2025
1 parent 10daba6 commit 370a76e
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 0 deletions.
17 changes: 17 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3037,6 +3037,7 @@ cc_library(
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/transforms/expanders:op_expander_pass",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
],
Expand Down Expand Up @@ -3080,6 +3081,22 @@ xla_cc_test(
],
)

xla_cc_test(
name = "all_to_all_decomposer_test",
srcs = ["all_to_all_decomposer_test.cc"],
deps = [
":all_to_all_decomposer",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/hlo/utils:hlo_matchers",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main", # fixdeps: keep
"//xla/tsl/platform:statusor",
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "tuple_simplifier",
hdrs = ["tuple_simplifier.h"],
Expand Down
73 changes: 73 additions & 0 deletions xla/service/all_to_all_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ limitations under the License.

#include "xla/service/all_to_all_decomposer.h"

#include <cstdint>
#include <optional>
#include <vector>

#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
Expand All @@ -31,6 +34,19 @@ limitations under the License.
namespace xla {
bool AllToAllDecomposer::InstructionMatchesPattern(
HloInstruction* instruction) {
if (instruction->opcode() == HloOpcode::kRaggedAllToAll) {
auto* ragged_all_to_all =
DynCast<HloRaggedAllToAllInstruction>(instruction);
if (ragged_all_to_all == nullptr) {
return false;
}
// Do not attempt to change layout constrained collectives.
if (ragged_all_to_all->constrain_layout()) {
return false;
}
return ragged_all_to_all->shape().rank() < min_array_rank_;
}

auto* all_to_all = DynCast<HloAllToAllInstruction>(instruction);
if (all_to_all == nullptr) {
return false;
Expand All @@ -47,8 +63,65 @@ bool AllToAllDecomposer::InstructionMatchesPattern(
}
return all_to_all->shape().rank() < min_array_rank_;
}

absl::StatusOr<HloInstruction*> AllToAllDecomposer::ExpandRaggedAllToAll(
HloInstruction* instruction) {
Shape input_shape = instruction->operand(0)->shape();
Shape aliased_output_shape = instruction->operand(1)->shape();
Shape output_shape = instruction->shape();
CHECK_EQ(instruction->operand_count(), 6);
CHECK_EQ(input_shape.rank(), output_shape.rank());
CHECK_EQ(output_shape, aliased_output_shape)
<< "Output shape must match shape of operand 1 shape (which is aliased "
"to output).";

Shape new_input_shape;
Shape new_output_shape;
new_input_shape.set_element_type(input_shape.element_type());
new_output_shape.set_element_type(output_shape.element_type());

// New input and output shape are the same as original shape but dimensions
// are padded with 1s until min_array_rank_.
for (int64_t i = 0; i < input_shape.rank(); ++i) {
new_input_shape.add_dimensions(input_shape.dimensions(i));
new_output_shape.add_dimensions(output_shape.dimensions(i));
}
while (new_input_shape.dimensions_size() < min_array_rank_) {
new_input_shape.add_dimensions(1);
new_output_shape.add_dimensions(1);
}
*(new_input_shape.mutable_layout()) =
LayoutUtil::GetDefaultLayoutForRank(min_array_rank_);
*(new_output_shape.mutable_layout()) =
LayoutUtil::GetDefaultLayoutForRank(min_array_rank_);

// Reshape operands
HloInstruction* operand_0_reshape =
instruction->parent()->AddInstruction(HloInstruction::CreateReshape(
new_input_shape, instruction->mutable_operand(0)));
instruction->SetupDerivedInstruction(operand_0_reshape);
HloInstruction* operand_1_reshape =
instruction->parent()->AddInstruction(HloInstruction::CreateReshape(
new_output_shape, instruction->mutable_operand(1)));
instruction->SetupDerivedInstruction(operand_1_reshape);
HloInstruction* ragged_all_to_all =
instruction->parent()->AddInstruction(instruction->CloneWithNewOperands(
new_output_shape,
{operand_0_reshape, operand_1_reshape,
instruction->mutable_operand(2), instruction->mutable_operand(3),
instruction->mutable_operand(4), instruction->mutable_operand(5)}));
HloInstruction* output_reshape = instruction->parent()->AddInstruction(
HloInstruction::CreateReshape(instruction->shape(), ragged_all_to_all));
instruction->SetupDerivedInstruction(output_reshape);
return output_reshape;
}

absl::StatusOr<HloInstruction*> AllToAllDecomposer::ExpandInstruction(
HloInstruction* instruction) {
if (instruction->opcode() == HloOpcode::kRaggedAllToAll) {
return ExpandRaggedAllToAll(instruction);
}

auto* all_to_all = Cast<HloAllToAllInstruction>(instruction);
int64_t split_dim = *all_to_all->split_dimension();
int64_t all_to_all_group_size =
Expand Down
2 changes: 2 additions & 0 deletions xla/service/all_to_all_decomposer.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class AllToAllDecomposer : public OpExpanderPass {
bool InstructionMatchesPattern(HloInstruction* instruction) override;
absl::StatusOr<HloInstruction*> ExpandInstruction(
HloInstruction* instruction) override;
absl::StatusOr<HloInstruction*> ExpandRaggedAllToAll(
HloInstruction* instruction);
bool decompose_to_tuple_;
int64_t min_array_rank_;
};
Expand Down
111 changes: 111 additions & 0 deletions xla/service/all_to_all_decomposer_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/all_to_all_decomposer.h"

#include <memory>
#include <string>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/utils/hlo_matchers.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"

namespace xla {
namespace {

using AllToAllDecomposerTest = HloTestBase;
using ::testing::_;
namespace op = xla::testing::opcode_matchers;

TEST_F(AllToAllDecomposerTest, RaggedAllToAllRank1) {
const std::string module_str =
R"(HloModule RaggedAllToAll
ENTRY AllToAll {
p0 = s32[8]{0} parameter(0)
c0 = s32[] constant(0)
output = s32[8]{0} broadcast(c0), dimensions={}
p1 = s32[4]{0} parameter(1)
p2 = s32[4]{0} parameter(2)
p3 = s32[4]{0} parameter(3)
p4 = s32[4]{0} parameter(4)
input = s32[8]{0} copy(p0)
input_offsets = s32[4]{0} copy(p1)
send_sizes = s32[4]{0} copy(p2)
output_offsets = s32[4]{0} copy(p3)
recv_sizes = s32[4]{0} copy(p4)
ra2a = s32[8]{0} ragged-all-to-all(input, output, input_offsets, send_sizes, output_offsets, recv_sizes), replica_groups={{0,1,2,3}}
ROOT copy = s32[8]{0} copy(ra2a)
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((module_str)));
AllToAllDecomposer decomposer(true, 3);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Copy(op::Reshape(op::RaggedAllToAll(
op::Reshape(op::Copy(op::Parameter(0))),
op::Reshape(op::Broadcast(op::Constant())), _, _, _, _))));
std::vector<HloInstruction*> reshapes;
std::vector<HloInstruction*> ragged_all_to_alls;
for (HloInstruction* instruction :
module->entry_computation()->instructions()) {
if (instruction->opcode() == HloOpcode::kReshape) {
reshapes.push_back(instruction);
}
if (instruction->opcode() == HloOpcode::kRaggedAllToAll) {
ragged_all_to_alls.push_back(instruction);
}
}
EXPECT_EQ(reshapes.size(), 3);
EXPECT_EQ(ragged_all_to_alls.size(), 1);
EXPECT_EQ(ragged_all_to_alls[0]->shape().rank(), 3);
}

TEST_F(AllToAllDecomposerTest, RaggedAllToAllRank3) {
const std::string module_str =
R"(HloModule RaggedAllToAll
ENTRY AllToAll {
p0 = s32[8,16,256]{2,1,0} parameter(0)
c0 = s32[] constant(0)
output = s32[8,16,256]{2,1,0} broadcast(c0), dimensions={}
p1 = s32[4]{0} parameter(1)
p2 = s32[4]{0} parameter(2)
p3 = s32[4]{0} parameter(3)
p4 = s32[4]{0} parameter(4)
input = s32[8,16,256]{2,1,0} copy(p0)
input_offsets = s32[4]{0} copy(p1)
send_sizes = s32[4]{0} copy(p2)
output_offsets = s32[4]{0} copy(p3)
recv_sizes = s32[4]{0} copy(p4)
ra2a = s32[8,16,256]{2,1,0} ragged-all-to-all(input, output, input_offsets, send_sizes, output_offsets, recv_sizes), replica_groups={{0,1,2,3}}
ROOT copy = s32[8,16,256]{2,1,0} copy(ra2a)
})";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule((module_str)));
AllToAllDecomposer decomposer(true, 3);
TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get()));
EXPECT_FALSE(changed);
}

} // namespace
} // namespace xla

0 comments on commit 370a76e

Please sign in to comment.