Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Wait for #2876 ] [ neon ] Implement neon kernel for copy_s16_f32 #2881

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions nntrainer/tensor/blas_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,20 @@ void scopy_int8_to_float32(const unsigned int N, const int8_t *X,
}
}

static inline void copy_s16_fp32_fallback(const unsigned int N,
const int16_t *X, float *Y) {
for (unsigned int idx = 0; idx < N; ++idx) {
Y[idx] = (float)X[idx];
}
}

void copy_s16_fp32(const unsigned int N, const int16_t *X, float *Y) {
#ifdef USE_NEON
nntrainer::neon::copy_s16_fp32(N, X, Y);
#endif
copy_s16_fp32_fallback(N, X, Y);
}

float snrm2(const int N, const float *X, const int incX) {
#ifdef USE_BLAS
#ifdef BLAS_NUM_THREADS
Expand Down
8 changes: 8 additions & 0 deletions nntrainer/tensor/blas_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,14 @@ void scopy_int8_to_float32(const unsigned int N, const uint8_t *X,
void scopy_int8_to_float32(const unsigned int N, const int8_t *X,
const int incX, float *Y, const int intY);

/**
* @brief copy function : Y = X
* @param[in] N number of elements in X
* @param[in] X int16_t * for Vector X
* @param[in] Y float * for Vector Y
*/
void copy_s16_fp32(const unsigned int N, const int16_t *X, float *Y);

/**
* @brief sdot computation : sum of all X * Y
* @param[in] N number of elements in Y
Expand Down
23 changes: 23 additions & 0 deletions nntrainer/tensor/blas_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,29 @@ void copy_int8_to_fp32(const unsigned int N, const int8_t *X, float *Y) {
}
}

void copy_s16_fp32(const unsigned int N, const int16_t *X, float *Y) {
unsigned int idx = 0;
for (; (N - idx) >= 8; idx += 8) {
int16x8_t batch = vld1q_s16(&X[idx]);
int16x4_t low = vget_low_s16(batch);
int16x4_t high = vget_high_s16(batch);

// widen to s32
int32x4_t low_s32 = vmovl_s16(low);
int32x4_t high_s32 = vmovl_s16(high);

// convert to f32
float32x4_t low_f32 = vcvtq_f32_s32(low_s32);
float32x4_t high_f32 = vcvtq_f32_s32(high_s32);

vst1q_f32(&Y[idx], low_f32);
vst1q_f32(&Y[idx + 4], high_f32);
}
for (; (N - idx) >= 1; ++idx) {
Y[idx] = X[idx];
}
}

void copy_fp16_to_fp32(const unsigned int N, const __fp16 *X, float *Y) {
unsigned int idx = 0;

Expand Down
9 changes: 9 additions & 0 deletions nntrainer/tensor/blas_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ void copy_int8_or_int4(const unsigned int N, const uint8_t *X, uint8_t *Y);
* @param[in] Y int8_t * for Vector Y
*/
void copy_int8(const unsigned int N, const int8_t *X, int8_t *Y);

/**
* @brief copy function with neon: Y = X
* @param[in] N number of elements in X
* @param[in] X int16_t * for Vector X
* @param[in] Y float * for Vector Y
*/
void copy_s16_fp32(const unsigned int N, const int16_t *X, float *Y);

/**
* @brief sine with neon: Y = sin(alpha * X)
* @param[in] N number of elements in X
Expand Down
3 changes: 3 additions & 0 deletions nntrainer/tensor/float_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,9 @@ void FloatTensor::copyData(const Tensor &from) {
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
break;
case ml::train::TensorDim::DataType::QINT16:
copy_s16_fp32(from.size(), from.getData<int16_t>(), (float *)getData());
break;
case ml::train::TensorDim::DataType::QINT8:
scopy_int8_to_float32(from.size(), from.getData<int8_t>(), 1,
(float *)getData(), 1);
Expand Down
Loading