diff --git a/fundamentals/llama.cpp/src/embeddings.cpp b/fundamentals/llama.cpp/src/embeddings.cpp index 139e875..7269c3a 100644 --- a/fundamentals/llama.cpp/src/embeddings.cpp +++ b/fundamentals/llama.cpp/src/embeddings.cpp @@ -70,6 +70,28 @@ const char* BLUE = "\033[0;34m"; const char* ORANGE = "\033[0;33m"; // Actually yellow, but often appears as orange in many terminals const char* RESET = "\033[0m"; +void print_top_logits(llama_model* model, llama_context* ctx) { + float* logits = llama_get_logits(ctx); + printf("%sTop 5 logits:%s\n", BLUE, RESET); + std::vector> top_logits; + for (int i = 0; i < llama_n_vocab(model); i++) { + top_logits.push_back(std::make_pair(i, logits[i])); + } + std::partial_sort(top_logits.begin(), top_logits.begin() + 5, top_logits.end(), + [](const std::pair& a, + const std::pair& b) { + return a.second > b.second; + }); + for (int i = 0; i < 5; i++) { + printf("%sToken %d (%s): %f%s\n", + BLUE, + top_logits[i].first, + token_as_string(model, top_logits[i].first).c_str(), + top_logits[i].second, + RESET); + } +} + int main(int argc, char** argv) { fprintf(stdout, "llama.cpp embedding exploration\n"); llama_model_params model_params = llama_model_default_params();