diff --git a/src/backends/qnn/op/QNNRoPE.cpp b/src/backends/qnn/op/QNNRoPE.cpp index 4e1505c3..183633a2 100644 --- a/src/backends/qnn/op/QNNRoPE.cpp +++ b/src/backends/qnn/op/QNNRoPE.cpp @@ -6,14 +6,19 @@ namespace mllm { +vector QNNRoPE::theta_; + vector> QNNRoPE::sin_; vector> QNNRoPE::cos_; int QNNRoPE::global_pose_type_ = -1; int QNNRoPE::ishape_old; - -extern void sinusoidal_position_embedding_llama(int seq_len, int output_dim, vector> &sin, vector> &cos); -extern void sinusoidal_position_embedding_huggingface(int seq_len, int output_dim, vector> &sin, vector> &cos, int base = 10000); +extern void sinusoidal_position_embedding_llama(int seq_len, int output_dim, const vector &theta, + vector> &sin, vector> &cos, float attention_scaling = 1.0); +extern void sinusoidal_position_embedding_huggingface(int seq_len, int output_dim, const vector &theta, + vector> &sin, vector> &cos, float attention_scaling = 1.0); +typedef float (*mllm_rope_init_func)(const OpParam &, std::vector &); +extern unordered_map rope_init_func_map; QNNRoPE::QNNRoPE(Backend *bn, string opName, int pose_type) : QNNCommonOp(bn, opName) { @@ -53,6 +58,25 @@ QNNRoPE::QNNRoPE(Backend *bn, string opName, int pose_type, float rope_theta, fl scale_.setBackend(bn); } +QNNRoPE::QNNRoPE(Backend *bn, string opName, OpParam &config) : + QNNCommonOp(bn, opName) { + config_ = config; + pose_type_ = config.at("pose_type"); + auto it = config.find("rope_theta"); + if (it != config.end()) { + rope_theta_ = it->second; + } + it = config.find("partial_rotary_factor"); + if (it != config.end()) { + partial_rotary_factor_ = it->second; + } + it = config.find("max_position_embeddings"); + if (it != config.end()) { + pos_max_ = it->second; + } + rope_type = (RoPEThetaType)config.at("rope_type"); +} + ErrorCode QNNRoPE::reshape(vector> inputs, vector> outputs) { assert(inputs.size() == 1); assert(outputs.size() == 1); @@ -66,18 +90,34 @@ ErrorCode QNNRoPE::reshape(vector> inputs, vector &theta); +// from CPURoPE.cpp 2025/1/24 +extern float _compute_llama3_theta(const OpParam &config, vector &theta); + ErrorCode QNNRoPE::setUp(vector> inputs, vector> outputs) { // in case ishape is 0 when Op is the first one in the graph + const unordered_map rope_init_func_map = { + {DEFAULT, _default_init_rope}, + {LLAMA3, _compute_llama3_theta}, + }; + if (sin_.empty() || ishape_old < ishape || global_pose_type_ != pose_type_) { + auto calc_theta = rope_init_func_map.at(rope_type); + auto config = config_; + config["base"] = (float)rope_theta_; + config["dim"] = ishape; + float attention_scaling = calc_theta(config, theta_); + global_pose_type_ = pose_type_; ishape_old = ishape; if (pose_type_ == LLAMAROPE) { - sinusoidal_position_embedding_llama(pos_max_, ishape, sin_, cos_); + sinusoidal_position_embedding_llama(pos_max_, ishape, theta_, sin_, cos_, attention_scaling); } else if (pose_type_ == PERSIMMONROPE) { - sinusoidal_position_embedding_huggingface(pos_max_, ishape / 2, sin_, cos_, 25000); + sinusoidal_position_embedding_huggingface(pos_max_, ishape / 2, theta_, sin_, cos_, attention_scaling); } else if (pose_type_ == HFHUBROPE || pose_type_ == MLAROPE) { - sinusoidal_position_embedding_huggingface(pos_max_, ishape, sin_, cos_, rope_theta_); + sinusoidal_position_embedding_huggingface(pos_max_, ishape, theta_, sin_, cos_, attention_scaling); } else { } } @@ -92,57 +132,46 @@ ErrorCode QNNRoPE::setUp(vector> inputs, vectordtype() == MLLM_TYPE_F32) { - sinTensor_.setName(name() + ".sin"); - sinTensor_.reshape(1, 1, pos_max_, ishape/2); + sinTensor_.reshape(1, 1, pos_max_, ishape / 2); sinTensor_.setDtype(MLLM_TYPE_F32); sinTensor_.alloc(); - cosTensor_.setName(name() + ".cos"); - cosTensor_.reshape(1, 1, pos_max_, ishape/2); + cosTensor_.reshape(1, 1, pos_max_, ishape / 2); cosTensor_.setDtype(MLLM_TYPE_F32); cosTensor_.alloc(); - for (int i = 0; i(0, 0, i, j, sin_[i][j] * dequantScale); cosTensor_.setDataAt(0, 0, i, j, cos_[i][j] * dequantScale); } } - - } else if (outputs[0]->dtype() == MLLM_TYPE_F16) { - + + } else if (outputs[0]->dtype() == MLLM_TYPE_F16) { sinTensor_.setName(name() + ".sin"); - sinTensor_.reshape(1, 1, pos_max_, ishape/2); + sinTensor_.reshape(1, 1, pos_max_, ishape / 2); sinTensor_.setDtype(MLLM_TYPE_F32); sinTensor_.alloc(); - cosTensor_.setName(name() + ".cos"); - cosTensor_.reshape(1, 1, pos_max_, ishape/2); + cosTensor_.reshape(1, 1, pos_max_, ishape / 2); cosTensor_.setDtype(MLLM_TYPE_F32); cosTensor_.alloc(); - for (int i = 0; i(0, 0, i, j, static_cast(sin_[i][j])); cosTensor_.setDataAt(0, 0, i, j, static_cast(cos_[i][j])); } } type = QNN_DATATYPE_FLOAT_16; + } - } - - - - - - - - uint32_t sin_dimensions[] = {static_cast(pos_max_), static_cast(ishape/2)}; - uint32_t cos_dimensions[] = {static_cast(pos_max_), static_cast(ishape/2)}; + uint32_t sin_dimensions[] = {static_cast(pos_max_), static_cast(ishape / 2)}; + uint32_t cos_dimensions[] = {static_cast(pos_max_), static_cast(ishape / 2)}; auto sinWeightsName = name() + ".sin.weights"; @@ -243,13 +272,11 @@ ErrorCode QNNRoPE::setUp(vector> inputs, vector> inputs, vector> inputs, vector> outputs) { - h_cnt_ += inputs[0]->sequence(); - hcntTensor_.setDataAt(0,0,0,0, h_cnt_); + hcntTensor_.setDataAt(0, 0, 0, 0, h_cnt_); return QNNCommonOp::execute(inputs, outputs); } diff --git a/src/backends/qnn/op/QNNRoPE.hpp b/src/backends/qnn/op/QNNRoPE.hpp index b93e6676..8663418e 100644 --- a/src/backends/qnn/op/QNNRoPE.hpp +++ b/src/backends/qnn/op/QNNRoPE.hpp @@ -9,6 +9,7 @@ class QNNRoPE : public QNNCommonOp { QNNRoPE(Backend *bn, string opName, int pose_type); QNNRoPE(Backend *bn, string opName, int pose_type, float rope_theta, int max_position_embeddings); QNNRoPE(Backend *bn, string opName, int pose_type, float rope_theta, float partial_rotary_factor, int max_position_embeddings); + QNNRoPE(Backend *bn, string opName, OpParam &config); virtual ~QNNRoPE() = default; virtual ErrorCode reshape(vector> inputs, vector> outputs) override; virtual ErrorCode setUp(vector> inputs, vector> outputs) override; @@ -17,7 +18,7 @@ class QNNRoPE : public QNNCommonOp { virtual ErrorCode execute(vector> inputs, vector> outputs) override; private: - + static vector theta_; static vector> sin_; static vector> cos_; static int global_pose_type_; @@ -29,6 +30,10 @@ class QNNRoPE : public QNNCommonOp { int ishape; float partial_rotary_factor_ = 1; + OpParam config_; + + RoPEThetaType rope_type = DEFAULT; + Tensor hcntTensor_; Tensor sinTensor_; @@ -40,6 +45,12 @@ class QNNRoPE : public QNNCommonOp { class QNNRoPECreator : public QNNBackend::Creator { public: virtual Op *create(OpParam op_param, Backend *bn, string name) const { + // from CPURoPE.cpp 2025/1/24 + auto it = op_param.find("rope_type"); + if (it != op_param.end()) { + return new QNNRoPE(bn, name, op_param); + } + int pose_type = op_param["pose_type"]; if (op_param.find("rope_theta") == op_param.end()) { return new QNNRoPE(bn, name, pose_type); diff --git a/src/models/phonelm/modeling_phonelm_npu.hpp b/src/models/phonelm/modeling_phonelm_npu.hpp index 26cbd071..ac5bee9b 100644 --- a/src/models/phonelm/modeling_phonelm_npu.hpp +++ b/src/models/phonelm/modeling_phonelm_npu.hpp @@ -10,6 +10,7 @@ using namespace mllm; +namespace phonelm { // get the closest factors of a number, used in NPU part2 view to speed up the QNN linear inline pair closestFactors(int n) { int root = static_cast(sqrt(n)); @@ -20,6 +21,7 @@ inline pair closestFactors(int n) { } return {1, n}; } +} // namespace phonelm // NPU QKV part class PhoneLMDecoderNPUPart1 final : public Module { @@ -196,7 +198,7 @@ class PhoneLMDecoderNPUPart2 final : public Module { num_key_value_groups = num_heads / num_key_value_heads; // for QNN linear speed up - pre_oproj_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_"); + pre_oproj_view = View(1, phonelm::closestFactors(chunk_size).first, phonelm::closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_"); out_proj = Linear(hidden_size, hidden_size, false, base_name + names._attn_base_name + names._o_proj_name); post_oproj_dequantize = Dequantize(true, base_name + names._attn_base_name + names._o_proj_name + ".dequantize"); post_oproj_view = View(1, 1, chunk_size, hidden_size, base_name + names._attn_base_name + names._o_proj_name + ".dequantize-00_view_"); @@ -207,7 +209,7 @@ class PhoneLMDecoderNPUPart2 final : public Module { auto mlp_base_name = base_name + names._ffn_base_name; pre_mlp_quantize = Quantize(true, mlp_base_name + names._up_proj_name + ".quantize"); - pre_mlp_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_"); + pre_mlp_view = View(1, phonelm::closestFactors(chunk_size).first, phonelm::closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_"); gate_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._gate_proj_name); relu = ReLU(mlp_base_name + names._gate_proj_name + ".relu"); up_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._up_proj_name); @@ -306,7 +308,7 @@ class PhoneLMDecoderNPUPart2WithShadow final : public Module { num_key_value_groups = num_heads / num_key_value_heads; // for QNN linear speed up - pre_oproj_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_"); + pre_oproj_view = View(1, phonelm::closestFactors(chunk_size).first, phonelm::closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_"); out_proj = Linear(hidden_size, hidden_size, false, base_name + names._attn_base_name + names._o_proj_name); post_oproj_dequantize = Dequantize(true, base_name + names._attn_base_name + names._o_proj_name + ".dequantize"); post_oproj_view = View(1, 1, chunk_size, hidden_size, base_name + names._attn_base_name + names._o_proj_name + ".dequantize-00_view_"); @@ -317,7 +319,7 @@ class PhoneLMDecoderNPUPart2WithShadow final : public Module { auto mlp_base_name = base_name + names._ffn_base_name; pre_mlp_quantize = Quantize(true, mlp_base_name + names._up_proj_name + ".quantize"); - pre_mlp_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_"); + pre_mlp_view = View(1, phonelm::closestFactors(chunk_size).first, phonelm::closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_"); gate_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._gate_proj_name); relu = ReLU(mlp_base_name + names._gate_proj_name + ".relu"); up_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._up_proj_name); diff --git a/src/models/qwen/modeling_qwen_npu.hpp b/src/models/qwen/modeling_qwen_npu.hpp index 0305f1b9..db6b19e0 100644 --- a/src/models/qwen/modeling_qwen_npu.hpp +++ b/src/models/qwen/modeling_qwen_npu.hpp @@ -10,6 +10,7 @@ using namespace mllm; +namespace qwen { // get the closest factors of a number, used in NPU part2 view to speed up the QNN linear inline pair closestFactors(int n) { int root = static_cast(sqrt(n)); @@ -20,6 +21,7 @@ inline pair closestFactors(int n) { } return {1, n}; } +} // namespace qwen // NPU QKV part class QwenDecoderNPUPart1 final : public Module { @@ -189,7 +191,7 @@ class QwenDecoderNPUPart2 final : public Module { num_key_value_groups = num_heads / num_key_value_heads; // for QNN linear speed up - pre_oproj_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_"); + pre_oproj_view = View(1, qwen::closestFactors(chunk_size).first, qwen::closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_"); out_proj = Linear(hidden_size, hidden_size, false, base_name + names._attn_base_name + names._o_proj_name); post_oproj_dequantize = Dequantize(true, base_name + names._attn_base_name + names._o_proj_name + ".dequantize"); post_oproj_view = View(1, 1, chunk_size, hidden_size, base_name + names._attn_base_name + names._o_proj_name + ".dequantize-00_view_"); @@ -200,7 +202,7 @@ class QwenDecoderNPUPart2 final : public Module { auto mlp_base_name = base_name + names._ffn_base_name; pre_mlp_quantize = Quantize(true, mlp_base_name + names._up_proj_name + ".quantize"); - pre_mlp_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_"); + pre_mlp_view = View(1, qwen::closestFactors(chunk_size).first, qwen::closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_"); gate_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._gate_proj_name); silu = SiLU(mlp_base_name + "act"); up_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._up_proj_name); @@ -298,7 +300,7 @@ class QwenDecoderNPUPart2WithShadow final : public Module { num_key_value_groups = num_heads / num_key_value_heads; // for QNN linear speed up - pre_oproj_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_"); + pre_oproj_view = View(1, qwen::closestFactors(chunk_size).first, qwen::closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_"); out_proj = Linear(hidden_size, hidden_size, false, base_name + names._attn_base_name + names._o_proj_name); post_oproj_dequantize = Dequantize(true, base_name + names._attn_base_name + names._o_proj_name + ".dequantize"); post_oproj_view = View(1, 1, chunk_size, hidden_size, base_name + names._attn_base_name + names._o_proj_name + ".dequantize-00_view_"); @@ -309,7 +311,7 @@ class QwenDecoderNPUPart2WithShadow final : public Module { auto mlp_base_name = base_name + names._ffn_base_name; pre_mlp_quantize = Quantize(true, mlp_base_name + names._up_proj_name + ".quantize"); - pre_mlp_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_"); + pre_mlp_view = View(1, qwen::closestFactors(chunk_size).first, qwen::closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_"); gate_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._gate_proj_name); silu = SiLU(mlp_base_name + "act"); up_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._up_proj_name); @@ -531,7 +533,7 @@ class QWenModel_NPU final : public Module { class QWenForCausalLM_NPU final : public Module { public: - QWenForCausalLM_NPU(QWenConfig &config, int chunk_size) { + QWenForCausalLM_NPU(QWenConfig &config, int chunk_size = 64) { auto names = config.names_config; hidden_size = config.hidden_size; tie_embedding_words = config.tie_embedding_words; diff --git a/tools/jni/LibHelper.cpp b/tools/jni/LibHelper.cpp index 12340578..853a856b 100644 --- a/tools/jni/LibHelper.cpp +++ b/tools/jni/LibHelper.cpp @@ -74,13 +74,13 @@ bool LibHelper::setUp(const std::string &base_path, std::string weights_path, st module_ = make_shared(qwconfig); #ifdef USE_QNN if (backend_type == MLLMBackendType::QNN) { - prefill_module_ = make_shared(qwconfig); + int chunk_size = 64; + prefill_module_ = make_shared(qwconfig, chunk_size); prefill_module_->load(qnn_weights_path); auto tokenizer = dynamic_pointer_cast(tokenizer_); // warmup START std::string input_str = " "; - int chunk_size = 64; auto res = tokenizer->tokenizePaddingByChunk(input_str, chunk_size, 151936); auto input_tensor = res.second; auto real_seq_length = res.first; @@ -98,7 +98,7 @@ bool LibHelper::setUp(const std::string &base_path, std::string weights_path, st return true; }); Module::isFirstChunk = false; - static_cast(Backend::global_backends[MLLM_CPU])->setSequenceLength(0); + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(0); static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(PROMPT); static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); Module::isMultiChunkPrefilling = true; @@ -146,7 +146,7 @@ bool LibHelper::setUp(const std::string &base_path, std::string weights_path, st return true; }); Module::isFirstChunk = false; - static_cast(Backend::global_backends[MLLM_CPU])->setSequenceLength(0); + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(0); static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(PROMPT); static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); Module::isMultiChunkPrefilling = true; @@ -186,6 +186,11 @@ void LibHelper::run(std::string &input_str, uint8_t *image, unsigned max_step, u const int chunk_num = seq_length_padding / chunk_size; bool isSwitched = false; + // set total seq length for HeadLinear execute, which can not get the real seq length from Opts + static_cast(Backend::global_backends[MLLM_CPU])->setTotalSequenceLength(real_seq_length); + // set chunk size for the HeadLinear execute, which can not get the chunk size from Opts + static_cast(Backend::global_backends[MLLM_CPU])->setChunkSize(chunk_size); + LlmTextGeneratorOpts opt{ .max_new_tokens = 1, .do_sample = false, @@ -226,7 +231,7 @@ void LibHelper::run(std::string &input_str, uint8_t *image, unsigned max_step, u }); Module::isFirstChunk = false; } - static_cast(Backend::global_backends[MLLM_CPU])->setSequenceLength(real_seq_length); + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(real_seq_length); static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(AUTOREGRESSIVE); static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); @@ -259,7 +264,7 @@ void LibHelper::run(std::string &input_str, uint8_t *image, unsigned max_step, u if (!not_end) { return false; } return true; }); - static_cast(Backend::global_backends[MLLM_CPU])->setSequenceLength(0); + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(0); static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(PROMPT); static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); } else { // CPU @@ -318,6 +323,11 @@ void LibHelper::run(std::string &input_str, uint8_t *image, unsigned max_step, u const int chunk_num = seq_length_padding / chunk_size; bool isSwitched = false; + // set total seq length for HeadLinear execute, which can not get the real seq length from Opts + static_cast(Backend::global_backends[MLLM_CPU])->setTotalSequenceLength(real_seq_length); + // set chunk size for the HeadLinear execute, which can not get the chunk size from Opts + static_cast(Backend::global_backends[MLLM_CPU])->setChunkSize(chunk_size); + LlmTextGeneratorOpts opt{ .max_new_tokens = 1, .do_sample = false, @@ -360,7 +370,7 @@ void LibHelper::run(std::string &input_str, uint8_t *image, unsigned max_step, u }); Module::isFirstChunk = false; } - static_cast(Backend::global_backends[MLLM_CPU])->setSequenceLength(real_seq_length); + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(real_seq_length); static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(AUTOREGRESSIVE); static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); @@ -393,7 +403,7 @@ void LibHelper::run(std::string &input_str, uint8_t *image, unsigned max_step, u if (!not_end) { return false; } return true; }); - static_cast(Backend::global_backends[MLLM_CPU])->setSequenceLength(0); + static_cast(Backend::global_backends[MLLM_CPU])->setCurSequenceLength(0); static_cast(Backend::global_backends[MLLM_CPU])->setExecutionType(PROMPT); static_cast(Backend::global_backends[MLLM_CPU])->toggleSwitching(); } else { // CPU