Skip to content

Commit

Permalink
[xla:cpu:xnn] Take into account operand sizes when deciding if xnn fu…
Browse files Browse the repository at this point in the history
…sion needs a thread pool

PiperOrigin-RevId: 719688159
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 26, 2025
1 parent 1321030 commit 79b0e1d
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 142 deletions.
2 changes: 2 additions & 0 deletions xla/backends/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ cc_library(
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/runtime:dot_lib",
"//xla/hlo/ir:hlo",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status:statusor",
],
)
49 changes: 0 additions & 49 deletions xla/backends/cpu/runtime/work_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,6 @@ class Worker {
const Eigen::ThreadPoolDevice* device, size_t num_workers,
size_t num_tasks, ParallelTask&& parallel_task);

// Compute the number of workers that should be used for parallel operation,
// by executing the first task, measuring the compute time and estimating how
// many workers are needed, so that each worker will handle `worker_timeslice`
// amount of compute.
template <typename ParallelTask>
static std::conditional_t<
std::is_same_v<std::invoke_result_t<ParallelTask, size_t>, absl::Status>,
absl::StatusOr<size_t>, size_t>
ComputeOptimalNumWorkers(absl::Duration worker_timeslice, size_t num_threads,
size_t num_tasks, ParallelTask& parallel_task);

private:
template <typename ParallelTask>
struct ParallelizeContext;
Expand Down Expand Up @@ -345,44 +334,6 @@ ABSL_ATTRIBUTE_ALWAYS_INLINE tsl::AsyncValueRef<tsl::Chain> Worker::Parallelize(
return execute_event;
}

template <typename ParallelTask>
std::conditional_t<
std::is_same_v<std::invoke_result_t<ParallelTask, size_t>, absl::Status>,
absl::StatusOr<size_t>, size_t>
Worker::ComputeOptimalNumWorkers(absl::Duration worker_timeslice,
size_t num_threads, size_t num_tasks,
ParallelTask& parallel_task) {
// Run first task in the caller thread, to estimate the number of parallel
// workers that should be used for parallel operation.
uint64_t start_ns = tsl::Env::Default()->NowNanos();

using R = std::invoke_result_t<ParallelTask, size_t>;
static_assert(std::is_same_v<R, absl::Status> || std::is_void_v<R>,
"Unsupported parallel task return type");

if constexpr (std::is_same_v<R, absl::Status>) {
TF_RETURN_IF_ERROR(parallel_task(0));
} else {
parallel_task(0);
}

uint64_t end_ns = tsl::Env::Default()->NowNanos();

// We assume that all tasks take roughly the same amount of compute and we
// can estimate the total workload duration by multiplying the number of
// remaining tasks by the duration of a single task.
size_t workload_ns = (num_tasks - 1) * (end_ns - start_ns);
size_t timeslice_ns = absl::ToInt64Nanoseconds(worker_timeslice);

// Get the number of workers, so that each worker will take roughly
// `worker_timeslice` amount of compute. Don't create more workers than
// the number of threads in the thread pool or the number of tasks.
size_t num_workers =
std::min(std::min(num_tasks - 1, num_threads),
tsl::MathUtil::CeilOfRatio(workload_ns, timeslice_ns));
return std::min(num_workers, size_t{std::numeric_limits<uint16_t>::max()});
}

} // namespace xla::cpu

#endif // XLA_BACKENDS_CPU_RUNTIME_WORK_QUEUE_H_
17 changes: 0 additions & 17 deletions xla/backends/cpu/runtime/work_queue_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,23 +142,6 @@ TEST(WorkQueueTest, WorkerParallelize) {
EXPECT_EQ(data, expected);
}

TEST(WorkQueueTest, ComputeOptimalNumWorkers) {
{ // Parallel task with void return type.
auto noop = [](size_t task_index) {};
size_t num_workers =
Worker::ComputeOptimalNumWorkers(absl::Nanoseconds(10), 8, 1024, noop);
EXPECT_LE(num_workers, 8);
}

{ // Parallel task with absl:Status return type.
auto noop = [](size_t task_index) { return absl::OkStatus(); };
TF_ASSERT_OK_AND_ASSIGN(
size_t num_workers,
Worker::ComputeOptimalNumWorkers(absl::Nanoseconds(10), 8, 1024, noop));
EXPECT_LE(num_workers, 8);
}
}

