Skip to content

Commit

Permalink
Add argument to fix a random seed when generating random arguments fo…
Browse files Browse the repository at this point in the history
…r HLO runner. Also add OutputFormat so that literal dumps can be saved as a pb file.

PiperOrigin-RevId: 723237912
  • Loading branch information
hanrach9 authored and Google-ML-Automation committed Feb 4, 2025
1 parent 680904e commit 519df88
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 14 deletions.
1 change: 1 addition & 0 deletions xla/tools/multihost_hlo_runner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ xla_test(
"//xla:xla_proto_cc",
"//xla/hlo/testlib:filecheck",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_executable",
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
"//xla/service:hlo_proto_cc",
"//xla/tsl/lib/core:status_test_util",
Expand Down
73 changes: 62 additions & 11 deletions xla/tools/multihost_hlo_runner/functional_hlo_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,37 @@ std::string AbslUnparseFlag(InputFormat input_format) {
}
}

bool AbslParseFlag(absl::string_view text, OutputFormat* output_format,
std::string* error) {
if (text == "text") {
*output_format = OutputFormat::kText;
return true;
}
if (text == "proto_binary") {
*output_format = OutputFormat::kProtoBinary;
return true;
}
if (text == "proto_text") {
*output_format = OutputFormat::kProtoText;
return true;
}
*error = "unknown value for enumeration";
return false;
}

std::string AbslUnparseFlag(OutputFormat output_format) {
switch (output_format) {
case OutputFormat::kText:
return "text";
case OutputFormat::kProtoBinary:
return "proto_binary";
case OutputFormat::kProtoText:
return "proto_text";
default:
return absl::StrCat(output_format);
}
}

