From 519df881f2f78f66d39ab7020f37982f7e026130 Mon Sep 17 00:00:00 2001 From: Rachel Han Date: Tue, 4 Feb 2025 14:59:31 -0800 Subject: [PATCH] Add argument to fix a random seed when generating random arguments for HLO runner. Also add OutputFormat so that literal dumps can be saved as a pb file. PiperOrigin-RevId: 723237912 --- xla/tools/multihost_hlo_runner/BUILD | 1 + .../functional_hlo_runner.cc | 73 ++++++++++++++++--- .../functional_hlo_runner.h | 19 ++++- .../functional_hlo_runner_test.cc | 19 +++++ 4 files changed, 98 insertions(+), 14 deletions(-) diff --git a/xla/tools/multihost_hlo_runner/BUILD b/xla/tools/multihost_hlo_runner/BUILD index 7556a11ec8919..d8490019f4e7d 100644 --- a/xla/tools/multihost_hlo_runner/BUILD +++ b/xla/tools/multihost_hlo_runner/BUILD @@ -224,6 +224,7 @@ xla_test( "//xla:xla_proto_cc", "//xla/hlo/testlib:filecheck", "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_executable", "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", "//xla/service:hlo_proto_cc", "//xla/tsl/lib/core:status_test_util", diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index 14f8888266c1e..5bb5f790407d4 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -189,6 +189,37 @@ std::string AbslUnparseFlag(InputFormat input_format) { } } +bool AbslParseFlag(absl::string_view text, OutputFormat* output_format, + std::string* error) { + if (text == "text") { + *output_format = OutputFormat::kText; + return true; + } + if (text == "proto_binary") { + *output_format = OutputFormat::kProtoBinary; + return true; + } + if (text == "proto_text") { + *output_format = OutputFormat::kProtoText; + return true; + } + *error = "unknown value for enumeration"; + return false; +} + +std::string AbslUnparseFlag(OutputFormat output_format) { + switch (output_format) { + case OutputFormat::kText: + return "text"; + case OutputFormat::kProtoBinary: + return "proto_binary"; + case OutputFormat::kProtoText: + return "proto_text"; + default: + return absl::StrCat(output_format); + } +} + bool AbslParseFlag(absl::string_view text, FunctionalHloRunner::ModuleArgumentMode* argument_mode, std::string* error) { @@ -442,7 +473,7 @@ FunctionalHloRunner::CreateExecutableBuildOptionsFromExecutionOptions( absl::Status FunctionalHloRunner::DumpOutput( const FunctionalHloRunner::PerDeviceLiteralVecType& output, - absl::string_view dump_output_to, int task_id) { + absl::string_view dump_output_to, int task_id, OutputFormat output_format) { std::vector output_path_vec = absl::StrSplit(dump_output_to, '.'); std::string suffix = output_path_vec.back(); @@ -458,12 +489,30 @@ absl::Status FunctionalHloRunner::DumpOutput( for (int literal_id = 0; literal_id < literal_vec.size(); ++literal_id) { output_path_vec[literal_id_index] = absl::StrCat("literal_", literal_id); std::string literal_path = absl::StrJoin(output_path_vec, "."); - CHECK_EQ(suffix, std::string("txt")); - absl::Status write_status = - tsl::WriteStringToFile(tsl::Env::Default(), literal_path, - literal_vec[literal_id].ToString()); - if (!write_status.ok()) { - return write_status; + switch (output_format) { + case OutputFormat::kText: { + CHECK_EQ(suffix, std::string("txt")); + absl::Status write_status = + tsl::WriteStringToFile(tsl::Env::Default(), literal_path, + literal_vec[literal_id].ToString()); + if (!write_status.ok()) { + return write_status; + } + } break; + case OutputFormat::kProtoBinary: { + CHECK_EQ(suffix, std::string("pb")); + TF_RETURN_IF_ERROR( + tsl::WriteBinaryProto(tsl::Env::Default(), literal_path, + literal_vec[literal_id].ToProto())); + break; + } + case OutputFormat::kProtoText: { + CHECK_EQ(suffix, std::string("pbtxt")); + TF_RETURN_IF_ERROR( + tsl::WriteTextProto(tsl::Env::Default(), literal_path, + literal_vec[literal_id].ToProto())); + break; + } } } } @@ -545,7 +594,8 @@ FunctionalHloRunner::LoadAndRun(PjRtClient& client, const RunningOptions& running_options, absl::string_view hlo_text, InputFormat input_format, - const PerDeviceLiteralVecType& arguments) { + const PerDeviceLiteralVecType& arguments, + std::minstd_rand0* engine) { // We only support SPMD as of now, i.e., all devices are supposed // to execute the same HLO module. // Currently there is no mechanism to map the loaded arguments to @@ -577,7 +627,7 @@ FunctionalHloRunner::LoadAndRun(PjRtClient& client, return CompileAndRun( client, debug_options, preproc_options, compile_options, running_options, - hlo_module_and_arguments.hlo_module.get(), loaded_arguments); + hlo_module_and_arguments.hlo_module.get(), loaded_arguments, engine); } absl::Status FunctionalHloRunner::LoadAndCompile( @@ -723,12 +773,13 @@ FunctionalHloRunner::CompileAndRun(PjRtClient& client, const CompileOptions& compile_options, const RunningOptions& running_options, HloModule* hlo_module, - const PerDeviceLiteralVecType& arguments) { + const PerDeviceLiteralVecType& arguments, + std::minstd_rand0* engine) { TF_ASSIGN_OR_RETURN(std::unique_ptr executable, Compile(client, hlo_module, debug_options, preproc_options, compile_options)); - return Run(client, executable.get(), arguments, running_options); + return Run(client, executable.get(), arguments, running_options, engine); } namespace { diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner.h b/xla/tools/multihost_hlo_runner/functional_hlo_runner.h index 1f4e33dd4a355..1351e2860b9d6 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner.h +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner.h @@ -66,6 +66,12 @@ enum class InputFormat { // in conjunction with xla_dump_as_text. }; +enum class OutputFormat : std::uint8_t { + kText, // Text format returned by Literal::ToString(). + kProtoBinary, // Protobuf binary format of an xla::LiteralProto message. + kProtoText, // Protobuf text format of an xla::LiteralProto message. +}; + // Interface for profiler plugins. If being set in RunningOptions, profiling // session will be created for the last run of the HLO module. class ProfilerInterface { @@ -134,6 +140,10 @@ bool AbslParseFlag(absl::string_view text, InputFormat* input_format, std::string* error); std::string AbslUnparseFlag(InputFormat input_format); +bool AbslParseFlag(absl::string_view text, OutputFormat* output_format, + std::string* error); +std::string AbslUnparseFlag(OutputFormat output_format); + // FunctionalHloRunner takes an HLO module as input and runs the HLO module // on a single or multiple hosts with various options (e.g. SPMD). The HLO // module can be pre- or post-optimizations. @@ -346,7 +356,8 @@ class FunctionalHloRunner { const PreprocessingOptions& preproc_options, const CompileOptions& compile_options, const RunningOptions& running_options, absl::string_view hlo_text, - InputFormat input_format, const PerDeviceLiteralVecType& arguments = {}); + InputFormat input_format, const PerDeviceLiteralVecType& arguments = {}, + std::minstd_rand0* engine = nullptr); // Loads and compiles an HLO for debugging purposes. // @@ -368,7 +379,8 @@ class FunctionalHloRunner { const PreprocessingOptions& preproc_options, const CompileOptions& compile_options, const RunningOptions& running_options, HloModule* hlo_module, - const PerDeviceLiteralVecType& arguments = {}); + const PerDeviceLiteralVecType& arguments = {}, + std::minstd_rand0* engine = nullptr); // Compiles the HLO module. static absl::StatusOr> Compile( @@ -429,7 +441,8 @@ class FunctionalHloRunner { static absl::Status DumpOutput( const FunctionalHloRunner::PerDeviceLiteralVecType& output, - absl::string_view dump_output_to, int task_id); + absl::string_view dump_output_to, int task_id, + OutputFormat output_format = OutputFormat::kText); private: // Calculates the requested number of replicas and partitions. diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc index 1b8a2f7521676..8a09796e0a913 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -31,6 +32,7 @@ limitations under the License. #include "xla/debug_options_flags.h" #include "xla/hlo/testlib/filecheck.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h" #include "xla/service/hlo.pb.h" #include "xla/status_macros.h" @@ -633,6 +635,23 @@ TEST_F(FunctionalHloRunnerTest, ReadHloUnoptimizedSnapshot) { hlo_module_and_arguments_from_binary.arguments.size()); } +TEST_F(FunctionalHloRunnerTest, FixFakeArguments) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetPjRtClient()); + + // Options corresponding to --num_replicas=1 --num_partitions=1 + xla::DebugOptions debug_options; + FunctionalHloRunner::PreprocessingOptions preproc_options; + CompileOptions compile_options; + FunctionalHloRunner::RunningOptions running_options; + + std::minstd_rand0 engine(42); + TF_EXPECT_OK(FunctionalHloRunner::LoadAndRun( + *client, debug_options, preproc_options, compile_options, running_options, + {GetHloPath("single_device.hlo")}, InputFormat::kText, + /*arguments=*/{}, /*engine=*/&engine)); +} + } // namespace } // namespace xla