Skip to content

Commit

Permalink
[Wait for #2536][application] add generate_multiple_tokens for llm
Browse files Browse the repository at this point in the history
Added generate_multiple_tokens function for first generation on llm.

This function takes one logits and generates multiple output tokens.
To meet the purpose of the target application,
even if input are multiple logits,
only the first logits is used to generate multiple output tokens.

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Seungbaek Hong <[email protected]>
  • Loading branch information
baek2sm committed Apr 5, 2024
1 parent 510f435 commit 9b08459
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions Applications/LLaMA/jni/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,39 @@ std::vector<int> generate(float *logits, unsigned int NUM_VOCAB = 0,
return outputs;
}

std::vector<int> generate_multi_tokens(
float *logits, unsigned int NUM_VOCAB = 0, unsigned int NUM_TARGET_TOKENS = 1,
float repetition_penalty = 1, unsigned int *input_ids = nullptr,
unsigned int NUM_INPUT_IDS = 0, unsigned int *bad_words_ids = nullptr,
unsigned int NUM_BAD_WORDS_IDS = 0) {
std::vector<int> outputs;
// apply repetition penalty
if (repetition_penalty != 1 && input_ids != nullptr && NUM_INPUT_IDS != 0) {
applyRepetitionPenalty(logits, input_ids, NUM_INPUT_IDS,
repetition_penalty);
}

// apply bad words penalty
if (bad_words_ids != nullptr && NUM_BAD_WORDS_IDS != 0) {
applyBadWordsPenalty(logits, bad_words_ids, NUM_BAD_WORDS_IDS);
}

// Sort and generate multiple tokens
std::vector<std::pair<int, float>> top_indices_and_logits;
for (int i = 0; i < NUM_VOCAB; ++i) {
top_indices_and_logits.push_back({i, logits[i]});
}
sort(top_indices_and_logits.begin(), top_indices_and_logits.end(),
[](auto &a, auto &b) { return a.second > b.second; });

// add sampled words
for (unsigned int i = 0; i < NUM_TARGET_TOKENS; ++i) {
outputs.push_back(top_indices_and_logits[i].first);
}

return outputs;
}

template <typename T>
T unwrap(std::optional<T> &&value, const std::string &error_msg) {
if (value.has_value()) {
Expand Down

0 comments on commit 9b08459

Please sign in to comment.