Skip to content

Commit

Permalink
[xla:cpu] Add support for running multiple executions inside each HLO…
Browse files Browse the repository at this point in the history
… benchmark iteration

PiperOrigin-RevId: 720202643
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 27, 2025
1 parent 7f3cb0c commit bd69e17
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 34 deletions.
8 changes: 5 additions & 3 deletions xla/backends/cpu/benchmarks/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ cc_library(
"//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
"//xla/service:hlo_module_config",
"//xla/tests:test_utils",
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test_benchmark",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test_benchmark",
],
)

Expand Down
18 changes: 10 additions & 8 deletions xla/backends/cpu/benchmarks/concatenate_benchmark_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ limitations under the License.
namespace xla::cpu {

static void BM_ConcatenateTwoR3F32(benchmark::State& state) {
auto disable_parallel_backend = !static_cast<bool>(state.range(0));
bool disable_parallel_backend = !static_cast<bool>(state.range(0));
int64_t dims[3] = {state.range(1), state.range(2), state.range(3)};
Shape shape = ShapeUtil::MakeShape(F32, dims);
int64_t axis = state.range(4);
Expand All @@ -57,14 +57,16 @@ static void BM_ConcatenateTwoR3F32(benchmark::State& state) {
auto p0 = *LiteralUtil::CreateRandomLiteral<F32>(shape, &engine, 1.0f, 0.1f);
auto p1 = *LiteralUtil::CreateRandomLiteral<F32>(shape, &engine, 1.0f, 0.1f);

HloBenchmarkOptions benchmark_options;
benchmark_options.disable_parallel_task_assigner = disable_parallel_backend;

std::vector<const Literal*> args = {&p0, &p1};
CHECK_OK(RunHloBenchmark(
state, hlo, args,
{{"$shape_repr", absl::StrJoin(dims, "x")},
{"$shape", absl::StrJoin(dims, ",")},
{"$out_shape", absl::StrJoin(out_dims, ",")},
{"$axis", absl::StrCat(axis)}},
/*disable_parallel_task_assigner=*/disable_parallel_backend));
CHECK_OK(RunHloBenchmark(state, hlo, args,
{{"$shape_repr", absl::StrJoin(dims, "x")},
{"$shape", absl::StrJoin(dims, ",")},
{"$out_shape", absl::StrJoin(out_dims, ",")},
{"$axis", absl::StrCat(axis)}},
benchmark_options));
}

BENCHMARK(BM_ConcatenateTwoR3F32)
Expand Down
72 changes: 54 additions & 18 deletions xla/backends/cpu/benchmarks/hlo_benchmark_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ limitations under the License.
#include <memory>
#include <vector>

#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/types/span.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
Expand All @@ -33,20 +35,22 @@ limitations under the License.
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
#include "xla/service/hlo_module_config.h"
#include "xla/tests/test_utils.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test_benchmark.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test_benchmark.h"
#include "xla/tsl/platform/threadpool.h"

namespace xla::cpu {

absl::Status RunHloBenchmark(benchmark::State& state,
absl::string_view hlo_module,
absl::Span<const Literal* const> args,
StrToStrMapping replacements,
bool disable_parallel_task_assigner) {
xla::CpuClientOptions options;
const HloBenchmarkOptions& benchmark_options) {
xla::CpuClientOptions client_options;
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
xla::GetXlaPjrtCpuClient(options));
xla::GetXlaPjrtCpuClient(client_options));
PjRtDevice* device = client->devices().front();

TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
Expand All @@ -58,7 +62,7 @@ absl::Status RunHloBenchmark(benchmark::State& state,

// Compile HLO module to executable.
CompileOptions compile_options;
if (disable_parallel_task_assigner) {
if (benchmark_options.disable_parallel_task_assigner) {
compile_options.executable_build_options.mutable_debug_options()
->add_xla_disable_hlo_passes("cpu-parallel-task-assigner");
}
Expand Down Expand Up @@ -97,7 +101,8 @@ absl::Status RunHloBenchmark(benchmark::State& state,
}
}

// Execute in synchronous mode to avoid thread hops.
// Execute in synchronous mode to avoid thread hops, as we anyway use our own
// thread pool if we need to run multiple executions in parallel.
ExecuteOptions execute_options;
execute_options.execution_mode = ExecuteOptions::ExecutionMode::kSynchronous;

Expand All @@ -107,16 +112,47 @@ absl::Status RunHloBenchmark(benchmark::State& state,
args_ptrs.push_back(arg.get());
}

CHECK_GE(benchmark_options.num_executions, 1);
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> results(
benchmark_options.num_executions);

// Thread pool for dispatching multiple executions in parallel.
tsl::thread::ThreadPool threads(tsl::Env::Default(), "hlo_benchmark_runner",
benchmark_options.num_executions);

// Warmup executable.
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<PjRtBuffer>> results,
executable->ExecuteSharded(args_ptrs, device, execute_options));
TF_ASSIGN_OR_RETURN(results[0], executable->ExecuteSharded(args_ptrs, device,
execute_options));

// Benchmark executable.
for (auto _ : state) {
TF_ASSIGN_OR_RETURN(results, executable->ExecuteSharded(args_ptrs, device,
execute_options));
tsl::testing::DoNotOptimize(results);
if (benchmark_options.num_executions == 1) {
// Single execution always runs in the caller thread.
results[0] =
executable->ExecuteSharded(args_ptrs, device, execute_options)
.value();
} else {
// Multiple executions run in parallel.
absl::BlockingCounter counter(benchmark_options.num_executions);

for (size_t i = 0; i < benchmark_options.num_executions; ++i) {
threads.Schedule([&, i]() {
results[i] =
executable->ExecuteSharded(args_ptrs, device, execute_options)
.value();
counter.DecrementCount();
});
}

counter.Wait();
}

// Wait for all results to be ready.
for (size_t i = 0; i < benchmark_options.num_executions; ++i) {
for (const auto& result : results[i]) {
CHECK_OK(result->GetReadyFuture().Await());
}
}
}

return absl::OkStatus();
Expand All @@ -125,10 +161,10 @@ absl::Status RunHloBenchmark(benchmark::State& state,
absl::Status CompileHloBenchmark(benchmark::State& state,
absl::string_view hlo_module,
StrToStrMapping replacements,
bool disable_parallel_task_assigner) {
xla::CpuClientOptions options;
const HloBenchmarkOptions& benchmark_options) {
xla::CpuClientOptions client_options;
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
xla::GetXlaPjrtCpuClient(options));
xla::GetXlaPjrtCpuClient(client_options));

TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
ParseAndReturnUnverifiedModule(
Expand All @@ -138,7 +174,7 @@ absl::Status CompileHloBenchmark(benchmark::State& state,
XlaComputation computation(module->ToProto());

CompileOptions compile_options;
if (disable_parallel_task_assigner) {
if (benchmark_options.disable_parallel_task_assigner) {
compile_options.executable_build_options.mutable_debug_options()
->add_xla_disable_hlo_passes("cpu-parallel-task-assigner");
}
Expand Down
19 changes: 14 additions & 5 deletions xla/backends/cpu/benchmarks/hlo_benchmark_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ limitations under the License.
#ifndef XLA_BACKENDS_CPU_BENCHMARKS_HLO_BENCHMARK_RUNNER_H_
#define XLA_BACKENDS_CPU_BENCHMARKS_HLO_BENCHMARK_RUNNER_H_

#include <cstdint>
#include <initializer_list>
#include <utility>

#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
Expand All @@ -28,6 +32,11 @@ namespace xla::cpu {
using StrToStrMapping =
std::initializer_list<std::pair<absl::string_view, absl::string_view>>;

struct HloBenchmarkOptions {
int32_t num_executions = 1;
bool disable_parallel_task_assigner = false;
};

// Runs the given HLO module as a benchmark.
//
// The HLO text can be interpolated using the given string replacements. Each
Expand All @@ -41,16 +50,16 @@ absl::Status RunHloBenchmark(benchmark::State& state,
absl::string_view hlo_module,
absl::Span<const Literal* const> args,
StrToStrMapping replacements = {},
bool disable_parallel_task_assigner = false);
const HloBenchmarkOptions& benchmark_options = {});

// Benchmarks the given HLO's compilation time.
//
// Takes the same options as RunHloBenchmark, except no arguments since the
// HLO is only compiled, not run.
absl::Status CompileHloBenchmark(benchmark::State& state,
absl::string_view hlo_module,
StrToStrMapping replacements = {},
bool disable_parallel_task_assigner = false);
absl::Status CompileHloBenchmark(
benchmark::State& state, absl::string_view hlo_module,
StrToStrMapping replacements = {},
const HloBenchmarkOptions& benchmark_options = {});

} // namespace xla::cpu

Expand Down

0 comments on commit bd69e17

Please sign in to comment.