Skip to content

Commit

Permalink
[ neon ] Implement neon kernel for copy_s16_f32
Browse files Browse the repository at this point in the history
- load for s16, widen to s32, convert to f32, and store.
- Add fallback function with the same function param for easier later refactor.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <[email protected]>
  • Loading branch information
skykongkong8 committed Jan 16, 2025
1 parent 15ae292 commit 4439e05
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
11 changes: 8 additions & 3 deletions nntrainer/tensor/blas_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,13 +979,18 @@ void scopy_int8_to_float32(const unsigned int N, const int8_t *X,
}
}

static inline void copy_s16_fp32_fallback(const unsigned int N,
const int16_t *X, float *Y) {
for (unsigned int idx = 0; idx < N; ++idx) {
Y[idx] = (float)X[idx];
}
}

void copy_s16_fp32(const unsigned int N, const int16_t *X, float *Y) {
#ifdef USE_NEON
nntrainer::neon::copy_s16_fp32(N, X, Y);
#endif
for (unsigned int idx = 0; idx < N; ++idx) {
Y[idx] = (float)X[idx];
}
copy_s16_fp32_fallback(N, X, Y);
}

float snrm2(const int N, const float *X, const int incX) {
Expand Down
17 changes: 16 additions & 1 deletion nntrainer/tensor/blas_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1598,8 +1598,23 @@ void copy_int8_to_fp32(const unsigned int N, const int8_t *X, float *Y) {
}

void copy_s16_fp32(const unsigned int N, const int16_t *X, float *Y) {
/// @todo implement int16_t to fp32
unsigned int idx = 0;
for (; (N - idx) >= 8; idx += 8) {
int16x8_t batch = vld1q_s16(&X[idx]);
int16x4_t low = vget_low_s16(batch);
int16x4_t high = vget_high_s16(batch);

// widen to s32
int32x4_t low_s32 = vmovl_s16(low);
int32x4_t high_s32 = vmovl_s16(high);

// convert to f32
float32x4_t low_f32 = vcvtq_f32_s32(low_s32);
float32x4_t high_f32 = vcvtq_f32_s32(high_s32);

vst1q_f32(&Y[idx], low_f32);
vst1q_f32(&Y[idx + 4], high_f32);
}
for (; (N - idx) >= 1; ++idx) {
Y[idx] = X[idx];
}
Expand Down

0 comments on commit 4439e05

Please sign in to comment.