From 14c2958dd3e0da15fe9b71248304b017983a95b7 Mon Sep 17 00:00:00 2001
From: Niklas Vangerow <nikv@google.com>
Date: Wed, 22 Jan 2025 21:38:27 -0800
Subject: [PATCH] Allow customization of HloParserOptions when using
 ParseAndReturnVerifiedModule.

PiperOrigin-RevId: 718697476
---
 xla/hlo/parser/BUILD                       |  3 +++
 xla/hlo/parser/hlo_parser.cc               |  5 +++--
 xla/hlo/parser/hlo_parser.h                | 11 ++++++++---
 xla/hlo/testlib/BUILD                      |  8 +++-----
 xla/hlo/testlib/verified_hlo_module.cc     | 11 +++++------
 xla/hlo/testlib/verified_hlo_module.h      | 11 ++++++++---
 xla/tests/BUILD                            |  1 +
 xla/tests/hlo_runner_agnostic_test_base.cc |  7 +++++--
 xla/tests/hlo_runner_agnostic_test_base.h  | 12 ++++++++++--
 9 files changed, 46 insertions(+), 23 deletions(-)

diff --git a/xla/hlo/parser/BUILD b/xla/hlo/parser/BUILD
index 8db79548d8ebf8..e5a0754485a73e 100644
--- a/xla/hlo/parser/BUILD
+++ b/xla/hlo/parser/BUILD
@@ -47,6 +47,7 @@ cc_library(
         "//xla/tsl/platform:errors",
         "//xla/tsl/platform:logging",
         "//xla/tsl/platform:status",
+        "//xla/tsl/platform:statusor",
         "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/base",
         "@com_google_absl//absl/container:flat_hash_map",
@@ -59,6 +60,8 @@ cc_library(
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/types:span",
         "@eigen_archive//:eigen3",
+        "@tsl//tsl/platform:protobuf",
+        "@tsl//tsl/platform:status",
     ],
 )
 
diff --git a/xla/hlo/parser/hlo_parser.cc b/xla/hlo/parser/hlo_parser.cc
index 126a31bac8d2f0..e976ac8bcaced3 100644
--- a/xla/hlo/parser/hlo_parser.cc
+++ b/xla/hlo/parser/hlo_parser.cc
@@ -81,6 +81,7 @@ limitations under the License.
 #include "xla/types.h"
 #include "xla/util.h"
 #include "xla/xla_data.pb.h"
+#include "tsl/platform/protobuf.h"
 
 namespace xla {
 
@@ -7303,8 +7304,8 @@ absl::StatusOr<Layout> ParseLayout(absl::string_view str) {
 }
 
 std::unique_ptr<HloParser> HloParser::CreateHloParserForTests(
-    absl::string_view str) {
-  return std::make_unique<HloParserImpl>(str);
+    absl::string_view str, const HloParserOptions& options) {
+  return std::make_unique<HloParserImpl>(str, options);
 }
 
 }  // namespace xla
diff --git a/xla/hlo/parser/hlo_parser.h b/xla/hlo/parser/hlo_parser.h
index 3d1d2f25f999f9..a79cb4ca57ae26 100644
--- a/xla/hlo/parser/hlo_parser.h
+++ b/xla/hlo/parser/hlo_parser.h
@@ -19,12 +19,16 @@ limitations under the License.
 #include <memory>
 #include <vector>
 
-#include "absl/status/statusor.h"
+#include "absl/status/status.h"
 #include "absl/strings/string_view.h"
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_module.h"
-#include "xla/hlo/parser/hlo_lexer.h"
+#include "xla/hlo/ir/hlo_sharding.h"
+#include "xla/layout.h"
+#include "xla/service/hlo_module_config.h"
+#include "xla/shape.h"
+#include "xla/tsl/platform/statusor.h"
 #include "xla/xla_data.pb.h"
 
 namespace xla {
@@ -118,7 +122,8 @@ class HloParser {
 
  private:
   static std::unique_ptr<HloParser> CreateHloParserForTests(
-      absl::string_view str);
+      absl::string_view str,
+      const HloParserOptions& options = HloParserOptions());
   friend class VerifiedHloModule;
 };
 
diff --git a/xla/hlo/testlib/BUILD b/xla/hlo/testlib/BUILD
index a5094ad5eb777c..85696f61e255d7 100644
--- a/xla/hlo/testlib/BUILD
+++ b/xla/hlo/testlib/BUILD
@@ -39,18 +39,16 @@ cc_library(
     deps = [
         "//xla:shape_util",
         "//xla:status_macros",
-        "//xla:types",
         "//xla:util",
         "//xla/hlo/ir:hlo",
         "//xla/hlo/parser:hlo_parser",
         "//xla/service:hlo_module_config",
         "//xla/service:hlo_verifier",
+        "//xla/tsl/platform:errors",
+        "//xla/tsl/platform:logging",
+        "//xla/tsl/platform:test",
         "@com_google_absl//absl/status",
         "@com_google_absl//absl/strings",
-        "@tsl//tsl/platform:errors",
-        "@tsl//tsl/platform:logging",
-        "@tsl//tsl/platform:status",
-        "@tsl//tsl/platform:test",
     ],
 )
 
diff --git a/xla/hlo/testlib/verified_hlo_module.cc b/xla/hlo/testlib/verified_hlo_module.cc
index 044bc2f5ca40bc..5b4f69b13a90fa 100644
--- a/xla/hlo/testlib/verified_hlo_module.cc
+++ b/xla/hlo/testlib/verified_hlo_module.cc
@@ -19,18 +19,17 @@ limitations under the License.
 #include "absl/strings/string_view.h"
 #include "xla/hlo/parser/hlo_parser.h"
 #include "xla/status_macros.h"
+#include "xla/tsl/platform/errors.h"
+#include "xla/tsl/platform/logging.h"
+#include "xla/tsl/platform/test.h"
 #include "xla/util.h"
-#include "tsl/platform/errors.h"
-#include "tsl/platform/logging.h"
-#include "tsl/platform/status.h"
-#include "tsl/platform/test.h"
 
 namespace xla {
 
 absl::Status VerifiedHloModule::ParseHloStringAndVerifyModule(
-    absl::string_view str) {
+    absl::string_view str, const HloParserOptions& options) {
   TF_RET_CHECK(computation_count() == 0);
-  auto parser = HloParser::CreateHloParserForTests(str);
+  auto parser = HloParser::CreateHloParserForTests(str, options);
   TF_RETURN_IF_ERROR(parser->Run(this));
   return Verify();
 }
diff --git a/xla/hlo/testlib/verified_hlo_module.h b/xla/hlo/testlib/verified_hlo_module.h
index 6c8f03a1c01df3..0abc5eba956695 100644
--- a/xla/hlo/testlib/verified_hlo_module.h
+++ b/xla/hlo/testlib/verified_hlo_module.h
@@ -15,15 +15,18 @@ limitations under the License.
 #ifndef XLA_HLO_TESTLIB_VERIFIED_HLO_MODULE_H_
 #define XLA_HLO_TESTLIB_VERIFIED_HLO_MODULE_H_
 
+#include <cstdint>
 #include <functional>
+#include <string>
 
+#include "absl/status/status.h"
 #include "absl/strings/string_view.h"
 #include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/parser/hlo_parser.h"
 #include "xla/service/hlo_module_config.h"
 #include "xla/service/hlo_verifier.h"
 #include "xla/shape.h"
-#include "xla/types.h"
-#include "tsl/platform/status.h"
+#include "xla/util.h"
 
 namespace xla {
 
@@ -48,7 +51,9 @@ class VerifiedHloModule : public HloModule {
   // builds the VerifiedHloModule in place. Before calling this method, the
   // module must be empty (no computations). Finally verifies the module using
   // HloVerifier and returns the status.
-  absl::Status ParseHloStringAndVerifyModule(absl::string_view str);
+  absl::Status ParseHloStringAndVerifyModule(
+      absl::string_view str,
+      const HloParserOptions& options = HloParserOptions());
 
   // Verifies the module and flags any error with ADD_FAILURE. 'message' is
   // included in the failure message.
diff --git a/xla/tests/BUILD b/xla/tests/BUILD
index 3257198a319fdd..39a30e20508ff5 100644
--- a/xla/tests/BUILD
+++ b/xla/tests/BUILD
@@ -219,6 +219,7 @@ cc_library(
         "//xla:util",
         "//xla:xla_data_proto_cc",
         "//xla/hlo/ir:hlo",
+        "//xla/hlo/parser:hlo_parser",
         "//xla/hlo/testlib:hlo_hardware_independent_test_base",
         "//xla/hlo/testlib:test_helpers",
         "//xla/hlo/testlib:verified_hlo_module",
diff --git a/xla/tests/hlo_runner_agnostic_test_base.cc b/xla/tests/hlo_runner_agnostic_test_base.cc
index 341ab477e55995..cff37063b55d85 100644
--- a/xla/tests/hlo_runner_agnostic_test_base.cc
+++ b/xla/tests/hlo_runner_agnostic_test_base.cc
@@ -35,6 +35,7 @@ limitations under the License.
 #include "llvm/ADT/STLExtras.h"
 #include "xla/error_spec.h"
 #include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/parser/hlo_parser.h"
 #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
 #include "xla/hlo/testlib/verified_hlo_module.h"
 #include "xla/literal.h"
@@ -84,13 +85,15 @@ HloRunnerAgnosticTestBase::ParseAndReturnVerifiedModule(
 
 absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
 HloRunnerAgnosticTestBase::ParseAndReturnVerifiedModule(
-    absl::string_view hlo_text, const HloModuleConfig& config) {
+    absl::string_view hlo_text, const HloModuleConfig& config,
+    const HloParserOptions& parser_options) {
   auto module = std::make_unique<VerifiedHloModule>(
       TestName(), config, verifier_layout_sensitive(),
       allow_mixed_precision_in_hlo_verifier(),
       test_runner_->device_shape_size_fn(),
       instruction_can_change_layout_func());
-  TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
+  TF_RETURN_IF_ERROR(
+      module->ParseHloStringAndVerifyModule(hlo_text, parser_options));
   UpdateEntryComputationLayout(module.get());
   return std::move(module);
 }
diff --git a/xla/tests/hlo_runner_agnostic_test_base.h b/xla/tests/hlo_runner_agnostic_test_base.h
index efbf871fa28208..848074e8178a93 100644
--- a/xla/tests/hlo_runner_agnostic_test_base.h
+++ b/xla/tests/hlo_runner_agnostic_test_base.h
@@ -34,6 +34,7 @@ limitations under the License.
 #include "xla/hlo/ir/hlo_computation.h"
 #include "xla/hlo/ir/hlo_instruction.h"
 #include "xla/hlo/ir/hlo_module.h"
+#include "xla/hlo/parser/hlo_parser.h"
 #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
 #include "xla/hlo/testlib/test_helpers.h"
 #include "xla/hlo/testlib/verified_hlo_module.h"
@@ -105,9 +106,16 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase {
   ParseAndReturnVerifiedModule(absl::string_view hlo_text,
                                int64_t replica_count = 1,
                                int64_t num_partitions = 1);
+  // Parses the given string and returns module as a VerifiedHloModule.
+  //
+  // To obtain a HloModuleConfig with a specific replica and partition count and
+  // no further customization, either use the overload above or use
+  // GetModuleConfigForTest. The latter option may be useful if you want to pass
+  // custom HloParserOptions as well.
   absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
-  ParseAndReturnVerifiedModule(absl::string_view hlo_text,
-                               const HloModuleConfig& config);
+  ParseAndReturnVerifiedModule(
+      absl::string_view hlo_text, const HloModuleConfig& config,
+      const HloParserOptions& parser_options = HloParserOptions());
 
   HloComputation* AddEntryComputationAndUpdateEntryComputationLayout(
       HloModule*, std::unique_ptr<HloComputation> computation);