//===----------------------------------------------------------------------===//
// Performance benchmarks.
//===----------------------------------------------------------------------===//
Expand Down
71 changes: 10 additions & 61 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,73 +91,22 @@ ABSL_ATTRIBUTE_ALWAYS_INLINE void ParallelLoopRunner::ScheduleAll(
size_t num_tasks, ParallelTask&& parallel_task) {
DCHECK_GT(num_tasks, 1) << "Expected at least two task";

// If done event is already available and we have a worker timeslice, we can
// compute the optimal number of workers for the parallel operation and
// potentially avoid allocating count down counter altogether.
if (ABSL_PREDICT_TRUE(done_event_.IsConcrete() && worker_timeslice_)) {
size_t optimal_num_workers = Worker::ComputeOptimalNumWorkers(
*worker_timeslice_, num_threads(), num_tasks, parallel_task);

// Execute remaining tasks in the caller thread if we have a single worker.
if (ABSL_PREDICT_TRUE(optimal_num_workers == 1)) {
for (size_t i = 1; i < num_tasks; ++i) {
parallel_task(i);
}
return;
}

tsl::CountDownAsyncValueRef<tsl::Chain> count_down(optimal_num_workers);
done_event_ = count_down.AsRef();

// Parallelize the remaining tasks (skip the first task that was executed
// when we were computing the number of workers).
Worker::Parallelize(
device_, std::move(count_down), num_tasks - 1,
[parallel_task = std::forward<ParallelTask>(parallel_task)](
size_t task_index) { parallel_task(task_index + 1); });
return;
}

// If `done_event_` is not available, we start with at most `num_threads()`
// workers as we can't run more parallel workers than the number of threads in
// the thread pool. Later we might adjust the number of workers when it's safe
// to execute the first task to measure the execution time.
// Use at most `num_threads()` workers as we can't run more parallel workers
// than the number of threads in the thread pool.
size_t num_workers = std::min(std::min(num_tasks, num_threads()),
size_t{std::numeric_limits<uint16_t>::max()});

tsl::CountDownAsyncValueRef<tsl::Chain> count_down(num_workers);
auto count_down_done = count_down.AsRef();

auto schedule_all =
[this, num_workers, num_tasks, count_down = std::move(count_down),
parallel_task = std::forward<ParallelTask>(parallel_task)]() mutable {
// If we don't have a worker timeslice, we can parallelize the task
// immediately using pre-computed number of workers.
if (ABSL_PREDICT_FALSE(!worker_timeslice_)) {
Worker::Parallelize(device_, std::move(count_down), num_tasks,
std::move(parallel_task));
return;
}

// Compute the optimal number of workers by executing the first task.
size_t optimal_num_workers = Worker::ComputeOptimalNumWorkers(
*worker_timeslice_, num_threads(), num_tasks, parallel_task);
DCHECK_GT(optimal_num_workers, 0);
DCHECK_LE(optimal_num_workers, num_workers);

// Count down for the workers that we don't need.
count_down.CountDown(num_workers - optimal_num_workers);

// Parallelize the remaining tasks (skip the first task that was
// executed when we were computing the number of workers).
Worker::Parallelize(
device_, std::move(count_down), num_tasks - 1,
[parallel_task = std::move(parallel_task)](size_t task_index) {
parallel_task(task_index + 1);
});
};

done_event_.AndThen(std::move(schedule_all));
auto parallelize = [this, num_tasks, count_down = std::move(count_down),
parallel_task =
std::forward<ParallelTask>(parallel_task)] {
Worker::Parallelize(device_, std::move(count_down), num_tasks,
std::move(parallel_task));
};

done_event_.AndThen(std::move(parallelize));
done_event_ = std::move(count_down_done);
}

Expand Down
8 changes: 4 additions & 4 deletions xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class XnnDotThunkTest : public testing::TestWithParam<bool> {
};

TEST_P(XnnDotThunkTest, SimpleDot) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());

auto lhs = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto rhs = LiteralUtil::CreateR2<float>({{4.0, 3.0}, {2.0, 1.0}});
auto out = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
Expand All @@ -61,10 +65,6 @@ TEST_P(XnnDotThunkTest, SimpleDot) {
{"dot"}, dot_dimensions, lhs_slice, shape,
rhs_slice, shape, out_slice, shape));

tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());

Thunk::ExecuteParams params;
params.buffer_allocations = &allocations;
params.intra_op_threadpool = use_threadpool() ? &device : nullptr;
Expand Down
8 changes: 4 additions & 4 deletions xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ class XnnFusionThunkTest : public testing::TestWithParam<bool> {
};

TEST_P(XnnFusionThunkTest, ElementwiseAdd) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());

auto lhs = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
auto rhs = LiteralUtil::CreateR1<float>({4.0, 3.0, 2.0, 1.0});
auto out = LiteralUtil::CreateR1<float>({0.0, 0.0, 0.0, 0.0});
Expand All @@ -110,10 +114,6 @@ TEST_P(XnnFusionThunkTest, ElementwiseAdd) {
XnnFusionThunk::Options{use_threadpool()}, {"fusion"},
{lhs_arg, rhs_arg}, {out_res}, &CreateBinaryAdd));

tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());

Thunk::ExecuteParams params;
params.buffer_allocations = &allocations;
params.intra_op_threadpool = use_threadpool() ? &device : nullptr;
Expand Down
53 changes: 53 additions & 0 deletions xla/backends/cpu/xnn_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,67 @@ limitations under the License.

