Skip to content

Commit

Permalink
[application] add repetition_penalty to generate func
Browse files Browse the repository at this point in the history
add some options to 'generate' function of llm

- add naive repetition_penalty option
- add bad_words option

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Seungbaek Hong <[email protected]>
  • Loading branch information
baek2sm committed Apr 6, 2024
1 parent 345ca62 commit f1e9165
Showing 1 changed file with 46 additions and 5 deletions.
51 changes: 46 additions & 5 deletions Applications/LLaMA/jni/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,53 @@ float applyTKP(float *logits, int len, float temperature, unsigned int top_k,
return top_indices_and_logits[0].second;
}

std::vector<int> generate(float *logits, bool do_sample = false,
/**
* @brief Apply repetition penalty to logits
*/
void applyRepetitionPenalty(float *logits, unsigned int *input_ids,
unsigned int NUM_INPUT_IDS,
float repetition_penalty = 1) {
for (unsigned int i = 0; i < NUM_INPUT_IDS; ++i) {
if (logits[input_ids[i]] < 0) {
logits[input_ids[i]] *= repetition_penalty;
} else {
logits[input_ids[i]] /= repetition_penalty;
}
}
}

/**
* @brief Apply bad words penalty
*/
void applyBadWordsPenalty(float *logits, unsigned int *bad_words_ids,
unsigned int NUM_BAD_WORDS_IDS) {
for (unsigned int i = 0; i < NUM_BAD_WORDS_IDS; ++i) {
logits[bad_words_ids[i]] = -INFINITY;
}
}

std::vector<int> generate(float *logits, unsigned int NUM_VOCAB = 0,
unsigned int NUM_BATCH = 1, bool do_sample = false,
float temperature = 1, unsigned int top_k = 1,
float top_p = 0, unsigned int NUM_BATCH = 1,
unsigned int NUM_VOCAB = 96000) {
float top_p = 1, float repetition_penalty = 1,
unsigned int *input_ids = nullptr,
unsigned int NUM_INPUT_IDS = 0,
unsigned int *bad_words_ids = nullptr,
unsigned int NUM_BAD_WORDS_IDS = 0) {
std::vector<int> outputs;
for (unsigned int iteration = 0; iteration < NUM_BATCH; ++iteration) {

// apply repetition penalty
if (repetition_penalty != 1 && input_ids != nullptr && NUM_INPUT_IDS != 0) {
applyRepetitionPenalty(logits, input_ids, NUM_INPUT_IDS,
repetition_penalty);
}

// apply bad words penalty
if (bad_words_ids != nullptr && NUM_BAD_WORDS_IDS != 0) {
applyBadWordsPenalty(logits, bad_words_ids, NUM_BAD_WORDS_IDS);
}

// return argmax if do_sample is false
if (do_sample == false) {
int argmax_idx =
Expand All @@ -148,7 +189,6 @@ std::vector<int> generate(float *logits, bool do_sample = false,
} else {
// apply temperature & top-k & top-p to logits
float max_logits = applyTKP(logits, NUM_VOCAB, temperature, top_k, top_p);

// transform logits to softmax
float sum_exp_logits = 0;
for (unsigned int i = 0; i < NUM_VOCAB; i++) {
Expand All @@ -173,6 +213,7 @@ std::vector<int> generate(float *logits, bool do_sample = false,

// set batch offset
logits = logits + NUM_VOCAB;
input_ids = input_ids + batch_size;
}

return outputs;
Expand Down Expand Up @@ -554,7 +595,7 @@ void run(std::string text, bool apply_temperature) {
for (unsigned int i = input_len + 1; i < input_len + NUM_TO_GENERATE; ++i) {
auto output_interval =
g_model->incremental_inference(1, input, label, MAX_SEQ_LEN, i - 1, i);
unsigned int ids = generate(output[0], true, 1, 1, 0, 1, NUM_VOCAB)[0];
unsigned int ids = generate(output[0], NUM_VOCAB)[0];

if (i < input_len) {
input_sample[0] = static_cast<float>(init_input[i]);
Expand Down

0 comments on commit f1e9165

Please sign in to comment.