diff --git a/xla/hlo/testlib/BUILD b/xla/hlo/testlib/BUILD index 85696f61e255d7..1351a8d78c665f 100644 --- a/xla/hlo/testlib/BUILD +++ b/xla/hlo/testlib/BUILD @@ -67,9 +67,11 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", "//xla/service:computation_layout", + "//xla/service:computation_placer_hdr", "//xla/service:hlo_module_config", "//xla/service:hlo_verifier", "@com_google_absl//absl/algorithm:container", diff --git a/xla/hlo/testlib/hlo_hardware_independent_test_base.cc b/xla/hlo/testlib/hlo_hardware_independent_test_base.cc index a93733be406b6c..3d5379d1876ec8 100644 --- a/xla/hlo/testlib/hlo_hardware_independent_test_base.cc +++ b/xla/hlo/testlib/hlo_hardware_independent_test_base.cc @@ -35,12 +35,15 @@ limitations under the License. #include "absl/types/span.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/testlib/filecheck.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/service/computation_placer.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" #include "xla/shape.h" @@ -77,12 +80,23 @@ HloHardwareIndependentTestBase::CreateNewVerifiedModule( instruction_can_change_layout_func_); } +DeviceAssignment HloHardwareIndependentTestBase::GetDefaultDeviceAssignment( + int64_t replica_count, int64_t num_partitions) const { + DeviceAssignment device_assignment(replica_count, num_partitions); + device_assignment.FillIota(0); + return device_assignment; +} + absl::StatusOr> HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( - absl::string_view hlo_text, int64_t replica_count, - int64_t num_partitions) const { - return ParseAndReturnVerifiedModule( - hlo_text, GetModuleConfigForTest(replica_count, num_partitions)); + absl::string_view hlo_text, int64_t replica_count, int64_t num_partitions, + std::optional device_assignment) const { + HloModuleConfig config = + GetModuleConfigForTest(replica_count, num_partitions); + if (device_assignment.has_value()) { + config.set_static_device_assignment(device_assignment.value()); + } + return ParseAndReturnVerifiedModule(hlo_text, config); } absl::Status HloHardwareIndependentTestBase:: @@ -110,12 +124,31 @@ absl::Status HloHardwareIndependentTestBase:: absl::StatusOr> HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( - absl::string_view hlo_text, const HloModuleConfig& config) const { + absl::string_view hlo_text, const HloModuleConfig& config, + const HloParserOptions& parser_options) const { + return ParseAndReturnVerifiedModule(hlo_text, config, parser_options, + ShapeUtil::ByteSizeOfElements); +} + +absl::StatusOr> +HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config, + const HloParserOptions& parser_options, + std::function shape_size_fn) const { + HloModuleConfig config_with_device_assignment = config; + if (!config.has_static_device_assignment()) { + default_device_assignment_ = + std::make_unique(GetDefaultDeviceAssignment( + config.replica_count(), config.num_partitions())); + config_with_device_assignment.set_static_device_assignment( + *default_device_assignment_); + } auto module = std::make_unique( - TestName(), config, verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_, ShapeUtil::ByteSizeOfElements, + TestName(), config_with_device_assignment, verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_, shape_size_fn, instruction_can_change_layout_func_); - TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); + TF_RETURN_IF_ERROR( + module->ParseHloStringAndVerifyModule(hlo_text, parser_options)); return module; } diff --git a/xla/hlo/testlib/hlo_hardware_independent_test_base.h b/xla/hlo/testlib/hlo_hardware_independent_test_base.h index 4de7ccc1db25bf..33037552555ed4 100644 --- a/xla/hlo/testlib/hlo_hardware_independent_test_base.h +++ b/xla/hlo/testlib/hlo_hardware_independent_test_base.h @@ -35,10 +35,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/layout.h" #include "xla/service/computation_layout.h" +#include "xla/service/computation_placer.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" #include "xla/shape_layout.h" @@ -95,14 +97,24 @@ class HloHardwareIndependentTestBase : public ::testing::Test { std::unique_ptr CreateNewVerifiedModule( const std::string& name = TestName(), int64_t replica_count = 1) const; + // + DeviceAssignment GetDefaultDeviceAssignment(int64_t replica_count, + int64_t num_partitions) const; // Parses the given string and returns module as a VerifiedHloModule. + virtual absl::StatusOr> + ParseAndReturnVerifiedModule( + absl::string_view hlo_text, int64_t replica_count = 1, + int64_t num_partitions = 1, + std::optional device_assignment = std::nullopt) const; + virtual absl::StatusOr> + ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config, + const HloParserOptions& parser_options = HloParserOptions()) const; absl::StatusOr> - ParseAndReturnVerifiedModule(absl::string_view hlo_text, - int64_t replica_count = 1, - int64_t num_partitions = 1) const; - absl::StatusOr> - ParseAndReturnVerifiedModule(absl::string_view hlo_text, - const HloModuleConfig& config) const; + ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config, + const HloParserOptions& parser_options, + std::function shape_size_fn) const; // Runs the hlo_pass with the provided module and returns the result. This // function also verifies that the module remains unchanged when hlo_pass @@ -181,13 +193,22 @@ class HloHardwareIndependentTestBase : public ::testing::Test { // options (e.g. disabling additional passes). virtual DebugOptions GetDebugOptionsForTest() const; + void TearDown() override { default_device_assignment_.reset(); } // Gets an HloModuleConfig with options appropriate for tests. - HloModuleConfig GetModuleConfigForTest(int64_t replica_count = 1, - int64_t num_partitions = 1) const { + HloModuleConfig GetModuleConfigForTest( + int64_t replica_count = 1, int64_t num_partitions = 1, + std::optional device_assignment = std::nullopt) const { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); config.set_replica_count(replica_count); config.set_num_partitions(num_partitions); + if (device_assignment.has_value()) { + config.set_static_device_assignment(std::move(device_assignment.value())); + } else { + default_device_assignment_ = std::make_unique( + GetDefaultDeviceAssignment(replica_count, num_partitions)); + config.set_static_device_assignment(*default_device_assignment_); + } return config; } @@ -269,6 +290,7 @@ class HloHardwareIndependentTestBase : public ::testing::Test { bool allow_mixed_precision_in_hlo_verifier_; HloPredicate instruction_can_change_layout_func_; std::unique_ptr hlo_verifier_; + mutable std::unique_ptr default_device_assignment_; }; } // namespace xla diff --git a/xla/hlo/transforms/collectives/BUILD b/xla/hlo/transforms/collectives/BUILD index ce143b08920694..0ec56ffa9f78a2 100644 --- a/xla/hlo/transforms/collectives/BUILD +++ b/xla/hlo/transforms/collectives/BUILD @@ -317,6 +317,44 @@ xla_cc_test( ], ) +cc_library( + name = "all_reduce_normalizer", + srcs = ["all_reduce_normalizer.cc"], + hdrs = ["all_reduce_normalizer.h"], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/hlo/utils:hlo_sharding_util", + "//xla/service:collective_ops_utils", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "all_reduce_normalizer_test", + srcs = ["all_reduce_normalizer_test.cc"], + deps = [ + ":all_reduce_normalizer", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "all_reduce_contiguous", srcs = ["all_reduce_contiguous.cc"], diff --git a/xla/hlo/transforms/collectives/all_reduce_normalizer.cc b/xla/hlo/transforms/collectives/all_reduce_normalizer.cc new file mode 100644 index 00000000000000..dca1845f60c21a --- /dev/null +++ b/xla/hlo/transforms/collectives/all_reduce_normalizer.cc @@ -0,0 +1,228 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/transforms/collectives/all_reduce_normalizer.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/dfs_hlo_visitor.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/hlo/utils/hlo_sharding_util.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/shape.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +namespace { + +// Split a tupled all-reduce into individual all-reduces. +std::vector SplitAllReduces(HloAllReduceInstruction* ar) { + std::vector separate_all_reduces; + separate_all_reduces.reserve(ar->operand_count()); + for (int64_t i = 0; i < ar->operand_count(); ++i) { + HloInstruction* separate_all_reduce = + ar->parent()->AddInstruction(HloInstruction::CreateAllReduce( + ar->shape().tuple_shapes(i), {ar->mutable_operand(i)}, + ar->to_apply(), ar->device_list(), ar->constrain_layout(), + hlo_query::NextChannelId(*ar->GetModule()), + ar->use_global_device_ids())); + separate_all_reduces.push_back(separate_all_reduce); + } + return separate_all_reduces; +} + +int64_t FindLeadingNonOneDimension(const Shape& shape) { + int64_t i; + for (i = shape.rank() - 1; i > 0; --i) { + if (shape.dimensions_minor(i) > 1) { + break; + } + } + return shape.layout().minor_to_major(i); +} + +// Convert a single-operand all-reduce to all-to-all + reduce + all-gather. +absl::StatusOr NormalizeSingleOperandAllReduce( + HloInstruction* hlo, + std::function is_supported_all_reduce) { + if (is_supported_all_reduce(hlo)) { + return false; + } + HloComputation* computation = hlo->parent(); + HloAllReduceInstruction* ar = Cast(hlo); + TF_ASSIGN_OR_RETURN(auto replica_group_count_and_size, + GetReplicaGroupCountAndSize(ar)); + if (!replica_group_count_and_size.has_value()) { + return absl::InvalidArgumentError("Unsupported all-reduce with : " + + ar->ToString()); + } + const int64_t replica_group_size = replica_group_count_and_size->second; + Shape ar_shape = ar->shape(); + if (!ar_shape.has_layout()) { + *ar_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(ar_shape); + } + const int64_t leading_non_one_dim = FindLeadingNonOneDimension(ar_shape); + // 1. Pad the leading non-one dimension to the nearest multiple of replica + // group size, so that it's divisible, and reshape it to have a leading + // dimension of replica group size. + Shape padded_shape = ar_shape; + padded_shape.set_dimensions( + leading_non_one_dim, + RoundUpTo(ar_shape.dimensions(leading_non_one_dim), replica_group_size)); + *padded_shape.mutable_layout() = + LayoutUtil::GetDefaultLayoutForShape(padded_shape); + const int64_t fake_dim = 0; + Shape reshape_shape = padded_shape; + reshape_shape.set_dimensions( + leading_non_one_dim, + reshape_shape.dimensions(leading_non_one_dim) / replica_group_size); + reshape_shape.add_dimensions(replica_group_size, /*index=*/fake_dim); + *reshape_shape.mutable_layout() = + LayoutUtil::GetDefaultLayoutForShape(reshape_shape); + + std::optional kind = + MatchReductionInstruction(ar->to_apply()->root_instruction()); + if (!kind) { + return absl::InvalidArgumentError( + "Unsupported reduction type with : " + ar->ToString() + "\n" + + ar->to_apply()->ToString()); + } + std::optional reduction_identity = + GetReductionIdentity(*kind, ar->shape().element_type()); + if (!reduction_identity) { + return absl::InvalidArgumentError( + "Unsupported reduction identity with : " + ar->ToString() + "\n" + + ar->to_apply()->ToString()); + } + HloInstruction* identity = computation->AddInstruction( + HloInstruction::CreateConstant(std::move(reduction_identity.value()))); + std::vector formatting_steps( + {hlo_sharding_util::FormattingStep{.input_shape = ar_shape, + .output_shape = padded_shape, + .formatting_opcode = HloOpcode::kPad, + .padding_value = identity}, + hlo_sharding_util::FormattingStep{ + .input_shape = padded_shape, + .output_shape = reshape_shape, + .formatting_opcode = HloOpcode::kReshape}}); + + HloInstruction* formatted = hlo_sharding_util::FormatShape( + ar->mutable_operand(0), formatting_steps, computation); + + // 2. Create an all-to-all on the leading non-one dimension. + HloInstruction* all_to_all = + computation->AddInstruction(HloInstruction::CreateAllToAll( + reshape_shape, {formatted}, ar->device_list(), ar->constrain_layout(), + hlo_query::NextChannelId(*ar->GetModule()), + /*split_dimension=*/fake_dim)); + // 3. Do a local reduce and an all-gather + Shape reduce_shape = reshape_shape; + reduce_shape.DeleteDimension(fake_dim); + HloInstruction* reduce = + computation->AddInstruction(HloInstruction::CreateReduce( + reduce_shape, all_to_all, identity, + /*dimensions_to_reduce=*/{fake_dim}, ar->to_apply())); + Shape ag_operand_shape = reshape_shape; + ag_operand_shape.set_dimensions(fake_dim, 1); + HloInstruction* ag_operand = computation->AddInstruction( + HloInstruction::CreateReshape(ag_operand_shape, reduce)); + HloInstruction* ag = + computation->AddInstruction(HloInstruction::CreateAllGather( + reshape_shape, {ag_operand}, /*all_gather_dimension=*/fake_dim, + ar->device_list(), ar->constrain_layout(), + hlo_query::NextChannelId(*ar->GetModule()), + ar->use_global_device_ids())); + // 4. Reshape and slice back to the original shape. + HloInstruction* unformatted = + hlo_sharding_util::ReverseFormatShape(ag, formatting_steps, computation); + + unformatted->set_metadata(ar->metadata()); + TF_RETURN_IF_ERROR(ar->ReplaceAllUsesWith(unformatted)); + TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(ar)); + return true; +} + +} // namespace + +absl::StatusOr AllReduceNormalizer::NormalizeAllReduce( + HloInstruction* hlo) { + HloComputation* computation = hlo->parent(); + HloAllReduceInstruction* ar = Cast(hlo); + if (ar->operand_count() > 1) { + // Tupled all-reduce's counterpart all-gathers might not be combinable, so + // we split them into individual all-reduces first and then convert to + // all-gathers and local reduces, and run all-reduce combiner after this if + // we want to combine them back. + bool changed = false; + std::vector separate_all_reduces = SplitAllReduces(ar); + for (HloInstruction* separate_all_reduce : separate_all_reduces) { + TF_ASSIGN_OR_RETURN(bool converted, + NormalizeSingleOperandAllReduce( + separate_all_reduce, is_supported_all_reduce_)); + changed |= converted; + } + if (changed) { + TF_RETURN_IF_ERROR(ar->ReplaceAllUsesWith(computation->AddInstruction( + HloInstruction::CreateTuple(separate_all_reduces)))); + TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(ar)); + } else { + for (HloInstruction* separate_all_reduce : separate_all_reduces) { + TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands( + separate_all_reduce)); + } + } + return changed; + } else { + return NormalizeSingleOperandAllReduce(hlo, is_supported_all_reduce_); + } + return absl::OkStatus(); +} + +absl::StatusOr AllReduceNormalizer::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + for (auto inst : computation->MakeInstructionPostOrder()) { + if (inst->opcode() == HloOpcode::kAllReduce) { + TF_ASSIGN_OR_RETURN(bool inst_changed, NormalizeAllReduce(inst)); + changed |= inst_changed; + } + } + } + return changed; +} + +} // namespace xla diff --git a/xla/hlo/transforms/collectives/all_reduce_normalizer.h b/xla/hlo/transforms/collectives/all_reduce_normalizer.h new file mode 100644 index 00000000000000..72474caee6fb83 --- /dev/null +++ b/xla/hlo/transforms/collectives/all_reduce_normalizer.h @@ -0,0 +1,53 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_REDUCE_NORMALIZER_H_ +#define XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_REDUCE_NORMALIZER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// A layout-insensitive pass that tries to convert an all-reduce to supported +// cases. +class AllReduceNormalizer : public HloModulePass { + public: + explicit AllReduceNormalizer( + std::function is_supported_all_reduce) + : is_supported_all_reduce_(is_supported_all_reduce) {} + + absl::string_view name() const override { return "all-reduce-normalizer"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + absl::StatusOr NormalizeAllReduce(HloInstruction* hlo); + + std::function is_supported_all_reduce_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_REDUCE_NORMALIZER_H_ diff --git a/xla/hlo/transforms/collectives/all_reduce_normalizer_test.cc b/xla/hlo/transforms/collectives/all_reduce_normalizer_test.cc new file mode 100644 index 00000000000000..68e30478fffd38 --- /dev/null +++ b/xla/hlo/transforms/collectives/all_reduce_normalizer_test.cc @@ -0,0 +1,119 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/transforms/collectives/all_reduce_normalizer.h" + +#include + +#include +#include +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/utils/hlo_matchers.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +using AllReduceNormalizerTest = HloHardwareIndependentTestBase; + +TEST_F(AllReduceNormalizerTest, Simple) { + const absl::string_view hlo_string = R"( +HloModule module + +%add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY %comp { + p0 = bf16[1,256,8,256]{1,3,2,0:T(8,128)(2,1)} parameter(0) + ROOT all-reduce = bf16[1,256,8,256]{1,3,2,0:T(8,128)(2,1)} + all-reduce(bf16[1,256,8,256]{1,3,2,0:T(8,128)(2,1)} p0), channel_id=115, + replica_groups={{0,2},{4,6},{8,10},{12,14},{16,18},{20,22},{24,26},{28,30}, + {1,3},{5,7},{9,11},{13,15},{17,19},{21,23},{25,27},{29,31}}, + use_global_device_ids=true, to_apply=%add, + backend_config={"flag_configs":[], + "barrier_config":{"barrier_type":"CUSTOM","id":"16"}, + "scoped_memory_configs":[{ + "memory_space":"0","offset":"0","size":"67108864"}], + "used_scoped_memory_configs":[]} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + AllReduceNormalizer normalizer( + /*is_supported_all_reduce=*/[](const HloInstruction*) { return false; }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, normalizer.Run(module.get())); + VLOG(1) << module->ToString(); + EXPECT_TRUE(changed); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + op::Slice(op::Reshape(op::AllGather(op::Reshape(op::Reduce( + op::AllToAll(op::Reshape(op::Pad(op::Parameter(0), op::Constant()))), + op::Constant())))))); +} + +TEST_F(AllReduceNormalizerTest, NonDivisible) { + const absl::string_view hlo_string = R"( +HloModule module + +%add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY %comp { + p0 = bf16[1,256,7,256]{1,3,2,0:T(8,128)(2,1)} parameter(0) + ROOT all-reduce = bf16[1,256,7,256]{1,3,2,0:T(8,128)(2,1)} + all-reduce(bf16[1,256,7,256]{1,3,2,0:T(8,128)(2,1)} p0), channel_id=115, + replica_groups={{0,2},{4,6},{8,10},{12,14},{16,18},{20,22},{24,26},{28,30}, + {1,3},{5,7},{9,11},{13,15},{17,19},{21,23},{25,27},{29,31}}, + use_global_device_ids=true, to_apply=%add, + backend_config={"flag_configs":[], + "barrier_config":{"barrier_type":"CUSTOM","id":"16"}, + "scoped_memory_configs":[{ + "memory_space":"0","offset":"0","size":"67108864"}], + "used_scoped_memory_configs":[]} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + AllReduceNormalizer normalizer( + /*is_supported_all_reduce=*/[](const HloInstruction*) { return false; }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, normalizer.Run(module.get())); + VLOG(1) << module->ToString(); + EXPECT_TRUE(changed); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + op::Slice(op::Reshape(op::AllGather(op::Reshape(op::Reduce( + op::AllToAll(op::Reshape(op::Pad(op::Parameter(0), op::Constant()))), + op::Constant())))))); +} + +} // namespace +} // namespace xla diff --git a/xla/service/hlo_module_config.h b/xla/service/hlo_module_config.h index cab597a53b1a2a..71d61d97f18226 100644 --- a/xla/service/hlo_module_config.h +++ b/xla/service/hlo_module_config.h @@ -271,6 +271,9 @@ class HloModuleConfig { void set_static_device_assignment(const DeviceAssignment& device_assignment) { static_device_assignment_ = device_assignment; } + void reset_static_device_assignment() { + static_device_assignment_ = std::nullopt; + } // Checks if this config has a simulated device assignment. bool has_pre_simulation_device_assignment() const { diff --git a/xla/shape.h b/xla/shape.h index 27a244c39fa317..99c0f8641871ed 100644 --- a/xla/shape.h +++ b/xla/shape.h @@ -178,6 +178,10 @@ class Shape { dimensions_.push_back(value); dynamic_dimensions_.push_back(false); } + void add_dimensions(int64_t value, int64_t index) { + dimensions_.insert(dimensions_.begin() + index, value); + dynamic_dimensions_.insert(dynamic_dimensions_.begin() + index, false); + } void clear_dimensions() { dimensions_.clear(); dynamic_dimensions_.clear(); diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 39a30e20508ff5..3bc33ec55f7392 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -245,6 +245,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/tests/hlo_runner_agnostic_test_base.cc b/xla/tests/hlo_runner_agnostic_test_base.cc index cff37063b55d85..215406f31e9a65 100644 --- a/xla/tests/hlo_runner_agnostic_test_base.cc +++ b/xla/tests/hlo_runner_agnostic_test_base.cc @@ -53,6 +53,7 @@ limitations under the License. #include "xla/tsl/platform/test.h" #include "xla/util.h" #include "tsl/platform/protobuf.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -76,24 +77,14 @@ HloRunnerAgnosticTestBase::CreateNewVerifiedModule( instruction_can_change_layout_func()); } -absl::StatusOr> -HloRunnerAgnosticTestBase::ParseAndReturnVerifiedModule( - absl::string_view hlo_text, int64_t replica_count, int64_t num_partitions) { - return ParseAndReturnVerifiedModule( - hlo_text, GetModuleConfigForTest(replica_count, num_partitions)); -} - absl::StatusOr> HloRunnerAgnosticTestBase::ParseAndReturnVerifiedModule( absl::string_view hlo_text, const HloModuleConfig& config, - const HloParserOptions& parser_options) { - auto module = std::make_unique( - TestName(), config, verifier_layout_sensitive(), - allow_mixed_precision_in_hlo_verifier(), - test_runner_->device_shape_size_fn(), - instruction_can_change_layout_func()); - TF_RETURN_IF_ERROR( - module->ParseHloStringAndVerifyModule(hlo_text, parser_options)); + const HloParserOptions& parser_options) const { + TF_ASSIGN_OR_RETURN( + auto module, HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( + hlo_text, config, parser_options, + test_runner_->device_shape_size_fn())); UpdateEntryComputationLayout(module.get()); return std::move(module); } diff --git a/xla/tests/hlo_runner_agnostic_test_base.h b/xla/tests/hlo_runner_agnostic_test_base.h index 848074e8178a93..3641358edb4f86 100644 --- a/xla/tests/hlo_runner_agnostic_test_base.h +++ b/xla/tests/hlo_runner_agnostic_test_base.h @@ -102,20 +102,17 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { const std::string& name = TestName(), int64_t replica_count = 1); // Parses the given string and returns module as a VerifiedHloModule. - absl::StatusOr> - ParseAndReturnVerifiedModule(absl::string_view hlo_text, - int64_t replica_count = 1, - int64_t num_partitions = 1); - // Parses the given string and returns module as a VerifiedHloModule. - // + using HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule; + // To obtain a HloModuleConfig with a specific replica and partition count and // no further customization, either use the overload above or use // GetModuleConfigForTest. The latter option may be useful if you want to pass // custom HloParserOptions as well. absl::StatusOr> - ParseAndReturnVerifiedModule( - absl::string_view hlo_text, const HloModuleConfig& config, - const HloParserOptions& parser_options = HloParserOptions()); + ParseAndReturnVerifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config, + const HloParserOptions& parser_options = + HloParserOptions()) const override; HloComputation* AddEntryComputationAndUpdateEntryComputationLayout( HloModule*, std::unique_ptr computation);