#include "xla/backends/cpu/xnn_fusion.h"

#include <algorithm>
#include <cstdint>

#include "absl/algorithm/container.h"
#include "absl/status/statusor.h"
#include "xla/backends/cpu/runtime/dot_lib.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/xla_data.pb.h"

namespace xla::cpu {

// Thresholds for when to use thread pool for XNNPACK fusions for different
// HLOs. These numbers picked up randomly and need benchmarks to tune.
static constexpr int64_t kDotThreshold = 10 * 1000;
static constexpr int64_t kDefaultThreshold = 100 * 1000;

static int64_t MaxElementsCount(const Shape& shape) {
int64_t ret = 0;
ShapeUtil::ForEachSubshape(
shape, [&](const Shape& shape, const ShapeIndex& index) {
ret = std::max(ret, ShapeUtil::ElementsIn(shape));
});
return ret;
}

// We rely on a very simple heuristic to determine if thread pool is beneficial
// for XNNPACK fusions. We assume that if the HLO produces a large result (or
// has large operands), thread pool will be beneficial for running operation in
// parallel. For small operations, thread pool overheads are higher than the
// actual computation.
static int64_t MaxElementsCount(const HloInstruction* hlo,
bool include_operands = true) {
int64_t ret = MaxElementsCount(hlo->shape());
if (include_operands) {
for (auto* operand : hlo->operands()) {
ret = std::max(ret, MaxElementsCount(operand->shape()));
}
}
return ret;
}

bool XnnShouldUseThreadPool(const HloInstruction* hlo) {
switch (hlo->opcode()) {
case HloOpcode::kDot:
return MaxElementsCount(hlo) > kDotThreshold;
default:
return MaxElementsCount(hlo) > kDefaultThreshold;
}
}

bool XnnShouldUseThreadPool(const HloComputation* computation) {
return absl::c_any_of(
computation->instructions(),
[](const HloInstruction* hlo) { return XnnShouldUseThreadPool(hlo); });
}

absl::StatusOr<bool> IsXnnDotSupported(
const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape,
const Shape& rhs_shape, const Shape& out_shape) {
Expand Down
8 changes: 8 additions & 0 deletions xla/backends/cpu/xnn_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,19 @@ limitations under the License.
#define XLA_BACKENDS_CPU_XNN_FUSION_H_

#include "absl/status/statusor.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/shape.h"
#include "xla/xla_data.pb.h"

namespace xla::cpu {

// Returns true if XNNPACK should use thread pool to execute given HLO
// instruction or computation. We rely on simple heuristics to determine if
// thread pool is beneficial.
bool XnnShouldUseThreadPool(const HloInstruction* hlo);
bool XnnShouldUseThreadPool(const HloComputation* computation);

// Returns true if the dot operation is supported by XNNPACK. Returns an error
// if the dot operation shape is invalid.
absl::StatusOr<bool> IsXnnDotSupported(
Expand Down
2 changes: 1 addition & 1 deletion xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,7 @@ cc_library(
"//xla/service:pattern_matcher",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
Expand All @@ -895,7 +896,6 @@ cc_library(
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:JITLink",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
14 changes: 8 additions & 6 deletions xla/service/cpu/thunk_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ limitations under the License.
#include "xla/status_macros.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/statusor.h"

namespace xla::cpu {

Expand Down Expand Up @@ -846,8 +846,9 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitDotThunk(
}

if (use_xnn) {
XnnDotThunk::Options options = {XnnShouldUseThreadPool(instruction)};
return ThunkSequence::Of<XnnDotThunk>(
XnnDotThunk::Options{}, ThunkInfo(instruction), dnums, lhs_slice,
std::move(options), ThunkInfo(instruction), dnums, lhs_slice,
lhs->shape(), rhs_slice, rhs->shape(), out_slice,
instruction->shape());
} else {
Expand Down Expand Up @@ -1184,13 +1185,14 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitXnnFusionThunk(
results.push_back(XnnFusionThunk::Result{slice, indexed.shape});
}

const HloComputation* computation = fusion->fused_instructions_computation();

// Construct XNNPACK subgraph builder from the fusion computation.
TF_ASSIGN_OR_RETURN(
auto builder,
EmitXnnFusionBuilder(fusion->fused_instructions_computation()));
TF_ASSIGN_OR_RETURN(auto builder, EmitXnnFusionBuilder(computation));

XnnFusionThunk::Options options = {XnnShouldUseThreadPool(computation)};
return ThunkSequence::Of<XnnFusionThunk>(
XnnFusionThunk::Options{}, ThunkInfo(instruction), std::move(arguments),
std::move(options), ThunkInfo(instruction), std::move(arguments),
std::move(results),
[b = std::move(builder)](auto, auto) mutable { return b(); });
}
Expand Down

0 comments on commit 79b0e1d

Please sign in to comment.