-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #88 from chenghuaWang/main
feat: Add Yi-1.5-6B support
- Loading branch information
Showing
9 changed files
with
454 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/** | ||
* @file demo_yi.cpp | ||
* @author Chenghua Wang ([email protected]) | ||
* @brief | ||
* @version 0.1 | ||
* @date 2024-07-02 | ||
* | ||
* @copyright Copyright (c) 2024 | ||
* | ||
*/ | ||
#include "cmdline.h" | ||
#include "models/yi/configuration_yi.hpp" | ||
#include "models/yi/modeling_yi.hpp" | ||
#include "models/yi/tokenization_yi.hpp" | ||
#include "processor/PostProcess.hpp" | ||
|
||
using namespace mllm; | ||
|
||
int main(int argc, char **argv) { | ||
cmdline::parser cmdParser; | ||
cmdParser.add<string>("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/yi_vocab.mllm"); | ||
cmdParser.add<string>("model", 'm', "specify mllm model path", false, "../models/yi-1.5-6b-chat-q4_k.mllm"); | ||
cmdParser.add<int>("limits", 'l', "max KV cache size", false, 400); | ||
cmdParser.add<int>("thread", 't', "num of threads", false, 4); | ||
cmdParser.parse_check(argc, argv); | ||
|
||
string vocab_path = cmdParser.get<string>("vocab"); | ||
string model_path = cmdParser.get<string>("model"); | ||
int tokens_limit = cmdParser.get<int>("limits"); | ||
CPUBackend::cpu_threads = cmdParser.get<int>("thread"); | ||
|
||
auto tokenizer = YiTokenizer(vocab_path); | ||
YiConfig config(tokens_limit, "6B", RoPEType::HFHUBROPE); | ||
auto model = YiForCausalLM(config); | ||
model.load(model_path); | ||
|
||
vector<string> in_strs = { | ||
"请介绍北京邮电大学,推荐同学们报考。", | ||
}; | ||
|
||
auto processOutput = [&](std::string &text) -> std::pair<bool, std::string> { | ||
text = std::regex_replace(text, std::regex("▁"), " "); | ||
if (text == "<|endoftext|>" || text == "<|im_end|>") return {false, ""}; | ||
return {true, text}; | ||
}; | ||
|
||
for (int i = 0; i < in_strs.size(); ++i) { | ||
auto in_str = in_strs[i]; | ||
std::cout << "[Q] " << in_str << std::endl; | ||
auto input_tensor = tokenizer.tokenize(in_str, i); | ||
std::cout << "[A] " << std::flush; | ||
for (int step = 0; step < 1000; step++) { | ||
auto result = model({input_tensor}); | ||
auto outputs = tokenizer.detokenize(result[0]); | ||
auto out_string = outputs.first; | ||
auto out_token = outputs.second; | ||
auto [isOk, print_string] = processOutput(out_string); | ||
if (isOk) { | ||
std::cout << print_string << std::flush; | ||
} else { | ||
break; | ||
} | ||
chatPostProcessing(out_token, input_tensor, {}); | ||
} | ||
printf("\n"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
/** | ||
* @file configuration_Yi.hpp | ||
* @author Chenghua Wang ([email protected]) | ||
* @brief | ||
* @version 0.1 | ||
* @date 2024-07-02 | ||
* | ||
* @copyright Copyright (c) 2024 | ||
* | ||
*/ | ||
#ifndef CONFIG_YI_HPP | ||
#define CONFIG_YI_HPP | ||
#include "models/transformer/configuration_transformer.hpp" | ||
|
||
using namespace mllm; | ||
|
||
class YiNameConfig : public TransformerNameConfig { | ||
public: | ||
std::string blk_name; | ||
std::string token_embd_name; | ||
std::string post_norm_name; | ||
std::string lm_head_name; | ||
std::string _gate_proj_name; | ||
|
||
void init(RoPEType type = LLAMAROPE) { | ||
switch (type) { | ||
case LLAMAROPE: { | ||
blk_name = "layers."; | ||
_attn_base_name = "attention."; | ||
_ffn_base_name = "feed_forward."; | ||
_q_proj_name = "wq"; | ||
_k_proj_name = "wk"; | ||
_v_proj_name = "wv"; | ||
_o_proj_name = "wo"; | ||
_gate_proj_name = "w1"; | ||
_up_proj_name = "w3"; | ||
_down_proj_name = "w2"; | ||
_attn_norm_name = "attention_norm"; | ||
_ffn_norm_name = "ffn_norm"; | ||
token_embd_name = "tok_embeddings"; | ||
post_norm_name = "norm"; | ||
lm_head_name = "output"; | ||
break; | ||
} | ||
case HFHUBROPE: { | ||
blk_name = "model.layers."; | ||
_attn_base_name = "self_attn."; | ||
_ffn_base_name = "mlp."; | ||
_q_proj_name = "q_proj"; | ||
_k_proj_name = "k_proj"; | ||
_v_proj_name = "v_proj"; | ||
_o_proj_name = "o_proj"; | ||
_gate_proj_name = "gate_proj"; | ||
_up_proj_name = "up_proj"; | ||
_down_proj_name = "down_proj"; | ||
_attn_norm_name = "input_layernorm"; | ||
_ffn_norm_name = "post_attention_layernorm"; | ||
token_embd_name = "model.embed_tokens"; | ||
post_norm_name = "model.norm"; | ||
lm_head_name = "lm_head"; | ||
break; | ||
} | ||
default: { | ||
throw std::runtime_error("Unsupported llama type"); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
class YiConfig { | ||
public: | ||
explicit YiConfig(int token_limit, string billions = "6B", RoPEType type = LLAMAROPE, int vocab = 64000) { | ||
names_config.init(type); | ||
vocab_size = vocab; | ||
if (!(billions == "6B" || billions == "6b")) { | ||
throw std::runtime_error("Unsupported model size"); | ||
} | ||
RoPE_type = type; | ||
cache_limit = token_limit; | ||
} | ||
|
||
public: | ||
bool attention_bias = false; | ||
float attention_drop = 0.0; | ||
int pad_token_id = 0; | ||
int bos_token_id = 1; | ||
int eos_token_id = 2; | ||
int hidden_size = 4096; | ||
float initializer_range = 0.02; | ||
int intermediate_size = 11008; | ||
int max_position_embeddings = 4096; | ||
int num_attention_heads = 32; | ||
int num_hidden_layers = 32; | ||
int num_key_value_heads = 4; | ||
int pretraining_tp = 1; | ||
float rms_norm_eps = 1e-6; | ||
float rope_theta = 5000000.0; | ||
int vocab_size = 64000; | ||
int cache_limit; | ||
RoPEType RoPE_type; | ||
YiNameConfig names_config; | ||
}; | ||
|
||
#endif //! CONFIG_YI_HPP |
Oops, something went wrong.