diff --git a/xla/hlo/testlib/BUILD b/xla/hlo/testlib/BUILD index 85696f61e255d7..1351a8d78c665f 100644 --- a/xla/hlo/testlib/BUILD +++ b/xla/hlo/testlib/BUILD @@ -67,9 +67,11 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", "//xla/service:computation_layout", + "//xla/service:computation_placer_hdr", "//xla/service:hlo_module_config", "//xla/service:hlo_verifier", "@com_google_absl//absl/algorithm:container", diff --git a/xla/hlo/testlib/hlo_hardware_independent_test_base.cc b/xla/hlo/testlib/hlo_hardware_independent_test_base.cc index a93733be406b6c..3d5379d1876ec8 100644 --- a/xla/hlo/testlib/hlo_hardware_independent_test_base.cc +++ b/xla/hlo/testlib/hlo_hardware_independent_test_base.cc @@ -35,12 +35,15 @@ limitations under the License. #include "absl/types/span.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/testlib/filecheck.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/service/computation_placer.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" #include "xla/shape.h" @@ -77,12 +80,23 @@ HloHardwareIndependentTestBase::CreateNewVerifiedModule( instruction_can_change_layout_func_); } +DeviceAssignment HloHardwareIndependentTestBase::GetDefaultDeviceAssignment( + int64_t replica_count, int64_t num_partitions) const { + DeviceAssignment device_assignment(replica_count, num_partitions); + device_assignment.FillIota(0); + return device_assignment; +} + absl::StatusOr> HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( - absl::string_view hlo_text, int64_t replica_count, - int64_t num_partitions) const { - return ParseAndReturnVerifiedModule( - hlo_text, GetModuleConfigForTest(replica_count, num_partitions)); + absl::string_view hlo_text, int64_t replica_count, int64_t num_partitions, + std::optional device_assignment) const { + HloModuleConfig config = + GetModuleConfigForTest(replica_count, num_partitions); + if (device_assignment.has_value()) { + config.set_static_device_assignment(device_assignment.value()); + } + return ParseAndReturnVerifiedModule(hlo_text, config); } absl::Status HloHardwareIndependentTestBase:: @@ -110,12 +124,31 @@ absl::Status HloHardwareIndependentTestBase:: absl::StatusOr> HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( - absl::string_view hlo_text, const HloModuleConfig& config) const { + absl::string_view hlo_text, const HloModuleConfig& config, + const HloParserOptions& parser_options) const { + return ParseAndReturnVerifiedModule(hlo_text, config, parser_options, + ShapeUtil::ByteSizeOfElements); +} + +absl::StatusOr> +HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config, + const HloParserOptions& parser_options, + std::function shape_size_fn) const { + HloModuleConfig config_with_device_assignment = config; + if (!config.has_static_device_assignment()) { + default_device_assignment_ = + std::make_unique(GetDefaultDeviceAssignment( + config.replica_count(), config.num_partitions())); + config_with_device_assignment.set_static_device_assignment( + *default_device_assignment_); + } auto module = std::make_unique( - TestName(), config, verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_, ShapeUtil::ByteSizeOfElements, + TestName(), config_with_device_assignment, verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_, shape_size_fn, instruction_can_change_layout_func_); - TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); + TF_RETURN_IF_ERROR( + module->ParseHloStringAndVerifyModule(hlo_text, parser_options)); return module; } diff --git a/xla/hlo/testlib/hlo_hardware_independent_test_base.h b/xla/hlo/testlib/hlo_hardware_independent_test_base.h index 4de7ccc1db25bf..33037552555ed4 100644 --- a/xla/hlo/testlib/hlo_hardware_independent_test_base.h +++ b/xla/hlo/testlib/hlo_hardware_independent_test_base.h @@ -35,10 +35,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/layout.h" #include "xla/service/computation_layout.h" +#include "xla/service/computation_placer.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" #include "xla/shape_layout.h" @@ -95,14 +97,24 @@ class HloHardwareIndependentTestBase : public ::testing::Test { std::unique_ptr CreateNewVerifiedModule( const std::string& name = TestName(), int64_t replica_count = 1) const; + // + DeviceAssignment GetDefaultDeviceAssignment(int64_t replica_count, + int64_t num_partitions) const; // Parses the given string and returns module as a VerifiedHloModule. + virtual absl::StatusOr> + ParseAndReturnVerifiedModule( + absl::string_view hlo_text, int64_t replica_count = 1, + int64_t num_partitions = 1, + std::optional device_assignment = std::nullopt) const; + virtual absl::StatusOr> + ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config, + const HloParserOptions& parser_options = HloParserOptions()) const; absl::StatusOr> - ParseAndReturnVerifiedModule(absl::string_view hlo_text, - int64_t replica_count = 1, - int64_t num_partitions = 1) const; - absl::StatusOr> - ParseAndReturnVerifiedModule(absl::string_view hlo_text, - const HloModuleConfig& config) const; + ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config, + const HloParserOptions& parser_options, + std::function shape_size_fn) const; // Runs the hlo_pass with the provided module and returns the result. This // function also verifies that the module remains unchanged when hlo_pass @@ -181,13 +193,22 @@ class HloHardwareIndependentTestBase : public ::testing::Test { // options (e.g. disabling additional passes). virtual DebugOptions GetDebugOptionsForTest() const; + void TearDown() override { default_device_assignment_.reset(); } // Gets an HloModuleConfig with options appropriate for tests. - HloModuleConfig GetModuleConfigForTest(int64_t replica_count = 1, - int64_t num_partitions = 1) const { + HloModuleConfig GetModuleConfigForTest( + int64_t replica_count = 1, int64_t num_partitions = 1, + std::optional device_assignment = std::nullopt) const { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); config.set_replica_count(replica_count); config.set_num_partitions(num_partitions); + if (device_assignment.has_value()) { + config.set_static_device_assignment(std::move(device_assignment.value())); + } else { + default_device_assignment_ = std::make_unique( + GetDefaultDeviceAssignment(replica_count, num_partitions)); + config.set_static_device_assignment(*default_device_assignment_); + } return config; } @@ -269,6 +290,7 @@ class HloHardwareIndependentTestBase : public ::testing::Test { bool allow_mixed_precision_in_hlo_verifier_; HloPredicate instruction_can_change_layout_func_; std::unique_ptr hlo_verifier_; + mutable std::unique_ptr default_device_assignment_; }; } // namespace xla diff --git a/xla/service/hlo_module_config.h b/xla/service/hlo_module_config.h index cab597a53b1a2a..71d61d97f18226 100644 --- a/xla/service/hlo_module_config.h +++ b/xla/service/hlo_module_config.h @@ -271,6 +271,9 @@ class HloModuleConfig { void set_static_device_assignment(const DeviceAssignment& device_assignment) { static_device_assignment_ = device_assignment; } + void reset_static_device_assignment() { + static_device_assignment_ = std::nullopt; + } // Checks if this config has a simulated device assignment. bool has_pre_simulation_device_assignment() const { diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 39a30e20508ff5..3bc33ec55f7392 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -245,6 +245,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/tests/hlo_runner_agnostic_test_base.cc b/xla/tests/hlo_runner_agnostic_test_base.cc index cff37063b55d85..215406f31e9a65 100644 --- a/xla/tests/hlo_runner_agnostic_test_base.cc +++ b/xla/tests/hlo_runner_agnostic_test_base.cc @@ -53,6 +53,7 @@ limitations under the License. #include "xla/tsl/platform/test.h" #include "xla/util.h" #include "tsl/platform/protobuf.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -76,24 +77,14 @@ HloRunnerAgnosticTestBase::CreateNewVerifiedModule( instruction_can_change_layout_func()); } -absl::StatusOr> -HloRunnerAgnosticTestBase::ParseAndReturnVerifiedModule( - absl::string_view hlo_text, int64_t replica_count, int64_t num_partitions) { - return ParseAndReturnVerifiedModule( - hlo_text, GetModuleConfigForTest(replica_count, num_partitions)); -} - absl::StatusOr> HloRunnerAgnosticTestBase::ParseAndReturnVerifiedModule( absl::string_view hlo_text, const HloModuleConfig& config, - const HloParserOptions& parser_options) { - auto module = std::make_unique( - TestName(), config, verifier_layout_sensitive(), - allow_mixed_precision_in_hlo_verifier(), - test_runner_->device_shape_size_fn(), - instruction_can_change_layout_func()); - TF_RETURN_IF_ERROR( - module->ParseHloStringAndVerifyModule(hlo_text, parser_options)); + const HloParserOptions& parser_options) const { + TF_ASSIGN_OR_RETURN( + auto module, HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( + hlo_text, config, parser_options, + test_runner_->device_shape_size_fn())); UpdateEntryComputationLayout(module.get()); return std::move(module); } diff --git a/xla/tests/hlo_runner_agnostic_test_base.h b/xla/tests/hlo_runner_agnostic_test_base.h index 848074e8178a93..3641358edb4f86 100644 --- a/xla/tests/hlo_runner_agnostic_test_base.h +++ b/xla/tests/hlo_runner_agnostic_test_base.h @@ -102,20 +102,17 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { const std::string& name = TestName(), int64_t replica_count = 1); // Parses the given string and returns module as a VerifiedHloModule. - absl::StatusOr> - ParseAndReturnVerifiedModule(absl::string_view hlo_text, - int64_t replica_count = 1, - int64_t num_partitions = 1); - // Parses the given string and returns module as a VerifiedHloModule. - // + using HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule; + // To obtain a HloModuleConfig with a specific replica and partition count and // no further customization, either use the overload above or use // GetModuleConfigForTest. The latter option may be useful if you want to pass // custom HloParserOptions as well. absl::StatusOr> - ParseAndReturnVerifiedModule( - absl::string_view hlo_text, const HloModuleConfig& config, - const HloParserOptions& parser_options = HloParserOptions()); + ParseAndReturnVerifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config, + const HloParserOptions& parser_options = + HloParserOptions()) const override; HloComputation* AddEntryComputationAndUpdateEntryComputationLayout( HloModule*, std::unique_ptr computation);