Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Nov 23, 2023
1 parent 4bfee96 commit 03f7e5e
Show file tree
Hide file tree
Showing 9 changed files with 276 additions and 68 deletions.
7 changes: 7 additions & 0 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class BatchConfig {
static int const MAX_NUM_REQUESTS = 64;
static int const MAX_NUM_TOKENS = 1024;

// TODO change this
static int const MAX_PEFT_TOKENS = 10;

// Set by update
int num_tokens = 0, num_peft_tokens = 0, num_peft_label_tokens = 0;

Expand All @@ -72,6 +75,8 @@ class BatchConfig {
request_guid = 0;
peft_model_id = PEFTModelID::NO_ID;
peft_bwd = false;
peft_fwd_tokens = 0;
peft_bwd_tokens = 0;
}
int first_token_depth_in_request;
int first_token_offset_in_batch;
Expand All @@ -81,6 +86,8 @@ class BatchConfig {
// PEFT fields
PEFTModelID peft_model_id;
bool peft_bwd;
size_t peft_fwd_tokens;
size_t peft_bwd_tokens;
};
struct PerTokenInfo {
int abs_depth_in_request;
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/ops/inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta {
// PEFT specific fields
void *softmax_activation_buffer;
void *query_activation_buffer;
void *keyGradCache, *valueGradCache;
};

}; // namespace FlexFlow
Expand Down
4 changes: 4 additions & 0 deletions include/flexflow/ops/kernels/softmax_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class SoftmaxMeta : public OpMeta {
public:
SoftmaxMeta(FFHandler handle,
Softmax const *softmax,
MemoryAllocator &gpu_mem_allocator,
Legion::Domain const &input_domain);
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
cudnnTensorDescriptor_t inputTensor;
Expand All @@ -23,6 +24,9 @@ class SoftmaxMeta : public OpMeta {
bool profiling;
bool inference_debugging;
int dim;

// PEFT partial loss
void *lm_head_cache;
};

namespace Kernels {
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/ops/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "flexflow/node.h"
#include "flexflow/operator.h"
#include "flexflow/ops/softmax_params.h"
#include "flexflow/utils/memory_allocator.h"

namespace FlexFlow {

Expand Down
129 changes: 101 additions & 28 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,31 @@ __global__ void fill_entries_above_diagonal(DT *matrix,
}
}

template <typename DT>
__global__ void update_key_value_gradient(DT *devQKVProjArray,
DT *kGradCache,
DT *vGradCache,
int update_tokens,
int hidden_size) {
CUDA_KERNEL_LOOP(i, update_tokens * hidden_size) {
int token_idx = i / hidden_size;
int offset = i % hidden_size;

size_t val_idx =
token_idx * QKV_WEIGHT_NUM * hidden_size + hidden_size + offset;
DT kVal = devQKVProjArray[val_idx];
DT vVal = devQKVProjArray[val_idx + hidden_size];

// key cache
kGradCache[i] += kVal;
vGradCache[i] += vVal;

// for computation
devQKVProjArray[val_idx] = kGradCache[i];
devQKVProjArray[val_idx + hidden_size] = vGradCache[i];
}
}

template <typename DT>
void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
Expand Down Expand Up @@ -504,11 +529,13 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
continue;
}
int num_tokens = bc->requestsInfo[i].num_tokens_in_batch;
int num_total_tokens = bc->requestsInfo[i].first_token_depth_in_request +
bc->requestsInfo[i].num_tokens_in_batch;
int num_total_tokens = bc->requestsInfo[i].peft_fwd_tokens;
int num_processed_tokens = bc->requestsInfo[i].peft_bwd_tokens;
// int num_total_tokens = bc->requestsInfo[i].first_token_depth_in_request +
// bc->requestsInfo[i].num_tokens_in_batch;
// Currently assume we are calculating gradients for all tokens
// of a request
assert(num_tokens == num_total_tokens);
// assert(num_tokens == num_total_tokens);
int kt_block_size = m->kProjSize;
int kt_req_block_size =
kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length();
Expand All @@ -531,9 +558,11 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
m->vProjSize * m->num_q_heads);
// matrix B: output gradients
// matrix B's layout: [num_new_tokens, oProjSize]
// DT const *B =
// output_grad_ptr +
// bc->requestsInfo[i].first_token_offset_in_batch * m->oProjSize;
DT const *B =
output_grad_ptr +
bc->requestsInfo[i].first_token_offset_in_batch * m->oProjSize;
output_grad_ptr + bc->requestsInfo[i].peft_bwd_tokens * m->oProjSize;
// matrix C: attn_heads gradients
// matrix C's layout: [num_new_tokens, num_heads, vProjSize]
DT *C = static_cast<DT *>(m->handle.workSpace);
Expand Down Expand Up @@ -565,20 +594,23 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
DT const *A = static_cast<DT *>(m->handle.workSpace);
// matrix B: qk_prods_softmax
// matrix B's layout: [num_heads, num_tokens, num_tokens]
DT const *B = static_cast<DT *>(m->qk_prods_softmax);
// DT const *B = static_cast<DT *>(m->qk_prods_softmax);
DT const *B = static_cast<DT *>(m->softmax_activation_buffer) +
(num_total_tokens - num_processed_tokens - num_tokens) *
num_total_tokens;
// matrix C: gradients for value (saved as part of m->devQKVProjArray)
// matrix C's layout: [num_tokens, num_heads, qProjsize + kProjSize +
// vProjSize]
DT *C =
static_cast<DT *>(m->devQKVProjArray) + m->qProjSize + m->kProjSize;
int m_ = m->vProjSize;
int n_ = num_tokens;
int n_ = num_total_tokens;
int k_ = num_tokens;
int lda = m->vProjSize * m->num_q_heads;
int ldb = num_tokens;
int ldc = m->num_q_heads * (m->qProjSize + m->kProjSize + m->vProjSize);
int strideA = m->vProjSize;
int strideB = num_tokens * num_tokens;
int strideB = num_tokens * num_total_tokens;
int strideC = m->qProjSize + m->kProjSize + m->vProjSize;
checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas,
CUBLAS_OP_T,
Expand Down Expand Up @@ -607,15 +639,15 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
// Step 3: compute gradients w.r.t. the qk_prods_softmax tensor
{
float alpha = 1.0f, beta = 0.0f;
int m_ = num_tokens;
int m_ = num_total_tokens;
int n_ = num_tokens;
int k_ = m->vProjSize;
int lda = m->vProjSize * m->num_q_heads;
int ldb = m->vProjSize * m->num_q_heads;
int ldc = num_tokens;
int strideA = m->vProjSize;
int strideB = m->vProjSize;
int strideC = num_tokens * num_tokens;
int strideC = num_tokens * num_total_tokens;
// matrix A: value cache
// matrix A's layout: [num_req, max_num_tokens, num_heads, vProjSize]
DT const *A = static_cast<DT *>(m->valueCache) + i * vt_req_block_size;
Expand Down Expand Up @@ -653,7 +685,7 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
{
float alpha = 1.0f, beta = 0.0f;
int n_param = m->num_q_heads;
int c_param = num_tokens;
int c_param = num_total_tokens;
int h_param = 1;
int w_param = num_tokens;
checkCUDNN(cudnnSetTensor4dDescriptor(m->qk_tensor,
Expand All @@ -674,8 +706,8 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
&beta,
m->qk_tensor,
m->qk_prods));
// TODO: fill all elements above diagonal to force causal attention
size_t entries_above_diagonal = num_tokens * (num_tokens - 1) / 2;
// fill all elements above diagonal to force causal attention
size_t entries_above_diagonal = num_tokens * (num_total_tokens - 1) / 2;
if (entries_above_diagonal > 0) {
size_t parallelism = m->num_q_heads * entries_above_diagonal;
fill_entries_above_diagonal<<<GET_BLOCKS(parallelism),
Expand All @@ -684,7 +716,7 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
0,
stream>>>(static_cast<DT *>(m->qk_prods),
num_tokens,
num_tokens,
num_total_tokens,
m->num_q_heads,
entries_above_diagonal,
DT(0.0f));
Expand All @@ -698,7 +730,9 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
}
// matrix A: query activation (in query_activation_buffer)
// matrix A's layout: [num_tokens, num_heads, m->qProjSize]
DT const *A = static_cast<DT *>(m->query_activation_buffer);
DT const *A = static_cast<DT *>(m->query_activation_buffer) +
(num_total_tokens - num_processed_tokens + num_tokens) *
m->hidden_size;
// matrix B: gradients w.r.t. qk_prods
// matrix B's layout: [num_heads, num_tokens, num_tokens]
DT const *B = static_cast<DT *>(m->qk_prods);
Expand All @@ -707,13 +741,13 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
// vProjSize]
DT *C = static_cast<DT *>(m->devQKVProjArray) + m->qProjSize;
int m_ = m->kProjSize;
int n_ = num_tokens;
int n_ = num_total_tokens;
int k_ = num_tokens;
int lda = m->num_q_heads * m->qProjSize;
int ldb = num_tokens;
int ldc = m->num_q_heads * (m->qProjSize + m->kProjSize + m->vProjSize);
int strideA = m->qProjSize;
int strideB = num_tokens * num_tokens;
int strideB = num_tokens * num_total_tokens;
int strideC = m->qProjSize + m->kProjSize + m->vProjSize;
checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas,
CUBLAS_OP_N,
Expand All @@ -739,6 +773,19 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
// update the key value gradient cache;
int update_tokens = num_total_tokens - num_processed_tokens;
int parallelism = m->hidden_size * num_tokens;
update_key_value_gradient<<<GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS, parallelism),
0,
stream>>>(static_cast<DT *>(m->devQKVProjArray),
static_cast<DT *>(m->keyGradCache),
static_cast<DT *>(m->valueGradCache),
update_tokens,
m->hidden_size);
// Step 6: compute gradients w.r.t query
{
float alpha = 1.0f, beta = 0.0f;
Expand All @@ -757,12 +804,12 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
DT *C = static_cast<DT *>(m->devQKVProjArray);
int m_ = m->qProjSize;
int n_ = num_tokens;
int k_ = num_tokens;
int k_ = num_total_tokens;
int lda = m->kProjSize * m->num_q_heads;
int ldb = num_tokens;
int ldc = m->num_q_heads * (m->qProjSize + m->kProjSize + m->vProjSize);
int strideA = m->kProjSize;
int strideB = num_tokens * num_tokens;
int strideB = num_tokens * num_total_tokens;
int strideC = m->qProjSize + m->kProjSize + m->vProjSize;
checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas,
CUBLAS_OP_N,
Expand Down Expand Up @@ -927,6 +974,7 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta *m,
if (bc->request_completed[i]) {
continue;
}
int start_offset = bc->requestsInfo[i].first_token_depth_in_request;
assert(tokens_previous_requests ==
bc->requestsInfo[i].first_token_offset_in_batch);
int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch;
Expand All @@ -935,16 +983,17 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta *m,
// Copy query to m->query_activation_buffer if we need to compute
// PEFT backward
if (bc->requestsInfo[i].peft_bwd) {
MemoryAllocator *allocator = m->handle.peft_activation_allocator;
m->query_activation_buffer = allocator->allocate_instance_untyped(
sizeof(DT) * total_tokens * m->num_q_heads * m->qProjSize);
// MemoryAllocator *allocator = m->handle.peft_activation_allocator;
// m->query_activation_buffer = allocator->allocate_instance_untyped(
// sizeof(DT) * total_tokens * m->num_q_heads * m->qProjSize);
int parallelism = m->hidden_size * num_tokens;
store_query_cache<<<GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS, parallelism),
0,
stream>>>(
static_cast<DT *>(m->devQKVProjArray),
static_cast<DT *>(m->query_activation_buffer),
static_cast<DT *>(m->query_activation_buffer) +
start_offset * m->hidden_size,
num_tokens,
m->hidden_size);
}
Expand Down Expand Up @@ -1065,10 +1114,12 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta *m,
// Copy C_softmax to m->softmax_activation_buffer if we need to compute
// PEFT backward
if (bc->requestsInfo[i].peft_bwd) {
MemoryAllocator *allocator = m->handle.peft_activation_allocator;
m->softmax_activation_buffer = allocator->allocate_instance_untyped(
sizeof(DT) * total_tokens * num_new_tokens * m->num_q_heads);
checkCUDA(cudaMemcpyAsync(m->softmax_activation_buffer,
// MemoryAllocator *allocator = m->handle.peft_activation_allocator;
// m->softmax_activation_buffer = allocator->allocate_instance_untyped(
// sizeof(DT) * total_tokens * num_new_tokens * m->num_q_heads);
DT *softmax_cache = static_cast<DT *>(m->softmax_activation_buffer) +
start_offset * bc->requestsInfo[i].peft_bwd_tokens;
checkCUDA(cudaMemcpyAsync(softmax_cache,
C_softmax,
sizeof(DT) * total_tokens * num_new_tokens *
m->num_q_heads,
Expand Down Expand Up @@ -1508,7 +1559,20 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
size_of_dt +
tokeninfo_size * sizeof(BatchConfig::PerTokenInfo) +
complex_size * sizeof(cuFloatComplex); // more components will
// be added here later
// assume we have only one peft requests.
size_t key_grad_cache_size =
num_q_heads * kProjSize * BatchConfig::max_sequence_length();
size_t value_grad_cache_size =
num_q_heads * vProjSize * BatchConfig::max_sequence_length();
size_t query_activation_size =
num_q_heads * qProjSize * BatchConfig::max_sequence_length();
size_t softmax_activation_size = num_q_heads *
BatchConfig::max_sequence_length() *
BatchConfig::max_sequence_length();
totalSize += (key_grad_cache_size + value_grad_cache_size +
query_activation_size + softmax_activation_size);
if (offload) {
// assert that we have enough reserved work space left
size_t totalSharedSize =
Expand Down Expand Up @@ -1551,6 +1615,15 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
valueCache = gpu_mem_allocator.allocate_instance_untyped(value_cache_size *
size_of_dt);
query_activation_buffer = gpu_mem_allocator.allocate_instance_untyped(
query_activation_size * size_of_dt);
softmax_activation_buffer = gpu_mem_allocator.allocate_instance_untyped(
softmax_activation_size * size_of_dt);
keyGradCache = gpu_mem_allocator.allocate_instance_untyped(
key_grad_cache_size * size_of_dt);
valueGradCache = gpu_mem_allocator.allocate_instance_untyped(
value_grad_cache_size * size_of_dt);
if (offload) {
token_infos =
gpu_mem_allocator.allocate_reserved<BatchConfig::PerTokenInfo>(
Expand Down
1 change: 1 addition & 0 deletions src/ops/kernels/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ using Legion::Domain;

SoftmaxMeta::SoftmaxMeta(FFHandler handler,
Softmax const *softmax,
MemoryAllocator &gpu_mem_allocator,
Domain const &input_domain)
: OpMeta(handler) {
checkCUDNN(miopenCreateTensorDescriptor(&inputTensor));
Expand Down
Loading

0 comments on commit 03f7e5e

Please sign in to comment.