-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #105 from chenghuaWang/main
feat: topk/topp sampling
- Loading branch information
Showing
4 changed files
with
381 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
/** | ||
* @file Generate.cpp | ||
* @author Chenghua Wang ([email protected]) | ||
* @brief The Mllm Generator Impl | ||
* @version 0.1 | ||
* @date 2024-07-30 | ||
* | ||
* @copyright Copyright (c) 2024 | ||
* | ||
*/ | ||
#include "Generate.hpp" | ||
#include <algorithm> | ||
#include <numeric> | ||
|
||
namespace mllm { | ||
|
||
unsigned int _LlmTextGenerateGreedySearchMethod::generate(Tensor &t) { | ||
std::vector<float> scores; | ||
this->_tensor_to_vec(t, scores); | ||
return std::max_element(scores.begin(), scores.end()) - scores.begin(); | ||
} | ||
|
||
unsigned int _LlmTextGenerateTopkSamplingMethod::generate(Tensor &t) { | ||
auto argmax = [](const std::vector<float> &vec) -> unsigned int { | ||
return std::distance(vec.begin(), std::max_element(vec.begin(), vec.end())); | ||
}; | ||
|
||
if (m_k == 0 || m_k == 1) { | ||
std::vector<float> scores; | ||
this->_tensor_to_vec(t, scores); | ||
return argmax(scores); | ||
} | ||
|
||
std::vector<std::pair<float, unsigned int>> scores; | ||
this->_tensor_to_vec_with_idx(t, scores); | ||
|
||
// find top k | ||
std::partial_sort(scores.begin(), scores.begin() + m_k, scores.end(), | ||
[](std::pair<float, unsigned int> a, std::pair<float, unsigned int> b) { return a.first > b.first; }); | ||
std::vector<float> top_k_elements(m_k, 0.f); | ||
std::vector<unsigned int> top_k_elements_idx(m_k, 0); | ||
for (int i = 0; i < m_k; ++i) { | ||
top_k_elements[i] = scores[i].first; | ||
top_k_elements_idx[i] = scores[i].second; | ||
} | ||
|
||
// softmax with temperature | ||
std::vector<float> softmax(top_k_elements.size(), 0.f); | ||
double max_logit = top_k_elements[argmax(top_k_elements)]; | ||
double sum_exp = 0.f; | ||
|
||
for (size_t i = 0; i < top_k_elements.size(); ++i) { | ||
softmax[i] = exp((top_k_elements[i] - max_logit) / m_temperature); | ||
sum_exp += softmax[i]; | ||
} | ||
|
||
for (float &value : softmax) { | ||
value /= sum_exp; | ||
} | ||
|
||
// sampling | ||
float _sum = std::accumulate(softmax.begin(), softmax.end(), 0.0); | ||
for (float &value : softmax) { | ||
value /= _sum; | ||
} | ||
|
||
auto idx = _sample_element(top_k_elements_idx, softmax); | ||
return idx; | ||
} | ||
|
||
unsigned int _LlmTextGenerateToppSamplingMethod::generate(Tensor &t) { | ||
auto argmax = [](const std::vector<float> &vec) -> unsigned int { | ||
return std::distance(vec.begin(), std::max_element(vec.begin(), vec.end())); | ||
}; | ||
std::vector<std::pair<float, unsigned int>> scores; | ||
this->_tensor_to_vec_with_idx(t, scores); | ||
|
||
std::sort(scores.begin(), scores.end(), [](std::pair<float, unsigned int> a, std::pair<float, unsigned int> b) { return a.first > b.first; }); | ||
std::vector<float> top_k_elements; | ||
std::vector<unsigned int> top_k_elements_idx; | ||
|
||
if (scores[0].first > 1.f) { | ||
throw std::runtime_error("The input tensor t should go through softmax first.(0.f - 1.f is acceptable)"); | ||
} | ||
|
||
float p = 0.f; | ||
size_t idx = 0; | ||
while (p < m_p) { | ||
top_k_elements.emplace_back(scores[idx].first); | ||
top_k_elements_idx.emplace_back(scores[idx].second); | ||
p += scores[idx].first; | ||
idx++; | ||
} | ||
|
||
if (top_k_elements.size() == 1) { | ||
return top_k_elements_idx[0]; | ||
} | ||
|
||
// softmax with temperature | ||
std::vector<float> softmax(top_k_elements.size(), 0.f); | ||
double max_logit = top_k_elements[argmax(top_k_elements)]; | ||
double sum_exp = 0.f; | ||
|
||
for (size_t i = 0; i < top_k_elements.size(); ++i) { | ||
softmax[i] = exp((top_k_elements[i] - max_logit) / m_temperature); | ||
sum_exp += softmax[i]; | ||
} | ||
|
||
for (float &value : softmax) { | ||
value /= sum_exp; | ||
} | ||
|
||
// sampling | ||
float _sum = std::accumulate(softmax.begin(), softmax.end(), 0.0); | ||
for (float &value : softmax) { | ||
value /= _sum; | ||
} | ||
|
||
auto ret = _sample_element(top_k_elements_idx, softmax); | ||
return ret; | ||
} | ||
|
||
} // namespace mllm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
/** | ||
* @file Generate.hpp | ||
* @author Chenghua Wang ([email protected]) | ||
* @brief The Mllm Generator | ||
* @version 0.1 | ||
* @date 2024-07-30 | ||
* | ||
* @copyright Copyright (c) 2024 | ||
* | ||
*/ | ||
#pragma once | ||
#ifndef MLLM_GENERATE_HPP | ||
#define MLLM_GENERATE_HPP | ||
#include <cstdint> | ||
#include <cassert> | ||
#include <vector> | ||
#include <random> | ||
#include <utility> | ||
#include "Tensor.hpp" | ||
|
||
namespace mllm { | ||
|
||
struct LlmTextGeneratorOpts { | ||
size_t max_new_tokens = 100; | ||
size_t min_new_tokens = 10; | ||
bool do_sample = true; | ||
float temperature = 0.7; | ||
int top_k = 5; | ||
float top_p = 0.92; | ||
}; | ||
|
||
template <typename T> | ||
T _sample_element(const std::vector<T> &elements, const std::vector<float> &probabilities) { | ||
std::random_device rd; | ||
std::mt19937 gen(rd()); | ||
std::discrete_distribution<> dist(probabilities.begin(), probabilities.end()); | ||
size_t index = dist(gen); | ||
return elements[index]; | ||
} | ||
|
||
enum class LLmTextGeneratorType : int32_t { | ||
kNone = 0, | ||
kGreedySearch, | ||
kTopkSampling, | ||
kToppSampling, | ||
KLast, | ||
}; | ||
|
||
class _LlmTextGenerateMethod { | ||
public: | ||
virtual ~_LlmTextGenerateMethod() = default; | ||
virtual unsigned int generate(Tensor &t) = 0; | ||
inline void _tensor_to_vec(Tensor &t, std::vector<float> &scores) { | ||
assert(t.batch() == 1 && "Batch size of result is not 1. Which is not supported for now."); | ||
assert(t.head() == 1 && "The 3rd dim of result should be one. e.g.:[1, 1, seq, hidden]"); | ||
int _dims = t.dimension(); | ||
int _seq = t.sequence() - 1; | ||
for (int i = 0; i < _dims; ++i) { | ||
auto value = t.dataAt<float>(0, 0, _seq, i); | ||
scores.push_back(value); | ||
} | ||
} | ||
|
||
inline void _tensor_to_vec_with_idx(Tensor &t, std::vector<std::pair<float, unsigned int>> &scores) { | ||
assert(t.batch() == 1 && "Batch size of result is not 1. Which is not supported for now."); | ||
assert(t.head() == 1 && "The 3rd dim of result should be one. e.g.:[1, 1, seq, hidden]"); | ||
int _dims = t.dimension(); | ||
int _seq = t.sequence() - 1; | ||
for (int i = 0; i < _dims; ++i) { | ||
auto value = t.dataAt<float>(0, 0, _seq, i); | ||
scores.push_back(std::make_pair(value, i)); | ||
} | ||
} | ||
}; | ||
|
||
class _LlmTextGenerateGreedySearchMethod : public _LlmTextGenerateMethod { | ||
public: | ||
_LlmTextGenerateGreedySearchMethod() = default; | ||
~_LlmTextGenerateGreedySearchMethod() = default; | ||
unsigned int generate(Tensor &t) override; | ||
}; | ||
|
||
class _LlmTextGenerateTopkSamplingMethod : public _LlmTextGenerateMethod { | ||
public: | ||
~_LlmTextGenerateTopkSamplingMethod() = default; | ||
_LlmTextGenerateTopkSamplingMethod(int32_t k = 5, float temperature = 0.f) : | ||
m_k(k), | ||
m_temperature(temperature) { | ||
} | ||
unsigned int generate(Tensor &t) override; | ||
|
||
private: | ||
int32_t m_k; | ||
float m_temperature = 0.f; | ||
}; | ||
|
||
class _LlmTextGenerateToppSamplingMethod : public _LlmTextGenerateMethod { | ||
public: | ||
~_LlmTextGenerateToppSamplingMethod() = default; | ||
_LlmTextGenerateToppSamplingMethod(float p = 5, float temperature = 0.f) : | ||
m_p(p), | ||
m_temperature(temperature) { | ||
} | ||
unsigned int generate(Tensor &t) override; | ||
|
||
private: | ||
float m_p; | ||
float m_temperature = 0.f; | ||
}; | ||
|
||
// Usage: | ||
// LlmTextGeneratorOpts opt{ | ||
// .max_new_tokens = 100, | ||
// .do_sample = true, | ||
// .temperature = 0.7f, | ||
// .top_k = 50, | ||
// .top_p = 0.f, | ||
// }; | ||
// model.generate(input_tensor, opt, [&](unsigned int out_token) -> bool { | ||
// auto out_string = tokenizer.detokenize({out_token}); | ||
// auto [isOk, print_string] = processOutput(out_string); | ||
// if (isOk) { | ||
// std::cout << print_string << std::flush; | ||
// } else { | ||
// return false; | ||
// } | ||
// return true; | ||
// }); | ||
// printf("\n"); | ||
|
||
class LlmTextGenerator { | ||
public: | ||
~LlmTextGenerator() { | ||
delete m_method_class; | ||
} | ||
|
||
LlmTextGenerator(const LLmTextGeneratorType &type, const LlmTextGeneratorOpts &opt) : | ||
m_type(type) { | ||
switch (type) { | ||
case LLmTextGeneratorType::kGreedySearch: m_method_class = new _LlmTextGenerateGreedySearchMethod(); break; | ||
case LLmTextGeneratorType::kTopkSampling: m_method_class = new _LlmTextGenerateTopkSamplingMethod(opt.top_k, opt.temperature); break; | ||
case LLmTextGeneratorType::kToppSampling: m_method_class = new _LlmTextGenerateToppSamplingMethod(opt.top_p, opt.temperature); break; | ||
default: | ||
assert(false && "NIY"); | ||
break; | ||
} | ||
} | ||
|
||
inline unsigned int generate(Tensor &t) { | ||
return m_method_class->generate(t); | ||
} | ||
|
||
inline LLmTextGeneratorType type() { | ||
return m_type; | ||
} | ||
|
||
private: | ||
LLmTextGeneratorType m_type; | ||
_LlmTextGenerateMethod *m_method_class = nullptr; | ||
}; | ||
|
||
} // namespace mllm | ||
|
||
#endif //! MLLM_GENERATE_HPP |
Oops, something went wrong.