Skip to content

Commit

Permalink
[feature] add vision_embedding bias param
Browse files Browse the repository at this point in the history
  • Loading branch information
Alcanderian committed Jul 10, 2024
1 parent dc9a5b4 commit 8fbceda
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions src/ppl/nn/engines/llm_cuda/kernels/opmx/vision_embedding_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,27 @@ ppl::common::RetCode VisionEmbeddingKernel::DoExecute(KernelExecContext* ctx) {

#ifdef PPLNN_CUDA_ENABLE_CUDNN
PPLNN_LLM_CUDA_REQUIRED_INPUT(pixel_values, 0);
PPLNN_LLM_CUDA_REQUIRED_INPUT(cls_emb_weight, 1);
PPLNN_LLM_CUDA_OPTIONAL_INPUT(patch_emb_weight, 2);
PPLNN_LLM_CUDA_OPTIONAL_INPUT(pos_emb_weight, 3);
PPLNN_LLM_CUDA_REQUIRED_INPUT(class_weight, 1);
// TODO: fix optional input to required input. refer to pmx doc
PPLNN_LLM_CUDA_OPTIONAL_INPUT(patch_weight, 2);
PPLNN_LLM_CUDA_OPTIONAL_INPUT(position_weight, 3);
PPLNN_LLM_CUDA_OPTIONAL_INPUT(patch_bias, 4);
PPLNN_LLM_CUDA_REQUIRED_OUTPUT(vision_embeddings, 0);

PPLNN_LLM_CUDA_DEBUG_TRACE("Input [pixel_values]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(pixel_values);
PPLNN_LLM_CUDA_DEBUG_TRACE("Input [cls_emb_weight]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(cls_emb_weight);
PPLNN_LLM_CUDA_DEBUG_TRACE("Input [patch_emb_weight]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(patch_emb_weight);
PPLNN_LLM_CUDA_DEBUG_TRACE("Input [pos_emb_weight]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(pos_emb_weight);
PPLNN_LLM_CUDA_DEBUG_TRACE("Input [class_weight]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(class_weight);
PPLNN_LLM_CUDA_DEBUG_TRACE("Input [patch_weight]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(patch_weight);
PPLNN_LLM_CUDA_DEBUG_TRACE("Input [position_weight]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(position_weight);
void* patch_bias_data = nullptr;
if (patch_bias) {
PPLNN_LLM_CUDA_DEBUG_TRACE("Input [patch_bias]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(patch_bias);
patch_bias_data = patch_bias->GetBufferPtr();
}

PPLNN_LLM_CUDA_DEBUG_TRACE("hidden_dim: %d\n", param_->hidden_dim);
PPLNN_LLM_CUDA_DEBUG_TRACE("patch_size: %d\n", param_->patch_size);
Expand All @@ -64,6 +72,7 @@ ppl::common::RetCode VisionEmbeddingKernel::DoExecute(KernelExecContext* ctx) {
return ppl::common::RC_OTHER_ERROR;
}

config.bias_term = patch_bias != nullptr;
config.hidden_dim = param_->hidden_dim;
config.patch_size = param_->patch_size;
auto image_shape = pixel_values->GetShape();
Expand All @@ -75,6 +84,7 @@ ppl::common::RetCode VisionEmbeddingKernel::DoExecute(KernelExecContext* ctx) {
return ppl::common::RC_UNSUPPORTED;
}

// TODO: check retcode here
ppl::kernel::llm::cuda::pmx::vision_embedding_preprocessing(config);

BufferDesc buffers_desc;
Expand All @@ -89,16 +99,19 @@ ppl::common::RetCode VisionEmbeddingKernel::DoExecute(KernelExecContext* ctx) {
});
config.buffer_addr = buffers_desc.addr;

// TODO: check retcode here
ppl::kernel::llm::cuda::pmx::vision_embedding(
GetStream(),
config,
pixel_values->GetBufferPtr(),
patch_emb_weight->GetBufferPtr(), // [hidden_dim, image_channel, patch_size, patch_size]
cls_emb_weight->GetBufferPtr(), // [hidden_dim]
pos_emb_weight->GetBufferPtr(), // [num_positions * hidden_dim]
patch_weight->GetBufferPtr(), // [hidden_dim, image_channel, patch_size, patch_size]
patch_bias->GetBufferPtr(), // [hidden_dim]
class_weight->GetBufferPtr(), // [hidden_dim]
position_weight->GetBufferPtr(), // [num_positions * hidden_dim]
vision_embeddings->GetBufferPtr()
);

// TODO: check retcode here
ppl::kernel::llm::cuda::pmx::vision_embedding_postprocessing(config);

return ppl::common::RC_SUCCESS;
Expand Down

0 comments on commit 8fbceda

Please sign in to comment.