From 5ce24067dd42f0480755fdbfce1d6c61e9b1de07 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Sat, 25 Jan 2025 12:11:28 -0800 Subject: [PATCH] [xla:cpu:xnn] Take into account operand sizes when deciding if xnn fusion needs a thread pool PiperOrigin-RevId: 719688159 --- xla/backends/cpu/xnn_fusion.cc | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/xla/backends/cpu/xnn_fusion.cc b/xla/backends/cpu/xnn_fusion.cc index b0d12ca0d78a0a..0f0887ba3978f2 100644 --- a/xla/backends/cpu/xnn_fusion.cc +++ b/xla/backends/cpu/xnn_fusion.cc @@ -36,19 +36,31 @@ namespace xla::cpu { static constexpr int64_t kDotThreshold = 10 * 1000; static constexpr int64_t kDefaultThreshold = 100 * 1000; -// We rely on a very simple heuristic to determine if thread pool is beneficial -// for XNNPACK fusions. We assume that if the HLO produces a large result, -// thread pool will be beneficial for running operation in parallel. For small -// operations, thread pool overheads are higher than the actual computation. -static int64_t MaxElementsCount(const HloInstruction* hlo) { +static int64_t MaxElementsCount(const Shape& shape) { int64_t ret = 0; ShapeUtil::ForEachSubshape( - hlo->shape(), [&](const Shape& shape, const ShapeIndex& index) { + shape, [&](const Shape& shape, const ShapeIndex& index) { ret = std::max(ret, ShapeUtil::ElementsIn(shape)); }); return ret; } +// We rely on a very simple heuristic to determine if thread pool is beneficial +// for XNNPACK fusions. We assume that if the HLO produces a large result (or +// has large operands), thread pool will be beneficial for running operation in +// parallel. For small operations, thread pool overheads are higher than the +// actual computation. +static int64_t MaxElementsCount(const HloInstruction* hlo, + bool include_operands = true) { + int64_t ret = MaxElementsCount(hlo->shape()); + if (include_operands) { + for (auto* operand : hlo->operands()) { + ret = std::max(ret, MaxElementsCount(operand->shape())); + } + } + return ret; +} + bool XnnShouldUseThreadPool(const HloInstruction* hlo) { switch (hlo->opcode()) { case HloOpcode::kDot: