Skip to content

Commit

Permalink
[XLA:CPU] Add multithread support to FFT thunk.
Browse files Browse the repository at this point in the history
As reported in jax-ml/jax#25808, the performance of XLA's CPU FFT is dramatically reduced with the thunks runtime. This is because the intra-op thread pool wasn't properly passed through. This change adds multithreading support, by passing through the thread pool provided by xla::ExecuteParams, but I don't know if this is the right approach.

PiperOrigin-RevId: 720233207
  • Loading branch information
dfm authored and Google-ML-Automation committed Jan 27, 2025
1 parent d3236dd commit c85aa3b
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 10 deletions.
2 changes: 1 addition & 1 deletion xla/backends/cpu/runtime/fft_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> FftThunk::Execute(

// Args have been computed, make the call.
if (is_multi_thread_eigen_) {
__xla_cpu_runtime_DuccFft(nullptr,
__xla_cpu_runtime_DuccFft(params.intra_op_threadpool,
reinterpret_cast<float*>(output_data.opaque()),
reinterpret_cast<float*>(input_data.opaque()),
fft_type_, is_double_precision_, fft_rank,
Expand Down
2 changes: 2 additions & 0 deletions xla/service/cpu/cpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ extern const char* const kEigenConv3DF16SymbolName =
"__xla_cpu_runtime_EigenConv3DF16";
extern const char* const kEigenConv3DF32SymbolName =
"__xla_cpu_runtime_EigenConv3DF32";
extern const char* const kLegacyDuccFftSymbolName =
"__xla_cpu_runtime_LegacyDuccFft";
extern const char* const kDuccFftSymbolName = "__xla_cpu_runtime_DuccFft";
extern const char* const kDuccSingleThreadedFftSymbolName =
"__xla_cpu_runtime_DuccSingleThreadedFft";
Expand Down
1 change: 1 addition & 0 deletions xla/service/cpu/cpu_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ extern const char* const kEigenConv2DF16SymbolName;
extern const char* const kEigenConv2DF32SymbolName;
extern const char* const kEigenConv3DF16SymbolName;
extern const char* const kEigenConv3DF32SymbolName;
extern const char* const kLegacyDuccFftSymbolName;
extern const char* const kDuccFftSymbolName;
extern const char* const kDuccSingleThreadedFftSymbolName;
extern const char* const kEigenSingleThreadedMatMulF16SymbolName;
Expand Down
2 changes: 1 addition & 1 deletion xla/service/cpu/ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ absl::Status IrEmitter::HandleFft(HloInstruction* fft) {
bool multi_threaded_eigen =
hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
const char* fn_name = multi_threaded_eigen
? runtime::kDuccFftSymbolName
? runtime::kLegacyDuccFftSymbolName
: runtime::kDuccSingleThreadedFftSymbolName;
auto* fft_lengths =
EmitGlobalForLiteral(LiteralUtil::CreateR1<int64_t>(fft_length));
Expand Down
19 changes: 14 additions & 5 deletions xla/service/cpu/runtime_fft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,23 @@ limitations under the License.
#include "unsupported/Eigen/CXX11/Tensor" // For ThreadPoolDevice.
#include "xla/executable_run_options.h"

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_DuccFft(
ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_LegacyDuccFft(
const void *run_options_ptr, void *out, void *operand, int32_t fft_type,
int32_t double_precision, int32_t fft_rank, const int64_t *input_shape,
const int64_t *fft_length) {
const xla::ExecutableRunOptions *run_options =
static_cast<const xla::ExecutableRunOptions *>(run_options_ptr);
const Eigen::ThreadPoolDevice *thread_pool_device =
run_options == nullptr ? nullptr : run_options->intra_op_thread_pool();
__xla_cpu_runtime_DuccFft(thread_pool_device, out, operand, fft_type,
double_precision, fft_rank, input_shape,
fft_length);
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_DuccFft(
const void *thread_pool_ptr, void *out, void *operand, int32_t fft_type,
int32_t double_precision, int32_t fft_rank, const int64_t *input_shape,
const int64_t *fft_length) {
bool forward = (fft_type == /*FFT*/ 0 || fft_type == /*RFFT*/ 2);
bool real = (fft_type == /*RFFT*/ 2 || fft_type == /*IRFFT*/ 3);

Expand Down Expand Up @@ -83,11 +93,10 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_DuccFft(
}
double scale = forward ? 1.0 : 1.0 / inv_scale;

const Eigen::ThreadPoolDevice *thread_pool_device =
static_cast<const Eigen::ThreadPoolDevice *>(thread_pool_ptr);
Eigen::ThreadPoolInterface *thread_pool =
run_options == nullptr ? nullptr
: run_options->intra_op_thread_pool() == nullptr
? nullptr
: run_options->intra_op_thread_pool()->getPool();
thread_pool_device == nullptr ? nullptr : thread_pool_device->getPool();

if (!real) {
if (double_precision) {
Expand Down
7 changes: 6 additions & 1 deletion xla/service/cpu/runtime_fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@ limitations under the License.

extern "C" {

extern void __xla_cpu_runtime_DuccFft(
extern void __xla_cpu_runtime_LegacyDuccFft(
const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out,
void* operand, int32_t fft_type, int32_t double_precision, int32_t fft_rank,
const int64_t* input_shape, const int64_t* fft_length);

extern void __xla_cpu_runtime_DuccFft(
const void* /* Eigen::ThreadPoolDevice* */ thread_pool_ptr, void* out,
void* operand, int32_t fft_type, int32_t double_precision, int32_t fft_rank,
const int64_t* input_shape, const int64_t* fft_length);

} // extern "C"

#endif // XLA_SERVICE_CPU_RUNTIME_FFT_H_
4 changes: 2 additions & 2 deletions xla/service/cpu/runtime_single_threaded_fft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ limitations under the License.
#include "xla/service/cpu/runtime_fft.h"

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_DuccSingleThreadedFft(
const void* /*run_options_ptr*/, void* out, void* operand, int32_t fft_type,
const void* /*thread_pool_ptr*/, void* out, void* operand, int32_t fft_type,
int32_t double_precision, int32_t fft_rank, const int64_t* input_shape,
const int64_t* fft_length) {
return __xla_cpu_runtime_DuccFft(
/*run_options_ptr=*/nullptr, out, operand, fft_type, double_precision,
/*thread_pool_ptr=*/nullptr, out, operand, fft_type, double_precision,
fft_rank, input_shape, fft_length);
}
1 change: 1 addition & 0 deletions xla/service/cpu/runtime_symbol_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ static bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(EigenConv2DF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenConv3DF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenConv3DF32);
REGISTER_CPU_RUNTIME_SYMBOL(LegacyDuccFft);
REGISTER_CPU_RUNTIME_SYMBOL(DuccFft);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
Expand Down

0 comments on commit c85aa3b

Please sign in to comment.