Skip to content

Commit

Permalink
Allow customization of HloParserOptions when using ParseAndReturnVeri…
Browse files Browse the repository at this point in the history
…fiedModule.

PiperOrigin-RevId: 718697476
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Jan 27, 2025
1 parent 8591928 commit fe63f8a
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 23 deletions.
3 changes: 3 additions & 0 deletions xla/hlo/parser/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ cc_library(
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
"//xla/tsl/platform:status",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
Expand All @@ -59,6 +60,8 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
"@tsl//tsl/platform:protobuf",
"@tsl//tsl/platform:status",
],
)

Expand Down
5 changes: 3 additions & 2 deletions xla/hlo/parser/hlo_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ limitations under the License.
#include "xla/types.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/protobuf.h"

namespace xla {

Expand Down Expand Up @@ -7303,8 +7304,8 @@ absl::StatusOr<Layout> ParseLayout(absl::string_view str) {
}

std::unique_ptr<HloParser> HloParser::CreateHloParserForTests(
absl::string_view str) {
return std::make_unique<HloParserImpl>(str);
absl::string_view str, const HloParserOptions& options) {
return std::make_unique<HloParserImpl>(str, options);
}

} // namespace xla
11 changes: 8 additions & 3 deletions xla/hlo/parser/hlo_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@ limitations under the License.
#include <memory>
#include <vector>

#include "absl/status/statusor.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/parser/hlo_lexer.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/layout.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/xla_data.pb.h"

namespace xla {
Expand Down Expand Up @@ -118,7 +122,8 @@ class HloParser {

private:
static std::unique_ptr<HloParser> CreateHloParserForTests(
absl::string_view str);
absl::string_view str,
const HloParserOptions& options = HloParserOptions());
friend class VerifiedHloModule;
};

Expand Down
8 changes: 3 additions & 5 deletions xla/hlo/testlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,16 @@ cc_library(
deps = [
"//xla:shape_util",
"//xla:status_macros",
"//xla:types",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/service:hlo_module_config",
"//xla/service:hlo_verifier",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
"//xla/tsl/platform:test",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:test",
],
)

Expand Down
11 changes: 5 additions & 6 deletions xla/hlo/testlib/verified_hlo_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/status_macros.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/test.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/status.h"
#include "tsl/platform/test.h"

namespace xla {

absl::Status VerifiedHloModule::ParseHloStringAndVerifyModule(
absl::string_view str) {
absl::string_view str, const HloParserOptions& options) {
TF_RET_CHECK(computation_count() == 0);
auto parser = HloParser::CreateHloParserForTests(str);
auto parser = HloParser::CreateHloParserForTests(str, options);
TF_RETURN_IF_ERROR(parser->Run(this));
return Verify();
}
Expand Down
11 changes: 8 additions & 3 deletions xla/hlo/testlib/verified_hlo_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@ limitations under the License.
#ifndef XLA_HLO_TESTLIB_VERIFIED_HLO_MODULE_H_
#define XLA_HLO_TESTLIB_VERIFIED_HLO_MODULE_H_

#include <cstdint>
#include <functional>
#include <string>

#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_verifier.h"
#include "xla/shape.h"
#include "xla/types.h"
#include "tsl/platform/status.h"
#include "xla/util.h"

namespace xla {

Expand All @@ -48,7 +51,9 @@ class VerifiedHloModule : public HloModule {
// builds the VerifiedHloModule in place. Before calling this method, the
// module must be empty (no computations). Finally verifies the module using
// HloVerifier and returns the status.
absl::Status ParseHloStringAndVerifyModule(absl::string_view str);
absl::Status ParseHloStringAndVerifyModule(
absl::string_view str,
const HloParserOptions& options = HloParserOptions());

// Verifies the module and flags any error with ADD_FAILURE. 'message' is
// included in the failure message.
Expand Down
1 change: 1 addition & 0 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/hlo/testlib:test_helpers",
"//xla/hlo/testlib:verified_hlo_module",
Expand Down
7 changes: 5 additions & 2 deletions xla/tests/hlo_runner_agnostic_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "xla/error_spec.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/hlo/testlib/verified_hlo_module.h"
#include "xla/literal.h"
Expand Down Expand Up @@ -84,13 +85,15 @@ HloRunnerAgnosticTestBase::ParseAndReturnVerifiedModule(

absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
HloRunnerAgnosticTestBase::ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config) {
absl::string_view hlo_text, const HloModuleConfig& config,
const HloParserOptions& parser_options) {
auto module = std::make_unique<VerifiedHloModule>(
TestName(), config, verifier_layout_sensitive(),
allow_mixed_precision_in_hlo_verifier(),
test_runner_->device_shape_size_fn(),
instruction_can_change_layout_func());
TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
TF_RETURN_IF_ERROR(
module->ParseHloStringAndVerifyModule(hlo_text, parser_options));
UpdateEntryComputationLayout(module.get());
return std::move(module);
}
Expand Down
12 changes: 10 additions & 2 deletions xla/tests/hlo_runner_agnostic_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/parser/hlo_parser.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/hlo/testlib/test_helpers.h"
#include "xla/hlo/testlib/verified_hlo_module.h"
Expand Down Expand Up @@ -105,9 +106,16 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase {
ParseAndReturnVerifiedModule(absl::string_view hlo_text,
int64_t replica_count = 1,
int64_t num_partitions = 1);
// Parses the given string and returns module as a VerifiedHloModule.
//
// To obtain a HloModuleConfig with a specific replica and partition count and
// no further customization, either use the overload above or use
// GetModuleConfigForTest. The latter option may be useful if you want to pass
// custom HloParserOptions as well.
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
ParseAndReturnVerifiedModule(absl::string_view hlo_text,
const HloModuleConfig& config);
ParseAndReturnVerifiedModule(
absl::string_view hlo_text, const HloModuleConfig& config,
const HloParserOptions& parser_options = HloParserOptions());

HloComputation* AddEntryComputationAndUpdateEntryComputationLayout(
HloModule*, std::unique_ptr<HloComputation> computation);
Expand Down

0 comments on commit fe63f8a

Please sign in to comment.