Skip to content

Commit

Permalink
src: add print_top_logits function to embeddings.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
danbev committed Jan 15, 2025
1 parent d6fd70a commit a788d61
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions fundamentals/llama.cpp/src/embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,28 @@ const char* BLUE = "\033[0;34m";
const char* ORANGE = "\033[0;33m"; // Actually yellow, but often appears as orange in many terminals
const char* RESET = "\033[0m";

void print_top_logits(llama_model* model, llama_context* ctx) {
float* logits = llama_get_logits(ctx);
printf("%sTop 5 logits:%s\n", BLUE, RESET);
std::vector<std::pair<llama_token, float>> top_logits;
for (int i = 0; i < llama_n_vocab(model); i++) {
top_logits.push_back(std::make_pair(i, logits[i]));
}
std::partial_sort(top_logits.begin(), top_logits.begin() + 5, top_logits.end(),
[](const std::pair<llama_token, float>& a,
const std::pair<llama_token, float>& b) {
return a.second > b.second;
});
for (int i = 0; i < 5; i++) {
printf("%sToken %d (%s): %f%s\n",
BLUE,
top_logits[i].first,
token_as_string(model, top_logits[i].first).c_str(),
top_logits[i].second,
RESET);
}
}

int main(int argc, char** argv) {
fprintf(stdout, "llama.cpp embedding exploration\n");
llama_model_params model_params = llama_model_default_params();
Expand Down

0 comments on commit a788d61

Please sign in to comment.