diff --git a/xla/service/gpu/gpu_fused_mha_runner.cc b/xla/service/gpu/gpu_fused_mha_runner.cc index 61be619b1d46d..09af1e523ab7f 100644 --- a/xla/service/gpu/gpu_fused_mha_runner.cc +++ b/xla/service/gpu/gpu_fused_mha_runner.cc @@ -68,29 +68,13 @@ absl::Status RunFusedMHA(GpufMHAParams params, se::Stream *stream, dropout_rate = *params.config->dropout_rate; } - double scale = 1.0; - if (params.config->fmha_scale) { - scale = *params.config->fmha_scale; - } - std::optional seed; if (params.config->seed) { seed = *params.config->seed; } - TF_ASSIGN_OR_RETURN( - se::dnn::FMHAMaskKind mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(params.config->mask_type)); - se::dnn::FusedMHAOp::Config config{scale, - params.config->lhs_bmm1, - params.config->rhs_bmm1, - params.config->rhs_bmm2, - params.config->intermediate_lhs_bmm2, - params.config->output, - params.config->bias, - params.config->activation, - dropout_rate, - seed, - mask_type}; + + TF_ASSIGN_OR_RETURN(se::dnn::FusedMHAOp::Config config, + params.config->AsDnnFusedMHAOpConfig()); TF_ASSIGN_OR_RETURN(auto *runner, lazy_runner->GetOrCreateRunner(config, stream)); return (*runner)(stream, options.profile_result, scratch_memory, @@ -183,35 +167,13 @@ absl::Status RunFusedMHABackward( dropout_rate = *params.config->dropout_rate; } - double scale = 1.0; - if (params.config->fmha_scale) { - scale = *params.config->fmha_scale; - } - std::optional seed; if (params.config->seed) { seed = *params.config->seed; } - TF_ASSIGN_OR_RETURN( - se::dnn::FMHAMaskKind mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(params.config->mask_type)); - se::dnn::FusedMHABackwardOp::Config config{scale, - params.config->bmm1_grad_gemm1_rhs, - params.config->bmm1_grad_gemm2_rhs, - params.config->bmm2_grad_gemm1_lhs, - params.config->bmm2_grad_gemm2_rhs, - params.config->d_output, - params.config->d_bmm1_lhs, - params.config->d_bmm1_rhs, - params.config->d_bmm2_rhs, - params.config->d_s, - params.config->d_bias, - params.config->fwd_output, - params.config->bias, - dropout_rate, - seed, - mask_type}; + TF_ASSIGN_OR_RETURN(se::dnn::FusedMHABackwardOp::Config config, + params.config->AsDnnFusedMHABackwardOpConfig()); TF_ASSIGN_OR_RETURN(auto *runner, lazy_runner->GetOrCreateRunner(config, stream)); // TODO: pass in real softmax_sum, dQ_accum, fwd_output @@ -404,6 +366,21 @@ absl::Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, return config; } +absl::StatusOr +GpufMHAConfig::AsDnnFusedMHAOpConfig() const { + double scale = 1.0; + if (fmha_scale.has_value()) { + scale = *fmha_scale; + } + TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type)); + + return se::dnn::FusedMHAOp::Config{ + scale, lhs_bmm1, rhs_bmm1, rhs_bmm2, intermediate_lhs_bmm2, + output, bias, activation, dropout_rate, seed, + mask_type}; +} + /*static*/ absl::StatusOr GpufMHABackwardConfig::For( const GpufMHABackwardDescriptor &desc) { // Get shapes from desc. @@ -546,6 +523,32 @@ absl::Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, return config; } +absl::StatusOr +GpufMHABackwardConfig::AsDnnFusedMHABackwardOpConfig() const { + double scale = 1.0; + if (fmha_scale.has_value()) { + scale = *fmha_scale; + } + TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type)); + return se::dnn::FusedMHABackwardOp::Config{scale, + bmm1_grad_gemm1_rhs, + bmm1_grad_gemm2_rhs, + bmm2_grad_gemm1_lhs, + bmm2_grad_gemm2_rhs, + d_output, + d_bmm1_lhs, + d_bmm1_rhs, + d_bmm2_rhs, + d_s, + d_bias, + fwd_output, + bias, + dropout_rate, + seed, + mask_type}; +} + /*static*/ absl::StatusOr GpufMHAParams::For( const GpufMHAConfig &config, se::DeviceMemoryBase lhs_bmm1_buffer, se::DeviceMemoryBase rhs_bmm1_buffer, se::DeviceMemoryBase rhs_bmm2_buffer, diff --git a/xla/service/gpu/gpu_fused_mha_runner.h b/xla/service/gpu/gpu_fused_mha_runner.h index 6538d54298432..7ca35805be251 100644 --- a/xla/service/gpu/gpu_fused_mha_runner.h +++ b/xla/service/gpu/gpu_fused_mha_runner.h @@ -101,10 +101,14 @@ struct GpufMHABackwardDescriptor { std::optional d_bias_shape; std::optional bias_shape; }; + // Structure to describe static properties of a GPU fused Multi-Headed // Attention. struct GpufMHAConfig { static absl::StatusOr For(const GpufMHADescriptor& fmha_desc); + + absl::StatusOr AsDnnFusedMHAOpConfig() const; + PrimitiveType input_type; // Capture the primitive type of one of the inputs of BMM1 PrimitiveType output_type; @@ -133,6 +137,10 @@ struct GpufMHAConfig { struct GpufMHABackwardConfig { static absl::StatusOr For( const GpufMHABackwardDescriptor& fmha_desc); + + absl::StatusOr + AsDnnFusedMHABackwardOpConfig() const; + PrimitiveType input_type; // Capture the primitive type of one of the inputs of BMM1 PrimitiveType output_type; diff --git a/xla/service/gpu/runtime/fused_mha_thunk.cc b/xla/service/gpu/runtime/fused_mha_thunk.cc index efa90bfc9feae..41613f0121e65 100644 --- a/xla/service/gpu/runtime/fused_mha_thunk.cc +++ b/xla/service/gpu/runtime/fused_mha_thunk.cc @@ -65,6 +65,13 @@ std::optional AssignBufferIfNotNull( : std::nullopt; } +absl::Status FusedMHAThunk::Initialize(const InitializeParams& params) { + se::dnn::LazyOpRunner* lazy_runner = + GetOrCreateRunner(params.stream).AsFusedMHARunner(); + TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHAOpConfig()); + return lazy_runner->GetOrCreateRunner(config, params.stream).status(); +} + absl::Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) { const auto& buffer_allocations = *params.buffer_allocations; se::DeviceMemoryBase lhs_bmm1_buffer = @@ -143,6 +150,13 @@ FusedMHABackwardThunk::GetOrCreateRunner( return *it->second; } +absl::Status FusedMHABackwardThunk::Initialize(const InitializeParams& params) { + se::dnn::LazyOpRunner* lazy_runner = + GetOrCreateRunner(params.stream).AsFusedMHABackwardRunner(); + TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHABackwardOpConfig()); + return lazy_runner->GetOrCreateRunner(config, params.stream).status(); +} + absl::Status FusedMHABackwardThunk::ExecuteOnStream( const ExecuteParams& params) { const auto& buffer_allocations = *params.buffer_allocations; diff --git a/xla/service/gpu/runtime/fused_mha_thunk.h b/xla/service/gpu/runtime/fused_mha_thunk.h index 32bfdd0ecd19f..bf9cff354b502 100644 --- a/xla/service/gpu/runtime/fused_mha_thunk.h +++ b/xla/service/gpu/runtime/fused_mha_thunk.h @@ -53,6 +53,7 @@ class FusedMHAThunk : public Thunk { FusedMHAThunk(const FusedMHAThunk&) = delete; FusedMHAThunk& operator=(const FusedMHAThunk&) = delete; + absl::Status Initialize(const InitializeParams& params) override; absl::Status ExecuteOnStream(const ExecuteParams& params) override; private: @@ -101,6 +102,7 @@ class FusedMHABackwardThunk : public Thunk { FusedMHABackwardThunk(const FusedMHABackwardThunk&) = delete; FusedMHABackwardThunk& operator=(const FusedMHABackwardThunk&) = delete; + absl::Status Initialize(const InitializeParams& params) override; absl::Status ExecuteOnStream(const ExecuteParams& params) override; private: