forked from ggerganov/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
llama : add simple-chat example (ggerganov#10124)
* llama : add simple-chat example --------- Co-authored-by: Xuan Son Nguyen <[email protected]>
- Loading branch information
Showing
6 changed files
with
220 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
set(TARGET llama-simple-chat) | ||
add_executable(${TARGET} simple-chat.cpp) | ||
install(TARGETS ${TARGET} RUNTIME) | ||
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) | ||
target_compile_features(${TARGET} PRIVATE cxx_std_11) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# llama.cpp/example/simple-chat | ||
|
||
The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the chat template from the GGUF file. | ||
|
||
```bash | ||
./llama-simple-chat -m Meta-Llama-3.1-8B-Instruct.gguf -c 2048 | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
#include "llama.h" | ||
#include <cstdio> | ||
#include <cstring> | ||
#include <iostream> | ||
#include <string> | ||
#include <vector> | ||
|
||
static void print_usage(int, char ** argv) { | ||
printf("\nexample usage:\n"); | ||
printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]); | ||
printf("\n"); | ||
} | ||
|
||
int main(int argc, char ** argv) { | ||
std::string model_path; | ||
int ngl = 99; | ||
int n_ctx = 2048; | ||
|
||
// parse command line arguments | ||
for (int i = 1; i < argc; i++) { | ||
try { | ||
if (strcmp(argv[i], "-m") == 0) { | ||
if (i + 1 < argc) { | ||
model_path = argv[++i]; | ||
} else { | ||
print_usage(argc, argv); | ||
return 1; | ||
} | ||
} else if (strcmp(argv[i], "-c") == 0) { | ||
if (i + 1 < argc) { | ||
n_ctx = std::stoi(argv[++i]); | ||
} else { | ||
print_usage(argc, argv); | ||
return 1; | ||
} | ||
} else if (strcmp(argv[i], "-ngl") == 0) { | ||
if (i + 1 < argc) { | ||
ngl = std::stoi(argv[++i]); | ||
} else { | ||
print_usage(argc, argv); | ||
return 1; | ||
} | ||
} else { | ||
print_usage(argc, argv); | ||
return 1; | ||
} | ||
} catch (std::exception & e) { | ||
fprintf(stderr, "error: %s\n", e.what()); | ||
print_usage(argc, argv); | ||
return 1; | ||
} | ||
} | ||
if (model_path.empty()) { | ||
print_usage(argc, argv); | ||
return 1; | ||
} | ||
|
||
// only print errors | ||
llama_log_set([](enum ggml_log_level level, const char * text, void * /* user_data */) { | ||
if (level >= GGML_LOG_LEVEL_ERROR) { | ||
fprintf(stderr, "%s", text); | ||
} | ||
}, nullptr); | ||
|
||
// initialize the model | ||
llama_model_params model_params = llama_model_default_params(); | ||
model_params.n_gpu_layers = ngl; | ||
|
||
llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params); | ||
if (!model) { | ||
fprintf(stderr , "%s: error: unable to load model\n" , __func__); | ||
return 1; | ||
} | ||
|
||
// initialize the context | ||
llama_context_params ctx_params = llama_context_default_params(); | ||
ctx_params.n_ctx = n_ctx; | ||
ctx_params.n_batch = n_ctx; | ||
|
||
llama_context * ctx = llama_new_context_with_model(model, ctx_params); | ||
if (!ctx) { | ||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); | ||
return 1; | ||
} | ||
|
||
// initialize the sampler | ||
llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params()); | ||
llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1)); | ||
llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f)); | ||
llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); | ||
|
||
// helper function to evaluate a prompt and generate a response | ||
auto generate = [&](const std::string & prompt) { | ||
std::string response; | ||
|
||
// tokenize the prompt | ||
const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); | ||
std::vector<llama_token> prompt_tokens(n_prompt_tokens); | ||
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { | ||
GGML_ABORT("failed to tokenize the prompt\n"); | ||
} | ||
|
||
// prepare a batch for the prompt | ||
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); | ||
llama_token new_token_id; | ||
while (true) { | ||
// check if we have enough space in the context to evaluate this batch | ||
int n_ctx = llama_n_ctx(ctx); | ||
int n_ctx_used = llama_get_kv_cache_used_cells(ctx); | ||
if (n_ctx_used + batch.n_tokens > n_ctx) { | ||
printf("\033[0m\n"); | ||
fprintf(stderr, "context size exceeded\n"); | ||
exit(0); | ||
} | ||
|
||
if (llama_decode(ctx, batch)) { | ||
GGML_ABORT("failed to decode\n"); | ||
} | ||
|
||
// sample the next token | ||
new_token_id = llama_sampler_sample(smpl, ctx, -1); | ||
|
||
// is it an end of generation? | ||
if (llama_token_is_eog(model, new_token_id)) { | ||
break; | ||
} | ||
|
||
// convert the token to a string, print it and add it to the response | ||
char buf[256]; | ||
int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true); | ||
if (n < 0) { | ||
GGML_ABORT("failed to convert token to piece\n"); | ||
} | ||
std::string piece(buf, n); | ||
printf("%s", piece.c_str()); | ||
fflush(stdout); | ||
response += piece; | ||
|
||
// prepare the next batch with the sampled token | ||
batch = llama_batch_get_one(&new_token_id, 1); | ||
} | ||
|
||
return response; | ||
}; | ||
|
||
std::vector<llama_chat_message> messages; | ||
std::vector<char> formatted(llama_n_ctx(ctx)); | ||
int prev_len = 0; | ||
while (true) { | ||
// get user input | ||
printf("\033[32m> \033[0m"); | ||
std::string user; | ||
std::getline(std::cin, user); | ||
|
||
if (user.empty()) { | ||
break; | ||
} | ||
|
||
// add the user input to the message list and format it | ||
messages.push_back({"user", strdup(user.c_str())}); | ||
int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); | ||
if (new_len > (int)formatted.size()) { | ||
formatted.resize(new_len); | ||
new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); | ||
} | ||
if (new_len < 0) { | ||
fprintf(stderr, "failed to apply the chat template\n"); | ||
return 1; | ||
} | ||
|
||
// remove previous messages to obtain the prompt to generate the response | ||
std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len); | ||
|
||
// generate a response | ||
printf("\033[33m"); | ||
std::string response = generate(prompt); | ||
printf("\n\033[0m"); | ||
|
||
// add the response to the messages | ||
messages.push_back({"assistant", strdup(response.c_str())}); | ||
prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0); | ||
if (prev_len < 0) { | ||
fprintf(stderr, "failed to apply the chat template\n"); | ||
return 1; | ||
} | ||
} | ||
|
||
// free resources | ||
for (auto & msg : messages) { | ||
free(const_cast<char *>(msg.content)); | ||
} | ||
llama_sampler_free(smpl); | ||
llama_free(ctx); | ||
llama_free_model(model); | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters