Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wait for #2536][application] add generate_multiple_tokens for llm @open sesame 04/05 16:14 #2540

Merged
merged 1 commit into from
May 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions Applications/LLaMA/jni/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ int const NUM_HEADS = 18;
int const MULTIPLE_OF = 256;

float const NORM_EPS = 0.000001;
float const EPS = 0.000001;
int const NUM_VOCAB = 96000;
int MAX_SEQ_LEN = 1024;
int NUM_TO_GENERATE = 100;
Expand Down Expand Up @@ -219,6 +220,40 @@ std::vector<int> generate(float *logits, unsigned int NUM_VOCAB = 0,
return outputs;
}

std::vector<int> generate_multi_tokens(
float *logits, unsigned int NUM_VOCAB = 0, unsigned int NUM_TARGET_TOKENS = 1,
float repetition_penalty = 1, unsigned int *input_ids = nullptr,
unsigned int NUM_INPUT_IDS = 0, unsigned int *bad_words_ids = nullptr,
unsigned int NUM_BAD_WORDS_IDS = 0) {
std::vector<int> outputs;
// apply repetition penalty
if (abs(repetition_penalty - 1) < EPS && input_ids != nullptr &&
NUM_INPUT_IDS != 0) {
applyRepetitionPenalty(logits, input_ids, NUM_INPUT_IDS,
repetition_penalty);
}

// apply bad words penalty
if (bad_words_ids != nullptr && NUM_BAD_WORDS_IDS != 0) {
applyBadWordsPenalty(logits, bad_words_ids, NUM_BAD_WORDS_IDS);
}

// Sort and generate multiple tokens
std::vector<std::pair<int, float>> top_indices_and_logits;
for (unsigned int i = 0; i < NUM_VOCAB; ++i) {
top_indices_and_logits.push_back({i, logits[i]});
}
sort(top_indices_and_logits.begin(), top_indices_and_logits.end(),
[](auto &a, auto &b) { return a.second > b.second; });

// add sampled words
for (unsigned int i = 0; i < NUM_TARGET_TOKENS; ++i) {
outputs.push_back(top_indices_and_logits[i].first);
}

return outputs;
}

template <typename T>
T unwrap(std::optional<T> &&value, const std::string &error_msg) {
if (value.has_value()) {
Expand Down