diff --git a/xla/backends/cpu/xnn_fusion.cc b/xla/backends/cpu/xnn_fusion.cc index b0d12ca0d78a0a..0f0887ba3978f2 100644 --- a/xla/backends/cpu/xnn_fusion.cc +++ b/xla/backends/cpu/xnn_fusion.cc @@ -36,19 +36,31 @@ namespace xla::cpu { static constexpr int64_t kDotThreshold = 10 * 1000; static constexpr int64_t kDefaultThreshold = 100 * 1000; -// We rely on a very simple heuristic to determine if thread pool is beneficial -// for XNNPACK fusions. We assume that if the HLO produces a large result, -// thread pool will be beneficial for running operation in parallel. For small -// operations, thread pool overheads are higher than the actual computation. -static int64_t MaxElementsCount(const HloInstruction* hlo) { +static int64_t MaxElementsCount(const Shape& shape) { int64_t ret = 0; ShapeUtil::ForEachSubshape( - hlo->shape(), [&](const Shape& shape, const ShapeIndex& index) { + shape, [&](const Shape& shape, const ShapeIndex& index) { ret = std::max(ret, ShapeUtil::ElementsIn(shape)); }); return ret; } +// We rely on a very simple heuristic to determine if thread pool is beneficial +// for XNNPACK fusions. We assume that if the HLO produces a large result (or +// has large operands), thread pool will be beneficial for running operation in +// parallel. For small operations, thread pool overheads are higher than the +// actual computation. +static int64_t MaxElementsCount(const HloInstruction* hlo, + bool include_operands = true) { + int64_t ret = MaxElementsCount(hlo->shape()); + if (include_operands) { + for (auto* operand : hlo->operands()) { + ret = std::max(ret, MaxElementsCount(operand->shape())); + } + } + return ret; +} + bool XnnShouldUseThreadPool(const HloInstruction* hlo) { switch (hlo->opcode()) { case HloOpcode::kDot: