diff --git a/examples/demo_qwen.cpp b/examples/demo_qwen.cpp index 422540ec..1a90efea 100644 --- a/examples/demo_qwen.cpp +++ b/examples/demo_qwen.cpp @@ -52,19 +52,24 @@ int main(int argc, char **argv) { auto input_tensor = tokenizer.tokenize(in_str, i); std::cout << "[Q] " << in_str << std::endl; std::cout << "[A] " << std::flush; - for (int step = 0; step < 100; step++) { - auto result = model({input_tensor}); - auto outputs = tokenizer.detokenize(result[0]); - auto out_string = outputs.first; - auto out_token = outputs.second; + + LlmTextGeneratorOpts opt{ + .max_new_tokens = 100, + .do_sample = true, + .temperature = 0.3f, + .top_k = 50, + .top_p = 0.f, + }; + model.generate(input_tensor, opt, [&](unsigned int out_token) -> bool { + auto out_string = tokenizer.detokenize({out_token}); auto [isOk, print_string] = processOutput(out_string); if (isOk) { std::cout << print_string << std::flush; } else { - break; + return false; } - chatPostProcessing(out_token, input_tensor, {}); - } + return true; + }); printf("\n"); } } \ No newline at end of file diff --git a/src/Generate.cpp b/src/Generate.cpp new file mode 100644 index 00000000..6f6f74b1 --- /dev/null +++ b/src/Generate.cpp @@ -0,0 +1,123 @@ +/** + * @file Generate.cpp + * @author Chenghua Wang (chenghua.wang.edu@gmail.com) + * @brief The Mllm Generator Impl + * @version 0.1 + * @date 2024-07-30 + * + * @copyright Copyright (c) 2024 + * + */ +#include "Generate.hpp" +#include +#include + +namespace mllm { + +unsigned int _LlmTextGenerateGreedySearchMethod::generate(Tensor &t) { + std::vector scores; + this->_tensor_to_vec(t, scores); + return std::max_element(scores.begin(), scores.end()) - scores.begin(); +} + +unsigned int _LlmTextGenerateTopkSamplingMethod::generate(Tensor &t) { + auto argmax = [](const std::vector &vec) -> unsigned int { + return std::distance(vec.begin(), std::max_element(vec.begin(), vec.end())); + }; + + if (m_k == 0 || m_k == 1) { + std::vector scores; + this->_tensor_to_vec(t, scores); + return argmax(scores); + } + + std::vector> scores; + this->_tensor_to_vec_with_idx(t, scores); + + // find top k + std::partial_sort(scores.begin(), scores.begin() + m_k, scores.end(), + [](std::pair a, std::pair b) { return a.first > b.first; }); + std::vector top_k_elements(m_k, 0.f); + std::vector top_k_elements_idx(m_k, 0); + for (int i = 0; i < m_k; ++i) { + top_k_elements[i] = scores[i].first; + top_k_elements_idx[i] = scores[i].second; + } + + // softmax with temperature + std::vector softmax(top_k_elements.size(), 0.f); + double max_logit = top_k_elements[argmax(top_k_elements)]; + double sum_exp = 0.f; + + for (size_t i = 0; i < top_k_elements.size(); ++i) { + softmax[i] = exp((top_k_elements[i] - max_logit) / m_temperature); + sum_exp += softmax[i]; + } + + for (float &value : softmax) { + value /= sum_exp; + } + + // sampling + float _sum = std::accumulate(softmax.begin(), softmax.end(), 0.0); + for (float &value : softmax) { + value /= _sum; + } + + auto idx = _sample_element(top_k_elements_idx, softmax); + return idx; +} + +unsigned int _LlmTextGenerateToppSamplingMethod::generate(Tensor &t) { + auto argmax = [](const std::vector &vec) -> unsigned int { + return std::distance(vec.begin(), std::max_element(vec.begin(), vec.end())); + }; + std::vector> scores; + this->_tensor_to_vec_with_idx(t, scores); + + std::sort(scores.begin(), scores.end(), [](std::pair a, std::pair b) { return a.first > b.first; }); + std::vector top_k_elements; + std::vector top_k_elements_idx; + + if (scores[0].first > 1.f) { + throw std::runtime_error("The input tensor t should go through softmax first.(0.f - 1.f is acceptable)"); + } + + float p = 0.f; + size_t idx = 0; + while (p < m_p) { + top_k_elements.emplace_back(scores[idx].first); + top_k_elements_idx.emplace_back(scores[idx].second); + p += scores[idx].first; + idx++; + } + + if (top_k_elements.size() == 1) { + return top_k_elements_idx[0]; + } + + // softmax with temperature + std::vector softmax(top_k_elements.size(), 0.f); + double max_logit = top_k_elements[argmax(top_k_elements)]; + double sum_exp = 0.f; + + for (size_t i = 0; i < top_k_elements.size(); ++i) { + softmax[i] = exp((top_k_elements[i] - max_logit) / m_temperature); + sum_exp += softmax[i]; + } + + for (float &value : softmax) { + value /= sum_exp; + } + + // sampling + float _sum = std::accumulate(softmax.begin(), softmax.end(), 0.0); + for (float &value : softmax) { + value /= _sum; + } + + auto ret = _sample_element(top_k_elements_idx, softmax); + return ret; +} + +} // namespace mllm \ No newline at end of file diff --git a/src/Generate.hpp b/src/Generate.hpp new file mode 100644 index 00000000..41f58b1b --- /dev/null +++ b/src/Generate.hpp @@ -0,0 +1,164 @@ +/** + * @file Generate.hpp + * @author Chenghua Wang (chenghua.wang.edu@gmail.com) + * @brief The Mllm Generator + * @version 0.1 + * @date 2024-07-30 + * + * @copyright Copyright (c) 2024 + * + */ +#pragma once +#ifndef MLLM_GENERATE_HPP +#define MLLM_GENERATE_HPP +#include +#include +#include +#include +#include +#include "Tensor.hpp" + +namespace mllm { + +struct LlmTextGeneratorOpts { + size_t max_new_tokens = 100; + size_t min_new_tokens = 10; + bool do_sample = true; + float temperature = 0.7; + int top_k = 5; + float top_p = 0.92; +}; + +template +T _sample_element(const std::vector &elements, const std::vector &probabilities) { + std::random_device rd; + std::mt19937 gen(rd()); + std::discrete_distribution<> dist(probabilities.begin(), probabilities.end()); + size_t index = dist(gen); + return elements[index]; +} + +enum class LLmTextGeneratorType : int32_t { + kNone = 0, + kGreedySearch, + kTopkSampling, + kToppSampling, + KLast, +}; + +class _LlmTextGenerateMethod { +public: + virtual ~_LlmTextGenerateMethod() = default; + virtual unsigned int generate(Tensor &t) = 0; + inline void _tensor_to_vec(Tensor &t, std::vector &scores) { + assert(t.batch() == 1 && "Batch size of result is not 1. Which is not supported for now."); + assert(t.head() == 1 && "The 3rd dim of result should be one. e.g.:[1, 1, seq, hidden]"); + int _dims = t.dimension(); + int _seq = t.sequence() - 1; + for (int i = 0; i < _dims; ++i) { + auto value = t.dataAt(0, 0, _seq, i); + scores.push_back(value); + } + } + + inline void _tensor_to_vec_with_idx(Tensor &t, std::vector> &scores) { + assert(t.batch() == 1 && "Batch size of result is not 1. Which is not supported for now."); + assert(t.head() == 1 && "The 3rd dim of result should be one. e.g.:[1, 1, seq, hidden]"); + int _dims = t.dimension(); + int _seq = t.sequence() - 1; + for (int i = 0; i < _dims; ++i) { + auto value = t.dataAt(0, 0, _seq, i); + scores.push_back(std::make_pair(value, i)); + } + } +}; + +class _LlmTextGenerateGreedySearchMethod : public _LlmTextGenerateMethod { +public: + _LlmTextGenerateGreedySearchMethod() = default; + ~_LlmTextGenerateGreedySearchMethod() = default; + unsigned int generate(Tensor &t) override; +}; + +class _LlmTextGenerateTopkSamplingMethod : public _LlmTextGenerateMethod { +public: + ~_LlmTextGenerateTopkSamplingMethod() = default; + _LlmTextGenerateTopkSamplingMethod(int32_t k = 5, float temperature = 0.f) : + m_k(k), + m_temperature(temperature) { + } + unsigned int generate(Tensor &t) override; + +private: + int32_t m_k; + float m_temperature = 0.f; +}; + +class _LlmTextGenerateToppSamplingMethod : public _LlmTextGenerateMethod { +public: + ~_LlmTextGenerateToppSamplingMethod() = default; + _LlmTextGenerateToppSamplingMethod(float p = 5, float temperature = 0.f) : + m_p(p), + m_temperature(temperature) { + } + unsigned int generate(Tensor &t) override; + +private: + float m_p; + float m_temperature = 0.f; +}; + +// Usage: +// LlmTextGeneratorOpts opt{ +// .max_new_tokens = 100, +// .do_sample = true, +// .temperature = 0.7f, +// .top_k = 50, +// .top_p = 0.f, +// }; +// model.generate(input_tensor, opt, [&](unsigned int out_token) -> bool { +// auto out_string = tokenizer.detokenize({out_token}); +// auto [isOk, print_string] = processOutput(out_string); +// if (isOk) { +// std::cout << print_string << std::flush; +// } else { +// return false; +// } +// return true; +// }); +// printf("\n"); + +class LlmTextGenerator { +public: + ~LlmTextGenerator() { + delete m_method_class; + } + + LlmTextGenerator(const LLmTextGeneratorType &type, const LlmTextGeneratorOpts &opt) : + m_type(type) { + switch (type) { + case LLmTextGeneratorType::kGreedySearch: m_method_class = new _LlmTextGenerateGreedySearchMethod(); break; + case LLmTextGeneratorType::kTopkSampling: m_method_class = new _LlmTextGenerateTopkSamplingMethod(opt.top_k, opt.temperature); break; + case LLmTextGeneratorType::kToppSampling: m_method_class = new _LlmTextGenerateToppSamplingMethod(opt.top_p, opt.temperature); break; + default: + assert(false && "NIY"); + break; + } + } + + inline unsigned int generate(Tensor &t) { + return m_method_class->generate(t); + } + + inline LLmTextGeneratorType type() { + return m_type; + } + +private: + LLmTextGeneratorType m_type; + _LlmTextGenerateMethod *m_method_class = nullptr; +}; + +} // namespace mllm + +#endif //! MLLM_GENERATE_HPP \ No newline at end of file diff --git a/src/Module.hpp b/src/Module.hpp index 42bf6ff8..b8dd5583 100644 --- a/src/Module.hpp +++ b/src/Module.hpp @@ -4,6 +4,7 @@ #ifndef MODULE_HPP #define MODULE_HPP +#include "Generate.hpp" #include "Tensor.hpp" #include "Op.hpp" #include "ParamLoader.hpp" @@ -12,8 +13,10 @@ #include "backends/cpu/CPUBackend.hpp" #include +#include #include #include +#include #include #include #include @@ -23,11 +26,12 @@ namespace mllm { class Module { private: double load_time_; - int prefilling_token_size_=0; - int decoding_token_size_=0; + int prefilling_token_size_ = 0; + int decoding_token_size_ = 0; vector inference_times_; vector> last_shape_bshd_; - + std::shared_ptr text_generator_ = nullptr; + public: static map backends; static AbstructLoader *loader; @@ -68,24 +72,23 @@ class Module { vector tmps; int max_in_size = 5; for (int i = 0; i < max_in_size; ++i) { - Tensor::graphs["input"+std::to_string(i)] = std::make_shared(Module::backends[MLLM_CPU]); - Tensor::graphs["input"+std::to_string(i)]->setName("input"+std::to_string(i)); - tmps.push_back(*Tensor::graphs["input"+std::to_string(i)]); + Tensor::graphs["input" + std::to_string(i)] = std::make_shared(Module::backends[MLLM_CPU]); + Tensor::graphs["input" + std::to_string(i)]->setName("input" + std::to_string(i)); + tmps.push_back(*Tensor::graphs["input" + std::to_string(i)]); } - vector alternate_args={ + vector alternate_args = { {}, vector{0, 0}, - std::vector>(32, std::vector(2)) - }; + std::vector>(32, std::vector(2))}; uint64_t time_start = 0; for (auto args : alternate_args) { time_start = mllm_time_us(); try { operator()(tmps, args); break; - } catch (const std::exception& e) { + } catch (const std::exception &e) { #if not defined(__ARM_NEON) - if(std::string("bad any_cast") != e.what()) { + if (std::string("bad any_cast") != e.what()) { std::cerr << e.what() << std::endl; exit(0); } @@ -93,10 +96,10 @@ class Module { } catch (...) { std::cerr << "load error" << std::endl; exit(0); - } + } } uint64_t time_end = mllm_time_us(); - load_time_ = (time_end - time_start) / 1000.0F;//ms + load_time_ = (time_end - time_start) / 1000.0F; // ms Module::doLoad = false; // Tensor::graphs.clear(); } @@ -104,15 +107,15 @@ class Module { void load(AbstructLoader ¶m_loader) { Tensor::graphs.clear(); Tensor::tensor_status = TENSOR_STATIC_INIT; - + loader = ¶m_loader; Module::doLoad = true; vector tmps; int max_in_size = 5; for (int i = 0; i < max_in_size; ++i) { - Tensor::graphs["input"+std::to_string(i)] = std::make_shared(Module::backends[MLLM_CPU]); - Tensor::graphs["input"+std::to_string(i)]->setName("input"+std::to_string(i)); - tmps.push_back(*Tensor::graphs["input"+std::to_string(i)]); + Tensor::graphs["input" + std::to_string(i)] = std::make_shared(Module::backends[MLLM_CPU]); + Tensor::graphs["input" + std::to_string(i)]->setName("input" + std::to_string(i)); + tmps.push_back(*Tensor::graphs["input" + std::to_string(i)]); } vector tmpt = {0, 0}; operator()(tmps, tmpt); @@ -129,30 +132,27 @@ class Module { template vector operator()(vector inputs, Args... args) { vector anyArgs = convertArgsToAnyVector(args...); - if(doLoad) { + if (doLoad) { return Forward(inputs, anyArgs); } if (inputs[0].ttype() == TensorType::INPUT_TENSOR) { - if(prefilling_token_size_==0){ // first time init + if (prefilling_token_size_ == 0) { // first time init // if(!Tensor::graphs.empty()){ // Tensor::graphs.clear(); // } prefilling_token_size_ = inputs[0].sequence(); - }else if(decoding_token_size_==0){ + } else if (decoding_token_size_ == 0) { decoding_token_size_ = inputs[0].sequence(); } bool need_setup = true; for (int i = 0; i < inputs.size(); i++) { auto &input = inputs[i]; - input.setName("input"+std::to_string(i)); + input.setName("input" + std::to_string(i)); input.setTtype(TensorType::NORMAL_TENSOR); Tensor::graphs[input.name()] = std::shared_ptr(&input, [](Tensor *) {}); - if(inputs[0].sequence()!=1 && !last_shape_bshd_.empty()){ + if (inputs[0].sequence() != 1 && !last_shape_bshd_.empty()) { // if LLM/VLLM model, the `need_setup` should be `true` - if(input.batch() == last_shape_bshd_[i][0] & - input.sequence() == last_shape_bshd_[i][1] & - input.head() == last_shape_bshd_[i][2] & - input.dimension() == last_shape_bshd_[i][3]){ + if (input.batch() == last_shape_bshd_[i][0] & input.sequence() == last_shape_bshd_[i][1] & input.head() == last_shape_bshd_[i][2] & input.dimension() == last_shape_bshd_[i][3]) { need_setup = false; } } @@ -160,7 +160,7 @@ class Module { Tensor::tensor_status = TENSOR_STATIC_INIT; uint64_t time_start = mllm_time_us(); - if(need_setup){ + if (need_setup) { Forward(inputs, anyArgs); } Tensor::tensor_status = TENSOR_STATIC_READY; @@ -168,11 +168,11 @@ class Module { auto output = Forward(inputs, anyArgs); uint64_t time_end = mllm_time_us(); - double inference_time_ = (time_end - time_start) / 1000.0F;//ms + double inference_time_ = (time_end - time_start) / 1000.0F; // ms inference_times_.push_back(inference_time_); last_shape_bshd_.clear(); for (auto &input : inputs) { - last_shape_bshd_.push_back({input.batch(), input.sequence(), + last_shape_bshd_.push_back({input.batch(), input.sequence(), input.head(), input.dimension()}); } @@ -186,25 +186,25 @@ class Module { static int runlistIdx; template - static vector List(int n) { + static vector List(int n) { static_assert(std::is_base_of::value, "T must be a subclass of Module"); listIdx = 0; vector modules; for (int i = 0; i < n; i++) { modules.push_back(T()); - listIdx ++; + listIdx++; } listIdx = 0; return modules; } // 递归终止函数 - template + template static auto change_last(T value) { return std::make_tuple(value + std::to_string(listIdx) + "."); } // 递归函数 - template + template static auto change_last(T head, Args... tail) { auto tail_tuple = change_last(tail...); return std::tuple_cat(std::make_tuple(head), tail_tuple); @@ -215,16 +215,16 @@ class Module { listIdx = 0; vector modules; for (int i = 0; i < n; i++) { - auto new_args = change_last(args...); // 创建新的参数包,最后一个参数被修改为原来的值+ std::to_string(listIdx)+ "." - modules.push_back(std::move(T(std::apply([&](auto&&... args){ return T(std::forward(args)...); }, new_args)))); + auto new_args = change_last(args...); // 创建新的参数包,最后一个参数被修改为原来的值+ std::to_string(listIdx)+ "." + modules.push_back(std::move(T(std::apply([&](auto &&...args) { return T(std::forward(args)...); }, new_args)))); listIdx++; } listIdx = 0; return modules; } - - void free(){ - Tensor::graphs.clear(); + + void free() { + Tensor::graphs.clear(); } void profiling(string name = "") { @@ -234,16 +234,16 @@ class Module { std::cout << " " << name << std::endl; std::cout << "-------------------------------------------" << std::endl; } - std::cout << " Load time: " << load_time_/1000.0F << " s" << std::endl; - if(inference_times_.size()>1 && decoding_token_size_ != prefilling_token_size_){ + std::cout << " Load time: " << load_time_ / 1000.0F << " s" << std::endl; + if (inference_times_.size() > 1 && decoding_token_size_ != prefilling_token_size_) { std::cout << " Prefilling speed: " << 1000 * prefilling_token_size_ / inference_times_[0] << " tokens/s" << std::endl; - double sum_decoding_time = std::accumulate(std::begin(inference_times_)+1, std::end(inference_times_), 0.0); - double mean_decoding_time = sum_decoding_time / (inference_times_.size()-1); + double sum_decoding_time = std::accumulate(std::begin(inference_times_) + 1, std::end(inference_times_), 0.0); + double mean_decoding_time = sum_decoding_time / (inference_times_.size() - 1); std::cout << " Decoding speed: " << 1000 / mean_decoding_time << " tokens/s" << std::endl; - } else{ + } else { double sum_time = std::accumulate(std::begin(inference_times_), std::end(inference_times_), 0.0); double mean_time = sum_time / (inference_times_.size()); - std::cout << " Inference latency: " << mean_time/1000.0F << " s" << std::endl; + std::cout << " Inference latency: " << mean_time / 1000.0F << " s" << std::endl; } // double sum_time = std::accumulate(std::begin(inference_times_), std::end(inference_times_), 0.0); // std::cout< &call_back = [](unsigned int) -> bool { return true; }) { + auto chatPostProcessing = [](unsigned token_idx, Tensor &tokens_tensor, const vector &clean_tensors) { + tokens_tensor.reshape(1, 1, 1, 1); + tokens_tensor.alloc(); + tokens_tensor.setDataAt(0, 0, 0, 0, token_idx); + + for (auto tensor : clean_tensors) { + tensor->reshape(0, 0, 0, 0); + tensor->alloc(); + } + }; + + if (!opt.do_sample) { + // fail to greedy search + if (!text_generator_ || text_generator_->type() != LLmTextGeneratorType::kGreedySearch) + text_generator_ = std::make_shared(LLmTextGeneratorType::kGreedySearch, opt); + } else if (opt.do_sample && !opt.top_k && opt.top_p != 0.f) { + // fail to top p sampling + if (!text_generator_ || text_generator_->type() != LLmTextGeneratorType::kToppSampling) + text_generator_ = std::make_shared(LLmTextGeneratorType::kToppSampling, opt); + } else if (opt.do_sample && opt.top_k) { + // fail to top k sampling + if (!text_generator_ || text_generator_->type() != LLmTextGeneratorType::kTopkSampling) + text_generator_ = std::make_shared(LLmTextGeneratorType::kTopkSampling, opt); + } + + for (int step = 0; step < opt.max_new_tokens; ++step) { + auto _out = (*this)({input_ids}); + auto out_token = text_generator_->generate(_out[0]); + if (!call_back(out_token)) break; + chatPostProcessing(out_token, input_ids, {}); + } + } }; } // namespace mllm