Skip to content

Commit

Permalink
[xla:cpu:cnn] Add ParallelTask structs to improve performance debuggi…
Browse files Browse the repository at this point in the history
…ng experience

Named structs instead of lambdas give a much better debugging experience.

PiperOrigin-RevId: 719688160
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 27, 2025
1 parent e4914bc commit b0c7aae
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 28 deletions.
76 changes: 54 additions & 22 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,17 @@ static Task3DTile2DIndex Delinearize(size_t task_index, size_t range_i,
//
// (2) If done event is not available, we have to overwrite it with a new one
// that will be set to concrete state after the task is executed.
//
// We wrap all tasks into structs conforming to the `ParallelTest` API, so that
// in profiles we can see human-readable names of the tasks instead of lambdas.

struct ParallelLoopRunner::ParallelTask1D {
ABSL_ATTRIBUTE_ALWAYS_INLINE void operator()(size_t task_index) const {
task(task_index);
}

Task1D task;
};

void ParallelLoopRunner::Parallelize(size_t range, Task1D task) {
DCHECK(done_event_) << "Parallel loop runner is in moved-from state";
Expand All @@ -232,9 +243,20 @@ void ParallelLoopRunner::Parallelize(size_t range, Task1D task) {
return;
}

ScheduleAll(range, std::move(task));
ScheduleAll(range, ParallelTask1D{std::move(task)});
}

struct ParallelLoopRunner::ParallelTask1DTile1D {
ABSL_ATTRIBUTE_ALWAYS_INLINE void operator()(size_t task_index) const {
auto x = Delinearize(task_index, range, tile);
task(x.offset, x.extent);
}

size_t range;
size_t tile;
Task1DTile1D task;
};

void ParallelLoopRunner::Parallelize(size_t range, size_t tile,
Task1DTile1D task) {
DCHECK(done_event_) << "Parallel loop runner is in moved-from state";
Expand All @@ -255,15 +277,21 @@ void ParallelLoopRunner::Parallelize(size_t range, size_t tile,
return;
}

auto parallel_task = [range, tile,
task = std::move(task)](size_t task_index) {
auto x = Delinearize(task_index, range, tile);
task(x.offset, x.extent);
};

ScheduleAll(num_tasks, std::move(parallel_task));
ScheduleAll(num_tasks, ParallelTask1DTile1D{range, tile, std::move(task)});
}

struct ParallelLoopRunner::ParallelTask2DTile1D {
ABSL_ATTRIBUTE_ALWAYS_INLINE void operator()(size_t task_index) const {
auto x = Delinearize(task_index, range_i, range_j, tile_j);
task(x.i, x.offset_j, x.extent_j);
}

size_t range_i;
size_t range_j;
size_t tile_j;
Task2DTile1D task;
};

void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,
size_t tile_j, Task2DTile1D task) {
DCHECK(done_event_) << "Parallel loop runner is in moved-from state";
Expand All @@ -282,15 +310,24 @@ void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,
return;
}

auto parallel_task = [range_i, range_j, tile_j,
task = std::move(task)](size_t task_index) {
auto x = Delinearize(task_index, range_i, range_j, tile_j);
task(x.i, x.offset_j, x.extent_j);
};

ScheduleAll(num_tasks, std::move(parallel_task));
ScheduleAll(num_tasks,
ParallelTask2DTile1D{range_i, range_j, tile_j, std::move(task)});
}

struct ParallelLoopRunner::ParallelTask3DTile2D {
ABSL_ATTRIBUTE_ALWAYS_INLINE void operator()(size_t task_index) const {
auto x = Delinearize(task_index, range_i, range_j, range_k, tile_j, tile_k);
task(x.i, x.offset_j, x.offset_k, x.extent_j, x.extent_k);
}

size_t range_i;
size_t range_j;
size_t range_k;
size_t tile_j;
size_t tile_k;
Task3DTile2D task;
};

void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,
size_t range_k, size_t tile_j,
size_t tile_k, Task3DTile2D task) {
Expand All @@ -312,13 +349,8 @@ void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,
return;
}

auto parallel_task = [range_i, range_j, range_k, tile_j, tile_k,
task = std::move(task)](size_t task_index) {
auto x = Delinearize(task_index, range_i, range_j, range_k, tile_j, tile_k);
task(x.i, x.offset_j, x.offset_k, x.extent_j, x.extent_k);
};

ScheduleAll(num_tasks, std::move(parallel_task));
ScheduleAll(num_tasks, ParallelTask3DTile2D{range_i, range_j, range_k, tile_j,
tile_k, std::move(task)});
}

} // namespace xla::cpu
6 changes: 6 additions & 0 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ class ParallelLoopRunner {
size_t num_threads() const;

private:
// Forward declarations of the parallel tasks.
struct ParallelTask1D;
struct ParallelTask1DTile1D;
struct ParallelTask2DTile1D;
struct ParallelTask3DTile2D;

// Schedules `task` as the AndThen callback of the `done_event_`. Updates
// `done_event_` to the new completion event.
template <typename Task>
Expand Down
24 changes: 18 additions & 6 deletions xla/backends/cpu/xnn_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,31 @@ namespace xla::cpu {
static constexpr int64_t kDotThreshold = 10 * 1000;
static constexpr int64_t kDefaultThreshold = 100 * 1000;

// We rely on a very simple heuristic to determine if thread pool is beneficial
// for XNNPACK fusions. We assume that if the HLO produces a large result,
// thread pool will be beneficial for running operation in parallel. For small
// operations, thread pool overheads are higher than the actual computation.
static int64_t MaxElementsCount(const HloInstruction* hlo) {
static int64_t MaxElementsCount(const Shape& shape) {
int64_t ret = 0;
ShapeUtil::ForEachSubshape(
hlo->shape(), [&](const Shape& shape, const ShapeIndex& index) {
shape, [&](const Shape& shape, const ShapeIndex& index) {
ret = std::max(ret, ShapeUtil::ElementsIn(shape));
});
return ret;
}

// We rely on a very simple heuristic to determine if thread pool is beneficial
// for XNNPACK fusions. We assume that if the HLO produces a large result (or
// has large operands), thread pool will be beneficial for running operation in
// parallel. For small operations, thread pool overheads are higher than the
// actual computation.
static int64_t MaxElementsCount(const HloInstruction* hlo,
bool include_operands = true) {
int64_t ret = MaxElementsCount(hlo->shape());
if (include_operands) {
for (auto* operand : hlo->operands()) {
ret = std::max(ret, MaxElementsCount(operand->shape()));
}
}
return ret;
}

bool XnnShouldUseThreadPool(const HloInstruction* hlo) {
switch (hlo->opcode()) {
case HloOpcode::kDot:
Expand Down

0 comments on commit b0c7aae

Please sign in to comment.