Skip to content

Commit

Permalink
align backward of o_proj, attn_heads, qk_prods_softmax, and v_proj wi…
Browse files Browse the repository at this point in the history
…th huggingface
  • Loading branch information
goliaro committed Nov 30, 2023
1 parent a122e30 commit 53e737b
Showing 1 changed file with 58 additions and 25 deletions.
83 changes: 58 additions & 25 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,26 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
// compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
// }
// #endif
std::string op_name_without_uid = std::string(m->op_name);
size_t last_underscore = op_name_without_uid.length() - 1;
for (int i = op_name_without_uid.length() - 1; i > 0; i--) {
if (!(std::isdigit(m->op_name[i]) || m->op_name[i] == '_')) {
break;
} else if (m->op_name[i] == '_') {
last_underscore = i;
}
}
op_name_without_uid.erase(last_underscore);
std::string base_filepath =
"./inference_tensors/model_" + std::to_string(m->layer_guid.model_id) +
"_bwd-step_" + std::to_string(m->bwd_step) +
"_layer-num_" + std::to_string(m->layer_guid.transformer_layer_id) +
"_layer-name_" + op_name_without_uid + "_shard-id_" +
std::to_string(shard_id);
for (int i = 0; i < bc->max_requests_per_batch(); i++) {
if (bc->request_completed[i]) {
continue;
Expand All @@ -913,30 +933,31 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
int vt_block_size = m->vProjSize;
int vt_req_block_size =
vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length();
assert(m->qProjSize == m->kProjSize && m->kProjSize == m->vProjSize);
// Step 1: compute gradients before final projection
{
int m_ = m->vProjSize * m->num_q_heads;
int n_ = num_tokens;
int k_ = m->oProjSize;
int lda = k_;
int lda = m_;
int ldb = k_;
int ldc = m_;
float alpha = 1.0f, beta = 0.0f;
// matrix A: output projection weight
// matrix A's layout: [num_heads, vProjSize, oProjSize]
// matrix A's layout: [vProjSize * num_heads, oProjSize]
DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads +
m->kProjSize * m->num_q_heads +
m->vProjSize * m->num_q_heads);
// matrix B: output gradients
// matrix B's layout: [num_new_tokens, oProjSize]
// matrix B's layout: [oProjSize, num_new_tokens]
DT const *B =
output_grad_ptr +
bc->requestsInfo[i].first_token_offset_in_batch * m->oProjSize;
// matrix C: attn_heads gradients
// matrix C's layout: [num_new_tokens, num_heads, vProjSize]
// matrix C's layout: [vProjSize * num_heads, num_new_tokens]
DT *C = static_cast<DT *>(m->handle.workSpace);
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_N,
CUBLAS_OP_N,
m_,
n_,
Expand All @@ -954,33 +975,38 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
ldc,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// save result to file for checking
std::string filename = base_filepath + "_o_proj_in_grad";
std::cout << "FILENAME: " << filename << std::endl;
save_tensor(C, m_*n_, filename.c_str());
}
// Step 2: compute gradients w.r.t. value
{
float alpha = 1.0f, beta = 0.0f;
// matrix A: attn_heads gradients
// matrix A's layout: [num_tokens, num_heads, vProjSize]
DT const *A = static_cast<DT *>(m->handle.workSpace);
// matrix B: qk_prods_softmax
// matrix B's layout: [num_heads, num_tokens, num_tokens]
DT const *B = static_cast<DT *>(m->qk_prods_softmax);
// matrix A: qk_prods_softmax
// matrix A's layout: [num_new_tokens, total_tokens, num_heads]
DT const *A = static_cast<DT *>(m->qk_prods_softmax);
// matrix B: attn_heads gradients
// matrix B's layout: [vProjSize * num_heads, num_new_tokens]
DT const *B = static_cast<DT *>(m->handle.workSpace);
// matrix C: gradients for value (saved as part of m->devQKVProjArray)
// matrix C's layout: [num_tokens, num_heads, qProjsize + kProjSize +
// vProjSize]
DT *C =
static_cast<DT *>(m->devQKVProjArray) + m->qProjSize + m->kProjSize;
int m_ = m->vProjSize;
int n_ = num_tokens;
int k_ = num_tokens;
int lda = m->vProjSize * m->num_q_heads;
int ldb = num_tokens;
int ldc = m->num_q_heads * (m->qProjSize + m->kProjSize + m->vProjSize);
int strideA = m->vProjSize;
int strideB = num_tokens * num_tokens;
int strideC = m->qProjSize + m->kProjSize + m->vProjSize;
// matrix C's layout: [num_tokens, qProjsize * num_heads, 3]
DT *C = static_cast<DT *>(m->devQKVProjArray) + 2*(m->qProjSize * m->num_q_heads); // skip over regions reserved for Q and K gradients
// after transpositions
int m_ = num_tokens; // total_tokens
int n_ = m->vProjSize; // num_new_tokens
int k_ = num_tokens; // num_new_tokens
// before transpositions
int lda = num_tokens; // num_new_tokens
int ldb = m->vProjSize * m->num_q_heads;
int ldc = num_tokens; // total_tokens
// N.B. strides are applied before transpose operations
int strideA = num_tokens * num_tokens; // num_new_tokens * total_tokens
int strideB = m->vProjSize;
int strideC = num_tokens * m->vProjSize;
checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_N,
CUBLAS_OP_T,
m_,
n_,
k_,
Expand All @@ -1001,6 +1027,13 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m,
m->num_q_heads,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// save result to file for checking
std::string filename = base_filepath + "_v_proj_in_grad";
std::cout << "FILENAME: " << filename << std::endl;
save_tensor(C, m_*n_*m->num_q_heads, filename.c_str());
std::string filename2 = base_filepath + "_qk_prods_softmax";
std::cout << "FILENAME: " << filename2 << std::endl;
save_tensor(A, m_*k_*m->num_q_heads, filename2.c_str());
}
// Step 3: compute gradients w.r.t. the qk_prods_softmax tensor
{
Expand Down

0 comments on commit 53e737b

Please sign in to comment.