Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix QNN Op and LibHelper bugs #229

Merged
merged 43 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
e428bb5
dev: qnn multi input inference developing
oreomaker Nov 8, 2024
fca8584
Merge branch 'main' into develop-zh
oreomaker Nov 12, 2024
f94937d
feat: qnn multi chunk prefilling in new frontend
oreomaker Nov 13, 2024
83fefd3
Merge branch 'main' into develop-zh
yirongjie Nov 14, 2024
5ff02cd
fix: clearCache in RoPE
yirongjie Nov 14, 2024
68ddb99
fix: genarate with Padding
yirongjie Nov 14, 2024
28811fe
feat: make chunk_size configurable in demo_phonelm_npu
oreomaker Nov 14, 2024
d930d14
Merge branch 'main' into develop-zh
oreomaker Nov 16, 2024
9406c02
chore: andoid submodule diff
oreomaker Nov 16, 2024
caf4f9c
fix: qwen npu demos execution type set
oreomaker Nov 18, 2024
dfac9d5
feat: KVCacheNPU GQA
oreomaker Nov 18, 2024
2b20472
refactor: simplify kvcache npu
oreomaker Nov 21, 2024
fb95bcd
fix: kvcache npu seq
oreomaker Nov 21, 2024
865dcd6
Merge branch 'main' into develop-zh
oreomaker Nov 27, 2024
3ec8d77
chore: android diff
oreomaker Nov 27, 2024
3b0042d
dev: pipeline class init
oreomaker Dec 4, 2024
05bcd77
refactor: configurable chunk_size in qwen
oreomaker Dec 4, 2024
c3a0c4e
chore: remove qnn getBuildId
oreomaker Dec 7, 2024
9a56c4c
chore: clean qnn backend include
oreomaker Dec 7, 2024
f627381
dev: qnn new frontend pipeline(wrap implement)
oreomaker Dec 7, 2024
bec0a63
fix: kvcache nrep and stage switching bug in old frontend
oreomaker Dec 11, 2024
38fff58
feat: new frontend pipeline
oreomaker Jan 6, 2025
8d07588
chore: qnn qwen executable change
oreomaker Jan 6, 2025
693700b
feat: qnn prefill optimization, only do 1 seq lm_head
oreomaker Jan 7, 2025
be8258f
refactor: main qwen npu token post process
oreomaker Jan 7, 2025
f8def41
Merge branch 'main' into develop-zh
oreomaker Jan 7, 2025
7e20e8b
fix: qnn old frontend modeling backend assign
oreomaker Jan 7, 2025
1d526c1
fix: qnn total length and cur length conflict
oreomaker Jan 7, 2025
6070a4a
feat: configurable chunk size for HeadLinear
oreomaker Jan 7, 2025
a459c74
Update CMakeLists.txt
yirongjie Jan 8, 2025
0f03af2
fix:
yirongjie Jan 8, 2025
79151f2
fix
yirongjie Jan 8, 2025
c5973f1
feat: add gemma 2 model
oreomaker Jan 10, 2025
c19a0ad
chore: update gemma2 vocab
oreomaker Jan 10, 2025
2f0c3bf
Merge branch 'main' into develop-zh
yirongjie Jan 10, 2025
249ab14
doc: README
yirongjie Jan 10, 2025
04ceeab
fix: qnn demo libHelper setSeqLength
oreomaker Jan 17, 2025
2c8206c
Merge branch 'main' into develop-zh
yirongjie Jan 18, 2025
2baf222
Merge branch 'main' into develop-zh
oreomaker Jan 18, 2025
ebd050e
Merge branch 'develop-zh' of github.com:liang1232018/mllm into develo…
oreomaker Jan 18, 2025
ab2c4c2
fix: QNNRoPE rope extern func in CPURoPE
oreomaker Jan 25, 2025
73abeb8
fix: qwen_npu and phonelm_npu closestFactor func conflict
oreomaker Jan 25, 2025
a3866fd
fix: LibHelper qwen npu chunk_size arg
oreomaker Jan 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 60 additions & 34 deletions src/backends/qnn/op/QNNRoPE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@

