Skip to content

Commit

Permalink
[feature] support convolution with bias in vision_embedding.
Browse files Browse the repository at this point in the history
  • Loading branch information
jimurk committed Jul 16, 2024
1 parent cb14e84 commit 7e226fb
Showing 1 changed file with 29 additions and 16 deletions.
45 changes: 29 additions & 16 deletions src/ppl/nn/engines/llm_cuda/kernels/opmx/vision_embedding_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ ppl::common::RetCode VisionEmbeddingKernel::DoExecute(KernelExecContext* ctx) {
PPLNN_LLM_CUDA_REQUIRED_INPUT(pixel_values, 0);
PPLNN_LLM_CUDA_REQUIRED_INPUT(class_weight, 1);
// TODO: fix optional input to required input. refer to pmx doc
PPLNN_LLM_CUDA_OPTIONAL_INPUT(patch_weight, 2);
PPLNN_LLM_CUDA_OPTIONAL_INPUT(position_weight, 3);
PPLNN_LLM_CUDA_REQUIRED_INPUT(patch_weight, 2);
PPLNN_LLM_CUDA_REQUIRED_INPUT(position_weight, 3);
PPLNN_LLM_CUDA_OPTIONAL_INPUT(patch_bias, 4);
PPLNN_LLM_CUDA_REQUIRED_OUTPUT(vision_embeddings, 0);
PPLNN_LLM_CUDA_REQUIRED_OUTPUT(output_embeddings, 0);

PPLNN_LLM_CUDA_DEBUG_TRACE("Input [pixel_values]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(pixel_values);
Expand All @@ -47,21 +47,19 @@ ppl::common::RetCode VisionEmbeddingKernel::DoExecute(KernelExecContext* ctx) {
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(patch_weight);
PPLNN_LLM_CUDA_DEBUG_TRACE("Input [position_weight]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(position_weight);
void* patch_bias_data = nullptr;
if (patch_bias) {
PPLNN_LLM_CUDA_DEBUG_TRACE("Input [patch_bias]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(patch_bias);
patch_bias_data = patch_bias->GetBufferPtr();
}

PPLNN_LLM_CUDA_DEBUG_TRACE("hidden_dim: %d\n", param_->hidden_dim);
PPLNN_LLM_CUDA_DEBUG_TRACE("patch_size: %d\n", param_->patch_size);

PPLNN_LLM_CUDA_RESHAPE_OUTPUTS();

PPLNN_LLM_CUDA_REALLOC_TENSOR_BUFFER(vision_embeddings);
PPLNN_LLM_CUDA_DEBUG_TRACE("Output [vision_embeddings]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(vision_embeddings);
PPLNN_LLM_CUDA_REALLOC_TENSOR_BUFFER(output_embeddings);
PPLNN_LLM_CUDA_DEBUG_TRACE("Output [output_embeddings]:\n");
PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(output_embeddings);

ppl::kernel::llm::cuda::pmx::vision_embedding_config config;

Expand All @@ -85,10 +83,15 @@ ppl::common::RetCode VisionEmbeddingKernel::DoExecute(KernelExecContext* ctx) {
}

// TODO: check retcode here
ppl::kernel::llm::cuda::pmx::vision_embedding_preprocessing(config);
auto pplcommon_status = ppl::kernel::llm::cuda::pmx::vision_embedding_preprocessing(config);
if (pplcommon_status != ppl::common::RC_SUCCESS) {
LOG(ERROR) << "ppl::kernel::llm::cuda::pmx::vision_embedding_preprocessing() in kernel "
<< GetName() << "] failed: " << ppl::common::GetRetCodeStr(pplcommon_status);
return pplcommon_status;
}

BufferDesc buffers_desc;
auto pplcommon_status = GetCudaDevice()->AllocTmpBuffer(config.total_buffer_size, &buffers_desc);
pplcommon_status = GetCudaDevice()->AllocTmpBuffer(config.total_buffer_size, &buffers_desc);
if (pplcommon_status != ppl::common::RC_SUCCESS) {
LOG(ERROR) << "alloc buffers size[" << config.total_buffer_size << "] for kernel["
<< GetName() << "] failed: " << ppl::common::GetRetCodeStr(pplcommon_status);
Expand All @@ -100,19 +103,29 @@ ppl::common::RetCode VisionEmbeddingKernel::DoExecute(KernelExecContext* ctx) {
config.buffer_addr = buffers_desc.addr;

// TODO: check retcode here
ppl::kernel::llm::cuda::pmx::vision_embedding(
pplcommon_status = ppl::kernel::llm::cuda::pmx::vision_embedding(
GetStream(),
config,
pixel_values->GetBufferPtr(),
patch_weight->GetBufferPtr(), // [hidden_dim, image_channel, patch_size, patch_size]
patch_bias->GetBufferPtr(), // [hidden_dim]
class_weight->GetBufferPtr(), // [hidden_dim]
position_weight->GetBufferPtr(), // [num_positions * hidden_dim]
vision_embeddings->GetBufferPtr()
patch_weight->GetBufferPtr(), // [hidden_dim, image_channel, patch_size, patch_size]
position_weight->GetBufferPtr(), // [num_positions * hidden_dim]
patch_bias->GetBufferPtr(), // [hidden_dim]
output_embeddings->GetBufferPtr()
);
if (pplcommon_status != ppl::common::RC_SUCCESS) {
LOG(ERROR) << "ppl::kernel::llm::cuda::pmx::vision_embedding() in kernel "
<< GetName() << "] failed: " << ppl::common::GetRetCodeStr(pplcommon_status);
return pplcommon_status;
}

// TODO: check retcode here
ppl::kernel::llm::cuda::pmx::vision_embedding_postprocessing(config);
pplcommon_status = ppl::kernel::llm::cuda::pmx::vision_embedding_postprocessing(config);
if (pplcommon_status != ppl::common::RC_SUCCESS) {
LOG(ERROR) << "ppl::kernel::llm::cuda::pmx::vision_embedding_postprocessing() in kernel "
<< GetName() << "] failed: " << ppl::common::GetRetCodeStr(pplcommon_status);
return pplcommon_status;
}

return ppl::common::RC_SUCCESS;
#else
Expand Down

0 comments on commit 7e226fb

Please sign in to comment.