Skip to content

Commit

Permalink
[OpenCL/GPU] Optimized Blas and Attention kernels with the latest GPU…
Browse files Browse the repository at this point in the history
… Pipeline changes

Upated the kernels as per the latest buffer generalized changes.
Added unittest for Addition FP16 in unittest_blas_kernels_cl.cpp

Signed-off-by: Yash Singh <[email protected]>
  • Loading branch information
yashSingh0723 committed Jan 7, 2025
1 parent 252b2df commit 493b26c
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 110 deletions.
20 changes: 12 additions & 8 deletions nntrainer/tensor/cl_operations/attention_kernel_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
#define __ATTENTION_KERNEL_STRINGS_H__

#include <string>

// unsigned int offsetFeqsSin,
// unsigned int offsetSin
namespace nntrainer {
static const std::string rotary_emb_cl_kernel_ = R"(
Expand All @@ -34,10 +35,11 @@ __kernel void rotary_emb_cl(__global float *input,
unsigned int dim,
unsigned int half_,
unsigned int max_timestep,
unsigned int from) {
unsigned int from,
unsigned int offsetFreqsSin,
unsigned int offsetSin) {
__global float *cos_ptr = cos_;
__global float *sin_ptr = sin_;
float value = 0.0f;
float transformed_value = 0.0f;
Expand All @@ -50,7 +52,7 @@ __kernel void rotary_emb_cl(__global float *input,
unsigned idx = (from + h)*dim;
for(unsigned int i = idx; i < idx + dim; i++){
cos_ptr[i - idx] = freqs_cos[i];
sin_ptr[i - idx] = freqs_sin[i];
sin_ptr[i - idx + offsetSin] = freqs_sin[i + offsetFreqsSin];
}
}
Expand All @@ -63,7 +65,7 @@ __kernel void rotary_emb_cl(__global float *input,
} else {
transformed_value = input[b * channel * height * width + c * height * width + h * width + span - half_];
}
value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
value = value * cos_ptr[k] + transformed_value * sin_ptr[k + offsetSin];
output[b * channel * height * width + c * height * width + h * width + span] = value;
}
}
Expand All @@ -90,7 +92,9 @@ __kernel void rotary_emb_cl_fp16(__global half *input,
unsigned int dim,
unsigned int half_,
unsigned int max_timestep,
unsigned int from) {
unsigned int from,
unsigned int offsetFreqsSin,
unsigned int offsetSin) {
__global float *cos_ptr = cos_;
__global float *sin_ptr = sin_;
Expand All @@ -106,7 +110,7 @@ __kernel void rotary_emb_cl_fp16(__global half *input,
unsigned idx = (from + h)*dim;
for(int i = idx; i < idx + dim; i++ ){
cos_ptr[i - idx] = freqs_cos[i];
sin_ptr[i - idx] = freqs_sin[i];
sin_ptr[i - idx + offsetSin] = freqs_sin[i + offsetFreqsSin];
}
}
Expand All @@ -119,7 +123,7 @@ __kernel void rotary_emb_cl_fp16(__global half *input,
} else {
transformed_value = (float)input[b * channel * height * width + c * height * width + h * width + span - half_];
}
value = value * cos_ptr[k] + transformed_value * sin_ptr[k];
value = value * cos_ptr[k] + transformed_value * sin_ptr[k + offsetSin];
output[b * channel * height * width + c * height * width + h * width + span] = (half)value;
}
}
Expand Down
83 changes: 43 additions & 40 deletions nntrainer/tensor/cl_operations/attention_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,6 @@ void rotary_emb_cl(float *in, float *out,
sizeof(float) * freqs_cos_dim * dim; // max_timestep * dim
size_t dim6_size = sizeof(float) * freqs_sin_dim * dim;

opencl::Buffer inputA(cl_context_ref.context_inst_, dim1_size, true,
nullptr);

opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim2_size, true,
nullptr);

opencl::Buffer cosBuf(cl_context_ref.context_inst_, dim3_size, true,
nullptr);

opencl::Buffer sinBuf(cl_context_ref.context_inst_, dim4_size, true,
nullptr);

opencl::Buffer freqs_cosBuf(cl_context_ref.context_inst_, dim5_size, true,
nullptr);

opencl::Buffer freqs_sinBuf(cl_context_ref.context_inst_, dim6_size, true,
nullptr);

std::vector<float> freqs_cos_flat;
std::vector<float> freqs_sin_flat;
for (const auto &row : freqs_cos) {
Expand All @@ -73,81 +55,86 @@ void rotary_emb_cl(float *in, float *out,
freqs_sin_flat.insert(freqs_sin_flat.end(), row.begin(), row.end());
}

result = inputA.WriteData(cl_context_ref.command_queue_inst_, in);
result = clbuffInstance.getInBufferA()->WriteDataRegion(
cl_context_ref.command_queue_inst_, dim1_size, in);
if (!result) {
printf("Failed to write input data\n");
break;
}

result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, out);
result = clbuffInstance.getOutBufferA()->WriteDataRegion(
cl_context_ref.command_queue_inst_, dim2_size, out);
if (!result) {
printf("Failed to write output data\n");
break;
}

