Skip to content

Commit

Permalink
[xla:cpu:xnn] Take into account operand sizes when deciding if xnn fu…
Browse files Browse the repository at this point in the history
…sion needs a thread pool

PiperOrigin-RevId: 719688159
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 26, 2025
1 parent e4914bc commit 5ce2406
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions xla/backends/cpu/xnn_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,31 @@ namespace xla::cpu {
static constexpr int64_t kDotThreshold = 10 * 1000;
static constexpr int64_t kDefaultThreshold = 100 * 1000;

// We rely on a very simple heuristic to determine if thread pool is beneficial
// for XNNPACK fusions. We assume that if the HLO produces a large result,
// thread pool will be beneficial for running operation in parallel. For small
// operations, thread pool overheads are higher than the actual computation.
static int64_t MaxElementsCount(const HloInstruction* hlo) {
static int64_t MaxElementsCount(const Shape& shape) {
int64_t ret = 0;
ShapeUtil::ForEachSubshape(
hlo->shape(), [&](const Shape& shape, const ShapeIndex& index) {
shape, [&](const Shape& shape, const ShapeIndex& index) {
ret = std::max(ret, ShapeUtil::ElementsIn(shape));
});
return ret;
}

// We rely on a very simple heuristic to determine if thread pool is beneficial
// for XNNPACK fusions. We assume that if the HLO produces a large result (or
// has large operands), thread pool will be beneficial for running operation in
// parallel. For small operations, thread pool overheads are higher than the
// actual computation.
static int64_t MaxElementsCount(const HloInstruction* hlo,
bool include_operands = true) {
int64_t ret = MaxElementsCount(hlo->shape());
if (include_operands) {
for (auto* operand : hlo->operands()) {
ret = std::max(ret, MaxElementsCount(operand->shape()));
}
}
return ret;
}

bool XnnShouldUseThreadPool(const HloInstruction* hlo) {
switch (hlo->opcode()) {
case HloOpcode::kDot:
Expand Down

0 comments on commit 5ce2406

Please sign in to comment.