From fe63f8aa897852d8fcaeeadc7967aa3c4901bf7f Mon Sep 17 00:00:00 2001 From: Niklas Vangerow Date: Wed, 22 Jan 2025 21:38:27 -0800 Subject: [PATCH] Allow customization of HloParserOptions when using ParseAndReturnVerifiedModule. PiperOrigin-RevId: 718697476 --- xla/hlo/parser/BUILD | 3 +++ xla/hlo/parser/hlo_parser.cc | 5 +++-- xla/hlo/parser/hlo_parser.h | 11 ++++++++--- xla/hlo/testlib/BUILD | 8 +++----- xla/hlo/testlib/verified_hlo_module.cc | 11 +++++------ xla/hlo/testlib/verified_hlo_module.h | 11 ++++++++--- xla/tests/BUILD | 1 + xla/tests/hlo_runner_agnostic_test_base.cc | 7 +++++-- xla/tests/hlo_runner_agnostic_test_base.h | 12 ++++++++++-- 9 files changed, 46 insertions(+), 23 deletions(-) diff --git a/xla/hlo/parser/BUILD b/xla/hlo/parser/BUILD index 8db79548d8ebf8..e5a0754485a73e 100644 --- a/xla/hlo/parser/BUILD +++ b/xla/hlo/parser/BUILD @@ -47,6 +47,7 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:status", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", @@ -59,6 +60,8 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", + "@tsl//tsl/platform:protobuf", + "@tsl//tsl/platform:status", ], ) diff --git a/xla/hlo/parser/hlo_parser.cc b/xla/hlo/parser/hlo_parser.cc index 126a31bac8d2f0..e976ac8bcaced3 100644 --- a/xla/hlo/parser/hlo_parser.cc +++ b/xla/hlo/parser/hlo_parser.cc @@ -81,6 +81,7 @@ limitations under the License. #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/protobuf.h" namespace xla { @@ -7303,8 +7304,8 @@ absl::StatusOr ParseLayout(absl::string_view str) { } std::unique_ptr HloParser::CreateHloParserForTests( - absl::string_view str) { - return std::make_unique(str); + absl::string_view str, const HloParserOptions& options) { + return std::make_unique(str, options); } } // namespace xla diff --git a/xla/hlo/parser/hlo_parser.h b/xla/hlo/parser/hlo_parser.h index 3d1d2f25f999f9..a79cb4ca57ae26 100644 --- a/xla/hlo/parser/hlo_parser.h +++ b/xla/hlo/parser/hlo_parser.h @@ -19,12 +19,16 @@ limitations under the License. #include #include -#include "absl/status/statusor.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/parser/hlo_lexer.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" namespace xla { @@ -118,7 +122,8 @@ class HloParser { private: static std::unique_ptr CreateHloParserForTests( - absl::string_view str); + absl::string_view str, + const HloParserOptions& options = HloParserOptions()); friend class VerifiedHloModule; }; diff --git a/xla/hlo/testlib/BUILD b/xla/hlo/testlib/BUILD index a5094ad5eb777c..85696f61e255d7 100644 --- a/xla/hlo/testlib/BUILD +++ b/xla/hlo/testlib/BUILD @@ -39,18 +39,16 @@ cc_library( deps = [ "//xla:shape_util", "//xla:status_macros", - "//xla:types", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/service:hlo_module_config", "//xla/service:hlo_verifier", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:test", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:status", - "@tsl//tsl/platform:test", ], ) diff --git a/xla/hlo/testlib/verified_hlo_module.cc b/xla/hlo/testlib/verified_hlo_module.cc index 044bc2f5ca40bc..5b4f69b13a90fa 100644 --- a/xla/hlo/testlib/verified_hlo_module.cc +++ b/xla/hlo/testlib/verified_hlo_module.cc @@ -19,18 +19,17 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/status_macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/test.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/status.h" -#include "tsl/platform/test.h" namespace xla { absl::Status VerifiedHloModule::ParseHloStringAndVerifyModule( - absl::string_view str) { + absl::string_view str, const HloParserOptions& options) { TF_RET_CHECK(computation_count() == 0); - auto parser = HloParser::CreateHloParserForTests(str); + auto parser = HloParser::CreateHloParserForTests(str, options); TF_RETURN_IF_ERROR(parser->Run(this)); return Verify(); } diff --git a/xla/hlo/testlib/verified_hlo_module.h b/xla/hlo/testlib/verified_hlo_module.h index 6c8f03a1c01df3..0abc5eba956695 100644 --- a/xla/hlo/testlib/verified_hlo_module.h +++ b/xla/hlo/testlib/verified_hlo_module.h @@ -15,15 +15,18 @@ limitations under the License. #ifndef XLA_HLO_TESTLIB_VERIFIED_HLO_MODULE_H_ #define XLA_HLO_TESTLIB_VERIFIED_HLO_MODULE_H_ +#include #include +#include +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" #include "xla/shape.h" -#include "xla/types.h" -#include "tsl/platform/status.h" +#include "xla/util.h" namespace xla { @@ -48,7 +51,9 @@ class VerifiedHloModule : public HloModule { // builds the VerifiedHloModule in place. Before calling this method, the // module must be empty (no computations). Finally verifies the module using // HloVerifier and returns the status. - absl::Status ParseHloStringAndVerifyModule(absl::string_view str); + absl::Status ParseHloStringAndVerifyModule( + absl::string_view str, + const HloParserOptions& options = HloParserOptions()); // Verifies the module and flags any error with ADD_FAILURE. 'message' is // included in the failure message. diff --git a/xla/tests/BUILD b/xla/tests/BUILD index 3257198a319fdd..39a30e20508ff5 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -219,6 +219,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:test_helpers", "//xla/hlo/testlib:verified_hlo_module", diff --git a/xla/tests/hlo_runner_agnostic_test_base.cc b/xla/tests/hlo_runner_agnostic_test_base.cc index 341ab477e55995..cff37063b55d85 100644 --- a/xla/tests/hlo_runner_agnostic_test_base.cc +++ b/xla/tests/hlo_runner_agnostic_test_base.cc @@ -35,6 +35,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/literal.h" @@ -84,13 +85,15 @@ HloRunnerAgnosticTestBase::ParseAndReturnVerifiedModule( absl::StatusOr> HloRunnerAgnosticTestBase::ParseAndReturnVerifiedModule( - absl::string_view hlo_text, const HloModuleConfig& config) { + absl::string_view hlo_text, const HloModuleConfig& config, + const HloParserOptions& parser_options) { auto module = std::make_unique( TestName(), config, verifier_layout_sensitive(), allow_mixed_precision_in_hlo_verifier(), test_runner_->device_shape_size_fn(), instruction_can_change_layout_func()); - TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); + TF_RETURN_IF_ERROR( + module->ParseHloStringAndVerifyModule(hlo_text, parser_options)); UpdateEntryComputationLayout(module.get()); return std::move(module); } diff --git a/xla/tests/hlo_runner_agnostic_test_base.h b/xla/tests/hlo_runner_agnostic_test_base.h index efbf871fa28208..848074e8178a93 100644 --- a/xla/tests/hlo_runner_agnostic_test_base.h +++ b/xla/tests/hlo_runner_agnostic_test_base.h @@ -34,6 +34,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/test_helpers.h" #include "xla/hlo/testlib/verified_hlo_module.h" @@ -105,9 +106,16 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { ParseAndReturnVerifiedModule(absl::string_view hlo_text, int64_t replica_count = 1, int64_t num_partitions = 1); + // Parses the given string and returns module as a VerifiedHloModule. + // + // To obtain a HloModuleConfig with a specific replica and partition count and + // no further customization, either use the overload above or use + // GetModuleConfigForTest. The latter option may be useful if you want to pass + // custom HloParserOptions as well. absl::StatusOr> - ParseAndReturnVerifiedModule(absl::string_view hlo_text, - const HloModuleConfig& config); + ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config, + const HloParserOptions& parser_options = HloParserOptions()); HloComputation* AddEntryComputationAndUpdateEntryComputationLayout( HloModule*, std::unique_ptr computation);