bool AbslParseFlag(absl::string_view text,
FunctionalHloRunner::ModuleArgumentMode* argument_mode,
std::string* error) {
Expand Down Expand Up @@ -442,7 +473,7 @@ FunctionalHloRunner::CreateExecutableBuildOptionsFromExecutionOptions(

absl::Status FunctionalHloRunner::DumpOutput(
const FunctionalHloRunner::PerDeviceLiteralVecType& output,
absl::string_view dump_output_to, int task_id) {
absl::string_view dump_output_to, int task_id, OutputFormat output_format) {
std::vector<std::string> output_path_vec =
absl::StrSplit(dump_output_to, '.');
std::string suffix = output_path_vec.back();
Expand All @@ -458,12 +489,30 @@ absl::Status FunctionalHloRunner::DumpOutput(
for (int literal_id = 0; literal_id < literal_vec.size(); ++literal_id) {
output_path_vec[literal_id_index] = absl::StrCat("literal_", literal_id);
std::string literal_path = absl::StrJoin(output_path_vec, ".");
CHECK_EQ(suffix, std::string("txt"));
absl::Status write_status =
tsl::WriteStringToFile(tsl::Env::Default(), literal_path,
literal_vec[literal_id].ToString());
if (!write_status.ok()) {
return write_status;
switch (output_format) {
case OutputFormat::kText: {
CHECK_EQ(suffix, std::string("txt"));
absl::Status write_status =
tsl::WriteStringToFile(tsl::Env::Default(), literal_path,
literal_vec[literal_id].ToString());
if (!write_status.ok()) {
return write_status;
}
} break;
case OutputFormat::kProtoBinary: {
CHECK_EQ(suffix, std::string("pb"));
TF_RETURN_IF_ERROR(
tsl::WriteBinaryProto(tsl::Env::Default(), literal_path,
literal_vec[literal_id].ToProto()));
break;
}
case OutputFormat::kProtoText: {
CHECK_EQ(suffix, std::string("pbtxt"));
TF_RETURN_IF_ERROR(
tsl::WriteTextProto(tsl::Env::Default(), literal_path,
literal_vec[literal_id].ToProto()));
break;
}
}
}
}
Expand Down Expand Up @@ -545,7 +594,8 @@ FunctionalHloRunner::LoadAndRun(PjRtClient& client,
const RunningOptions& running_options,
absl::string_view hlo_text,
InputFormat input_format,
const PerDeviceLiteralVecType& arguments) {
const PerDeviceLiteralVecType& arguments,
std::minstd_rand0* engine) {
// We only support SPMD as of now, i.e., all devices are supposed
// to execute the same HLO module.
// Currently there is no mechanism to map the loaded arguments to
Expand Down Expand Up @@ -577,7 +627,7 @@ FunctionalHloRunner::LoadAndRun(PjRtClient& client,

return CompileAndRun(
client, debug_options, preproc_options, compile_options, running_options,
hlo_module_and_arguments.hlo_module.get(), loaded_arguments);
hlo_module_and_arguments.hlo_module.get(), loaded_arguments, engine);
}

absl::Status FunctionalHloRunner::LoadAndCompile(
Expand Down Expand Up @@ -723,12 +773,13 @@ FunctionalHloRunner::CompileAndRun(PjRtClient& client,
const CompileOptions& compile_options,
const RunningOptions& running_options,
HloModule* hlo_module,
const PerDeviceLiteralVecType& arguments) {
const PerDeviceLiteralVecType& arguments,
std::minstd_rand0* engine) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtLoadedExecutable> executable,
Compile(client, hlo_module, debug_options,
preproc_options, compile_options));

return Run(client, executable.get(), arguments, running_options);
return Run(client, executable.get(), arguments, running_options, engine);
}

namespace {
Expand Down
19 changes: 16 additions & 3 deletions xla/tools/multihost_hlo_runner/functional_hlo_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ enum class InputFormat {
// in conjunction with xla_dump_as_text.
};

enum class OutputFormat : std::uint8_t {
kText, // Text format returned by Literal::ToString().
kProtoBinary, // Protobuf binary format of an xla::LiteralProto message.
kProtoText, // Protobuf text format of an xla::LiteralProto message.
};

// Interface for profiler plugins. If being set in RunningOptions, profiling
// session will be created for the last run of the HLO module.
class ProfilerInterface {
Expand Down Expand Up @@ -134,6 +140,10 @@ bool AbslParseFlag(absl::string_view text, InputFormat* input_format,
std::string* error);
std::string AbslUnparseFlag(InputFormat input_format);

bool AbslParseFlag(absl::string_view text, OutputFormat* output_format,
std::string* error);
std::string AbslUnparseFlag(OutputFormat output_format);

// FunctionalHloRunner takes an HLO module as input and runs the HLO module
// on a single or multiple hosts with various options (e.g. SPMD). The HLO
// module can be pre- or post-optimizations.
Expand Down Expand Up @@ -346,7 +356,8 @@ class FunctionalHloRunner {
const PreprocessingOptions& preproc_options,
const CompileOptions& compile_options,
const RunningOptions& running_options, absl::string_view hlo_text,
InputFormat input_format, const PerDeviceLiteralVecType& arguments = {});
InputFormat input_format, const PerDeviceLiteralVecType& arguments = {},
std::minstd_rand0* engine = nullptr);

// Loads and compiles an HLO for debugging purposes.
//
Expand All @@ -368,7 +379,8 @@ class FunctionalHloRunner {
const PreprocessingOptions& preproc_options,
const CompileOptions& compile_options,
const RunningOptions& running_options, HloModule* hlo_module,
const PerDeviceLiteralVecType& arguments = {});
const PerDeviceLiteralVecType& arguments = {},
std::minstd_rand0* engine = nullptr);

// Compiles the HLO module.
static absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> Compile(
Expand Down Expand Up @@ -429,7 +441,8 @@ class FunctionalHloRunner {

static absl::Status DumpOutput(
const FunctionalHloRunner::PerDeviceLiteralVecType& output,
absl::string_view dump_output_to, int task_id);
absl::string_view dump_output_to, int task_id,
OutputFormat output_format = OutputFormat::kText);

private:
// Calculates the requested number of replicas and partitions.
Expand Down
19 changes: 19 additions & 0 deletions xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <cstdlib>
#include <memory>
#include <random>
#include <string>
#include <vector>

Expand All @@ -31,6 +32,7 @@ limitations under the License.
#include "xla/debug_options_flags.h"
#include "xla/hlo/testlib/filecheck.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/plugin/xla_gpu/xla_gpu_client_options.h"
#include "xla/service/hlo.pb.h"
#include "xla/status_macros.h"
Expand Down Expand Up @@ -633,6 +635,23 @@ TEST_F(FunctionalHloRunnerTest, ReadHloUnoptimizedSnapshot) {
hlo_module_and_arguments_from_binary.arguments.size());
}

TEST_F(FunctionalHloRunnerTest, FixFakeArguments) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::PjRtClient> client,
GetPjRtClient());

// Options corresponding to --num_replicas=1 --num_partitions=1
xla::DebugOptions debug_options;
FunctionalHloRunner::PreprocessingOptions preproc_options;
CompileOptions compile_options;
FunctionalHloRunner::RunningOptions running_options;

std::minstd_rand0 engine(42);
TF_EXPECT_OK(FunctionalHloRunner::LoadAndRun(
*client, debug_options, preproc_options, compile_options, running_options,
{GetHloPath("single_device.hlo")}, InputFormat::kText,
/*arguments=*/{}, /*engine=*/&engine));
}

} // namespace
} // namespace xla

Expand Down

0 comments on commit 519df88

Please sign in to comment.