Skip to content

Commit

Permalink
[Wait for #2536][application] add generate_multiple_tokens for llm
Browse files Browse the repository at this point in the history
Added generate_multiple_tokens function for first generation on llm.

This function takes one logits and generates multiple output tokens.
To meet the purpose of the target application,
even if input are multiple logits,
only the first logits is used to generate multiple output tokens.

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Seungbaek Hong <[email protected]>
  • Loading branch information
baek2sm committed Apr 26, 2024
1 parent f9a4cd4 commit 16c2324
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions Applications/LLaMA/jni/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ int const NUM_HEADS = 18;
int const MULTIPLE_OF = 256;

float const NORM_EPS = 0.000001;
float const EPS = 0.000001;
int const NUM_VOCAB = 96000;
int MAX_SEQ_LEN = 1024;
int NUM_TO_GENERATE = 100;
Expand Down Expand Up @@ -219,6 +220,39 @@ std::vector<int> generate(float *logits, unsigned int NUM_VOCAB = 0,
return outputs;
}

std::vector<int> generate_multi_tokens(
float *logits, unsigned int NUM_VOCAB = 0, unsigned int NUM_TARGET_TOKENS = 1,
float repetition_penalty = 1, unsigned int *input_ids = nullptr,
unsigned int NUM_INPUT_IDS = 0, unsigned int *bad_words_ids = nullptr,
unsigned int NUM_BAD_WORDS_IDS = 0) {
std::vector<int> outputs;
// apply repetition penalty
if (abs(repetition_penalty - 1) < EPS && input_ids != nullptr && NUM_INPUT_IDS != 0) {
applyRepetitionPenalty(logits, input_ids, NUM_INPUT_IDS,
repetition_penalty);
}

// apply bad words penalty
if (bad_words_ids != nullptr && NUM_BAD_WORDS_IDS != 0) {
applyBadWordsPenalty(logits, bad_words_ids, NUM_BAD_WORDS_IDS);
}

// Sort and generate multiple tokens
std::vector<std::pair<int, float>> top_indices_and_logits;
for (unsigned int i = 0; i < NUM_VOCAB; ++i) {
top_indices_and_logits.push_back({i, logits[i]});
}
sort(top_indices_and_logits.begin(), top_indices_and_logits.end(),
[](auto &a, auto &b) { return a.second > b.second; });

// add sampled words
for (unsigned int i = 0; i < NUM_TARGET_TOKENS; ++i) {
outputs.push_back(top_indices_and_logits[i].first);
}

return outputs;
}

template <typename T>
T unwrap(std::optional<T> &&value, const std::string &error_msg) {
if (value.has_value()) {
Expand Down

0 comments on commit 16c2324

Please sign in to comment.