diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index d625985552..e2903c4d11 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -59,6 +59,9 @@ class BatchConfig { // Set by update int num_tokens; + // number of tokens in prompt phase, start offset of tokens in inc_decoding + // phase. num_tokens - num_prompt_tokens = num_generation_tokens; + int num_generation_tokens; struct PerRequestInfo { int first_token_depth_in_request; diff --git a/include/flexflow/ops/inc_multihead_self_attention.h b/include/flexflow/ops/inc_multihead_self_attention.h index 5ff0942fff..43dc527bc8 100644 --- a/include/flexflow/ops/inc_multihead_self_attention.h +++ b/include/flexflow/ops/inc_multihead_self_attention.h @@ -29,7 +29,7 @@ class IncMultiHeadSelfAttention : public Op { IncMultiHeadSelfAttention(FFModel &model, LayerID const &layer_guid, - const ParallelTensor _input, + ParallelTensor const _input, int _embed_dim, int _num_q_heads, int _num_kv_heads, @@ -50,8 +50,8 @@ class IncMultiHeadSelfAttention : public Op { int _tensor_parallelism_degree, char const *name); IncMultiHeadSelfAttention(FFModel &model, - const ParallelTensor _input, - const ParallelTensor _weight, + ParallelTensor const _input, + ParallelTensor const _weight, int _embed_dim, int _num_q_heads, int _num_kv_heads, @@ -73,7 +73,7 @@ class IncMultiHeadSelfAttention : public Op { char const *name); IncMultiHeadSelfAttention(FFModel &model, IncMultiHeadSelfAttention const &other, - const ParallelTensor input, + ParallelTensor const input, bool allocate_weights); IncMultiHeadSelfAttention(FFModel &model, Params const ¶ms, @@ -192,9 +192,11 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta { void *attn_heads; char *quantized_weight_ptr; BatchConfig::PerTokenInfo *token_infos; + BatchConfig::PerRequestInfo *request_infos; DataType quantization_type; bool offload; #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) + // cudaStream_t task_local_stream; cudnnTensorDescriptor_t qk_tensor; cuFloatComplex *complex_input; #elif defined(FF_USE_HIP_ROCM) diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h index 763f654e28..9bf2f581e2 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h @@ -14,6 +14,22 @@ namespace FlexFlow { namespace Kernels { namespace IncMultiHeadAttention { +template +void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + DT *output_ptr, + ffStream_t stream); + +template +void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + DT *output_ptr, + DT const *weight_ptr, + DT const *bias_ptr, + int num_tokens, + ffStream_t stream); + template __global__ void apply_position_bias_qkprd(DT *input_ptr, int num_tokens, diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh new file mode 100644 index 0000000000..c128c1a126 --- /dev/null +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh @@ -0,0 +1,524 @@ +#ifndef _FLEXFLOW_OPS_KERNELS_INC_MULTIHEAD_SELF_UTILS_H +#define _FLEXFLOW_OPS_KERNELS_INC_MULTIHEAD_SELF_UTILS_H + +#include "flexflow/inference.h" + +namespace FlexFlow { + +////////////////basic datatype////////////////////// +struct half4 { + half x; + half y; + half z; + half w; +}; + +struct half8 { + half x; + half y; + half z; + half w; + half a; + half b; + half c; + half d; +}; +struct float8 { + float x; + float y; + float z; + float w; + float a; + float b; + float c; + float d; +}; + +////////////////data type/////////////// +template +struct VEC_K {}; +template <> +struct VEC_K { + using Type = float; +}; +template <> +struct VEC_K { + using Type = float2; +}; +template <> +struct VEC_K { + using Type = float4; +}; +template <> +struct VEC_K { + using Type = half; +}; +template <> +struct VEC_K { + using Type = half2; +}; +template <> +struct VEC_K { + using Type = half4; +}; + +// data type for QK production +template +struct Vec_fp32_ {}; + +template <> +struct Vec_fp32_ { + using Type = float; +}; +template <> +struct Vec_fp32_ { + using Type = float2; +}; +template <> +struct Vec_fp32_ { + using Type = float4; +}; +template <> +struct Vec_fp32_ { + using Type = float; +}; +template <> +struct Vec_fp32_ { + using Type = float2; +}; +template <> +struct Vec_fp32_ { + using Type = float4; +}; +template <> +struct Vec_fp32_ { + using Type = float8; +}; + +template +struct VEC_V {}; +template <> +struct VEC_V { + using Type = float4; +}; +template <> +struct VEC_V { + using Type = half8; +}; + +////////////////data structures half/////////////// + +////////////////////////////////////floating point +/// operations/////////////////////////////////////////// + +template +inline __device__ Acc mul(A a, B b) { + return Acc{}; // for compile +} +template <> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template <> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +// template <> +// inline __device__ float4 mul(half4 a, half4 b) { +// float4 c; +// c.x = a.x * b.x; +// c.y = a.y * b.y; +// c.z = a.z * b.z; +// c.w = a.w * b.w; +// return c; +// } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float fma(float a, float b, float c) { + return a * b + c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ float8 fma(float a, float8 f1, float8 f2) { + float8 res; + res.x = fma(a, f1.x, f2.x); + res.y = fma(a, f1.y, f2.y); + res.z = fma(a, f1.z, f2.z); + res.w = fma(a, f1.w, f2.w); + res.a = fma(a, f1.a, f2.a); + res.b = fma(a, f1.b, f2.b); + res.c = fma(a, f1.c, f2.c); + res.d = fma(a, f1.d, f2.d); + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float add(float a, float b) { + return a + b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float8 add(float8 f1, float8 f2) { + float8 res; + res.x = add(f1.x, f2.x); + res.y = add(f1.y, f2.y); + res.z = add(f1.z, f2.z); + res.w = add(f1.w, f2.w); + res.a = add(f1.a, f2.a); + res.b = add(f1.b, f2.b); + res.c = add(f1.c, f2.c); + res.d = add(f1.d, f2.d); + return res; +} + +inline __device__ float sum(float v) { + return v; +} + +template +inline __device__ __host__ T div_up(T m, T n) { + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(float2 v) { + return v.x + v.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; +} + +inline __device__ float cast_to_float(float u) { + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(float2 u) { + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 cast_to_float(float4 u) { + return u; +} + +inline __device__ float cast_to_float(half u) { + return __half2float(u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(half2 u) { + float2 tmp; + tmp.x = __half2float(u.x); + tmp.y = __half2float(u.y); + return tmp; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 cast_to_float(half4 u) { + float4 tmp; + tmp.x = __half2float(u.x); + tmp.y = __half2float(u.y); + tmp.z = __half2float(u.z); + tmp.w = __half2float(u.w); + return tmp; +} +inline __device__ float8 cast_to_float(half8 u) { + float8 tmp; + tmp.x = __half2float(u.x); + tmp.y = __half2float(u.y); + tmp.z = __half2float(u.z); + tmp.w = __half2float(u.w); + tmp.a = __half2float(u.a); + tmp.b = __half2float(u.b); + tmp.c = __half2float(u.c); + tmp.d = __half2float(u.d); + return tmp; +} + +inline __device__ void convert_from_float(float4 &dst, float4 src) { + dst = src; +} +inline __device__ void convert_from_float(float &dst, float src) { + dst = src; +} +inline __device__ void convert_from_float(float2 &dst, float2 src) { + dst = src; +} +inline __device__ void convert_from_float(float8 &dst, float8 src) { + dst = src; +} + +inline __device__ void convert_from_float(half4 &dst, float4 src) { + dst.x = __float2half(src.x); + dst.y = __float2half(src.y); + dst.z = __float2half(src.z); + dst.w = __float2half(src.w); +} + +inline __device__ void convert_from_float(half8 &dst, float8 src) { + dst.x = __float2half(src.x); + dst.y = __float2half(src.y); + dst.z = __float2half(src.z); + dst.w = __float2half(src.w); + dst.a = __float2half(src.a); + dst.b = __float2half(src.b); + dst.c = __float2half(src.c); + dst.d = __float2half(src.d); +} +inline __device__ void convert_from_float(half2 &dst, float2 src) { + dst.x = __float2half(src.x); + dst.y = __float2half(src.y); +} +inline __device__ void convert_from_float(half &dst, float src) { + dst = __float2half(src); +} + +//////////////////////////////////////utils/////////////////////////////////////////////// + +template +inline __device__ void zero(T &dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +template +inline __device__ float qk_dot_(K_vec const (&q)[N], K_vec const (&k)[N]) { + // use float32 to get better accuracy + using Vec_sum = typename Vec_fp32_::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + Vec_sum qk_vec = + mul(cast_to_float(q[0]), cast_to_float(k[0])); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = FlexFlow::fma(cast_to_float(q[ii]), cast_to_float(k[ii]), qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} +template +struct Qk_dot { + template + static inline __device__ float dot(K_vec const (&q)[N], K_vec const (&k)[N]) { + return qk_dot_(q, k); + } +}; + +template +inline __device__ float block_sum(float *red_smem, float sum) { + + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + +// Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + +// Parallel reduction inside the warp. +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +template +inline size_t smem_size_in_bytes(int hidden_size_per_head, + int max_sequence_length, + int threads_per_value, + int threads_per_block) { + // The amount of shared memory needed to store the Q*K^T values in float. + + size_t qk_sz = div_up(max_sequence_length + 1, 4) * 16; + size_t logits_sz = qk_sz; + + // The total size needed during softmax. + size_t softmax_sz = qk_sz + logits_sz; + size_t q_size = hidden_size_per_head * sizeof(DT); + + // The number of partial rows to reduce in the final reduction. + int rows_per_red = threads_per_block / threads_per_value; + // The amount of storage needed to finalize the outputs. + size_t red_sz = rows_per_red * hidden_size_per_head * sizeof(float) / 2; + // The max. + return max(softmax_sz, red_sz) + q_size; +} + +template +inline void smem_size_in_bytes_tree(int hidden_size_per_head, + int max_sequence_length, + int threads_per_value, + int threads_per_block, + TreeVerifyBatchConfig const *bc, + int shared_mem[]) { + + int max_query_length = 0; + int max_total_length = 0; + for (int i = 0; i < bc->max_requests_per_batch(); i++) { + if (bc->request_completed[i]) { + continue; + } + max_query_length = + max(max_query_length, bc->requestsInfo[i].num_tokens_in_batch); + max_total_length = max(max_total_length, + bc->requestsInfo[i].first_token_depth_in_request + + bc->requestsInfo[i].num_tokens_in_batch); + } + + // todo fix this + int max_qk_length = max_query_length * max_total_length; + + // The amount of shared memory needed to store the Q*K^T values in float. + size_t qk_sz = div_up(max_qk_length + 1, 4) * 16; + + size_t logits_sz = qk_sz; + + // The total size needed during softmax. + size_t softmax_sz = qk_sz + logits_sz; + + size_t q_size = hidden_size_per_head * sizeof(DT); + + // The number of partial rows to reduce in the final reduction. + int rows_per_red = threads_per_block / threads_per_value; + // The amount of storage needed to finalize the outputs. + // use 4 + size_t red_sz = rows_per_red * hidden_size_per_head * sizeof(float) / 2; + // The max. + shared_mem[0] = qk_sz; + shared_mem[1] = softmax_sz + red_sz + q_size; +} + +template +struct threads_per_value_t { + static int const value = Dh * sizeof(T) / 16; +}; + +} // namespace FlexFlow +#endif // _FLEXFLOW_OPS_KERNELS_INC_MULTIHEAD_SELF_UTILS_H \ No newline at end of file diff --git a/include/flexflow/ops/spec_inc_multihead_self_attention.h b/include/flexflow/ops/spec_inc_multihead_self_attention.h index 363776cdb0..56bb2bd80d 100644 --- a/include/flexflow/ops/spec_inc_multihead_self_attention.h +++ b/include/flexflow/ops/spec_inc_multihead_self_attention.h @@ -140,7 +140,6 @@ class SpecIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta { public: Realm::RegionInstance beam_search_reserve_inst; - BatchConfig::PerRequestInfo *request_infos; BeamSearchBatchConfig::BeamSearchPerTokenInfo *beam_token_infos; BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos; }; diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index cff5550c85..20f7d64936 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -19,6 +19,7 @@ #include "flexflow/ops/inc_multihead_self_attention.h" #include "flexflow/ops/kernels/decompress_kernels.h" #include "flexflow/ops/kernels/inc_multihead_self_attention_kernels.h" +#include "flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh" #include "flexflow/utils/cuda_helper.h" namespace FlexFlow { @@ -27,9 +28,277 @@ namespace FlexFlow { using Legion::coord_t; using Legion::Memory; +#define WARP_SIZE 32 + namespace Kernels { namespace IncMultiHeadAttention { +// gridDim = num_heads +// blockDim = num_tokens/num_request * head_size +// QKV tensor layout: |QKV| * num_new_tokens. |Q=K=V=head_size * num_heads| +// one thread process one head_size +template +__global__ void compute_attention_kernel_generation_kernel( + DT const *query, + DT const *key_cache, + DT const *value_cache, + DT *output_ptr, + float const scale, + int max_seq_length, + int per_head_size, + int hidden_size, + BatchConfig::PerRequestInfo *request_infos, + bool is_beam, + int max_beam_width) { + + // q, k + using Q_vec = typename VEC_K::Type; + using K_vec = typename VEC_K::Type; + using V_vec = typename VEC_V
::Type; + using Out_sum = typename Vec_fp32_::Type; + + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + // eg. if head_size = 128, thread_per_key = 4, with float32 precision + // then K_VEC_SIZE = 1, QK_VEC_SIZE = 4 + // K_ELTS_PER_THREAD = 128 / 4 = 32 + // K_VECS_PER_THREAD = 32 / 1 = 32 + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(DT); + // constexpr int QK_VEC_SIZE = 16 / sizeof(DT); + // // constexpr int QK_VEC_SIZE = sizeof(Qk_vec_k) / sizeof(DT); + constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY; + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + // constexpr int QK_ELTS_IN_16B = 16 / sizeof(DT); + + // thread id + int const tidx = threadIdx.x; + // head id + int const head_idx = blockIdx.x; + // request idx + int const request_idx = blockIdx.y; + + int const beam_request_idx = + is_beam ? request_idx / max_beam_width : request_idx; + int const beam_sub_request_idx = is_beam ? request_idx % max_beam_width : 0; + + int const first_step = 0; + + int const tlength = + request_infos[beam_request_idx].first_token_depth_in_request + + request_infos[beam_request_idx].num_tokens_in_batch; + + // shared memory objects + extern __shared__ char smem_[]; + + float *qk_smem = reinterpret_cast(smem_); + float *out_smem = reinterpret_cast(smem_); + + float qk_max = -FLT_MAX; + + // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + const DT *q_ptr = query + beam_request_idx * hidden_size * QKV_WEIGHT_NUM + + head_idx * per_head_size; + __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; + // DT const *q_ptr = + // query + request_idx * Dh * QKV_WEIGHT_NUM + head_idx * per_head_size; + + // q tensor in this thread + // if THREADS_PER_KEY is 4, first thread load 0, 4, 8, 12..., total + // K_VECS_PER_THREAD elements + // QK_vec_k: 32->1, 64->2, 128->4... head_size + // K_vec_k: 4->1, 2->2, 1->4 threads_per_key + + // the start offset of the element eg. (0, 1, 2, 3) * K_VEC_SIZE + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + int ki_o = tidx % THREADS_PER_KEY; + // the first key's offset for this thread + // ko = 0, 0, 0, 0, 1, 1, 1, 1, .... + int ko = tidx / THREADS_PER_KEY; + // load q tensor + Q_vec q_vec[K_VECS_PER_THREAD]; +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q_vecs[ki_o][ii] = *reinterpret_cast( + q_ptr + ki + ii * THREADS_PER_KEY * K_VEC_SIZE); + } + __syncthreads(); + // first iter = 128 / 4 = 32 + // K_VECS_PER_THREAD = 32 + // K_PER_ITER how many keys in this loop + // The number of timesteps loaded per iteration. + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + DT const *k_cache_batch = + key_cache + + (beam_request_idx * max_beam_width + beam_sub_request_idx) * + max_seq_length * hidden_size + + ki; + + int ti_end = + div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; + // get k, perform qk proj + + for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { + K_vec k[K_VECS_PER_THREAD]; + int const ti_circ = ti % max_seq_length; +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; + if (ti < tlength) { + k[ii] = *reinterpret_cast(k_cache_batch + + ti_circ * hidden_size + + head_idx * per_head_size + jj); + } + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + } + float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); + // // todo add positional embedding to the qk production + // // Store the product to shared memory. There's one qk value per + // timestep. + // // Update the max. + if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + // todo add alobi here + bool const mask = ti_circ >= tlength; + if (mask) { + assert(false); + } + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + qk_smem[ti - first_step] = mask ? 0.f : qk; + } + } + + __syncthreads(); + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + int const warp = tidx / WARP_SIZE; + int const lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + float exp_sum = 0.f; + for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { + float logit = __expf(qk_smem[ti - first_step] - qk_max); + exp_sum += logit; + qk_smem[ti - first_step] = logit; + } + + // Compute the sum. + exp_sum = block_sum(&red_smem[WARPS_PER_BLOCK], exp_sum); + + // softmax + float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); + for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { + qk_smem[ti - first_step] *= inv_sum; + } + + __syncthreads(); + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + // printf("softmax %.10f\n", qk_smem[0]); + // } + + // value projection + constexpr int V_VEC_SIZE = 16 / sizeof(DT); + // A vector of V elements for the current timestep. + // using V_vec_k = typename V_vec_k_::Type; + // using V_vec_acum = typename V_vec_acum_fp32_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + Out_sum out; + zero(out); + + // The base pointer for the value in the cache buffer. + DT const *v_cache_batch = + value_cache + + (beam_request_idx * max_beam_width + beam_sub_request_idx) * + max_seq_length * hidden_size + + vi; + + if (Dh == Dh_MAX || vi < Dh) { + for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { + // Load the values from the cache. + int const ti_circ = ti % max_seq_length; + + V_vec v = *reinterpret_cast( + v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); + float logit = qk_smem[ti - first_step]; + out = FlexFlow::fma(logit, cast_to_float(v), out); + } + } + + // // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different + // partial outputs. + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; + active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { + *reinterpret_cast(out_smem + (vo - midpoint) * Dh + vi) = + out; + } + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(out_smem + vo * Dh + vi), + out); + } + __syncthreads(); + } + } + + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { + convert_from_float( + *reinterpret_cast(output_ptr + beam_request_idx * hidden_size + + head_idx * per_head_size + vi), + out); + } +} + // only used by MPT model. https://arxiv.org/abs/2108.12409 template __global__ void apply_position_bias_qkprd(DT *input_ptr, @@ -350,6 +619,117 @@ void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m, } } +template +void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + DT *output_ptr, + DT const *weight_ptr, + DT const *bias_ptr, + int num_tokens, + cudaStream_t stream) { + cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); + cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]); + assert(data_type_size(m->output_type[0]) == sizeof(DT)); +#if CUDA_VERSION >= 11000 + // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance + cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; +#else + cudaDataType_t compute_type = cublas_data_type; +#endif + // Project to output, save result directly on output tensor + DT alpha = 1.0f, beta = 0.0f; + // int num_tokens = bc->num_active_tokens(); + int m_ = m->oProjSize; + int k = m->vProjSize * m->num_q_heads; + int n = num_tokens; + int lda = k, ldb = k, ldc = m_; + DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads + + m->kProjSize * m->num_q_heads + + m->vProjSize * m->num_q_heads); + DT const *B = static_cast
(m->attn_heads); + DT *C = static_cast
(output_ptr); + + checkCUDA(cublasGemmEx(m->handle.blas, + CUBLAS_OP_T, + CUBLAS_OP_N, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + B, + cublas_data_type, + ldb, + &beta, + C, + cublas_data_type, + ldc, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + if (*m->final_bias && shard_id == 0) { + int parallelism = m->oProjSize * num_tokens; + int qkv_weight_size = m->qProjSize * m->global_num_q_heads + + m->kProjSize * m->global_num_q_heads + + m->vProjSize * m->global_num_q_heads; + apply_proj_bias_w<<>>( + output_ptr, bias_ptr, num_tokens, qkv_weight_size, m->oProjSize); + } +} + +#define LAUNCH_ATTENTION_SCORE_KERNEL( \ + DT, Dh, Dh_MAX, THDS_PER_KEY, THREADS_PER_VALUE, THDS_PER_BLOCK, stream) \ + smem_sz = smem_size_in_bytes
(m->qProjSize, \ + BatchConfig::max_sequence_length(), \ + THREADS_PER_VALUE, \ + THDS_PER_BLOCK); \ + compute_attention_kernel_generation_kernel \ + <<>>( \ + static_cast
(m->devQKVProjArray), \ + static_cast
(m->keyCache), \ + static_cast
(m->valueCache), \ + output_ptr, \ + scale, \ + BatchConfig::max_sequence_length(), \ + m->qProjSize, \ + m->hidden_size, \ + m->request_infos, \ + false, \ + 0) + +template +void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + DT *output_ptr, + cudaStream_t stream) { + dim3 grid(m->num_q_heads, bc->num_active_requests()); + int const per_head_size = m->qProjSize; + float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; + size_t smem_sz; + if (per_head_size == 64) { + constexpr int THREADS_PER_VALUE_64 = threads_per_value_t::value; + LAUNCH_ATTENTION_SCORE_KERNEL( + DT, 64, 64, 4, THREADS_PER_VALUE_64, 128, stream); + } else if (per_head_size == 128) { + constexpr int THREADS_PER_VALUE_128 = threads_per_value_t::value; + LAUNCH_ATTENTION_SCORE_KERNEL( + DT, 128, 128, 4, THREADS_PER_VALUE_128, 128, stream); + } else { + assert(false && "a unsupported head size"); + } +} + template void pre_build_weight_kernel(IncMultiHeadSelfAttentionMeta const *m, GenericTensorAccessorR const weight, @@ -419,18 +799,26 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta const *m, DT *output_ptr, DT const *bias_ptr, cudaStream_t stream) { - // here because we need position info in inference 1 if (m->offload && m->biasSize > 0) { cudaMemcpyAsync( m->bias_ptr, bias_ptr, m->biasSize, cudaMemcpyHostToDevice, stream); bias_ptr = static_cast
(m->bias_ptr); } + + // todo Xinhao copy how many requests if requests are not continous? cudaMemcpyAsync(m->token_infos, &(bc->tokensInfo), bc->num_active_tokens() * sizeof(BatchConfig::PerTokenInfo), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(m->request_infos, + &(bc->requestsInfo), + bc->max_requests_per_batch() * + sizeof(BatchConfig::PerRequestInfo), + cudaMemcpyHostToDevice, + stream); + // phase 1: Implement kernel to compute KQV for input tokens compute_qkv_kernel(m, bc, @@ -440,14 +828,24 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta const *m, static_cast
(m->devQKVProjArray), bias_ptr, stream); - - // phase 2: Update key/val cache update_kv_cache_kernel
(m, bc, stream); - // phase 3: Compute attention score - // 3 kernels for pahse 3: matmul1 - softmax - matmal2 - compute_attention_kernel( - m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); + if (bc->num_generation_tokens > 0) { + // phase 3: Compute attention score for generation tokens + compute_attention_kernel_generation
( + m, bc, static_cast
(m->attn_heads), stream); + } + + if (bc->num_tokens > bc->num_generation_tokens) { + // phase 4: Compute attention score for prompt tokens; + compute_attention_kernel_prompt( + m, bc, shard_id, bias_ptr, weight_ptr, stream); + } + + // compute output production and bias together for all tokens + int num_tokens = bc->num_active_tokens(); + compute_o_prod_bias( + m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); } } // namespace IncMultiHeadAttention @@ -501,13 +899,12 @@ __global__ void fill_entries_above_diagonal(DT *matrix, } template -void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - int shard_id, - DT *output_ptr, - DT const *bias_ptr, - DT const *weight_ptr, - cudaStream_t stream) { +void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + DT const *bias_ptr, + DT const *weight_ptr, + cudaStream_t stream) { checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); @@ -675,8 +1072,11 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, B = C_softmax; // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous // requests + + // store the result attn heads, also skip the genration tokens C = static_cast
(m->attn_heads) + - tokens_previous_requests * m->num_q_heads * m->vProjSize; + (tokens_previous_requests + bc->num_generation_tokens) * + m->num_q_heads * m->vProjSize; checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, CUBLAS_OP_N, CUBLAS_OP_T, @@ -702,52 +1102,6 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); tokens_previous_requests += num_new_tokens; } - - // Project to output, save result directly on output tensor - DT alpha = 1.0f, beta = 0.0f; - int m_ = m->oProjSize; - int k = m->vProjSize * m->num_q_heads; - int n = bc->num_active_tokens(); - int lda = k, ldb = k, ldc = m_; - DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads + - m->kProjSize * m->num_q_heads + - m->vProjSize * m->num_q_heads); - DT const *B = static_cast
(m->attn_heads); - DT *C = static_cast
(output_ptr); - - checkCUDA(cublasGemmEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - B, - cublas_data_type, - ldb, - &beta, - C, - cublas_data_type, - ldc, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - if (*m->final_bias && shard_id == 0) { - int parallelism = m->oProjSize * num_tokens; - int qkv_weight_size = m->qProjSize * m->global_num_q_heads + - m->kProjSize * m->global_num_q_heads + - m->vProjSize * m->global_num_q_heads; - - apply_proj_bias_w<<>>( - output_ptr, bias_ptr, num_tokens, qkv_weight_size, m->oProjSize); - } - assert(tokens_previous_requests == num_tokens); } @@ -811,6 +1165,7 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( } else { assert(false && "Unspported data type"); } + if (m->profiling) { cudaEventRecord(t_end, stream); checkCUDA(cudaEventSynchronize(t_end)); @@ -819,38 +1174,6 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( cudaEventDestroy(t_start); cudaEventDestroy(t_end); printf("IncMultiHeadSelfAttention forward time = %.9fms\n", elapsed); - - // if (input.data_type == DT_HALF) { - // print_tensor(input.get_half_ptr(), - // 32, - // "[IncMultiHeadSelfAttention:forward:input]"); - // print_tensor(weight.get_half_ptr(), - // 32, - // "[IncMultiHeadSelfAttention:forward:weight]"); - // print_tensor(output.get_half_ptr(), - // 32, - // "[IncMultiHeadSelfAttention:forward:output]"); - // print_tensor( - // bias.get_half_ptr(), 32, - // "[IncMultiHeadSelfAttention:forward:bias]"); - // } else { - // print_tensor(input.get_float_ptr(), - // 32, - // "[IncMultiHeadSelfAttention:forward:input]"); - // print_tensor(weight.get_float_ptr(), - // 32, - // "[IncMultiHeadSelfAttention:forward:weight]"); - // print_tensor(output.get_float_ptr(), - // 32, - // "[IncMultiHeadSelfAttention:forward:output]"); - // print_tensor( - // bias.get_float_ptr(), 32, - // "[IncMultiHeadSelfAttention:forward:bias]"); - // } - - // print_tensor<3, float>(acc_query.ptr, acc_query.rect, - // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, - // acc_output.rect, "[Attention:forward:output]"); } } @@ -1013,6 +1336,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( default: assert(false && "Unkown inference mode"); } + size_t requestinfo_size = BatchConfig::max_requests_per_batch(); size_t tokeninfo_size = max_tokens_per_batch; size_t qk_prod_size = max_tokens_per_batch * BatchConfig::max_sequence_length() * num_q_heads; @@ -1025,8 +1349,10 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( 2 * qk_prod_size + attn_heads_size) * size_of_dt + tokeninfo_size * sizeof(BatchConfig::PerTokenInfo) + - complex_size * sizeof(cuFloatComplex); // more components will - // be added here later + complex_size * sizeof(cuFloatComplex) + + requestinfo_size * + sizeof(BatchConfig::PerRequestInfo); // more components will + // be added here later if (offload) { // assert that we have enough reserved work space left size_t totalSharedSize = @@ -1086,6 +1412,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( complex_input = gpu_mem_allocator.allocate_reserved(complex_size); // offset += complex_size * sizeof(cuFloatComplex); + request_infos = + gpu_mem_allocator.allocate_reserved( + requestinfo_size); } else { token_infos = gpu_mem_allocator.allocate_instance( @@ -1098,6 +1427,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( size_of_dt); complex_input = gpu_mem_allocator.allocate_instance(complex_size); + request_infos = + gpu_mem_allocator.allocate_instance( + requestinfo_size); } // allocate more size for quantization data @@ -1131,5 +1463,4 @@ template void Kernels::IncMultiHeadAttention::pre_build_weight_kernel( GenericTensorAccessorR const weight, DataType data_type, cudaStream_t stream); - }; // namespace FlexFlow diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 52e083889e..6dad1c6de9 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -17,6 +17,7 @@ #endif #include "flexflow/ffconst_utils.h" #include "flexflow/ops/kernels/inc_multihead_self_attention_kernels.h" +#include "flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh" #include "flexflow/ops/spec_inc_multihead_self_attention.h" #include "flexflow/utils/cuda_helper.h" @@ -203,13 +204,13 @@ __global__ void spec_fill_entries_above_diagonal(DT *matrix, } template -void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, - BeamSearchBatchConfig const *bc, - int shard_id, - DT *output_ptr, - DT const *bias_ptr, - DT const *weight_ptr, - cudaStream_t stream) { +void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, + BeamSearchBatchConfig const *bc, + int shard_id, + DT *output_ptr, + DT const *bias_ptr, + DT const *weight_ptr, + cudaStream_t stream) { checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); @@ -228,7 +229,7 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, #endif // int num_requests = bc->num_active_requests(); int num_tokens = bc->num_active_tokens(); - // int tokens_previous_requests = 0; + int tokens_previous_requests = 0; int tokens_prev_requests_squares = 0; // int qkv_block_size = // (m->qProjSize + m->kProjSize + m->vProjSize) * num_tokens; @@ -399,8 +400,8 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous // requests C = static_cast
(m->attn_heads) + - bc->requestsInfo[i].first_token_offset_in_batch * m->num_q_heads * - m->vProjSize; + (tokens_previous_requests + bc->num_generation_tokens) * + m->num_q_heads * m->vProjSize; checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, CUBLAS_OP_N, CUBLAS_OP_T, @@ -425,54 +426,11 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // tokens_previous_requests += num_new_tokens; + tokens_previous_requests += num_new_tokens; tokens_prev_requests_squares += num_new_tokens * total_tokens; } } - // Project to output, save result directly on output tensor - DT alpha = 1.0f, beta = 0.0f; - int m_ = m->oProjSize; - int k = m->vProjSize * m->num_q_heads; - int n = bc->num_active_tokens(); - int lda = k, ldb = k, ldc = m_; - DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads + - m->kProjSize * m->num_q_heads + - m->vProjSize * m->num_q_heads); - DT const *B = static_cast
(m->attn_heads); - DT *C = static_cast
(output_ptr); - - checkCUDA(cublasGemmEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - B, - cublas_data_type, - ldb, - &beta, - C, - cublas_data_type, - ldc, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - if (*m->final_bias && shard_id == 0) { - int parallelism = m->oProjSize * num_tokens; - int qkv_weight_size = m->qProjSize * m->global_num_q_heads + - m->kProjSize * m->global_num_q_heads + - m->vProjSize * m->global_num_q_heads; - apply_proj_bias_w<<>>( - output_ptr, bias_ptr, num_tokens, qkv_weight_size, m->oProjSize); - } - // assert(tokens_previous_requests == num_tokens); } @@ -520,11 +478,23 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, stream); // phase 2: Update key/val cache update_kv_cache_kernel
(m, bc, stream); - + if (bc->num_generation_tokens > 0) { + compute_attention_kernel_generation
( + m, bc, static_cast
(m->attn_heads), stream); + } // phase 3: Compute attention score // 3 kernels for pahse 3: matmul1 - softmax - matmal2 - compute_attention_kernel( - m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); + if (bc->num_tokens > bc->num_generation_tokens) { + compute_attention_kernel_prompt( + m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); + } + + // compute output production and bias together for all tokens + int num_tokens = + bc->num_active_tokens() * BeamSearchBatchConfig::MAX_BEAM_WIDTH; + + compute_o_prod_bias( + m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); } } // namespace SpecIncMultiHeadAttention @@ -643,7 +613,6 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( size_t beam_requestinfo_size = BeamSearchBatchConfig::max_requests_per_batch(); size_t total_size = - requestinfo_size * sizeof(BatchConfig::PerRequestInfo) + beam_tokeninfo_size * sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo) + beam_requestinfo_size * @@ -660,10 +629,6 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( beam_tokeninfo_size); // offset += beam_tokeninfo_size * // sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo); - request_infos = - gpu_mem_allocator.allocate_instance( - requestinfo_size); - // offset += requestinfo_size * sizeof(BatchConfig::PerRequestInfo); beam_request_infos = gpu_mem_allocator .allocate_instance( diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 0aa50f605c..bc7d1017b7 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -17,6 +17,7 @@ #endif #include "flexflow/ffconst_utils.h" #include "flexflow/ops/kernels/inc_multihead_self_attention_kernels.h" +#include "flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh" #include "flexflow/ops/tree_inc_multihead_self_attention.h" #include "flexflow/utils/cuda_helper.h" @@ -26,11 +27,251 @@ namespace FlexFlow { using Legion::coord_t; using Legion::Memory; +#define WARP_SIZE 32 + using namespace Kernels::IncMultiHeadAttention; namespace Kernels { namespace TreeIncMultiHeadAttention { +template +__global__ void compute_attention_kernel_fused_kernel( + DT const *query, + DT const *key_cache, + DT const *value_cache, + DT *output_ptr, + float const scale, + int const max_seq_length, + int const max_token_per_batch, + int per_head_size, + int hidden_size, + BatchConfig::PerRequestInfo *request_infos, + int num_heads, + int num_requests, + int qk_smem_sz) { + + // q, k + using Q_vec = typename VEC_K::Type; + using K_vec = typename VEC_K::Type; + using V_vec = typename VEC_V
::Type; + using Out_sum = typename Vec_fp32_::Type; + + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(DT); + constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY; + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + // constexpr int QK_ELTS_IN_16B = 16 / sizeof(DT); + + // thread id + int const tidx = threadIdx.x; + // head id + int const head_idx = blockIdx.x; + // request idx + int const request_idx = blockIdx.y; + + int const first_step = 0; + + int const tlength = request_infos[request_idx].first_token_depth_in_request + + request_infos[request_idx].num_tokens_in_batch; + int const qlength = request_infos[request_idx].num_tokens_in_batch; + + int first_token_idx = 0; + for (int r = 0; r < request_idx; r++) { + first_token_idx += request_infos[request_idx].num_tokens_in_batch; + } + + // shared memory objects + extern __shared__ char smem_[]; + + float *qk_smem = reinterpret_cast(smem_); + float *out_smem = reinterpret_cast(smem_ + qk_smem_sz); + + float qk_max = -FLT_MAX; + + // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + const DT *q_ptr = query + first_token_idx * hidden_size * QKV_WEIGHT_NUM + + head_idx * per_head_size; + __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; + + // the start offset of the element eg. (0, 1, 2, 3) * K_VEC_SIZE + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + int ki_o = tidx % THREADS_PER_KEY; + // the first key's offset for this thread + // ko = 0, 0, 0, 0, 1, 1, 1, 1, .... + int ko = tidx / THREADS_PER_KEY; + // load q tensor + Q_vec q_vec[K_VECS_PER_THREAD]; + + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + DT const *k_cache_batch = + key_cache + request_idx * max_seq_length * hidden_size + ki; + + int ti_end = + div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; + + for (int qi = 0; qi < qlength; qi += 1) { +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q_vecs[ki_o][ii] = *reinterpret_cast( + q_ptr + (hidden_size * QKV_WEIGHT_NUM * qi) + ki + + ii * THREADS_PER_KEY * K_VEC_SIZE); + } + __syncthreads(); + for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { + K_vec k[K_VECS_PER_THREAD]; + int const ti_circ = ti % max_seq_length; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; + if (ti < tlength) { + k[ii] = *reinterpret_cast( + k_cache_batch + ti_circ * hidden_size + head_idx * per_head_size + + jj); + } + } + float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); + + if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + bool const mask = ti_circ >= tlength; + if (mask) { + assert(false); + } + + int pos = ti * qlength + qi; + if (((pos / qlength) % tlength) > (pos % qlength + tlength - qlength)) { + qk = -FLT_MAX; + } + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + + qk_smem[pos] = mask ? 0.f : qk; + } + } + __syncthreads(); + + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + int const warp = tidx / WARP_SIZE; + int const lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; + + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + float exp_sum = 0.f; + + for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { + float logit = __expf(qk_smem[ti * qlength + qi] - qk_max); + exp_sum += logit; + qk_smem[ti * qlength + qi] = logit; + } + + // Compute the sum. + exp_sum = block_sum(&red_smem[WARPS_PER_BLOCK], exp_sum); + + // softmax + float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); + + for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { + qk_smem[ti * qlength + qi] *= inv_sum; + } + + __syncthreads(); + } + + // value projection + constexpr int V_VEC_SIZE = 16 / sizeof(DT); + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + Out_sum out; + // The base pointer for the value in the cache buffer. + DT const *v_cache_batch = + value_cache + request_idx * max_seq_length * hidden_size + vi; + + for (int qi = 0; qi < qlength; qi++) { + zero(out); + __syncthreads(); + if (Dh == Dh_MAX || vi < Dh) { + for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { + // Load the values from the cache. + int const ti_circ = ti % max_seq_length; + + V_vec v = *reinterpret_cast( + v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); + float logit = qk_smem[ti * qlength + qi]; + out = FlexFlow::fma(logit, cast_to_float(v), out); + } + } + + // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different + // partial outputs. + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; + active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { + *reinterpret_cast(out_smem + (vo - midpoint) * Dh + vi) = + out; + } + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(out_smem + vo * Dh + vi), + out); + } + __syncthreads(); + } + } + + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { + convert_from_float(*reinterpret_cast( + output_ptr + (first_token_idx + qi) * hidden_size + + head_idx * per_head_size + vi), + out); + } + } +} + template __global__ void commit_tokens_kernel( DT const *devQKVProjArray, @@ -128,6 +369,37 @@ __global__ void update_tree_branch_kv_cache( } } +template +__global__ void update_tree_branch_kv_cache_fused( + DT const *devQKVProjArray, + DT *kCache_ptr, + DT *vCache_ptr, + TreeVerifyBatchConfig::PerTokenInfo const *tokenInfos, + int qProjSize, + int kProjSize, + int vProjSize, + int num_new_tokens, + int max_seq_len, + int hidden_size) { + CUDA_KERNEL_LOOP(i, num_new_tokens * hidden_size) { + + int token_idx = i / hidden_size; + int offset = i % hidden_size; + size_t val_idx = + token_idx * QKV_WEIGHT_NUM * hidden_size + hidden_size + offset; + + DT kVal = devQKVProjArray[val_idx]; + DT vVal = devQKVProjArray[val_idx + hidden_size]; + + int const req_id = tokenInfos[token_idx].request_index; + int const tok_id = tokenInfos[token_idx].abs_depth_in_request; + kCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + + offset] = kVal; + vCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + + offset] = vVal; + } +} + template __global__ void tree_fill_entries_above_diagonal(DT *matrix, size_t new_tokens, @@ -200,6 +472,9 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, num_new_tokens++; } + std::cout << "num_new_tokens: " << num_new_tokens << "\n"; + assert(false); + int total_tokens_in_request = bc->tokensInfo[j].abs_depth_in_request + 1; assert(num_new_tokens >= 1 && total_tokens_in_request >= num_new_tokens); { @@ -438,6 +713,79 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, assert(processed_tokens_in_batch == bc->num_active_tokens()); } +#define LAUNCH_TREE_VERIFY_ATTENTION_SCORE_KERNEL( \ + DT, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ + smem_size_in_bytes_tree
(m->qProjSize, \ + BatchConfig::max_sequence_length(), \ + THDS_PER_VALUE, \ + THDS_PER_BLOCK, \ + bc, \ + smem_sz); \ + compute_attention_kernel_fused_kernel \ + <<>>( \ + static_cast
(m->devQKVProjArray), \ + static_cast
(m->keyCache), \ + static_cast
(m->valueCache), \ + output_ptr, \ + scale, \ + BatchConfig::max_sequence_length(), \ + BatchConfig::max_tokens_per_batch(), \ + m->qProjSize, \ + m->hidden_size, \ + m->request_infos, \ + m->num_q_heads, \ + bc->num_active_requests(), \ + smem_sz[0]) + +template +void compute_attention_kernel_fused(IncMultiHeadSelfAttentionMeta const *m, + TreeVerifyBatchConfig const *bc, + DT *output_ptr, + cudaStream_t stream) { + + // update the kv cache + // update K-V cache + int num_new_tokens = bc->num_active_tokens(); + int parallelism = m->hidden_size * num_new_tokens; + update_tree_branch_kv_cache_fused<<>>( + static_cast
(m->devQKVProjArray), + static_cast
(m->keyCache), + static_cast
(m->valueCache), + m->token_infos, + m->qProjSize, + m->kProjSize, + m->vProjSize, + num_new_tokens, + BatchConfig::max_sequence_length(), + m->hidden_size); + + dim3 grid(m->num_q_heads, bc->num_active_requests()); + int const per_head_size = m->qProjSize; + float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; + + // 0->qk production size, 1->total shared size + int smem_sz[2]; + if (per_head_size == 64) { + constexpr int THREADS_PER_VALUE_64 = threads_per_value_t::value; + LAUNCH_TREE_VERIFY_ATTENTION_SCORE_KERNEL( + DT, 64, 64, 4, THREADS_PER_VALUE_64, 128, stream); + } else if (per_head_size == 128) { + constexpr int THREADS_PER_VALUE_128 = threads_per_value_t::value; + LAUNCH_TREE_VERIFY_ATTENTION_SCORE_KERNEL( + DT, 128, 128, 4, THREADS_PER_VALUE_128, 128, stream); + } else { + assert(false && "a unsupported head size"); + } +} + template void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, TreeVerifyBatchConfig const *bc, @@ -463,6 +811,7 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, bias_ptr = static_cast
(m->bias_ptr); } } + // copy committed tokens info to GPU for the commit_tokens kernel // Note that m->num_active_tokens stores the number of active // tokens in the previous batch, which is needed for committing @@ -491,6 +840,12 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, sizeof(TreeVerifyBatchConfig::PerTokenInfo), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(m->request_infos, + &(bc->requestsInfo), + bc->max_requests_per_batch() * + sizeof(BatchConfig::PerRequestInfo), + cudaMemcpyHostToDevice, + stream); // phase 1: Implement kernel to compute KQV for input tokens compute_qkv_kernel(m, bc, @@ -504,11 +859,20 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // phase 2: No need to update key/val cache // IncMultiHeadSelfAttention::update_kv_cache_kernel( // m, bc, stream); + // use the new kernel + compute_attention_kernel_fused
( + m, bc, static_cast
(m->attn_heads), stream); + + int processed_tokens_in_batch = bc->num_active_tokens(); - // phase 3: Compute attention score - // 3 kernels for pahse 3: matmul1 - softmax - matmal2 - compute_attention_kernel( - m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); + compute_o_prod_bias(m, + bc, + shard_id, + output_ptr, + weight_ptr, + bias_ptr, + processed_tokens_in_batch, + stream); } } // namespace TreeIncMultiHeadAttention @@ -583,10 +947,6 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); cudaEventDestroy(t_start); cudaEventDestroy(t_end); - printf("TreeIncMultiHeadSelfAttention forward time = %.2fms\n", elapsed); - // print_tensor<3, float>(acc_query.ptr, acc_query.rect, - // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, - // acc_output.rect, "[Attention:forward:output]"); } } diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index f1164d3c49..7c37f3391e 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -357,6 +357,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, // log_req_mgr.print("Output: %s", output.c_str()); } } + int num_generation_tokens = 0; // Step 2: prepare the next batch for existing requests BatchConfig new_bc; @@ -450,6 +451,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, request.tokens.size()) { // Incremental phase new_bc.requestsInfo[i].num_tokens_in_batch = 1; + num_generation_tokens++; } else { // Prompt phase new_bc.requestsInfo[i].num_tokens_in_batch = @@ -471,6 +473,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, } } } + new_bc.num_generation_tokens = num_generation_tokens; // Step 3: add new requests to the next batch for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) { @@ -563,6 +566,8 @@ BeamSearchBatchConfig new_bc.model_id = model_id; int result_index = 0; + int num_generation_tokens = 0; + for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) { if (old_bc.request_completed[i]) { continue; @@ -889,6 +894,7 @@ BeamSearchBatchConfig } } } + new_bc.num_generation_tokens = num_generation_tokens; if (verbose) { std::cout << "prepare_next_batch_init OLD vs NEW batchconfigs below:" @@ -951,6 +957,7 @@ BeamSearchBatchConfig BeamSearchBatchConfig new_bc; new_bc.model_id = old_bc.model_id; // std::cout << "old_bc.model_id: " << old_bc.model_id << "\n"; + int num_generation_tokens = 0; // Add incremental tokens to the batch for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) { @@ -1155,11 +1162,13 @@ BeamSearchBatchConfig new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = k; new_bc.num_tokens++; + num_generation_tokens++; } } } } + new_bc.num_generation_tokens = num_generation_tokens; if (verbose) { std::cout << "prepare_next_batch_beam OLD vs NEW batchconfigs:" << std::endl;