Skip to content

Commit

Permalink
chore: remove redundant code
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzhuofu committed Jul 5, 2024
1 parent dfe4bec commit e08f06e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 91 deletions.
15 changes: 1 addition & 14 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1452,20 +1452,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
BatchConfig::max_sequence_length() * num_q_heads;
break;
}
case TREE_SEARCH_MODE: {
query_tmp_size =
num_q_heads * qProjSize * BatchConfig::max_tokens_per_batch();
// a K-ary tree max node is (k^n - 1) / 2
key_cache_size = num_q_heads * kProjSize *
BatchConfig::max_requests_per_batch() * max_num_pages *
kPagesize;
value_cache_size = num_q_heads * vProjSize *
BatchConfig::max_requests_per_batch() *
max_num_pages * kPagesize;
qk_prod_size = BatchConfig::max_sequence_length() * max_num_pages *
kPagesize * num_q_heads;
break;
}
case TREE_SEARCH_MODE:
case TREE_VERIFY_MODE: {
query_tmp_size =
num_q_heads * qProjSize * BatchConfig::max_tokens_per_batch();
Expand Down
86 changes: 9 additions & 77 deletions src/runtime/request_manager.cu
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ void RequestManager::load_batch_config_task(
total_copy_size += sizeof(BatchConfig::request_available);

// load speculative metadata
if (batch_config->get_mode() == TREE_SEARCH_MODE) {
if (batch_config->get_mode() == TREE_SEARCH_MODE or batch_config->get_mode() == TREE_VERIFY_MODE) {
for (int request_idx = 0;
request_idx < BatchConfig::max_requests_per_batch();
request_idx++) {
Expand Down Expand Up @@ -288,86 +288,18 @@ void RequestManager::load_batch_config_task(
batch_size);
}
}
} else if (batch_config->get_mode() == TREE_VERIFY_MODE) {
for (int request_idx = 0;
request_idx < BatchConfig::max_requests_per_batch();
request_idx++) {
if (batch_config->request_available[request_idx]) {
if (batch_config->get_mode() == TREE_VERIFY_MODE) {
if (batch_config->num_tokens_to_commit > 0) {
checkCUDA(cudaMemcpyAsync(
static_cast<char *>(handle.batch_config_metadata) +
total_copy_size + request_idx * sizeof(BatchConfig::BitMask),
&(batch_config->causalMask[request_idx]),
sizeof(BatchConfig::BitMask),
static_cast<char *>(handle.batch_config_metadata) + total_copy_size,
&(batch_config->committed_tokens),
batch_config->num_tokens_to_commit *
sizeof(BatchConfig::CommittedTokensInfo),
cudaMemcpyHostToDevice,
stream));
}
}
total_copy_size += sizeof(BatchConfig::causalMask);
if (batch_config->num_tokens_to_commit > 0) {
checkCUDA(cudaMemcpyAsync(
static_cast<char *>(handle.batch_config_metadata) + total_copy_size,
&(batch_config->committed_tokens),
batch_config->num_tokens_to_commit *
sizeof(BatchConfig::CommittedTokensInfo),
cudaMemcpyHostToDevice,
stream));
}
total_copy_size += sizeof(BatchConfig::committed_tokens);
// calculate the attention meta data
{
BatchConfig::PerRequestInfo *request_infos = reinterpret_cast<BatchConfig::PerRequestInfo *>(
static_cast<char *>(handle.batch_config_metadata) +
sizeof(BatchConfig::tokensInfo));
bool *request_available = reinterpret_cast<bool *>(
static_cast<char *>(handle.batch_config_metadata) +
sizeof(BatchConfig::tokensInfo) +
sizeof(BatchConfig::requestsInfo));
BatchConfig::BitMask *causalMask = reinterpret_cast<BatchConfig::BitMask *>(
static_cast<char *>(handle.batch_config_metadata) +
sizeof(BatchConfig::tokensInfo) +
sizeof(BatchConfig::requestsInfo) +
sizeof(BatchConfig::request_available));
int batch_size = batch_config->num_active_requests();
uint32_t const max_num_pages = (BatchConfig::max_sequence_length() +
BatchConfig::max_spec_tree_token_num() + kPagesize - 1) / kPagesize;
int parallelism = batch_size;
prepare_inference_params_kernel<<<GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS, parallelism),
0,
stream>>>(batch_size,
request_infos,
request_available,
max_num_pages,
handle.attention_metadata.q_indptr,
handle.attention_metadata.kv_indptr,
handle.attention_metadata.kv_indices,
handle.attention_metadata.kv_last_page_len,
handle.attention_metadata.qk_indptr);
// Update gpu-side custom mask referring from CaualMask
if (!batch_config->prompt_phase) {
int parallelism = 0;
for (int req_idx = 0; req_idx < batch_config->max_requests_per_batch(); req_idx++) {
if (batch_config->request_available[req_idx]) {
int q_len = batch_config->requestsInfo[req_idx].num_tokens_in_batch;
int kv_len = batch_config->requestsInfo[req_idx].num_tokens_in_batch +
batch_config->requestsInfo[req_idx].first_token_index_in_request;
parallelism += (q_len * kv_len + 7) / 8;
}
}
update_custom_mask_kernel<<<GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS, parallelism),
0,
stream>>>(handle.attention_metadata.custom_mask,
handle.attention_metadata.qk_indptr,
causalMask,
request_infos,
request_available,
batch_size);
}
total_copy_size += sizeof(BatchConfig::committed_tokens);
}
}
Expand Down

0 comments on commit e08f06e

Please sign in to comment.