Skip to content

Commit

Permalink
[XLA:GPU] Always run under profiles in the multihost runner
Browse files Browse the repository at this point in the history
And output ns-precision execution duration to log.

PiperOrigin-RevId: 718864097
  • Loading branch information
mooskagh authored and Google-ML-Automation committed Jan 23, 2025
1 parent 5acf7da commit 4a9c615
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion xla/tools/multihost_hlo_runner/hlo_runner_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ struct HloRunnerConfig {
std::string execution_options_path = "";
int64_t gpu_client_initialization_timeout_sec = 300;
float gpu_client_mem_fraction = xla::GpuAllocatorConfig{}.memory_fraction;
bool profile_execution = false;
};

} // namespace
Expand Down Expand Up @@ -249,8 +250,14 @@ static absl::Status RunMultihostHloRunner(int argc, char** argv,
}
CHECK(env.client != nullptr);

std::vector<ExecutionProfile> execution_profiles;
if (opts.profile_execution) {
running_options.execution_profiles = &execution_profiles;
}

for (int c = 1; c < argc; c++) {
const char* filename = argv[c];
execution_profiles.clear();
std::cout << "\n** Running " << filename << " **\n";
if (opts.should_run) {
TF_RETURN_IF_ERROR(xla::FunctionalHloRunner::LoadAndRunAndDump(
Expand All @@ -262,6 +269,11 @@ static absl::Status RunMultihostHloRunner(int argc, char** argv,
*env.client, GetDebugOptionsFromFlags(), preproc_options,
raw_compile_options, argv[c], opts.input_format, opts.task_id));
}
for (int i = 0; i < execution_profiles.size(); ++i) {
std::cout << "## Execution time, file=" << filename << " repeat=" << i
<< " duration=" << execution_profiles[i].compute_time_ns()
<< "ns" << std::endl;
}
}
return absl::OkStatus();
}
Expand Down Expand Up @@ -337,7 +349,10 @@ int main(int argc, char** argv) {
tsl::Flag("gpu_client_mem_fraction", &opts.gpu_client_mem_fraction,
"The maximum fraction of available memory to allocate in range "
"of (0.0, 1.0). Same as XLA_CLIENT_MEM_FRACTION in the Python "
"client. Only used with the BFC allocator.")};
"client. Only used with the BFC allocator."),
tsl::Flag(
"profile_execution", &opts.profile_execution,
"If set, we will profile the execution and print the results.")};

xla::AppendDebugOptionsFlags(&flag_list);

Expand Down

0 comments on commit 4a9c615

Please sign in to comment.