Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support GroupQueryAttention #676

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

# Minimum CMake required
cmake_minimum_required(VERSION 3.25)
# Don't let cmake set a default value for CMAKE_CUDA_ARCHITECTURES
cmake_policy(SET CMP0104 OLD)
project(onnxruntime_extensions LANGUAGES C CXX)

# set(CMAKE_VERBOSE_MAKEFILE ON)
Expand Down Expand Up @@ -281,6 +283,7 @@ endmacro()

if(OCOS_USE_CUDA)
include(ext_cuda)
include(cutlass)
endif()

#######################################################################################################################
Expand Down Expand Up @@ -347,7 +350,7 @@ endif()

file(GLOB TARGET_SRC_CONTRIB "operators/contrib/*.cc" "operators/contrib/*.h*")
if (OCOS_USE_CUDA)
file(GLOB TARGET_SRC_CONTRIB_CUDA "operators/contrib/cuda/*.*")
file(GLOB_RECURSE TARGET_SRC_CONTRIB_CUDA "operators/contrib/cuda/*.*")
list(APPEND TARGET_SRC_CONTRIB ${TARGET_SRC_CONTRIB_CUDA})
endif()
list(APPEND TARGET_SRC ${TARGET_SRC_CONTRIB})
Expand Down Expand Up @@ -561,6 +564,10 @@ target_include_directories(ocos_operators PUBLIC
${PROJECT_SOURCE_DIR}/base
${PROJECT_SOURCE_DIR}/operators)

if (OCOS_USE_CUDA)
target_include_directories(ocos_operators PUBLIC ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
endif()

set(ocos_libraries)
set(OCOS_COMPILE_DEFINITIONS)

Expand Down
16 changes: 16 additions & 0 deletions cmake/ext_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ enable_language(CUDA)

set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
set(CMAKE_CUDA_STANDARD 17)
include(CMakeDependentOption)
cmake_dependent_option(USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32" OFF)
option(USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
message( STATUS "Turn off flash attention and memory efficient attention since CUDA compiler version < 11.6")
set(USE_FLASH_ATTENTION OFF)
set(USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()


if(NOT CMAKE_CUDA_ARCHITECTURES)
Expand Down Expand Up @@ -69,3 +77,11 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=unsigned_co
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no_effect\"")

add_compile_definitions(USE_CUDA)
if (USE_FLASH_ATTENTION)
message( STATUS "Enable flash attention")
add_compile_definitions(USE_FLASH_ATTENTION)
endif()
if (USE_MEMORY_EFFICIENT_ATTENTION)
message( STATUS "Enable memory efficient attention")
add_compile_definitions(USE_MEMORY_EFFICIENT_ATTENTION)
endif()
11 changes: 11 additions & 0 deletions cmake/externals/cutlass.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
include(FetchContent)
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG v3.1.0
)

FetchContent_GetProperties(cutlass)
if(NOT cutlass_POPULATED)
FetchContent_Populate(cutlass)
endif()
6 changes: 6 additions & 0 deletions includes/custom_op_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,12 @@ struct OrtLiteCustomOp : public OrtCustomOp {
return INPUT_OUTPUT_OPTIONAL;
};
#endif

#if ORT_API_VERSION >= 18
OrtCustomOp::GetMayInplace = [](int**, int**) -> size_t {
return 0;
};
#endif
}

const std::string op_name_;
Expand Down
14 changes: 14 additions & 0 deletions includes/onnxruntime_customop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ struct CustomOp_defined_getInputMemoryType : std::false_type {};
template <typename T>
struct CustomOp_defined_getInputMemoryType<T, std::void_t<decltype(&T::GetInputMemoryType)>> : std::true_type {};

template <typename T, typename = void>
struct CustomOp_defined_getMayInplace : std::false_type {};

template <typename T>
struct CustomOp_defined_getMayInplace<T, std::void_t<decltype(&T::GetMayInplace)>> : std::true_type {};

template <typename CustomOpKernel>
struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
using ComputeFunction = decltype(&CustomOpKernel::Compute);
Expand Down Expand Up @@ -146,6 +152,14 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
};
}

#if ORT_API_VERSION >= 18
if constexpr (CustomOp_defined_getMayInplace<CustomOpKernel>::value) {
OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) -> size_t {
return CustomOpKernel::GetMayInplace(input_index, output_index);
};
}
#endif

OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
if (api == nullptr) {
Expand Down
12 changes: 8 additions & 4 deletions includes/onnxruntime_no_customop.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,14 @@ class API {
static OrtStatusPtr CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, OrtMemoryInfo** out) noexcept {
return instance()->CreateMemoryInfo(name, type, id, mem_type, out);
}
#if ORT_API_VERSION >= 15
// Caller is responsible for releasing OrtAllocator object: delete static_cast<onnxruntime::OrtAllocatorImpl*> (allocator)
static OrtStatusPtr GetOrtAllocator(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, OrtAllocator** out) {
return instance()->KernelContext_GetAllocator(context, mem_info, out);

static void ReleaseMemoryInfo(OrtMemoryInfo* mem_info) {
return instance()->ReleaseMemoryInfo(mem_info);
}

#if ORT_API_VERSION >= 18
static OrtStatusPtr KernelContextGetScratchBuffer(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, size_t count_or_bytes, void** out) {
return instance()->KernelContext_GetScratchBuffer(context, mem_info, count_or_bytes, out);
}
#endif
private:
Expand Down
148 changes: 148 additions & 0 deletions operators/contrib/attention_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

namespace contrib {

enum AttentionMaskType {
MASK_NONE, // No mask
MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length
MASK_1D_END_START, // [2 * batch_size] with end positions and start positions
MASK_1D_KEY_SEQ_LEN_START, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0],
// ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ...,
// key_start[batch_size - 1], key_end[batch_size - 1]]
MASK_2D_DUMMY, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask.
MASK_2D_KEY_PADDING, // [batch_size, total_sequence_length]
MASK_3D_ATTENTION, // [batch_size, sequence_length, total_sequence_length]
MASK_4D_MEGATRON, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length]
MASK_UNKNOWN
};

enum AttentionQkvFormat {
UNKNOWN, // enum value not set, or depends on qkv projection implementation details
Q_K_V_BNSH, // for non-packed qkv, permuted
Q_K_V_BSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention
QKV_BSN3H, // for TRT fused attention, qkv are packed
Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH)
Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed
Q_K_V_TNH, // for memory efficient attention, qkv are not packed, and paddings are removed.
QKV_TN3H, // for TRT fused attention, qkv are packed and paddings are removed
};

enum AttentionKernelType {
AttentionKernel_Unfused,
AttentionKernel_TrtFusedAttention,
AttentionKernel_TrtFlashAttention,
AttentionKernel_TrtFusedCrossAttention,
AttentionKernel_CutlassMemoryEfficientAttention,
AttentionKernel_FlashAttention,
AttentionKernel_Default
};

// Parameters deduced from node attributes and inputs/outputs.
struct AttentionParameters {
int batch_size;
int sequence_length;
int kv_sequence_length; // input sequence length of K or V
int past_sequence_length; // sequence length in past state of K or V
int total_sequence_length; // total sequence length of K or V
int max_sequence_length; // max sequence length from 4D mask
int input_hidden_size; // first dimension of weights for input projection
int hidden_size; // hidden size of Q or K
int head_size; // hidden size per head of Q or K
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int num_heads;
int num_splits;
int rotary_embedding;
bool is_unidirectional;
bool past_present_share_buffer;
bool do_rotary;
bool broadcast_res_pos_bias;
bool pass_past_in_kv;
float mask_filter_value;
float scale;
bool use_tf32;
AttentionMaskType mask_type;
AttentionQkvFormat qkv_format;
};

// Parameters deduced from node attributes and inputs/outputs.
struct PackedAttentionParameters {
int batch_size;
int sequence_length;
int input_hidden_size; // hidden size of input
int hidden_size; // hidden size of Q or K
int head_size; // hidden size per head of Q or K
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int num_heads;
float scale;
int token_count;
bool has_relative_position_bias;
bool broadcast_res_pos_bias;
bool use_tf32;
};

// Parameters deduced from node attributes and inputs/outputs.
struct GroupQueryAttentionParameters {
int batch_size;
int sequence_length; // sequence length of input query, key, value
int seqlen_past_kv_cache; // sequence length of past kv tensor
int seqlen_present_kv_cache; // sequence length of present kv tensor
int hidden_size;
int num_heads;
int head_size;
int kv_hidden_size;
int kv_num_heads;
int num_splits; // number of splits for splitkv
bool is_unidirectional; // causal
int local_window_size;
bool kv_share_buffer;
bool is_packed_qkv;
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool do_rotary;
bool rotary_interleaved;
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
int zeros_count;
int* zero_ptr;
};

namespace attention {
// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";

// Environment variable to enable or disable fused cross attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATTENTION";

// Environment variable to enable or disable TRT fused causal attention kernels. Default is 0 (disabled).
// Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels.
constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION";

// Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled).
constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION";

// Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled).
constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION";

// Environment variable to enable or disable flash attention. Default is 0 (enabled).
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";

// Minimum sequence length to enable memory efficient attention in FP32.
constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256;

// Minimum sequence length to prefer flash attention when input format is packed QKV for MultiHeadAttention
constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV";
// Default value for the above setting.
constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513;

// Environment variable to enable loading more KV data in flight in
// DecoderMaskedMultiHeadAttention/DecoderMaskedSelfAttention kernels
constexpr const char* kDecoderMaskedAttentionLoadKVDataInFlight = "ORT_DECODER_MASKED_ATTENTION_LOAD_KV_DATA_IN_FLIGHT";

} // namespace attention

} // namespace contrib
4 changes: 3 additions & 1 deletion operators/contrib/contrib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#ifdef USE_CUDA
#include "cuda/fast_gelu.h"
#include "cuda/group_query_attention.h"
#endif

FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
Expand All @@ -14,7 +15,8 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
,
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
#if ORT_API_VERSION >= 16

CustomCudaStructV2("GroupQueryAttention", contrib::GroupQueryAttention<ortc::MFloat16>),
CustomCudaStructV2("GroupQueryAttention", contrib::GroupQueryAttention<ortc::BFloat16>),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>)
#endif
Expand Down
Loading
Loading