result = freqs_cosBuf.WriteData(cl_context_ref.command_queue_inst_,
freqs_cos_flat.data());
result = clbuffInstance.getInBufferB()->WriteDataRegion(
cl_context_ref.command_queue_inst_, dim5_size, freqs_cos_flat.data());
if (!result) {
printf("Failed to write freqs cos data\n");
break;
}

result = freqs_sinBuf.WriteData(cl_context_ref.command_queue_inst_,
freqs_sin_flat.data());
result = clbuffInstance.getInBufferB()->WriteDataRegion(
cl_context_ref.command_queue_inst_, dim6_size, freqs_sin_flat.data(), 0,
dim5_size);
if (!result) {
printf("Failed to write freqs sin data\n");
break;
}

result = cosBuf.WriteData(cl_context_ref.command_queue_inst_, cos_.data());
result = clbuffInstance.getInBufferC()->WriteDataRegion(
cl_context_ref.command_queue_inst_, dim3_size, cos_.data());
if (!result) {
printf("Failed to write cos data\n");
break;
}

result = sinBuf.WriteData(cl_context_ref.command_queue_inst_, sin_.data());
result = clbuffInstance.getInBufferC()->WriteDataRegion(
cl_context_ref.command_queue_inst_, dim4_size, sin_.data(), 0, dim3_size);
if (!result) {
printf("Failed to write sin data\n");
break;
}

result =
kernel_rotaryEmb_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem));
result = kernel_rotaryEmb_ptr->SetKernelArguments(
0, clbuffInstance.getInBufferA(), sizeof(cl_mem));
if (!result) {
printf("Failed to set inputA argument\n");
break;
}

result =
kernel_rotaryEmb_ptr->SetKernelArguments(1, &inOutRes, sizeof(cl_mem));
result = kernel_rotaryEmb_ptr->SetKernelArguments(
1, clbuffInstance.getOutBufferA(), sizeof(cl_mem));
if (!result) {
printf("Failed to set inOutRes argument\n");
break;
}

result = kernel_rotaryEmb_ptr->SetKernelArguments(2, &freqs_cosBuf,
sizeof(cl_mem));
result = kernel_rotaryEmb_ptr->SetKernelArguments(
2, clbuffInstance.getInBufferB(), sizeof(cl_mem));
if (!result) {
printf("Failed to set freqs_cosBuf argument\n");
break;
}

result = kernel_rotaryEmb_ptr->SetKernelArguments(3, &freqs_sinBuf,
sizeof(cl_mem));
result = kernel_rotaryEmb_ptr->SetKernelArguments(
3, clbuffInstance.getInBufferB(), sizeof(cl_mem));
if (!result) {
printf("Failed to set freqs_sinBuf argument\n");
break;
}

result =
kernel_rotaryEmb_ptr->SetKernelArguments(4, &cosBuf, sizeof(cl_mem));
result = kernel_rotaryEmb_ptr->SetKernelArguments(
4, clbuffInstance.getInBufferC(), sizeof(cl_mem));
if (!result) {
printf("Failed to set cosBuf argument\n");
break;
}

result =
kernel_rotaryEmb_ptr->SetKernelArguments(5, &sinBuf, sizeof(cl_mem));
result = kernel_rotaryEmb_ptr->SetKernelArguments(
5, clbuffInstance.getInBufferC(), sizeof(cl_mem));
if (!result) {
printf("Failed to set sinBuf argument\n");
break;
Expand Down Expand Up @@ -202,6 +189,22 @@ void rotary_emb_cl(float *in, float *out,
break;
}

unsigned int offsetFreqsSin = freqs_cos_dim * dim;
result = kernel_rotaryEmb_ptr->SetKernelArguments(14, &offsetFreqsSin,
sizeof(int));
if (!result) {
printf("Failed to set offsetFreqsSin argument\n");
break;
}

unsigned int offsetSin = cos_dim;
result =
kernel_rotaryEmb_ptr->SetKernelArguments(15, &offsetSin, sizeof(int));
if (!result) {
printf("Failed to set offsetSin argument\n");
break;
}

const int work_groups_count[3] = {(int)batch, (int)channel, 1};
const int work_group_size[3] = {32, 32, 1}; // test-value
result = cl_context_ref.command_queue_inst_.DispatchCommand(
Expand All @@ -211,12 +214,12 @@ void rotary_emb_cl(float *in, float *out,
break;
}

result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, out);
result = clbuffInstance.getOutBufferA()->ReadDataRegion(
cl_context_ref.command_queue_inst_, dim2_size, out);
if (!result) {
printf("Failed to read data\n");
break;
}

} while (false);
}
} // namespace nntrainer
3 changes: 3 additions & 0 deletions nntrainer/tensor/cl_operations/attention_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@
#ifndef __ATTENTION_KERNELS_H__
#define __ATTENTION_KERNELS_H__

#include <cl_buffer_manager.h>
#include <cl_context.h>
#include <opencl_buffer.h>
#include <opencl_kernel.h>

#include <string>

namespace nntrainer {

// get global cl_context to use in kernels
static ClContext cl_context_ref;
static ClBufferManager &clbuffInstance = ClBufferManager::getInstance();

/**
* @brief Rotary Embedding process
Expand Down
Loading

0 comments on commit 493b26c

Please sign in to comment.