diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index f65c7389..af9c7530 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -83,6 +83,7 @@ func_vlm_add_executable(demo_imagebind_1mod) func_vlm_add_executable(demo_phi3v) # func_vlm_add_executable(demo) +# QNN demo if(QNN) func_llm_add_executable(demo_qwen_npu) diff --git a/src/Tensor.hpp b/src/Tensor.hpp index 887031cd..8fea30eb 100644 --- a/src/Tensor.hpp +++ b/src/Tensor.hpp @@ -1760,6 +1760,9 @@ class Tensor { TensorType &xnnTensorType(); void forceResetHostPointer(void *ptr); + +public: + float i8_scale = 1.f; }; } // namespace mllm #endif // MLLM_TENSOR_H \ No newline at end of file diff --git a/src/backends/cpu/CMakeLists.txt b/src/backends/cpu/CMakeLists.txt index 4473f621..76eea4a1 100644 --- a/src/backends/cpu/CMakeLists.txt +++ b/src/backends/cpu/CMakeLists.txt @@ -24,6 +24,7 @@ endif() endif() if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") message(STATUS "ARM detected") + add_compile_options(-march=armv8.2-a+dotprod+fp16+fp16fml) elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$") message(STATUS "x86_64 detected") add_compile_options(-mavx2) diff --git a/src/backends/cpu/CPUConvolution2D.cpp b/src/backends/cpu/CPUConvolution2D.cpp index ed4227c4..9ecf8d27 100644 --- a/src/backends/cpu/CPUConvolution2D.cpp +++ b/src/backends/cpu/CPUConvolution2D.cpp @@ -2,10 +2,14 @@ #include "CPUConvolution2D.hpp" #include "compute/Convolution.hpp" +#include "compute/Matmul.hpp" +#include "compute/Im2Col.hpp" + namespace mllm { -CPUConvolution2D::CPUConvolution2D(Backend *bn, string opName, int in_channel, int out_channel, vector kernal_size, vector stride, PaddingType padding_type, bool bias, int threadCount) : thread_count(threadCount), -Op(bn, opName) { +CPUConvolution2D::CPUConvolution2D(Backend *bn, string opName, int in_channel, int out_channel, vector kernal_size, vector stride, PaddingType padding_type, bool bias, int threadCount) : + thread_count(threadCount), + Op(bn, opName) { kernel_size_[0] = kernal_size[0]; kernel_size_[1] = kernal_size[1]; stride_[0] = stride[0]; @@ -16,48 +20,81 @@ Op(bn, opName) { support_bias_ = bias; weight_.setBackend(bn); bias_.setBackend(bn); + +#ifdef __ARM_NEON + im2col_layout_.setBackend(bn); + output_not_transposed_.setBackend(bn); +#endif //! __ARM_NEON } ErrorCode CPUConvolution2D::reshape(vector> inputs, vector> outputs) { - //batch = batch - //sequence = out_channel - //head = height - //dimension = width + // batch = batch + // sequence = out_channel + // head = height + // dimension = width assert(in_channel_ == inputs[0]->sequence()); + + // #ifdef __ARM_NEON + // if (kernel_size_[0] == 16 && kernel_size_[1] == 16 && padding_h_ == 0 && padding_w_ == 0 && stride_[0] == 16 && stride_[1] == 16) { + // im2col_layout_.setDtype(inputs[0]->dtype()); + // im2col_layout_.reshape(inputs[0]->batch(), 1, (inputs[0]->head() / 16) * (inputs[0]->dimension() / 16), 16 * 16 * in_channel_); + // im2col_layout_.alloc(); + // output_not_transposed_.setDtype(inputs[0]->dtype()); + // output_not_transposed_.reshape(inputs[0]->batch(), 1, (inputs[0]->head() / 16) * (inputs[0]->dimension() / 16), out_channel_); + // output_not_transposed_.alloc(); + // outputs[0]->reshape(inputs[0]->batch(), (inputs[0]->head() / 16), out_channel_, (inputs[0]->dimension() / 16)); + // return Op::reshape(inputs, outputs); + // } + + // if (kernel_size_[0] == kernel_size_[1] && kernel_size_[0] == stride_[0] && kernel_size_[1] == stride_[1] && padding_h_ == 0 && padding_w_ == 0) { + // im2col_layout_.setDtype(inputs[0]->dtype()); + // im2col_layout_.reshape(inputs[0]->batch(), 1, (inputs[0]->head() / kernel_size_[0]) * (inputs[0]->dimension() / kernel_size_[0]), kernel_size_[0] * kernel_size_[0] * in_channel_); + // im2col_layout_.alloc(); + // output_not_transposed_.setDtype(inputs[0]->dtype()); + // output_not_transposed_.reshape(inputs[0]->batch(), 1, (inputs[0]->head() / kernel_size_[0]) * (inputs[0]->dimension() / kernel_size_[0]), out_channel_); + // output_not_transposed_.alloc(); + // outputs[0]->reshape(inputs[0]->batch(), (inputs[0]->head() / kernel_size_[0]), out_channel_, (inputs[0]->dimension() / kernel_size_[0])); + // return Op::reshape(inputs, outputs); + // } + // #endif + switch (padding_type_) { - case SAME:{ + case SAME: { padding_h_ = (kernel_size_[0] - 1) / 2; padding_w_ = (kernel_size_[1] - 1) / 2; const int out_height = (inputs[0]->head() + 2 * padding_h_ - kernel_size_[0]) / stride_[0] + 1; const int out_width = (inputs[0]->dimension() + 2 * padding_w_ - kernel_size_[1]) / stride_[1] + 1; - outputs[0]->reshape(inputs[0]->batch(),out_height, out_channel_, out_width); + outputs[0]->reshape(inputs[0]->batch(), out_height, out_channel_, out_width); break; - } - case VALID:{ + } + case VALID: { padding_h_ = 0; padding_w_ = 0; const int out_height = (inputs[0]->head() - kernel_size_[0]) / stride_[0] + 1; - const int out_width = (inputs[0]->dimension()- kernel_size_[1]) / stride_[1] + 1; - outputs[0]->reshape(inputs[0]->batch(),out_height, out_channel_, out_width); + const int out_width = (inputs[0]->dimension() - kernel_size_[1]) / stride_[1] + 1; + outputs[0]->reshape(inputs[0]->batch(), out_height, out_channel_, out_width); break; - } + } } return Op::reshape(inputs, outputs); } ErrorCode CPUConvolution2D::load(AbstructLoader &loader) { - weight_.setName(name() + ".weight"); weight_.reshape(out_channel_, kernel_size_[0], in_channel_, kernel_size_[1]); if (loader.getDataType(weight_.name()) != MLLM_TYPE_COUNT) { weight_.setDtype(loader.getDataType(weight_.name())); weight_.alloc(); loader.load(&weight_); + // #ifndef __ARM_NEON kernal_ = reshape_conv2d_kernal_fp32(&weight_); + // #endif } else { weight_.setDtype(MLLM_TYPE_F32); weight_.alloc(); + // #ifndef __ARM_NEON kernal_ = reshape_conv2d_kernal_fp32(&weight_); + // #endif } if (support_bias_) { bias_.setName(name() + ".bias"); @@ -75,14 +112,39 @@ ErrorCode CPUConvolution2D::load(AbstructLoader &loader) { } ErrorCode CPUConvolution2D::execute(vector> inputs, vector> outputs) { + // #ifdef __ARM_NEON + // if (kernel_size_[0] == 16 && kernel_size_[1] == 16 && padding_h_ == 0 && padding_w_ == 0 && stride_[0] == 16 && stride_[1] == 16) { + // auto start = std::chrono::high_resolution_clock::now(); + // im2col_fp32_src_k16x16_s16_p0_to(inputs[0]->rawHostPtr(), im2col_layout_.rawHostPtr(), inputs[0]->head(), inputs[0]->dimension(), in_channel_); + // weight_.reshape(1, 1, out_channel_, 16 * 16 * in_channel_); + // mat_mul(&im2col_layout_, &weight_, &output_not_transposed_, true, &bias_, false, true, thread_count); + // transpose_fp32(output_not_transposed_.rawHostPtr(), outputs[0]->rawHostPtr(), (inputs[0]->head() / 16) * ((inputs[0]->dimension() / 16)), out_channel_); + // auto end = std::chrono::high_resolution_clock::now(); + // auto duration = std::chrono::duration_cast(end - start); + // std::cout << duration.count() << std::endl; + // return Op::execute(inputs, outputs); + // } + + // if (kernel_size_[0] == kernel_size_[1] && kernel_size_[0] == stride_[0] && kernel_size_[1] == stride_[1] && padding_h_ == 0 && padding_w_ == 0) { + // auto start = std::chrono::high_resolution_clock::now(); + // im2col_fp32_src_knxn_sn_p0_to(inputs[0]->rawHostPtr(), im2col_layout_.rawHostPtr(), inputs[0]->head(), inputs[0]->dimension(), in_channel_, kernel_size_[0]); + // weight_.reshape(1, 1, out_channel_, kernel_size_[0] * kernel_size_[0] * in_channel_); + // mat_mul(&im2col_layout_, &weight_, &output_not_transposed_, true, &bias_, false, true, thread_count); + // transpose_fp32(output_not_transposed_.rawHostPtr(), outputs[0]->rawHostPtr(), (inputs[0]->head() / kernel_size_[0]) * ((inputs[0]->dimension() / kernel_size_[0])), out_channel_); + // auto end = std::chrono::high_resolution_clock::now(); + // auto duration = std::chrono::duration_cast(end - start); + // std::cout << duration.count() << std::endl; + // return Op::execute(inputs, outputs); + // } + // #endif switch (padding_type_) { - case SAME:{ + case SAME: { conv2d_fp32_SAME(inputs[0].get(), outputs[0].get(), kernal_, kernel_size_[0], kernel_size_[1], support_bias_, &bias_, stride_[0], stride_[1], padding_h_, padding_w_, thread_count); break; } case VALID: { - conv2d_fp32_VALID(inputs[0].get(), outputs[0].get(), kernal_, kernel_size_[0], kernel_size_[1], support_bias_, &bias_,stride_[0], stride_[1], thread_count); + conv2d_fp32_VALID(inputs[0].get(), outputs[0].get(), kernal_, kernel_size_[0], kernel_size_[1], support_bias_, &bias_, stride_[0], stride_[1], thread_count); break; } } @@ -90,14 +152,11 @@ ErrorCode CPUConvolution2D::execute(vector> inputs, vector> inputs, vector> outputs) { - weight_.free(); return Op::free(inputs, outputs); } ErrorCode CPUConvolution2D::setUp(vector> inputs, vector> outputs) { - return Op::setUp(inputs, outputs); } } // namespace mllm - diff --git a/src/backends/cpu/CPUConvolution2D.hpp b/src/backends/cpu/CPUConvolution2D.hpp index f5bdc9b0..d712c252 100644 --- a/src/backends/cpu/CPUConvolution2D.hpp +++ b/src/backends/cpu/CPUConvolution2D.hpp @@ -9,13 +9,13 @@ namespace mllm { class CPUConvolution2D final : public Op { public: - CPUConvolution2D(Backend *bn, string opName, int in_channel, int out_channel, vector kernal_size, vector stride, PaddingType padding_type, bool bias, int threadCount); - virtual ~CPUConvolution2D() = default; - virtual ErrorCode reshape(vector> inputs, vector> outputs) override; - virtual ErrorCode load(AbstructLoader &loader) override; - virtual ErrorCode execute(vector> inputs, vector> outputs) override; - virtual ErrorCode free(vector> inputs, vector> outputs) override; - virtual ErrorCode setUp(vector> inputs, vector> outputs) override; + CPUConvolution2D(Backend *bn, string opName, int in_channel, int out_channel, vector kernal_size, vector stride, PaddingType padding_type, bool bias, int threadCount); + ~CPUConvolution2D() override = default; + ErrorCode reshape(vector> inputs, vector> outputs) override; + ErrorCode load(AbstructLoader &loader) override; + ErrorCode execute(vector> inputs, vector> outputs) override; + ErrorCode free(vector> inputs, vector> outputs) override; + ErrorCode setUp(vector> inputs, vector> outputs) override; Tensor &weight() { return weight_; @@ -34,21 +34,25 @@ class CPUConvolution2D final : public Op { Tensor weight_; Tensor bias_; - float ** kernal_; - bool support_bias_; +#ifdef __ARM_NEON + Tensor im2col_layout_; + Tensor output_not_transposed_; +#endif //! __ARM_NEON + float **kernal_; + bool support_bias_; }; class CPUConvolution2DCreator : public CPUBackend::Creator { public: - virtual Op *create(OpParam op_param, Backend *bn, string name, int threadCount) const { - vector kernal_size = {(int)op_param["kernal_h"],(int)op_param["kernal_w"]}; - vector stride = {(int)op_param["stride_h"],(int)op_param["stride_w"]}; + Op *create(OpParam op_param, Backend *bn, string name, int threadCount) const override { + vector kernal_size = {(int)op_param["kernal_h"], (int)op_param["kernal_w"]}; + vector stride = {(int)op_param["stride_h"], (int)op_param["stride_w"]}; int in_channel = op_param["in_channel"]; int out_channel = op_param["out_channel"]; PaddingType padding_type = (PaddingType)op_param["padding"]; bool bias = (bool)op_param["bias"]; - return new CPUConvolution2D(bn, name, in_channel, out_channel, kernal_size, stride, padding_type, bias, threadCount); + return new CPUConvolution2D(bn, name, in_channel, out_channel, kernal_size, stride, padding_type, bias, threadCount); } }; diff --git a/src/backends/cpu/compute/GEMM_AArch64.cpp b/src/backends/cpu/compute/GEMM_AArch64.cpp index 6a69e2d3..a66671a7 100644 --- a/src/backends/cpu/compute/GEMM_AArch64.cpp +++ b/src/backends/cpu/compute/GEMM_AArch64.cpp @@ -1,6 +1,7 @@ #include "GEMM_AArch64.hpp" #include "Types.hpp" #include +#include #include #include #include // for assert @@ -1212,6 +1213,10 @@ void mllm_gemm_q4_0_4x4_q8_0(int n, float *__restrict s, size_t bs, const void * const void *__restrict bias) { if (bias != nullptr) { _mllm_gemm_q4_0_4x4_q8_0_bias(n, s, bs, vx, vy, nr, nc, bias); +#if defined(__ARM_NEON) + std::cout << "_mllm_gemm_q4_0_4x4_q8_0_bias not implemented"; + abort(); +#endif return; } @@ -2297,6 +2302,10 @@ void mllm_gemm_q4_0_4x8_q8_0(int n, float *__restrict s, size_t bs, const void * const void *__restrict vy, int nr, int nc, const void *__restrict bias) { if (bias != nullptr) { +#if defined(__ARM_NEON) + std::cout << "_mllm_gemm_q4_0_4x4_q8_0_bias not implemented"; + abort(); +#endif _mllm_gemm_q4_0_4x8_q8_0_bias(n, s, bs, vx, vy, nr, nc, bias); return; } @@ -3258,6 +3267,10 @@ void mllm_gemm_q4_0_8x8_q8_0(int n, float *__restrict s, size_t bs, const void * const void *__restrict vy, int nr, int nc, const void *__restrict bias) { if (bias != nullptr) { +#if defined(__ARM_NEON) + std::cout << "_mllm_gemm_q4_0_4x4_q8_0_bias not implemented"; + abort(); +#endif _mllm_gemm_q4_0_8x8_q8_0_bias(n, s, bs, vx, vy, nr, nc, bias); return; } diff --git a/src/backends/cpu/compute/Im2Col.cpp b/src/backends/cpu/compute/Im2Col.cpp new file mode 100644 index 00000000..d66fcc81 --- /dev/null +++ b/src/backends/cpu/compute/Im2Col.cpp @@ -0,0 +1,228 @@ +#include "Im2Col.hpp" + +#include +#include + +#ifdef __ARM_NEON +#include +#endif + +namespace mllm { + +void im2col_fp32_src_knxn_sn_p0_to(void *src, void *dst, int32_t H, int32_t W, int32_t C, + int32_t FILTER_N) { + auto src_ptr = (float *)src; + auto dst_ptr = (float *)dst; + + int32_t h_blocks = H / FILTER_N; + int32_t w_blocks = W / FILTER_N; + for (int32_t c = 0; c < C; ++c) { + int D_N = 0; + for (int32_t h = 0; h < h_blocks; ++h) { + auto src_line_ptr = src_ptr + c * H * W + h * FILTER_N * W; + for (int32_t w = 0; w < w_blocks; ++w) { + auto gt_ptr = dst_ptr + c * FILTER_N * FILTER_N + D_N * FILTER_N * FILTER_N * C; + for (int i = 0; i < FILTER_N; ++i) { +#pragma unroll + for (int j = 0; j < FILTER_N; ++j) { + *(gt_ptr + i * FILTER_N + j) = *(src_line_ptr + FILTER_N * w + i * W + j); + } + } + + D_N++; + } + } + } +} + +#ifdef __ARM_NEON +void transpose_fp32(void *src, void *dst, int M, int N) { + auto src_ptr = static_cast(src); + auto dst_ptr = static_cast(dst); + + int32_t m_blocks = M / 4; + int32_t n_blocks = N / 4; + int32_t m_left = M % 4; + int32_t n_left = N % 4; + + if (M > 128) { +#pragma omp parallel for num_threads(4) + for (int32_t m = 0; m < m_blocks; ++m) { + auto m_line_ptr = src_ptr + m * 4 * N; + + for (int32_t n = 0; n < n_blocks; ++n) { + auto dst_line_ptr = dst_ptr + n * 4 * M; + + auto line_0 = vld1q_f32(m_line_ptr + 4 * n); + auto line_1 = vld1q_f32(m_line_ptr + 4 * n + N); + auto line_2 = vld1q_f32(m_line_ptr + 4 * n + 2 * N); + auto line_3 = vld1q_f32(m_line_ptr + 4 * n + 3 * N); + + float32x4x2_t row01 = vtrnq_f32(line_0, line_1); + float32x4x2_t row23 = vtrnq_f32(line_2, line_3); + + vst1q_f32(dst_line_ptr + 4 * m, + vcombine_f32(vget_low_f32(row01.val[0]), vget_low_f32(row23.val[0]))); + vst1q_f32(dst_line_ptr + 4 * m + M, + vcombine_f32(vget_low_f32(row01.val[1]), vget_low_f32(row23.val[1]))); + vst1q_f32(dst_line_ptr + 4 * m + 2 * M, + vcombine_f32(vget_high_f32(row01.val[0]), vget_high_f32(row23.val[0]))); + vst1q_f32(dst_line_ptr + 4 * m + 3 * M, + vcombine_f32(vget_high_f32(row01.val[1]), vget_high_f32(row23.val[1]))); + } + + if (n_left) { + auto dst_line_ptr = dst_ptr + (n_blocks * 4 - (4 - n_left)) * M; + + auto line_0 = vld1q_f32(m_line_ptr + (n_blocks * 4 - (4 - n_left))); + auto line_1 = vld1q_f32(m_line_ptr + (n_blocks * 4 - (4 - n_left)) + N); + auto line_2 = vld1q_f32(m_line_ptr + (n_blocks * 4 - (4 - n_left)) + 2 * N); + auto line_3 = vld1q_f32(m_line_ptr + (n_blocks * 4 - (4 - n_left)) + 3 * N); + + float32x4x2_t row01 = vtrnq_f32(line_0, line_1); + float32x4x2_t row23 = vtrnq_f32(line_2, line_3); + + vst1q_f32(dst_line_ptr + 4 * m, + vcombine_f32(vget_low_f32(row01.val[0]), vget_low_f32(row23.val[0]))); + vst1q_f32(dst_line_ptr + 4 * m + M, + vcombine_f32(vget_low_f32(row01.val[1]), vget_low_f32(row23.val[1]))); + vst1q_f32(dst_line_ptr + 4 * m + 2 * M, + vcombine_f32(vget_high_f32(row01.val[0]), vget_high_f32(row23.val[0]))); + vst1q_f32(dst_line_ptr + 4 * m + 3 * M, + vcombine_f32(vget_high_f32(row01.val[1]), vget_high_f32(row23.val[1]))); + } + } + } else { + for (int32_t m = 0; m < m_blocks; ++m) { + auto m_line_ptr = src_ptr + m * 4 * N; + + for (int32_t n = 0; n < n_blocks; ++n) { + auto dst_line_ptr = dst_ptr + n * 4 * M; + + auto line_0 = vld1q_f32(m_line_ptr + 4 * n); + auto line_1 = vld1q_f32(m_line_ptr + 4 * n + N); + auto line_2 = vld1q_f32(m_line_ptr + 4 * n + 2 * N); + auto line_3 = vld1q_f32(m_line_ptr + 4 * n + 3 * N); + + float32x4x2_t row01 = vtrnq_f32(line_0, line_1); + float32x4x2_t row23 = vtrnq_f32(line_2, line_3); + + vst1q_f32(dst_line_ptr + 4 * m, + vcombine_f32(vget_low_f32(row01.val[0]), vget_low_f32(row23.val[0]))); + vst1q_f32(dst_line_ptr + 4 * m + M, + vcombine_f32(vget_low_f32(row01.val[1]), vget_low_f32(row23.val[1]))); + vst1q_f32(dst_line_ptr + 4 * m + 2 * M, + vcombine_f32(vget_high_f32(row01.val[0]), vget_high_f32(row23.val[0]))); + vst1q_f32(dst_line_ptr + 4 * m + 3 * M, + vcombine_f32(vget_high_f32(row01.val[1]), vget_high_f32(row23.val[1]))); + } + + if (n_left) { + auto dst_line_ptr = dst_ptr + (n_blocks * 4 - (4 - n_left)) * M; + + auto line_0 = vld1q_f32(m_line_ptr + (n_blocks * 4 - (4 - n_left))); + auto line_1 = vld1q_f32(m_line_ptr + (n_blocks * 4 - (4 - n_left)) + N); + auto line_2 = vld1q_f32(m_line_ptr + (n_blocks * 4 - (4 - n_left)) + 2 * N); + auto line_3 = vld1q_f32(m_line_ptr + (n_blocks * 4 - (4 - n_left)) + 3 * N); + + float32x4x2_t row01 = vtrnq_f32(line_0, line_1); + float32x4x2_t row23 = vtrnq_f32(line_2, line_3); + + vst1q_f32(dst_line_ptr + 4 * m, + vcombine_f32(vget_low_f32(row01.val[0]), vget_low_f32(row23.val[0]))); + vst1q_f32(dst_line_ptr + 4 * m + M, + vcombine_f32(vget_low_f32(row01.val[1]), vget_low_f32(row23.val[1]))); + vst1q_f32(dst_line_ptr + 4 * m + 2 * M, + vcombine_f32(vget_high_f32(row01.val[0]), vget_high_f32(row23.val[0]))); + vst1q_f32(dst_line_ptr + 4 * m + 3 * M, + vcombine_f32(vget_high_f32(row01.val[1]), vget_high_f32(row23.val[1]))); + } + } + } + + if (m_left) { + auto m_line_ptr = src_ptr + (m_blocks * 4 - (4 - m_left)) * N; + + for (int32_t n = 0; n < n_blocks; ++n) { + auto dst_line_ptr = dst_ptr + n * 4 * M; + + auto line_0 = vld1q_f32(m_line_ptr + 4 * n); + auto line_1 = vld1q_f32(m_line_ptr + 4 * n + N); + auto line_2 = vld1q_f32(m_line_ptr + 4 * n + 2 * N); + auto line_3 = vld1q_f32(m_line_ptr + 4 * n + 3 * N); + + float32x4x2_t row01 = vtrnq_f32(line_0, line_1); + float32x4x2_t row23 = vtrnq_f32(line_2, line_3); + + vst1q_f32(dst_line_ptr + (m_blocks * 4 - (4 - m_left)), + vcombine_f32(vget_low_f32(row01.val[0]), vget_low_f32(row23.val[0]))); + vst1q_f32(dst_line_ptr + (m_blocks * 4 - (4 - m_left)) + M, + vcombine_f32(vget_low_f32(row01.val[1]), vget_low_f32(row23.val[1]))); + vst1q_f32(dst_line_ptr + (m_blocks * 4 - (4 - m_left)) + 2 * M, + vcombine_f32(vget_high_f32(row01.val[0]), vget_high_f32(row23.val[0]))); + vst1q_f32(dst_line_ptr + (m_blocks * 4 - (4 - m_left)) + 3 * M, + vcombine_f32(vget_high_f32(row01.val[1]), vget_high_f32(row23.val[1]))); + } + + if (n_left) { + auto dst_line_ptr = dst_ptr + (n_blocks * 4 - (4 - n_left)) * M; + + auto line_0 = vld1q_f32(m_line_ptr + (n_blocks * 4 - (4 - n_left))); + auto line_1 = vld1q_f32(m_line_ptr + (n_blocks * 4 - (4 - n_left)) + N); + auto line_2 = vld1q_f32(m_line_ptr + (n_blocks * 4 - (4 - n_left)) + 2 * N); + auto line_3 = vld1q_f32(m_line_ptr + (n_blocks * 4 - (4 - n_left)) + 3 * N); + + float32x4x2_t row01 = vtrnq_f32(line_0, line_1); + float32x4x2_t row23 = vtrnq_f32(line_2, line_3); + + vst1q_f32(dst_line_ptr + (m_blocks * 4 - (4 - m_left)), + vcombine_f32(vget_low_f32(row01.val[0]), vget_low_f32(row23.val[0]))); + vst1q_f32(dst_line_ptr + (m_blocks * 4 - (4 - m_left)) + M, + vcombine_f32(vget_low_f32(row01.val[1]), vget_low_f32(row23.val[1]))); + vst1q_f32(dst_line_ptr + (m_blocks * 4 - (4 - m_left)) + 2 * M, + vcombine_f32(vget_high_f32(row01.val[0]), vget_high_f32(row23.val[0]))); + vst1q_f32(dst_line_ptr + (m_blocks * 4 - (4 - m_left)) + 3 * M, + vcombine_f32(vget_high_f32(row01.val[1]), vget_high_f32(row23.val[1]))); + } + } +} + +void im2col_fp32_src_k16x16_s16_p0_to(void *src, void *dst, int32_t H, int32_t W, int32_t C) { + int32_t h_blocks = H / 16; + int32_t w_blocks = W / 16; + int32_t threads = C ? C < 4 : 4; + +#pragma omp parallel for num_threads(threads) + for (int c = 0; c < C; ++c) { + auto src_ptr = (float *)src + c * H * W; + auto dst_ptr = (float *)dst + c * 16 * 16; + + int N = 0; + for (int32_t h = 0; h < h_blocks; h++) { + auto line_ptr = src_ptr + h * 16 * W; + + for (int32_t w = 0; w < w_blocks; w++) { + auto block16x16_ptr = line_ptr + w * 16; + auto dst_line_ptr = dst_ptr + N * 16 * 16 * C; + +// process 4 x 16 four times +#pragma unroll + for (int i = 0; i < 4; ++i) { + float32x4x4_t line_0 = vld4q_f32(block16x16_ptr + 4 * i * W); + float32x4x4_t line_1 = vld4q_f32(block16x16_ptr + 4 * i * W + W); + float32x4x4_t line_2 = vld4q_f32(block16x16_ptr + 4 * i * W + 2 * W); + float32x4x4_t line_3 = vld4q_f32(block16x16_ptr + 4 * i * W + 3 * W); + + vst4q_f32(dst_line_ptr + 64 * i, line_0); + vst4q_f32(dst_line_ptr + 64 * i + 16, line_1); + vst4q_f32(dst_line_ptr + 64 * i + 32, line_2); + vst4q_f32(dst_line_ptr + 64 * i + 48, line_3); + } + + N++; + } + } + } +} +#endif //! __ARM_NEON +} // namespace mllm \ No newline at end of file diff --git a/src/backends/cpu/compute/Im2Col.hpp b/src/backends/cpu/compute/Im2Col.hpp new file mode 100644 index 00000000..9d6ac217 --- /dev/null +++ b/src/backends/cpu/compute/Im2Col.hpp @@ -0,0 +1,59 @@ +/** + * @file Im2Col.hpp + * @author chenghua wang (chenghua.wang.edu@gmail.com) + * @version 0.1 + * @date 2024-11-12 + * + * @copyright Copyright (c) 2024 + * + */ +#pragma once + +#include + +namespace mllm { + +/** + * @brief f32 Src. Kernel NxN, Stride N, Padding 0. + * + * C * H * W -> ((H / 16) * (W / 16)) * (N * N * C) + * + * !!! Dst is NOT Transposed. + * + * @param src + * @param dst + * @param H + * @param W + * @param C + * @param FILTER_N + */ +void im2col_fp32_src_knxn_sn_p0_to(void *src, void *dst, int32_t H, int32_t W, int32_t C, int32_t FILTER_N); + +#ifdef __ARM_NEON +/** + * @brief f32 Src. Kernel 16x16, Stride 16, Padding 0. + * + * C * H * W -> ((H / 16) * (W / 16)) * (16 * 16 * C) + * + * !!! Dst is NOT Transposed. + * + * @param src + * @param dst + * @param H + * @param W + * @param C + */ +void im2col_fp32_src_k16x16_s16_p0_to(void *src, void *dst, int32_t H, int32_t W, int32_t C); + +/** + * @brief transpose fp32 matrix + * + * @param src + * @param dst + * @param M + * @param N + */ +void transpose_fp32(void *src, void *dst, int M, int N); +#endif //! __ARM_NEON + +} // namespace mllm \ No newline at end of file diff --git a/src/backends/cpu/compute/Matmul.cpp b/src/backends/cpu/compute/Matmul.cpp index b434d043..0e4a1069 100644 --- a/src/backends/cpu/compute/Matmul.cpp +++ b/src/backends/cpu/compute/Matmul.cpp @@ -7,6 +7,12 @@ #include "VecDotType.hpp" // #include #include "SGEMM.hpp" +#include + +#ifdef __ARM_NEON +#include +#include +#endif #define ASSERT(x) \ do { \ @@ -640,4 +646,530 @@ ErrorCode mat_mul_elastic(Tensor *src0, Tensor *src1, Tensor *dst, bool support_ } } return MLLM_NO_ERROR; -} \ No newline at end of file +} + +ErrorCode mat_mul_i8(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias, bool transpose0, bool transpose1, int thread_count, float scale1, float scale2) { + if (support_bias) { + std::cout << "Not support bias in mat_mul_i8" << std::endl; + abort(); + } + + if (!transpose1) { + std::cout << "Not support transpose1==false in mat_mul_i8" << std::endl; + abort(); + } + +#ifdef __ARM_NEON + armv8::qt8_qt8_fp32_gemm_sdot_omp(src0->rawHostPtr(), src1->rawHostPtr(), dst->rawHostPtr(), src0->sequence(), src1->sequence(), src0->dimension(), + src0->i8_scale, src1->i8_scale, transpose1); +#else + std::cout << "mat_mul_i8 is only supported in armv8.2+" << std::endl; + abort(); +#endif + + return MLLM_NO_ERROR; +} + +#ifdef __ARM_NEON +namespace mllm::armv8 { + +// This function is dropped !!! +void qt8_qt8_fp32_gemv(void *A, void *B, void *C, int32_t N, int32_t K, float SA, float SB, + bool transpose_b) { + if (!transpose_b) { + // Not Supported Yet. + return; + } + + // I tile K with 16 element per block. the 16 element will be load to 128bit vector. + if (K % 16) abort(); + + int32_t k_blocks = K / 16; + int32_t n_blocks = N / 4; + int32_t n_blocks_left = N % 4; + float scale = SA * SB; + + for (int32_t n_block = 0; n_block < n_blocks; n_block++) { + auto a_ptr = (int8_t *)A; + auto b_ptr = (int8_t *)B + 4 * K * n_block; + auto c_ptr = (float *)C; + + // final accumulattor + // float acc_line_normal_reg[4] = {0.f, 0.f, 0.f, 0.f}; + // float32x4_t acc_line = vmovq_n_f32(0); + + // accumulator + int16x8x2_t acc_line_0 = {{vmovq_n_s16(0), vmovq_n_s16(0)}}; + int16x8x2_t acc_line_1 = {{vmovq_n_s16(0), vmovq_n_s16(0)}}; + int16x8x2_t acc_line_2 = {{vmovq_n_s16(0), vmovq_n_s16(0)}}; + int16x8x2_t acc_line_3 = {{vmovq_n_s16(0), vmovq_n_s16(0)}}; + + for (int32_t k_block = 0; k_block < k_blocks; k_block++) { + // load from A + int8x16_t r0 = vld1q_s8(a_ptr + 16 * k_block); + int8x8_t r0_l = vget_low_s8(r0); + int8x8_t r0_h = vget_high_s8(r0); + + // load from block, n = 0 + int8x16_t n0 = vld1q_s8(b_ptr + 16 * k_block); + int8x8_t n0_l = vget_low_s8(n0); + int8x8_t n0_h = vget_high_s8(n0); + acc_line_0.val[0] = vmlal_s8(acc_line_0.val[0], n0_h, r0_h); + acc_line_0.val[1] = vmlal_s8(acc_line_0.val[1], n0_l, r0_l); + + // load from block, n = 1 + int8x16_t n1 = vld1q_s8(b_ptr + K + 16 * k_block); + int8x8_t n1_l = vget_low_s8(n1); + int8x8_t n1_h = vget_high_s8(n1); + acc_line_1.val[0] = vmlal_s8(acc_line_1.val[0], n1_h, r0_h); + acc_line_1.val[1] = vmlal_s8(acc_line_1.val[1], n1_l, r0_l); + + // load from block, n = 2 + int8x16_t n2 = vld1q_s8(b_ptr + 2 * K + 16 * k_block); + int8x8_t n2_l = vget_low_s8(n2); + int8x8_t n2_h = vget_high_s8(n2); + acc_line_2.val[0] = vmlal_s8(acc_line_2.val[0], n2_h, r0_h); + acc_line_2.val[1] = vmlal_s8(acc_line_2.val[1], n2_l, r0_l); + + // load from block, n = 3 + int8x16_t n3 = vld1q_s8(b_ptr + 3 * K + 16 * k_block); + int8x8_t n3_l = vget_low_s8(n3); + int8x8_t n3_h = vget_high_s8(n3); + acc_line_3.val[0] = vmlal_s8(acc_line_3.val[0], n3_h, r0_h); + acc_line_3.val[1] = vmlal_s8(acc_line_3.val[1], n3_l, r0_l); + } + + // accumulate i16 vector to single i32 value. And turn it to float. + int32x4_t acc_line_0_i32_sum_0 = vpaddlq_s16(acc_line_0.val[0]); + int32x4_t acc_line_0_i32_sum_1 = vpaddlq_s16(acc_line_0.val[1]); + int32_t acc_line_0_i32 = vaddvq_s32(acc_line_0_i32_sum_0) + vaddvq_s32(acc_line_0_i32_sum_1); + // acc_line = vsetq_lane_f32((float)acc_line_0_i32, acc_line, 0); + *(c_ptr + 4 * n_block + 0) = (float)acc_line_0_i32 * scale; + + int32x4_t acc_line_1_i32_sum_0 = vpaddlq_s16(acc_line_1.val[0]); + int32x4_t acc_line_1_i32_sum_1 = vpaddlq_s16(acc_line_1.val[1]); + int32_t acc_line_1_i32 = vaddvq_s32(acc_line_1_i32_sum_0) + vaddvq_s32(acc_line_1_i32_sum_1); + // acc_line = vsetq_lane_f32((float)acc_line_1_i32, acc_line, 1); + *(c_ptr + 4 * n_block + 1) = (float)acc_line_1_i32 * scale; + + int32x4_t acc_line_2_i32_sum_0 = vpaddlq_s16(acc_line_2.val[0]); + int32x4_t acc_line_2_i32_sum_1 = vpaddlq_s16(acc_line_2.val[1]); + int32_t acc_line_2_i32 = vaddvq_s32(acc_line_2_i32_sum_0) + vaddvq_s32(acc_line_2_i32_sum_1); + // acc_line = vsetq_lane_f32((float)acc_line_2_i32, acc_line, 2); + *(c_ptr + 4 * n_block + 2) = (float)acc_line_2_i32 * scale; + + int32x4_t acc_line_3_i32_sum_0 = vpaddlq_s16(acc_line_3.val[0]); + int32x4_t acc_line_3_i32_sum_1 = vpaddlq_s16(acc_line_3.val[1]); + int32_t acc_line_3_i32 = vaddvq_s32(acc_line_3_i32_sum_0) + vaddvq_s32(acc_line_3_i32_sum_1); + // acc_line = vsetq_lane_f32((float)acc_line_3_i32, acc_line, 3); + *(c_ptr + 4 * n_block + 3) = (float)acc_line_3_i32 * scale; + + // scale it. + // acc_line = vmulq_n_f32(acc_line, scale); + + // store + // vst1q_f32(c_ptr + 4 * n_block, acc_line); + } + + // perform vector dot one by one. + for (int32_t n = 0; n < n_blocks_left; n++) { + auto a_ptr = (int8_t *)A; + auto b_ptr = (int8_t *)B + 4 * (N / 4) * K; + auto c_ptr = (float *)C + 4 * (N / 4); + + int16x8x2_t acc_line_0 = {{vmovq_n_s16(0), vmovq_n_s16(0)}}; + + for (int32_t k_block = 0; k_block < k_blocks; k_block++) { + // load from A + int8x16_t r0 = vld1q_s8(a_ptr + 16 * k_block); + int8x8_t r0_l = vget_low_s8(r0); + int8x8_t r0_h = vget_high_s8(r0); + + // load from block, n = 0 + int8x16_t n0 = vld1q_s8(b_ptr + n * K + 16 * k_block); + int8x8_t n0_l = vget_low_s8(n0); + int8x8_t n0_h = vget_high_s8(n0); + acc_line_0.val[0] = vmlal_s8(acc_line_0.val[0], n0_h, r0_h); + acc_line_0.val[1] = vmlal_s8(acc_line_0.val[1], n0_l, r0_l); + } + + // accumulate i16 vector to single i32 value. And turn it to float. + int32x4_t acc_line_0_i32_sum_0 = vpaddlq_s16(acc_line_0.val[0]); + int32x4_t acc_line_0_i32_sum_1 = vpaddlq_s16(acc_line_0.val[1]); + int32_t acc_line_0_i32 = vaddvq_s32(acc_line_0_i32_sum_0) + vaddvq_s32(acc_line_0_i32_sum_1); + + // store + *(c_ptr + n) = (float)acc_line_0_i32 * scale; + } +} + +void qt8_qt8_fp32_gemv_sdot(void *A, void *B, void *C, int32_t N, int32_t K, float SA, float SB, + bool transpose_b) { + if (!transpose_b) { + // Not Supported Yet. + return; + } + + // I tile K with 16 element per block. the 16 element will be load to 128bit vector. + if (K % 16) abort(); + + int32_t k_blocks = K / 16; + int32_t n_blocks = N / 4; + int32_t n_blocks_left = N % 4; + float scale = SA * SB; + + for (int32_t n_block = 0; n_block < n_blocks; n_block++) { + auto a_ptr = (int8_t *)A; + auto b_ptr = (int8_t *)B + 4 * K * n_block; + auto c_ptr = (float *)C; + + // accumulator + int32x4_t acc_line_0 = vmovq_n_s32(0); + int32x4_t acc_line_1 = vmovq_n_s32(0); + int32x4_t acc_line_2 = vmovq_n_s32(0); + int32x4_t acc_line_3 = vmovq_n_s32(0); + + // Only support v8.2+ with dotproduct enabled. + // Using SDOT. The throughput will increase 4 times compared to fp32 version. + for (int32_t k_block = 0; k_block < k_blocks; k_block++) { + // load from A + int8x16_t r0 = vld1q_s8(a_ptr + 16 * k_block); + + // load from block, n = 0 + int8x16_t n0 = vld1q_s8(b_ptr + 16 * k_block); + acc_line_0 = vdotq_s32(acc_line_0, r0, n0); + + // load from block, n = 1 + int8x16_t n1 = vld1q_s8(b_ptr + K + 16 * k_block); + acc_line_1 = vdotq_s32(acc_line_1, r0, n1); + + // load from block, n = 2 + int8x16_t n2 = vld1q_s8(b_ptr + 2 * K + 16 * k_block); + acc_line_2 = vdotq_s32(acc_line_2, r0, n2); + + // load from block, n = 3 + int8x16_t n3 = vld1q_s8(b_ptr + 3 * K + 16 * k_block); + acc_line_3 = vdotq_s32(acc_line_3, r0, n3); + } + + // reduce all and save to c_ptr + int32_t acc_line_0_i32 = vaddvq_s32(acc_line_0); + *(c_ptr + 4 * n_block + 0) = (float)acc_line_0_i32 * scale; + + int32_t acc_line_1_i32 = vaddvq_s32(acc_line_1); + *(c_ptr + 4 * n_block + 1) = (float)acc_line_1_i32 * scale; + + int32_t acc_line_2_i32 = vaddvq_s32(acc_line_2); + *(c_ptr + 4 * n_block + 2) = (float)acc_line_2_i32 * scale; + + int32_t acc_line_3_i32 = vaddvq_s32(acc_line_3); + *(c_ptr + 4 * n_block + 3) = (float)acc_line_3_i32 * scale; + } + + // perform vector dot one by one. + for (int32_t n = 0; n < n_blocks_left; n++) { + auto a_ptr = (int8_t *)A; + auto b_ptr = (int8_t *)B + 4 * (N / 4) * K; + auto c_ptr = (float *)C + 4 * (N / 4); + + int32x4_t acc_line_0 = vmovq_n_s32(0); + + for (int32_t k_block = 0; k_block < k_blocks; k_block++) { + // load from A + int8x16_t r0 = vld1q_s8(a_ptr + 16 * k_block); + + // load from block, n = 0 + int8x16_t n0 = vld1q_s8(b_ptr + n * K + 16 * k_block); + acc_line_0 = vdotq_s32(acc_line_0, r0, n0); + } + + // accumulate i32 vector to single i32 value. And turn it to float. + int32_t acc_line_0_i32 = vaddvq_s32(acc_line_0); + + // store + *(c_ptr + n) = (float)acc_line_0_i32 * scale; + } +} + +void qt8_qt8_fp32_kernel_4x4_sdot(void *A, void *B, void *C, int32_t N, int32_t K, float SA, + float SB, bool transpose_b) { + if (!transpose_b) { + // Not Supported Yet. + return; + } + + int32_t k_blocks = K / 16; + float scale = SA * SB; + + // 4 x K, K x 4. + if (K % 16) abort(); + + auto a_ptr = (int8_t *)A; + auto b_ptr = (int8_t *)B; + auto c_ptr = (float *)C; + + // accumulator should contain 16 values. + int32x4x4_t acc_line_0 = {{vmovq_n_s32(0), vmovq_n_s32(0), vmovq_n_s32(0), vmovq_n_s32(0)}}; + int32x4x4_t acc_line_1 = {{vmovq_n_s32(0), vmovq_n_s32(0), vmovq_n_s32(0), vmovq_n_s32(0)}}; + int32x4x4_t acc_line_2 = {{vmovq_n_s32(0), vmovq_n_s32(0), vmovq_n_s32(0), vmovq_n_s32(0)}}; + int32x4x4_t acc_line_3 = {{vmovq_n_s32(0), vmovq_n_s32(0), vmovq_n_s32(0), vmovq_n_s32(0)}}; + + // final accumulator + int32_t acc_0[4] = {0, 0, 0, 0}; + int32_t acc_1[4] = {0, 0, 0, 0}; + int32_t acc_2[4] = {0, 0, 0, 0}; + int32_t acc_3[4] = {0, 0, 0, 0}; + + for (int k_block = 0; k_block < k_blocks; k_block++) { + // load 4 vector from A + // a1 + int8x16_t a0 = vld1q_s8(a_ptr + 16 * k_block); + int8x16_t a1 = vld1q_s8(a_ptr + 16 * k_block + K); + int8x16_t a2 = vld1q_s8(a_ptr + 16 * k_block + 2 * K); + int8x16_t a3 = vld1q_s8(a_ptr + 16 * k_block + 3 * K); + + // load 4 vector from B + int8x16_t b0 = vld1q_s8(b_ptr + 16 * k_block); + int8x16_t b1 = vld1q_s8(b_ptr + 16 * k_block + K); + int8x16_t b2 = vld1q_s8(b_ptr + 16 * k_block + 2 * K); + int8x16_t b3 = vld1q_s8(b_ptr + 16 * k_block + 3 * K); + + acc_line_0.val[0] = vdotq_s32(acc_line_0.val[0], a0, b0); + acc_line_0.val[1] = vdotq_s32(acc_line_0.val[1], a0, b1); + acc_line_0.val[2] = vdotq_s32(acc_line_0.val[2], a0, b2); + acc_line_0.val[3] = vdotq_s32(acc_line_0.val[3], a0, b3); + + acc_line_1.val[0] = vdotq_s32(acc_line_1.val[0], a1, b0); + acc_line_1.val[1] = vdotq_s32(acc_line_1.val[1], a1, b1); + acc_line_1.val[2] = vdotq_s32(acc_line_1.val[2], a1, b2); + acc_line_1.val[3] = vdotq_s32(acc_line_1.val[3], a1, b3); + + acc_line_2.val[0] = vdotq_s32(acc_line_2.val[0], a2, b0); + acc_line_2.val[1] = vdotq_s32(acc_line_2.val[1], a2, b1); + acc_line_2.val[2] = vdotq_s32(acc_line_2.val[2], a2, b2); + acc_line_2.val[3] = vdotq_s32(acc_line_2.val[3], a2, b3); + + acc_line_3.val[0] = vdotq_s32(acc_line_3.val[0], a3, b0); + acc_line_3.val[1] = vdotq_s32(acc_line_3.val[1], a3, b1); + acc_line_3.val[2] = vdotq_s32(acc_line_3.val[2], a3, b2); + acc_line_3.val[3] = vdotq_s32(acc_line_3.val[3], a3, b3); + } + + acc_0[0] = vaddvq_s32(acc_line_0.val[0]); + acc_0[1] = vaddvq_s32(acc_line_0.val[1]); + acc_0[2] = vaddvq_s32(acc_line_0.val[2]); + acc_0[3] = vaddvq_s32(acc_line_0.val[3]); + int32x4_t acc_0_vec_i32 = vld1q_s32(acc_0); + float32x4_t acc_0_vec_f32 = vcvtq_f32_s32(acc_0_vec_i32); + float32x4_t acc_0_vec_final = vmulq_n_f32(acc_0_vec_f32, scale); + vst1q_f32(c_ptr, acc_0_vec_final); + + acc_1[0] = vaddvq_s32(acc_line_1.val[0]); + acc_1[1] = vaddvq_s32(acc_line_1.val[1]); + acc_1[2] = vaddvq_s32(acc_line_1.val[2]); + acc_1[3] = vaddvq_s32(acc_line_1.val[3]); + int32x4_t acc_1_vec_i32 = vld1q_s32(acc_1); + float32x4_t acc_1_vec_f32 = vcvtq_f32_s32(acc_1_vec_i32); + float32x4_t acc_1_vec_final = vmulq_n_f32(acc_1_vec_f32, scale); + vst1q_f32(c_ptr + N, acc_1_vec_final); + + acc_2[0] = vaddvq_s32(acc_line_2.val[0]); + acc_2[1] = vaddvq_s32(acc_line_2.val[1]); + acc_2[2] = vaddvq_s32(acc_line_2.val[2]); + acc_2[3] = vaddvq_s32(acc_line_2.val[3]); + int32x4_t acc_2_vec_i32 = vld1q_s32(acc_2); + float32x4_t acc_2_vec_f32 = vcvtq_f32_s32(acc_2_vec_i32); + float32x4_t acc_2_vec_final = vmulq_n_f32(acc_2_vec_f32, scale); + vst1q_f32(c_ptr + 2 * N, acc_2_vec_final); + + acc_3[0] = vaddvq_s32(acc_line_3.val[0]); + acc_3[1] = vaddvq_s32(acc_line_3.val[1]); + acc_3[2] = vaddvq_s32(acc_line_3.val[2]); + acc_3[3] = vaddvq_s32(acc_line_3.val[3]); + int32x4_t acc_3_vec_i32 = vld1q_s32(acc_3); + float32x4_t acc_3_vec_f32 = vcvtq_f32_s32(acc_3_vec_i32); + float32x4_t acc_3_vec_final = vmulq_n_f32(acc_3_vec_f32, scale); + vst1q_f32(c_ptr + 3 * N, acc_3_vec_final); +} + +void qt8_qt8_fp32_vec_dot(void *A, void *B, void *C, int32_t K, float SA, float SB) { + if (K % 16) abort(); + + int32_t k_blocks = K / 16; + float scale = SA * SB; + + auto a_ptr = (int8_t *)A; + auto b_ptr = (int8_t *)B; + auto c_ptr = (float *)C; + + // accumulator + int32x4_t acc_line_0 = vmovq_n_s32(0); + + // Only support v8.2+ with dotproduct enabled. + // Using SDOT. The throughput will increase 4 times compared to fp32 version. + for (int32_t k_block = 0; k_block < k_blocks; k_block++) { + // load from A + int8x16_t r0 = vld1q_s8(a_ptr + 16 * k_block); + + // load from block, n = 0 + int8x16_t n0 = vld1q_s8(b_ptr + 16 * k_block); + acc_line_0 = vdotq_s32(acc_line_0, r0, n0); + } + + // reduce all and save to c_ptr + int32_t acc_line_0_i32 = vaddvq_s32(acc_line_0); + *(c_ptr) = (float)acc_line_0_i32 * scale; +} + +// This function is dropped !!! +void qt8_qt8_fp32_gemm(void *A, void *B, void *C, int32_t M, int32_t N, int32_t K, float SA, + float SB, bool transpose_b) { + if (M == 1) { + qt8_qt8_fp32_gemv(A, B, C, N, K, SA, SB, transpose_b); + } else { + // FIXME: tile in more efficient way. + auto a_ptr = (int8_t *)A; + auto b_ptr = (int8_t *)B; + auto c_ptr = (float *)C; + for (int m = 0; m < M; ++m) { + qt8_qt8_fp32_gemv(a_ptr + m * K, b_ptr, c_ptr + m * N, N, K, SA, SB, transpose_b); + } + } +} + +// This function is dropped !!! +void qt8_qt8_fp32_gemm_omp(void *A, void *B, void *C, int32_t M, int32_t N, int32_t K, float SA, + float SB, bool transpose_b) { + if (M == 1) { + qt8_qt8_fp32_gemv(A, B, C, N, K, SA, SB, transpose_b); + } else { + // FIXME: tile in more efficient way. + auto a_ptr = (int8_t *)A; + auto b_ptr = (int8_t *)B; + auto c_ptr = (float *)C; + if (M > 64) { +#pragma omp parallel for num_threads(4) + for (int m = 0; m < M; ++m) { + qt8_qt8_fp32_gemv(a_ptr + m * K, b_ptr, c_ptr + m * N, N, K, SA, SB, transpose_b); + } + } else { + for (int m = 0; m < M; ++m) { + qt8_qt8_fp32_gemv(a_ptr + m * K, b_ptr, c_ptr + m * N, N, K, SA, SB, transpose_b); + } + } + } +} + +void qt8_qt8_fp32_gemm_sdot(void *A, void *B, void *C, int32_t M, int32_t N, int32_t K, float SA, + float SB, bool transpose_b) { + if (M == 1) { + qt8_qt8_fp32_gemv_sdot(A, B, C, N, K, SA, SB, transpose_b); + } else { + auto a_ptr = (int8_t *)A; + auto b_ptr = (int8_t *)B; + auto c_ptr = (float *)C; + + int32_t m_blocks = M / 4; + int32_t n_blocks = N / 4; + int32_t m_left = M % 4; + int32_t n_left = N % 4; + + if (M < 4 || N < 4) { + for (int m = 0; m < M; ++m) { + qt8_qt8_fp32_gemv_sdot(a_ptr + m * K, b_ptr, c_ptr + m * N, N, K, SA, SB, transpose_b); + } + return; + } + + // main loop + for (int m = 0; m < m_blocks; ++m) { + for (int n = 0; n < n_blocks; ++n) { + qt8_qt8_fp32_kernel_4x4_sdot(a_ptr + 4 * m * K, b_ptr + 4 * n * K, + c_ptr + 4 * m * N + 4 * n, N, K, SA, SB, transpose_b); + } + + // some re cauculating may be needed + if (n_left) { + qt8_qt8_fp32_kernel_4x4_sdot(a_ptr + 4 * m * K, b_ptr + (n_blocks * 4 - (4 - n_left)) * K, + c_ptr + 4 * m * N + n_blocks * 4 - (4 - n_left), N, K, SA, SB, + transpose_b); + } + } + + if (m_left) { + for (int m = m_blocks * 4; m < M; ++m) { + qt8_qt8_fp32_gemv_sdot(a_ptr + m * K, b_ptr, c_ptr + m * N, N, K, SA, SB, transpose_b); + } + } + } +} + +void qt8_qt8_fp32_gemm_sdot_omp(void *A, void *B, void *C, int32_t M, int32_t N, int32_t K, + float SA, float SB, bool transpose_b) { + if (M == 1) { + qt8_qt8_fp32_gemv_sdot(A, B, C, N, K, SA, SB, transpose_b); + } else { + auto a_ptr = (int8_t *)A; + auto b_ptr = (int8_t *)B; + auto c_ptr = (float *)C; + + int32_t m_blocks = M / 4; + int32_t n_blocks = N / 4; + int32_t m_left = M % 4; + int32_t n_left = N % 4; + + if (M < 4 || N < 4) { + for (int m = 0; m < M; ++m) { + qt8_qt8_fp32_gemv_sdot(a_ptr + m * K, b_ptr, c_ptr + m * N, N, K, SA, SB, transpose_b); + } + return; + } + + if (M > 64) { + // main loop +#pragma omp parallel for num_threads(4) + for (int m = 0; m < m_blocks; ++m) { + for (int n = 0; n < n_blocks; ++n) { + qt8_qt8_fp32_kernel_4x4_sdot(a_ptr + 4 * m * K, b_ptr + 4 * n * K, + c_ptr + 4 * m * N + 4 * n, N, K, SA, SB, transpose_b); + } + + // some re cauculating may be needed + if (n_left) { + qt8_qt8_fp32_kernel_4x4_sdot(a_ptr + 4 * m * K, b_ptr + (n_blocks * 4 - (4 - n_left)) * K, + c_ptr + 4 * m * N + n_blocks * 4 - (4 - n_left), N, K, SA, + SB, transpose_b); + } + } + + if (m_left) { + for (int m = m_blocks * 4; m < M; ++m) { + qt8_qt8_fp32_gemv_sdot(a_ptr + m * K, b_ptr, c_ptr + m * N, N, K, SA, SB, transpose_b); + } + } + } else { + // main loop + for (int m = 0; m < m_blocks; ++m) { + for (int n = 0; n < n_blocks; ++n) { + qt8_qt8_fp32_kernel_4x4_sdot(a_ptr + 4 * m * K, b_ptr + 4 * n * K, + c_ptr + 4 * m * N + 4 * n, N, K, SA, SB, transpose_b); + } + + // some re cauculating may be needed + if (n_left) { + qt8_qt8_fp32_kernel_4x4_sdot(a_ptr + 4 * m * K, b_ptr + (n_blocks * 4 - (4 - n_left)) * K, + c_ptr + 4 * m * N + n_blocks * 4 - (4 - n_left), N, K, SA, + SB, transpose_b); + } + } + + if (m_left) { + for (int m = m_blocks * 4; m < M; ++m) { + qt8_qt8_fp32_gemv_sdot(a_ptr + m * K, b_ptr, c_ptr + m * N, N, K, SA, SB, transpose_b); + } + } + } + } +} +} // namespace mllm::armv8 +#endif \ No newline at end of file diff --git a/src/backends/cpu/compute/Matmul.hpp b/src/backends/cpu/compute/Matmul.hpp index 7276dbe5..5d0d41b2 100644 --- a/src/backends/cpu/compute/Matmul.hpp +++ b/src/backends/cpu/compute/Matmul.hpp @@ -22,7 +22,148 @@ ErrorCode mat_mul_fp32_q6_K(Tensor *src0_, Tensor *src1, Tensor *dst, bool suppo ErrorCode mat_mul_elastic(Tensor *src0_, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias = nullptr, int activate_input_dim = -1, int activate_output_dim = -1, bool transpose0 = false, bool transpose1 = true, int thread_count = 4); -// smoothquant int8 matmul ErrorCode mat_mul_i8(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias = nullptr, bool transpose0 = false, bool transpose1 = false, int thread_count = 4, float scale1 = 1.0f, float scale2 = 1.0f); -ErrorCode mat_mul_fp32_i8(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias = nullptr, bool transpose0 = false, bool transpose1 = false, int thread_count = 4, float scale2 = 1.0f); + +#ifdef __ARM_NEON + +#ifndef __ARM_NEON +#error \ + "The mllm-advance Armv8 backend is enbaled but __ARM_NEON is not defined. Pls use cross-compile toolchains(such as NDK) to compile." +#endif + +#include + +namespace mllm::armv8 { + +/** + * @brief Decoding stage. qt8_qt8_fp32_gemm will accpect q(1 x K), k^T(K x N) as inputs. GEMV Like. + * + * This function is dropped !!! + * + * @param A + * @param B + * @param C + * @param N + * @param K + * @param SA + * @param SB + * @param transpose_b + */ +[[deprecated]] void qt8_qt8_fp32_gemv(void *A, void *B, void *C, int32_t N, int32_t K, float SA, + float SB, bool transpose_b = false); + +/** + * @brief Same logic as qt8_qt8_fp32_gemv. But accumulate numbers into int32 accumulator. + * + * @param A + * @param B + * @param C + * @param N + * @param K + * @param SA + * @param SB + * @param transpose_b + */ +void qt8_qt8_fp32_gemv_sdot(void *A, void *B, void *C, int32_t N, int32_t K, float SA, float SB, + bool transpose_b = false); + +/** + * @brief + * + * @param A + * @param B + * @param C + * @param K + * @param SA + * @param SB + * @param transpose_b + */ +void qt8_qt8_fp32_kernel_4x4_sdot(void *A, void *B, void *C, int32_t N, int32_t K, float SA, + float SB, bool transpose_b = false); + +/** + * @brief Per-Tensor Quantized Int8 vec dot product. + * + * @param A + * @param B + * @param C + * @param K + * @param SA + * @param SB + */ +void qt8_qt8_fp32_vec_dot(void *A, void *B, void *C, int32_t K, float SA, float SB); + +/** + * @brief Per-Tensor Quantized Int8 GEMM. + * A(per-tensor signed int 8) @ B(per-tensor signed int 8) -> C(fp32) + * A(M x K), B(K x N), C(M x K) + * + * This function is dropped !!! + * + * @param A int8_t array + * @param B int8_t array + * @param C float array + * @param M + * @param N + * @param K + * @param SA Per-tensor scale for A + * @param SB Per-tensor scale for B + */ +[[deprecated]] void qt8_qt8_fp32_gemm(void *A, void *B, void *C, int32_t M, int32_t N, int32_t K, + float SA, float SB, bool transpose_b = false); + +/** + * @brief Using openmp on GEMM. GEMV always disable multithread. + * + * This function is dropped !!! + * + * @param A + * @param B + * @param C + * @param M + * @param N + * @param K + * @param SA + * @param SB + * @param transpose_b + */ +[[deprecated]] void qt8_qt8_fp32_gemm_omp(void *A, void *B, void *C, int32_t M, int32_t N, + int32_t K, float SA, float SB, bool transpose_b = false); + +/** + * @brief Per-Tensor Quantized Int8 GEMM. + * A(per-tensor signed int 8) @ B(per-tensor signed int 8) -> C(fp32) + * A(M x K), B(K x N), C(M x K) + * + * @param A int8_t array + * @param B int8_t array + * @param C float array + * @param M + * @param N + * @param K + * @param SA Per-tensor scale for A + * @param SB Per-tensor scale for B + */ +void qt8_qt8_fp32_gemm_sdot(void *A, void *B, void *C, int32_t M, int32_t N, int32_t K, float SA, + float SB, bool transpose_b = false); + +/** + * @brief Using openmp on GEMM. GEMV always disable multithread. + * + * @param A + * @param B + * @param C + * @param M + * @param N + * @param K + * @param SA + * @param SB + * @param transpose_b + */ +void qt8_qt8_fp32_gemm_sdot_omp(void *A, void *B, void *C, int32_t M, int32_t N, int32_t K, + float SA, float SB, bool transpose_b = false); + +} // namespace mllm::armv8 +#endif + #endif // MLLM_MATMUL_HPP