diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d060daa..d9e0bdb2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -459,6 +459,17 @@ else () target_link_libraries(demo_mistral PUBLIC MLLM_CPU) endif () +add_executable(demo_yi ${PROJECT_SOURCE_DIR}/examples/demo_yi.cpp ${DIR_SRC_CPU} ${DIR_SRC_MEM_MANAGER} ${DIR_SRC_EXP} ${DIR_SRC} + src/tokenizers/Tokenizer.cpp + src/tokenizers/BPE/Bpe.cpp + src/processor/PreProcess.cpp +) +if (ARM AND NOT APK) + target_compile_options(demo_yi PRIVATE -fopenmp) + target_link_libraries(demo_yi PUBLIC MLLM_CPU -fopenmp -static-openmp) +else () + target_link_libraries(demo_yi PUBLIC MLLM_CPU) +endif () # add_executable(demo_deepseek ${PROJECT_SOURCE_DIR}/examples/demo_deepseek.cpp ${DIR_SRC_CPU} ${DIR_SRC_MEM_MANAGER} ${DIR_SRC_EXP} ${DIR_SRC} # src/tokenizers/Tokenizer.cpp diff --git a/README.md b/README.md index e0f743fa..08d72a49 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,8 @@ Wait.. why on-device multimodal LLM? - It's a key building block for [intelligent personal agent](https://arxiv.org/pdf/2401.05459.pdf), text-based image searching/retrieval, screen VQA, and many more exciting mobile apps, without giving away your private data (chat history, screenshots, taken photos, etc). ## Recent update -- [:fire::fire:Comming soon] Supporting Qualcomm NPU: [>1000 tokens/second prefilling!](https://arxiv.org/pdf/2407.05858v1) +- [🔥🔥Comming soon] Supporting Qualcomm NPU: [>1000 tokens/second prefilling!](https://arxiv.org/pdf/2407.05858v1) +- [2024 July 2] Support new model: Yi V1.5 6B https://github.com/UbiquitousLearning/mllm/pull/88 - [2024 May 29] Support new model: Mistral V0.2 7B https://github.com/UbiquitousLearning/mllm/pull/83 - [2024 May 4] Support new model: QWen V1.5 0.5B https://github.com/UbiquitousLearning/mllm/pull/79 - [2024 April 9] Support new model: Gemma 2B https://github.com/UbiquitousLearning/mllm/pull/75 @@ -75,6 +76,7 @@ Wait.. why on-device multimodal LLM? - It's a key building block for [intelligen | [Gemma 2B](https://github.com/google/gemma_pytorch) | [✔️](https://huggingface.co/mllmTeam/gemma-2b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/gemma-2b-mllm/tree/main) | | [Qwen 0.5B](https://github.com/QwenLM/Qwen) | [✔️](https://huggingface.co/mllmTeam/qwen-1.5-0.5b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/qwen-1.5-0.5b-mllm/tree/main) | | [Mistral 7B](https://github.com/mistralai/mistral-src) | [✔️](https://huggingface.co/mllmTeam/mistral-7b-instruct-v0.2-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/mistral-7b-instruct-v0.2-mllm/tree/main) | +| [Yi 6B](https://huggingface.co/01-ai/Yi-1.5-6B) | [✔️](https://huggingface.co/mllmTeam/yi-1.5-6b-chat-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/yi-1.5-6b-chat-mllm/tree/main) | ## Quick Start @@ -246,16 +248,16 @@ cd tools/convertor pip install -r ./requirements.txt # for one file pytorch model -python convert.py --input_model=model.pth --output_model=model.mllm --type=torch +python converter.py --input_model=model.pth --output_model=model.mllm --type=torch # for multi-file pytorch model -python convert.py --input_model=pytorch_model.bin.index.json --output_model=model.mllm --type=torch +python converter.py --input_model=pytorch_model.bin.index.json --output_model=model.mllm --type=torch # for one file safetensor model -python convert.py --input_model=model.bin --output_model=model.mllm --type=safetensor +python converter.py --input_model=model.bin --output_model=model.mllm --type=safetensor # for multi-file safetensor model -python convert.py --input_model=model.safetensors.index.json --output_model=model.mllm --type=safetensor +python converter.py --input_model=model.safetensors.index.json --output_model=model.mllm --type=safetensor ``` ### Convert vocabulary @@ -274,7 +276,7 @@ mllm only support two quantize modes: Q4_0 and Q4_K. ```bash cd bin -./quantize model.mllm model_q4_0.mllm Q4_K +./quantize model.mllm model_q4_k.mllm Q4_K ``` ## Roadmap diff --git a/examples/demo_yi.cpp b/examples/demo_yi.cpp new file mode 100644 index 00000000..7de3a784 --- /dev/null +++ b/examples/demo_yi.cpp @@ -0,0 +1,67 @@ +/** + * @file demo_yi.cpp + * @author Chenghua Wang (chenghua.wang.edu@gmail.com) + * @brief + * @version 0.1 + * @date 2024-07-02 + * + * @copyright Copyright (c) 2024 + * + */ +#include "cmdline.h" +#include "models/yi/configuration_yi.hpp" +#include "models/yi/modeling_yi.hpp" +#include "models/yi/tokenization_yi.hpp" +#include "processor/PostProcess.hpp" + +using namespace mllm; + +int main(int argc, char **argv) { + cmdline::parser cmdParser; + cmdParser.add("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/yi_vocab.mllm"); + cmdParser.add("model", 'm', "specify mllm model path", false, "../models/yi-1.5-6b-chat-q4_k.mllm"); + cmdParser.add("limits", 'l', "max KV cache size", false, 400); + cmdParser.add("thread", 't', "num of threads", false, 4); + cmdParser.parse_check(argc, argv); + + string vocab_path = cmdParser.get("vocab"); + string model_path = cmdParser.get("model"); + int tokens_limit = cmdParser.get("limits"); + CPUBackend::cpu_threads = cmdParser.get("thread"); + + auto tokenizer = YiTokenizer(vocab_path); + YiConfig config(tokens_limit, "6B", RoPEType::HFHUBROPE); + auto model = YiForCausalLM(config); + model.load(model_path); + + vector in_strs = { + "请介绍北京邮电大学,推荐同学们报考。", + }; + + auto processOutput = [&](std::string &text) -> std::pair { + text = std::regex_replace(text, std::regex("▁"), " "); + if (text == "<|endoftext|>" || text == "<|im_end|>") return {false, ""}; + return {true, text}; + }; + + for (int i = 0; i < in_strs.size(); ++i) { + auto in_str = in_strs[i]; + std::cout << "[Q] " << in_str << std::endl; + auto input_tensor = tokenizer.tokenize(in_str, i); + std::cout << "[A] " << std::flush; + for (int step = 0; step < 1000; step++) { + auto result = model({input_tensor}); + auto outputs = tokenizer.detokenize(result[0]); + auto out_string = outputs.first; + auto out_token = outputs.second; + auto [isOk, print_string] = processOutput(out_string); + if (isOk) { + std::cout << print_string << std::flush; + } else { + break; + } + chatPostProcessing(out_token, input_tensor, {}); + } + printf("\n"); + } +} diff --git a/src/models/yi/configuration_yi.hpp b/src/models/yi/configuration_yi.hpp new file mode 100644 index 00000000..a31bbc53 --- /dev/null +++ b/src/models/yi/configuration_yi.hpp @@ -0,0 +1,104 @@ +/** + * @file configuration_Yi.hpp + * @author Chenghua Wang (chenghua.wang.edu@gmail.com) + * @brief + * @version 0.1 + * @date 2024-07-02 + * + * @copyright Copyright (c) 2024 + * + */ +#ifndef CONFIG_YI_HPP +#define CONFIG_YI_HPP +#include "models/transformer/configuration_transformer.hpp" + +using namespace mllm; + +class YiNameConfig : public TransformerNameConfig { +public: + std::string blk_name; + std::string token_embd_name; + std::string post_norm_name; + std::string lm_head_name; + std::string _gate_proj_name; + + void init(RoPEType type = LLAMAROPE) { + switch (type) { + case LLAMAROPE: { + blk_name = "layers."; + _attn_base_name = "attention."; + _ffn_base_name = "feed_forward."; + _q_proj_name = "wq"; + _k_proj_name = "wk"; + _v_proj_name = "wv"; + _o_proj_name = "wo"; + _gate_proj_name = "w1"; + _up_proj_name = "w3"; + _down_proj_name = "w2"; + _attn_norm_name = "attention_norm"; + _ffn_norm_name = "ffn_norm"; + token_embd_name = "tok_embeddings"; + post_norm_name = "norm"; + lm_head_name = "output"; + break; + } + case HFHUBROPE: { + blk_name = "model.layers."; + _attn_base_name = "self_attn."; + _ffn_base_name = "mlp."; + _q_proj_name = "q_proj"; + _k_proj_name = "k_proj"; + _v_proj_name = "v_proj"; + _o_proj_name = "o_proj"; + _gate_proj_name = "gate_proj"; + _up_proj_name = "up_proj"; + _down_proj_name = "down_proj"; + _attn_norm_name = "input_layernorm"; + _ffn_norm_name = "post_attention_layernorm"; + token_embd_name = "model.embed_tokens"; + post_norm_name = "model.norm"; + lm_head_name = "lm_head"; + break; + } + default: { + throw std::runtime_error("Unsupported llama type"); + } + } + } +}; + +class YiConfig { +public: + explicit YiConfig(int token_limit, string billions = "6B", RoPEType type = LLAMAROPE, int vocab = 64000) { + names_config.init(type); + vocab_size = vocab; + if (!(billions == "6B" || billions == "6b")) { + throw std::runtime_error("Unsupported model size"); + } + RoPE_type = type; + cache_limit = token_limit; + } + +public: + bool attention_bias = false; + float attention_drop = 0.0; + int pad_token_id = 0; + int bos_token_id = 1; + int eos_token_id = 2; + int hidden_size = 4096; + float initializer_range = 0.02; + int intermediate_size = 11008; + int max_position_embeddings = 4096; + int num_attention_heads = 32; + int num_hidden_layers = 32; + int num_key_value_heads = 4; + int pretraining_tp = 1; + float rms_norm_eps = 1e-6; + float rope_theta = 5000000.0; + int vocab_size = 64000; + int cache_limit; + RoPEType RoPE_type; + YiNameConfig names_config; +}; + +#endif //! CONFIG_YI_HPP \ No newline at end of file diff --git a/src/models/yi/modeling_yi.hpp b/src/models/yi/modeling_yi.hpp new file mode 100644 index 00000000..699955dd --- /dev/null +++ b/src/models/yi/modeling_yi.hpp @@ -0,0 +1,195 @@ +/** + * @file modeling_Yi.hpp + * @author Chenghua Wang (chenghua.wang.edu@gmail.com) + * @brief + * @version 0.1 + * @date 2024-07-02 + * + * @copyright Copyright (c) 2024 + * + */ +#ifndef MODELING_YI_HPP +#define MODELING_YI_HPP + +#include "Backend.hpp" +#include "Layer.hpp" +#include "Module.hpp" +#include "Tensor.hpp" +#include "configuration_yi.hpp" +#include +using namespace mllm; + +class YiMLP final : public Module { +public: + YiMLP() = default; + YiMLP(int hidden_size, int intermediate_size, const YiNameConfig &names, const std::string &base_name) { + gate_proj = Linear(hidden_size, intermediate_size, false, base_name + names._gate_proj_name); + silu = SiLU(base_name + "act"); + up_proj = Linear(hidden_size, intermediate_size, false, base_name + names._up_proj_name); + down_proj = Linear(intermediate_size, hidden_size, false, base_name + names._down_proj_name); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + auto x = gate_proj(inputs[0]); + x = silu(x); + auto y = up_proj(inputs[0]); + x = x * y; + x = down_proj(x); + return {x}; + } + +private: + Layer gate_proj; + Layer up_proj; + Layer down_proj; + + Layer silu; +}; + +class YiAttention final : public Module { +public: + YiAttention() = default; + YiAttention(const YiConfig &config, const YiNameConfig &names, const string &base_name) { + hidden_size = config.hidden_size; + num_heads = config.num_attention_heads; + head_dim = config.hidden_size / num_heads; + num_key_value_heads = config.num_key_value_heads; + num_key_value_groups = num_heads / num_key_value_heads; + + // init layers + q_proj = Linear(hidden_size, num_heads * head_dim, false, base_name + names._q_proj_name); + k_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._k_proj_name); + v_proj = Linear(hidden_size, num_key_value_heads * head_dim, false, base_name + names._v_proj_name); + o_proj = Linear(num_heads * head_dim, hidden_size, false, base_name + names._o_proj_name); + q_rope = RoPE(config.RoPE_type, config.rope_theta, config.max_position_embeddings, base_name + "q_rope"); + k_rope = RoPE(config.RoPE_type, config.rope_theta, config.max_position_embeddings, base_name + "k_rope"); + k_cache = KVCache(num_key_value_groups, config.cache_limit, base_name + "k_cache"); + v_cache = KVCache(num_key_value_groups, config.cache_limit, base_name + "v_cache"); + mask = Causalmask(base_name + "mask"); + softmax = Softmax(DIMENSION, base_name + "softmax"); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + auto query_states = q_proj(inputs[0]); + auto key_states = k_proj(inputs[1]); + auto value_states = v_proj(inputs[2]); + + // [batch, heads, sequence, dims] + query_states = query_states.view(-1, num_heads, -1, head_dim); + key_states = key_states.view(-1, num_key_value_heads, -1, head_dim); + value_states = value_states.view(-1, num_key_value_heads, -1, head_dim); + + // embedding + query_states = q_rope(query_states); + key_states = k_rope(key_states); + + // kv cache + key_states = k_cache(key_states); + value_states = v_cache(value_states); + + // attention weight + auto atten_weight = Tensor::mm(query_states, key_states.transpose(Chl::SEQUENCE, Chl::DIMENSION)) / std::sqrt(head_dim); + atten_weight = mask(atten_weight); + atten_weight = softmax(atten_weight); + + // attention output + auto atten_output = Tensor::mm(atten_weight, value_states); + atten_output = atten_output.view(-1, 1, -1, head_dim * num_heads); + atten_output = o_proj(atten_output); + return {atten_output}; + } + +private: + int hidden_size; + int num_heads; + int head_dim; + int num_key_value_heads; + int num_key_value_groups; + Layer q_proj; + Layer k_proj; + Layer v_proj; + Layer o_proj; + Layer q_rope; + Layer k_rope; + Layer k_cache; + Layer v_cache; + Layer mask; + Layer softmax; +}; + +class YiDecoder final : public Module { +public: + YiDecoder() = default; + YiDecoder(const YiConfig &config, const YiNameConfig &names, const string &base_name) { + self_atten = YiAttention(config, names, base_name + names._attn_base_name); + mlp = YiMLP(config.hidden_size, config.intermediate_size, names, base_name + names._ffn_base_name); + input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._attn_norm_name); + post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, base_name + names._ffn_norm_name); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + auto x = input_layernorm(inputs[0]); + x = self_atten({x, x, x})[0]; + auto tmp = x + inputs[0]; + x = post_attention_layernorm(tmp); + x = mlp({x})[0]; + x = x + tmp; + return {x}; + } + +private: + YiAttention self_atten; + YiMLP mlp; + Layer input_layernorm; + Layer post_attention_layernorm; +}; + +class YiModel final : public Module { +public: + YiModel() = default; + YiModel(const YiConfig &config, const YiNameConfig &names, const string &base_name) { + blocks = List(config.num_hidden_layers, config, names, base_name); + norm = RMSNorm(config.hidden_size, config.rms_norm_eps, names.post_norm_name); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + auto x = inputs[0]; + for (auto &block : blocks) { + x = block({x})[0]; + } + x = norm(x); + return {x}; + } + +private: + std::vector blocks; + Layer norm; +}; + +class YiForCausalLM final : public Module { +public: + YiForCausalLM(YiConfig &config) { + auto names = config.names_config; + hidden_size = config.hidden_size; + embedding = Embedding(config.vocab_size, config.hidden_size, names.token_embd_name); + model = YiModel(config, names, names.blk_name); + lm_head = Linear(hidden_size, config.vocab_size, false, names.lm_head_name); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + auto x = embedding(inputs[0]); + + // go through model + auto outputs = model({x})[0]; + outputs = lm_head(outputs); + return {outputs}; + } + +private: + int hidden_size; + Layer embedding; + Layer lm_head; + YiModel model; +}; + +#endif //! MODELING_YI_HPP \ No newline at end of file diff --git a/src/models/yi/tokenization_yi.hpp b/src/models/yi/tokenization_yi.hpp new file mode 100644 index 00000000..5fa657c1 --- /dev/null +++ b/src/models/yi/tokenization_yi.hpp @@ -0,0 +1,63 @@ +/** + * @file tokenization_Yi.hpp + * @author Chenghua Wang (chenghua.wang.edu@gmail.com) + * @brief + * @version 0.1 + * @date 2024-07-02 + * + * @copyright Copyright (c) 2024 + * + */ +#ifndef TOKENIZATION_YI_HPP +#define TOKENIZATION_YI_HPP + +#include "tokenizers/BPE/Bpe.hpp" + +using namespace mllm; + +class YiTokenizer final { + BPETokenizer *tokenizer; + + unsigned int argmax(const std::vector &scores) { + if (scores.empty()) { + throw std::invalid_argument("Input vector is empty"); + } + return std::max_element(scores.begin(), scores.end()) - scores.begin(); + } + +public: + explicit YiTokenizer(const std::string &vocab_file) { + Module::initBackend(MLLM_CPU); + tokenizer = new BPETokenizer(vocab_file); + } + + Tensor tokenize(std::string &text, int str_i = 0) const { + if (text[0] != ' ') { + text = ' ' + text; + } + auto tokens_id = vector(); + tokenizer->tokenize(text, tokens_id, false); + if (str_i > 0) { + tokens_id[0] = 13; + } + return BPETokenizer::tokens2Input(tokens_id); + } + + std::string detokenize(const std::vector &tokens) { + return tokenizer->detokenize(tokens); + } + + std::pair detokenize(Tensor &result) { + assert(result.batch() == 1); + assert(result.head() == 1); + vector scores; + for (int i = 0; i < result.dimension(); ++i) { + auto value = result.dataAt(0, 0, result.sequence() - 1, i); + scores.push_back(value); + } + auto token_idx = this->argmax(scores); + return {tokenizer->detokenize({token_idx}), token_idx}; + } +}; + +#endif // !TOKENIZATION_YI_HPP diff --git a/tools/jni/LibHelper.cpp b/tools/jni/LibHelper.cpp index 87fe4ba5..4cd4f38e 100644 --- a/tools/jni/LibHelper.cpp +++ b/tools/jni/LibHelper.cpp @@ -38,14 +38,14 @@ unsigned int LibHelper::postProcessing(shared_ptr result, shared_ptrconvert(c->sub_param_, BackendType::MLLM_CPU); - tokenizer_ = new BPETokenizer(vacab_path); + tokenizer_ = new BPETokenizer(vocab_path); eos_id_ = 2; break; } @@ -80,7 +80,7 @@ bool LibHelper::setUp(const std::string &base_path, std::string weights_path, st int patch_size = 30; Fuyu(c, vocab_size, patch_size, 3, hidden_dim, ffn_hidden_dim, mutil_head_size); net_->convert(c->sub_param_, BackendType::MLLM_CPU); - tokenizer_ = new UnigramTokenizer(vacab_path); + tokenizer_ = new UnigramTokenizer(vocab_path); eos_id_ = 71013; break; } diff --git a/tools/jni/LibHelper.hpp b/tools/jni/LibHelper.hpp index d5f71698..da846577 100644 --- a/tools/jni/LibHelper.hpp +++ b/tools/jni/LibHelper.hpp @@ -53,7 +53,7 @@ class LibHelper { bool is_first_run_cond_ = true; unsigned postProcessing(std::shared_ptr result, std::shared_ptr &out_result) const; public: - bool setUp(const std::string &base_path, std::string weights_path, std::string vacab_path, PreDefinedModel model, MLLMBackendType backend_type = MLLMBackendType::CPU); + bool setUp(const std::string &base_path, std::string weights_path, std::string vocab_path, PreDefinedModel model, MLLMBackendType backend_type = MLLMBackendType::CPU); void setCallback(callback_t callback); void run(std::string &input_str, uint8_t *image, unsigned max_step, unsigned image_length) ; ~LibHelper(); diff --git a/vocab/yi_vocab.mllm b/vocab/yi_vocab.mllm new file mode 100644 index 00000000..024caf2b Binary files /dev/null and b/vocab/yi_vocab.mllm differ