namespace mllm {

vector<float> QNNRoPE::theta_;

vector<vector<float>> QNNRoPE::sin_;
vector<vector<float>> QNNRoPE::cos_;
int QNNRoPE::global_pose_type_ = -1;
int QNNRoPE::ishape_old;


extern void sinusoidal_position_embedding_llama(int seq_len, int output_dim, vector<vector<float>> &sin, vector<vector<float>> &cos);
extern void sinusoidal_position_embedding_huggingface(int seq_len, int output_dim, vector<vector<float>> &sin, vector<vector<float>> &cos, int base = 10000);
extern void sinusoidal_position_embedding_llama(int seq_len, int output_dim, const vector<float> &theta,
vector<vector<float>> &sin, vector<vector<float>> &cos, float attention_scaling = 1.0);
extern void sinusoidal_position_embedding_huggingface(int seq_len, int output_dim, const vector<float> &theta,
vector<vector<float>> &sin, vector<vector<float>> &cos, float attention_scaling = 1.0);
typedef float (*mllm_rope_init_func)(const OpParam &, std::vector<float> &);
extern unordered_map<RoPEThetaType, mllm_rope_init_func> rope_init_func_map;

QNNRoPE::QNNRoPE(Backend *bn, string opName, int pose_type) :
QNNCommonOp(bn, opName) {
Expand Down Expand Up @@ -53,6 +58,25 @@ QNNRoPE::QNNRoPE(Backend *bn, string opName, int pose_type, float rope_theta, fl
scale_.setBackend(bn);
}

QNNRoPE::QNNRoPE(Backend *bn, string opName, OpParam &config) :
QNNCommonOp(bn, opName) {
config_ = config;
pose_type_ = config.at("pose_type");
auto it = config.find("rope_theta");
if (it != config.end()) {
rope_theta_ = it->second;
}
it = config.find("partial_rotary_factor");
if (it != config.end()) {
partial_rotary_factor_ = it->second;
}
it = config.find("max_position_embeddings");
if (it != config.end()) {
pos_max_ = it->second;
}
rope_type = (RoPEThetaType)config.at("rope_type");
}

ErrorCode QNNRoPE::reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
Expand All @@ -66,18 +90,34 @@ ErrorCode QNNRoPE::reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<
return Op::reshape(inputs, outputs);
}

// from CPURoPE.cpp 2025/1/24
extern float _default_init_rope(const OpParam &config, vector<float> &theta);
// from CPURoPE.cpp 2025/1/24
extern float _compute_llama3_theta(const OpParam &config, vector<float> &theta);

ErrorCode QNNRoPE::setUp(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
// in case ishape is 0 when Op is the first one in the graph

const unordered_map<RoPEThetaType, mllm_rope_init_func> rope_init_func_map = {
{DEFAULT, _default_init_rope},
{LLAMA3, _compute_llama3_theta},
};

if (sin_.empty() || ishape_old < ishape || global_pose_type_ != pose_type_) {
auto calc_theta = rope_init_func_map.at(rope_type);
auto config = config_;
config["base"] = (float)rope_theta_;
config["dim"] = ishape;
float attention_scaling = calc_theta(config, theta_);

global_pose_type_ = pose_type_;
ishape_old = ishape;
if (pose_type_ == LLAMAROPE) {
sinusoidal_position_embedding_llama(pos_max_, ishape, sin_, cos_);
sinusoidal_position_embedding_llama(pos_max_, ishape, theta_, sin_, cos_, attention_scaling);
} else if (pose_type_ == PERSIMMONROPE) {
sinusoidal_position_embedding_huggingface(pos_max_, ishape / 2, sin_, cos_, 25000);
sinusoidal_position_embedding_huggingface(pos_max_, ishape / 2, theta_, sin_, cos_, attention_scaling);
} else if (pose_type_ == HFHUBROPE || pose_type_ == MLAROPE) {
sinusoidal_position_embedding_huggingface(pos_max_, ishape, sin_, cos_, rope_theta_);
sinusoidal_position_embedding_huggingface(pos_max_, ishape, theta_, sin_, cos_, attention_scaling);
} else {
}
}
Expand All @@ -92,57 +132,46 @@ ErrorCode QNNRoPE::setUp(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Te

auto type = QNN_DATATYPE_FLOAT_32;
if (outputs[0]->dtype() == MLLM_TYPE_F32) {

sinTensor_.setName(name() + ".sin");
sinTensor_.reshape(1, 1, pos_max_, ishape/2);
sinTensor_.reshape(1, 1, pos_max_, ishape / 2);
sinTensor_.setDtype(MLLM_TYPE_F32);
sinTensor_.alloc();


cosTensor_.setName(name() + ".cos");
cosTensor_.reshape(1, 1, pos_max_, ishape/2);
cosTensor_.reshape(1, 1, pos_max_, ishape / 2);
cosTensor_.setDtype(MLLM_TYPE_F32);
cosTensor_.alloc();

for (int i = 0; i<pos_max_; i++) {
for (int j=0; j<ishape/2; j++) {
for (int i = 0; i < pos_max_; i++) {
for (int j = 0; j < ishape / 2; j++) {
sinTensor_.setDataAt<float>(0, 0, i, j, sin_[i][j] * dequantScale);
cosTensor_.setDataAt<float>(0, 0, i, j, cos_[i][j] * dequantScale);
}
}

} else if (outputs[0]->dtype() == MLLM_TYPE_F16) {


} else if (outputs[0]->dtype() == MLLM_TYPE_F16) {
sinTensor_.setName(name() + ".sin");
sinTensor_.reshape(1, 1, pos_max_, ishape/2);
sinTensor_.reshape(1, 1, pos_max_, ishape / 2);
sinTensor_.setDtype(MLLM_TYPE_F32);
sinTensor_.alloc();


cosTensor_.setName(name() + ".cos");
cosTensor_.reshape(1, 1, pos_max_, ishape/2);
cosTensor_.reshape(1, 1, pos_max_, ishape / 2);
cosTensor_.setDtype(MLLM_TYPE_F32);
cosTensor_.alloc();

for (int i = 0; i<pos_max_; i++) {
for (int j=0; j<ishape/2; j++) {
for (int i = 0; i < pos_max_; i++) {
for (int j = 0; j < ishape / 2; j++) {
sinTensor_.setDataAt<float>(0, 0, i, j, static_cast<float>(sin_[i][j]));
cosTensor_.setDataAt<float>(0, 0, i, j, static_cast<float>(cos_[i][j]));
}
}

type = QNN_DATATYPE_FLOAT_16;
}

}







uint32_t sin_dimensions[] = {static_cast<uint32_t>(pos_max_), static_cast<uint32_t>(ishape/2)};
uint32_t cos_dimensions[] = {static_cast<uint32_t>(pos_max_), static_cast<uint32_t>(ishape/2)};
uint32_t sin_dimensions[] = {static_cast<uint32_t>(pos_max_), static_cast<uint32_t>(ishape / 2)};
uint32_t cos_dimensions[] = {static_cast<uint32_t>(pos_max_), static_cast<uint32_t>(ishape / 2)};

auto sinWeightsName = name() + ".sin.weights";

Expand Down Expand Up @@ -243,13 +272,11 @@ ErrorCode QNNRoPE::setUp(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Te
}

ErrorCode QNNRoPE::load(AbstructLoader &loader) {

hcntTensor_.setName(name() + ".hcnt.tensor");
hcntTensor_.reshape(1, 1, 1, 1);
hcntTensor_.setDtype(MLLM_TYPE_I32);
hcntTensor_.alloc();


string scaleName = name();
string scaleTypeName = "output_scale";

Expand All @@ -273,9 +300,8 @@ ErrorCode QNNRoPE::free(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Ten
}

ErrorCode QNNRoPE::execute(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {

h_cnt_ += inputs[0]->sequence();
hcntTensor_.setDataAt(0,0,0,0, h_cnt_);
hcntTensor_.setDataAt(0, 0, 0, 0, h_cnt_);

return QNNCommonOp::execute(inputs, outputs);
}
Expand Down
13 changes: 12 additions & 1 deletion src/backends/qnn/op/QNNRoPE.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class QNNRoPE : public QNNCommonOp {
QNNRoPE(Backend *bn, string opName, int pose_type);
QNNRoPE(Backend *bn, string opName, int pose_type, float rope_theta, int max_position_embeddings);
QNNRoPE(Backend *bn, string opName, int pose_type, float rope_theta, float partial_rotary_factor, int max_position_embeddings);
QNNRoPE(Backend *bn, string opName, OpParam &config);
virtual ~QNNRoPE() = default;
virtual ErrorCode reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;
virtual ErrorCode setUp(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;
Expand All @@ -17,7 +18,7 @@ class QNNRoPE : public QNNCommonOp {
virtual ErrorCode execute(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;

private:

static vector<float> theta_;
static vector<vector<float>> sin_;
static vector<vector<float>> cos_;
static int global_pose_type_;
Expand All @@ -29,6 +30,10 @@ class QNNRoPE : public QNNCommonOp {
int ishape;
float partial_rotary_factor_ = 1;

OpParam config_;

RoPEThetaType rope_type = DEFAULT;

Tensor hcntTensor_;

Tensor sinTensor_;
Expand All @@ -40,6 +45,12 @@ class QNNRoPE : public QNNCommonOp {
class QNNRoPECreator : public QNNBackend::Creator {
public:
virtual Op *create(OpParam op_param, Backend *bn, string name) const {
// from CPURoPE.cpp 2025/1/24
auto it = op_param.find("rope_type");
if (it != op_param.end()) {
return new QNNRoPE(bn, name, op_param);
}

int pose_type = op_param["pose_type"];
if (op_param.find("rope_theta") == op_param.end()) {
return new QNNRoPE(bn, name, pose_type);
Expand Down
10 changes: 6 additions & 4 deletions src/models/phonelm/modeling_phonelm_npu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

using namespace mllm;

namespace phonelm {
// get the closest factors of a number, used in NPU part2 view to speed up the QNN linear
inline pair<int, int> closestFactors(int n) {
int root = static_cast<int>(sqrt(n));
Expand All @@ -20,6 +21,7 @@ inline pair<int, int> closestFactors(int n) {
}
return {1, n};
}
} // namespace phonelm

// NPU QKV part
class PhoneLMDecoderNPUPart1 final : public Module {
Expand Down Expand Up @@ -196,7 +198,7 @@ class PhoneLMDecoderNPUPart2 final : public Module {
num_key_value_groups = num_heads / num_key_value_heads;

// for QNN linear speed up
pre_oproj_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_");
pre_oproj_view = View(1, phonelm::closestFactors(chunk_size).first, phonelm::closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_");
out_proj = Linear(hidden_size, hidden_size, false, base_name + names._attn_base_name + names._o_proj_name);
post_oproj_dequantize = Dequantize(true, base_name + names._attn_base_name + names._o_proj_name + ".dequantize");
post_oproj_view = View(1, 1, chunk_size, hidden_size, base_name + names._attn_base_name + names._o_proj_name + ".dequantize-00_view_");
Expand All @@ -207,7 +209,7 @@ class PhoneLMDecoderNPUPart2 final : public Module {

auto mlp_base_name = base_name + names._ffn_base_name;
pre_mlp_quantize = Quantize(true, mlp_base_name + names._up_proj_name + ".quantize");
pre_mlp_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_");
pre_mlp_view = View(1, phonelm::closestFactors(chunk_size).first, phonelm::closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_");
gate_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._gate_proj_name);
relu = ReLU(mlp_base_name + names._gate_proj_name + ".relu");
up_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._up_proj_name);
Expand Down Expand Up @@ -306,7 +308,7 @@ class PhoneLMDecoderNPUPart2WithShadow final : public Module {
num_key_value_groups = num_heads / num_key_value_heads;

// for QNN linear speed up
pre_oproj_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_");
pre_oproj_view = View(1, phonelm::closestFactors(chunk_size).first, phonelm::closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_");
out_proj = Linear(hidden_size, hidden_size, false, base_name + names._attn_base_name + names._o_proj_name);
post_oproj_dequantize = Dequantize(true, base_name + names._attn_base_name + names._o_proj_name + ".dequantize");
post_oproj_view = View(1, 1, chunk_size, hidden_size, base_name + names._attn_base_name + names._o_proj_name + ".dequantize-00_view_");
Expand All @@ -317,7 +319,7 @@ class PhoneLMDecoderNPUPart2WithShadow final : public Module {

auto mlp_base_name = base_name + names._ffn_base_name;
pre_mlp_quantize = Quantize(true, mlp_base_name + names._up_proj_name + ".quantize");
pre_mlp_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_");
pre_mlp_view = View(1, phonelm::closestFactors(chunk_size).first, phonelm::closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_");
gate_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._gate_proj_name);
relu = ReLU(mlp_base_name + names._gate_proj_name + ".relu");
up_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._up_proj_name);
Expand Down
12 changes: 7 additions & 5 deletions src/models/qwen/modeling_qwen_npu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

using namespace mllm;

namespace qwen {
// get the closest factors of a number, used in NPU part2 view to speed up the QNN linear
inline pair<int, int> closestFactors(int n) {
int root = static_cast<int>(sqrt(n));
Expand All @@ -20,6 +21,7 @@ inline pair<int, int> closestFactors(int n) {
}
return {1, n};
}
} // namespace qwen

// NPU QKV part
class QwenDecoderNPUPart1 final : public Module {
Expand Down Expand Up @@ -189,7 +191,7 @@ class QwenDecoderNPUPart2 final : public Module {
num_key_value_groups = num_heads / num_key_value_heads;

// for QNN linear speed up
pre_oproj_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_");
pre_oproj_view = View(1, qwen::closestFactors(chunk_size).first, qwen::closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_");
out_proj = Linear(hidden_size, hidden_size, false, base_name + names._attn_base_name + names._o_proj_name);
post_oproj_dequantize = Dequantize(true, base_name + names._attn_base_name + names._o_proj_name + ".dequantize");
post_oproj_view = View(1, 1, chunk_size, hidden_size, base_name + names._attn_base_name + names._o_proj_name + ".dequantize-00_view_");
Expand All @@ -200,7 +202,7 @@ class QwenDecoderNPUPart2 final : public Module {

auto mlp_base_name = base_name + names._ffn_base_name;
pre_mlp_quantize = Quantize(true, mlp_base_name + names._up_proj_name + ".quantize");
pre_mlp_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_");
pre_mlp_view = View(1, qwen::closestFactors(chunk_size).first, qwen::closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_");
gate_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._gate_proj_name);
silu = SiLU(mlp_base_name + "act");
up_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._up_proj_name);
Expand Down Expand Up @@ -298,7 +300,7 @@ class QwenDecoderNPUPart2WithShadow final : public Module {
num_key_value_groups = num_heads / num_key_value_heads;

// for QNN linear speed up
pre_oproj_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_");
pre_oproj_view = View(1, qwen::closestFactors(chunk_size).first, qwen::closestFactors(chunk_size).second, head_dim * num_heads, base_name + names._attn_base_name + "or_split-00_view_");
out_proj = Linear(hidden_size, hidden_size, false, base_name + names._attn_base_name + names._o_proj_name);
post_oproj_dequantize = Dequantize(true, base_name + names._attn_base_name + names._o_proj_name + ".dequantize");
post_oproj_view = View(1, 1, chunk_size, hidden_size, base_name + names._attn_base_name + names._o_proj_name + ".dequantize-00_view_");
Expand All @@ -309,7 +311,7 @@ class QwenDecoderNPUPart2WithShadow final : public Module {

auto mlp_base_name = base_name + names._ffn_base_name;
pre_mlp_quantize = Quantize(true, mlp_base_name + names._up_proj_name + ".quantize");
pre_mlp_view = View(1, closestFactors(chunk_size).first, closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_");
pre_mlp_view = View(1, qwen::closestFactors(chunk_size).first, qwen::closestFactors(chunk_size).second, hidden_size, mlp_base_name + names._up_proj_name + ".quantize-00_view_");
gate_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._gate_proj_name);
silu = SiLU(mlp_base_name + "act");
up_proj = Linear(hidden_size, intermediate_size, false, mlp_base_name + names._up_proj_name);
Expand Down Expand Up @@ -531,7 +533,7 @@ class QWenModel_NPU final : public Module {

class QWenForCausalLM_NPU final : public Module {
public:
QWenForCausalLM_NPU(QWenConfig &config, int chunk_size) {
QWenForCausalLM_NPU(QWenConfig &config, int chunk_size = 64) {
auto names = config.names_config;
hidden_size = config.hidden_size;
tie_embedding_words = config.tie_embedding_words;
Expand Down
Loading
Loading