diff --git a/Makefile b/Makefile index 87fe795aa8432..0604a10a4bf86 100644 --- a/Makefile +++ b/Makefile @@ -34,6 +34,7 @@ BUILD_TARGETS = \ llama-server \ llama-simple \ llama-simple-chat \ + llama-ramalama-core \ llama-speculative \ llama-tokenize \ llama-vdot \ @@ -1382,6 +1383,11 @@ llama-infill: examples/infill/infill.cpp \ $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) +llama-ramalama-core: examples/ramalama-core/ramalama-core.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + llama-simple: examples/simple/simple.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index d63a96c1c2547..524dbfe23d9f6 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -47,6 +47,7 @@ else() add_subdirectory(sycl) endif() add_subdirectory(save-load-state) + add_subdirectory(ramalama-core) add_subdirectory(simple) add_subdirectory(simple-chat) add_subdirectory(speculative) diff --git a/examples/ramalama-core/CMakeLists.txt b/examples/ramalama-core/CMakeLists.txt new file mode 100644 index 0000000000000..57b3312abdccb --- /dev/null +++ b/examples/ramalama-core/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-ramalama-core) +add_executable(${TARGET} ramalama-core.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/ramalama-core/README.md b/examples/ramalama-core/README.md new file mode 100644 index 0000000000000..16797b7f64d92 --- /dev/null +++ b/examples/ramalama-core/README.md @@ -0,0 +1,7 @@ +# llama.cpp/example/ramalama-core + +The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the chat template from the GGUF file. + +```bash +./llama-ramalama-core -m Meta-Llama-3.1-8B-Instruct.gguf -c 2048 +... diff --git a/examples/ramalama-core/ramalama-core.cpp b/examples/ramalama-core/ramalama-core.cpp new file mode 100644 index 0000000000000..dafcd385566ec --- /dev/null +++ b/examples/ramalama-core/ramalama-core.cpp @@ -0,0 +1,356 @@ +#include +#include +#include +#include +#include +#include + +#include "llama.h" + +// Add a message to `messages` and store its content in `owned_content` +static void add_message(const std::string & role, const std::string & text, std::vector & messages, + std::vector> & owned_content) { + auto content = std::unique_ptr(new char[text.size() + 1]); + std::strcpy(content.get(), text.c_str()); + messages.push_back({role.c_str(), content.get()}); + owned_content.push_back(std::move(content)); +} + +// Function to apply the chat template and resize `formatted` if needed +static int apply_chat_template(const llama_model * model, const std::vector & messages, + std::vector & formatted, const bool append) { + int result = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), append, formatted.data(), + formatted.size()); + if (result > static_cast(formatted.size())) { + formatted.resize(result); + result = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), append, formatted.data(), + formatted.size()); + } + + return result; +} + +// Function to tokenize the prompt +static int tokenize_prompt(const llama_model * model, const std::string & prompt, + std::vector & prompt_tokens) { + const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); + prompt_tokens.resize(n_prompt_tokens); + if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < + 0) { + GGML_ABORT("failed to tokenize the prompt\n"); + } + + return n_prompt_tokens; +} + +// Check if we have enough space in the context to evaluate this batch +static int check_context_size(const llama_context * ctx, const llama_batch & batch) { + const int n_ctx = llama_n_ctx(ctx); + const int n_ctx_used = llama_get_kv_cache_used_cells(ctx); + if (n_ctx_used + batch.n_tokens > n_ctx) { + printf("\033[0m\n"); + fprintf(stderr, "context size exceeded\n"); + return 1; + } + + return 0; +} + +// convert the token to a string +static int convert_token_to_string(const llama_model * model, const llama_token token_id, std::string & piece) { + char buf[256]; + int n = llama_token_to_piece(model, token_id, buf, sizeof(buf), 0, true); + if (n < 0) { + GGML_ABORT("failed to convert token to piece\n"); + } + + piece = std::string(buf, n); + return 0; +} + +static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) { + printf("%s", piece.c_str()); + fflush(stdout); + response += piece; +} + +// helper function to evaluate a prompt and generate a response +static int generate(const llama_model * model, llama_sampler * smpl, llama_context * ctx, const std::string & prompt, + std::string & response) { + std::vector prompt_tokens; + const int n_prompt_tokens = tokenize_prompt(model, prompt, prompt_tokens); + if (n_prompt_tokens < 0) { + return 1; + } + + // prepare a batch for the prompt + llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + llama_token new_token_id; + while (true) { + check_context_size(ctx, batch); + if (llama_decode(ctx, batch)) { + GGML_ABORT("failed to decode\n"); + } + + // sample the next token, check is it an end of generation? + new_token_id = llama_sampler_sample(smpl, ctx, -1); + if (llama_token_is_eog(model, new_token_id)) { + break; + } + + std::string piece; + if (convert_token_to_string(model, new_token_id, piece)) { + return 1; + } + + print_word_and_concatenate_to_response(piece, response); + + // prepare the next batch with the sampled token + batch = llama_batch_get_one(&new_token_id, 1); + } + + return 0; +} + +static void print_usage(int, const char ** argv) { + printf("\nexample usage:\n"); + printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]); + printf("\n"); +} + +static int parse_int_arg(const char * arg, int & value) { + char * end; + long val = std::strtol(arg, &end, 10); + if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) { + value = static_cast(val); + return 0; + } + + return 1; +} + +static int handle_model_path(const int argc, const char ** argv, int & i, std::string & model_path) { + if (i + 1 < argc) { + model_path = argv[++i]; + return 0; + } + + print_usage(argc, argv); + return 1; +} + +static int handle_n_ctx(const int argc, const char ** argv, int & i, int & n_ctx) { + if (i + 1 < argc) { + if (parse_int_arg(argv[++i], n_ctx)) { + return 0; + } else { + fprintf(stderr, "error: invalid value for -c: %s\n", argv[i]); + print_usage(argc, argv); + } + } else { + print_usage(argc, argv); + } + + return 1; +} + +static int handle_ngl(const int argc, const char ** argv, int & i, int & ngl) { + if (i + 1 < argc) { + if (parse_int_arg(argv[++i], ngl)) { + return 0; + } else { + fprintf(stderr, "error: invalid value for -ngl: %s\n", argv[i]); + print_usage(argc, argv); + } + } else { + print_usage(argc, argv); + } + + return 1; +} + +static int parse_arguments(const int argc, const char ** argv, std::string & model_path, int & n_ctx, int & ngl) { + for (int i = 1; i < argc; ++i) { + if (strcmp(argv[i], "-m") == 0) { + if (handle_model_path(argc, argv, i, model_path)) { + return 1; + } + } else if (strcmp(argv[i], "-c") == 0) { + if (handle_n_ctx(argc, argv, i, n_ctx)) { + return 1; + } + } else if (strcmp(argv[i], "-ngl") == 0) { + if (handle_ngl(argc, argv, i, ngl)) { + return 1; + } + } else { + print_usage(argc, argv); + return 1; + } + } + + if (model_path.empty()) { + print_usage(argc, argv); + return 1; + } + + return 0; +} + +static int read_user_input(std::string & user_input) { + // Use unique_ptr with free as the deleter + std::unique_ptr buffer(nullptr, &free); + + size_t buffer_size = 0; + char * raw_buffer = nullptr; + + // Use getline to dynamically allocate the buffer and get input + const ssize_t line_size = getline(&raw_buffer, &buffer_size, stdin); + + // Transfer ownership to unique_ptr + buffer.reset(raw_buffer); + + if (line_size > 0) { + // Remove the trailing newline character if present + if (buffer.get()[line_size - 1] == '\n') { + buffer.get()[line_size - 1] = '\0'; + } + + user_input = std::string(buffer.get()); + + return 0; // Success + } + + user_input.clear(); + + return 1; // Indicate an error or empty input +} + +// Function to generate a response based on the prompt +static int generate_response(llama_model * model, llama_sampler * sampler, llama_context * context, + const std::string & prompt, std::string & response) { + // Set response color + printf("\033[33m"); + if (generate(model, sampler, context, prompt, response)) { + fprintf(stderr, "failed to generate response\n"); + return 1; + } + + // End response with color reset and newline + printf("\n\033[0m"); + return 0; +} + +// The main chat loop where user inputs are processed and responses generated. +static int chat_loop(llama_model * model, llama_sampler * sampler, llama_context * context, + std::vector & messages) { + std::vector> owned_content; + std::vector formatted(llama_n_ctx(context)); + int prev_len = 0; + + while (true) { + // Print prompt for user input + printf("\033[32m> \033[0m"); + std::string user; + if (read_user_input(user)) { + break; + } + + add_message("user", user, messages, owned_content); + int new_len = apply_chat_template(model, messages, formatted, true); + if (new_len < 0) { + fprintf(stderr, "failed to apply the chat template\n"); + return 1; + } + + std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len); + std::string response; + if (generate_response(model, sampler, context, prompt, response)) { + return 1; + } + + add_message("assistant", response, messages, owned_content); + prev_len = apply_chat_template(model, messages, formatted, false); + if (prev_len < 0) { + fprintf(stderr, "failed to apply the chat template\n"); + return 1; + } + } + + return 0; +} + +static void log_callback(const enum ggml_log_level level, const char * text, void *) { + if (level == GGML_LOG_LEVEL_ERROR) { + fprintf(stderr, "%s", text); + } +} + +// Initializes the model and returns a unique pointer to it. +static std::unique_ptr initialize_model(const std::string & model_path, + int ngl) { + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = ngl; + + auto model = std::unique_ptr( + llama_load_model_from_file(model_path.c_str(), model_params), llama_free_model); + if (!model) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + } + + return model; +} + +// Initializes the context with the specified parameters. +static std::unique_ptr initialize_context(llama_model * model, int n_ctx) { + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = n_ctx; + ctx_params.n_batch = n_ctx; + + auto context = std::unique_ptr( + llama_new_context_with_model(model, ctx_params), llama_free); + if (!context) { + fprintf(stderr, "%s: error: failed to create the llama_context\n", __func__); + } + + return context; +} + +// Initializes and configures the sampler. +static std::unique_ptr initialize_sampler() { + auto sampler = std::unique_ptr( + llama_sampler_chain_init(llama_sampler_chain_default_params()), llama_sampler_free); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1)); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f)); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); + + return sampler; +} + +int main(int argc, const char ** argv) { + std::string model_path; + int ngl = 99; + int n_ctx = 2048; + if (parse_arguments(argc, argv, model_path, n_ctx, ngl)) { + return 1; + } + + llama_log_set(log_callback, nullptr); + auto model = initialize_model(model_path, ngl); + if (!model) { + return 1; + } + + auto context = initialize_context(model.get(), n_ctx); + if (!context) { + return 1; + } + + auto sampler = initialize_sampler(); + std::vector messages; + if (chat_loop(model.get(), sampler.get(), context.get(), messages)) { + return 1; + } + + return 0; +} diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 5f9973163732d..dafcd385566ec 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -1,197 +1,356 @@ -#include "llama.h" +#include #include #include -#include +#include #include #include -static void print_usage(int, char ** argv) { +#include "llama.h" + +// Add a message to `messages` and store its content in `owned_content` +static void add_message(const std::string & role, const std::string & text, std::vector & messages, + std::vector> & owned_content) { + auto content = std::unique_ptr(new char[text.size() + 1]); + std::strcpy(content.get(), text.c_str()); + messages.push_back({role.c_str(), content.get()}); + owned_content.push_back(std::move(content)); +} + +// Function to apply the chat template and resize `formatted` if needed +static int apply_chat_template(const llama_model * model, const std::vector & messages, + std::vector & formatted, const bool append) { + int result = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), append, formatted.data(), + formatted.size()); + if (result > static_cast(formatted.size())) { + formatted.resize(result); + result = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), append, formatted.data(), + formatted.size()); + } + + return result; +} + +// Function to tokenize the prompt +static int tokenize_prompt(const llama_model * model, const std::string & prompt, + std::vector & prompt_tokens) { + const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); + prompt_tokens.resize(n_prompt_tokens); + if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < + 0) { + GGML_ABORT("failed to tokenize the prompt\n"); + } + + return n_prompt_tokens; +} + +// Check if we have enough space in the context to evaluate this batch +static int check_context_size(const llama_context * ctx, const llama_batch & batch) { + const int n_ctx = llama_n_ctx(ctx); + const int n_ctx_used = llama_get_kv_cache_used_cells(ctx); + if (n_ctx_used + batch.n_tokens > n_ctx) { + printf("\033[0m\n"); + fprintf(stderr, "context size exceeded\n"); + return 1; + } + + return 0; +} + +// convert the token to a string +static int convert_token_to_string(const llama_model * model, const llama_token token_id, std::string & piece) { + char buf[256]; + int n = llama_token_to_piece(model, token_id, buf, sizeof(buf), 0, true); + if (n < 0) { + GGML_ABORT("failed to convert token to piece\n"); + } + + piece = std::string(buf, n); + return 0; +} + +static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) { + printf("%s", piece.c_str()); + fflush(stdout); + response += piece; +} + +// helper function to evaluate a prompt and generate a response +static int generate(const llama_model * model, llama_sampler * smpl, llama_context * ctx, const std::string & prompt, + std::string & response) { + std::vector prompt_tokens; + const int n_prompt_tokens = tokenize_prompt(model, prompt, prompt_tokens); + if (n_prompt_tokens < 0) { + return 1; + } + + // prepare a batch for the prompt + llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + llama_token new_token_id; + while (true) { + check_context_size(ctx, batch); + if (llama_decode(ctx, batch)) { + GGML_ABORT("failed to decode\n"); + } + + // sample the next token, check is it an end of generation? + new_token_id = llama_sampler_sample(smpl, ctx, -1); + if (llama_token_is_eog(model, new_token_id)) { + break; + } + + std::string piece; + if (convert_token_to_string(model, new_token_id, piece)) { + return 1; + } + + print_word_and_concatenate_to_response(piece, response); + + // prepare the next batch with the sampled token + batch = llama_batch_get_one(&new_token_id, 1); + } + + return 0; +} + +static void print_usage(int, const char ** argv) { printf("\nexample usage:\n"); printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]); printf("\n"); } -int main(int argc, char ** argv) { - std::string model_path; - int ngl = 99; - int n_ctx = 2048; +static int parse_int_arg(const char * arg, int & value) { + char * end; + long val = std::strtol(arg, &end, 10); + if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) { + value = static_cast(val); + return 0; + } - // parse command line arguments - for (int i = 1; i < argc; i++) { - try { - if (strcmp(argv[i], "-m") == 0) { - if (i + 1 < argc) { - model_path = argv[++i]; - } else { - print_usage(argc, argv); - return 1; - } - } else if (strcmp(argv[i], "-c") == 0) { - if (i + 1 < argc) { - n_ctx = std::stoi(argv[++i]); - } else { - print_usage(argc, argv); - return 1; - } - } else if (strcmp(argv[i], "-ngl") == 0) { - if (i + 1 < argc) { - ngl = std::stoi(argv[++i]); - } else { - print_usage(argc, argv); - return 1; - } - } else { - print_usage(argc, argv); + return 1; +} + +static int handle_model_path(const int argc, const char ** argv, int & i, std::string & model_path) { + if (i + 1 < argc) { + model_path = argv[++i]; + return 0; + } + + print_usage(argc, argv); + return 1; +} + +static int handle_n_ctx(const int argc, const char ** argv, int & i, int & n_ctx) { + if (i + 1 < argc) { + if (parse_int_arg(argv[++i], n_ctx)) { + return 0; + } else { + fprintf(stderr, "error: invalid value for -c: %s\n", argv[i]); + print_usage(argc, argv); + } + } else { + print_usage(argc, argv); + } + + return 1; +} + +static int handle_ngl(const int argc, const char ** argv, int & i, int & ngl) { + if (i + 1 < argc) { + if (parse_int_arg(argv[++i], ngl)) { + return 0; + } else { + fprintf(stderr, "error: invalid value for -ngl: %s\n", argv[i]); + print_usage(argc, argv); + } + } else { + print_usage(argc, argv); + } + + return 1; +} + +static int parse_arguments(const int argc, const char ** argv, std::string & model_path, int & n_ctx, int & ngl) { + for (int i = 1; i < argc; ++i) { + if (strcmp(argv[i], "-m") == 0) { + if (handle_model_path(argc, argv, i, model_path)) { return 1; } - } catch (std::exception & e) { - fprintf(stderr, "error: %s\n", e.what()); + } else if (strcmp(argv[i], "-c") == 0) { + if (handle_n_ctx(argc, argv, i, n_ctx)) { + return 1; + } + } else if (strcmp(argv[i], "-ngl") == 0) { + if (handle_ngl(argc, argv, i, ngl)) { + return 1; + } + } else { print_usage(argc, argv); return 1; } } + if (model_path.empty()) { print_usage(argc, argv); return 1; } - // only print errors - llama_log_set([](enum ggml_log_level level, const char * text, void * /* user_data */) { - if (level >= GGML_LOG_LEVEL_ERROR) { - fprintf(stderr, "%s", text); - } - }, nullptr); - - // initialize the model - llama_model_params model_params = llama_model_default_params(); - model_params.n_gpu_layers = ngl; + return 0; +} - llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params); - if (!model) { - fprintf(stderr , "%s: error: unable to load model\n" , __func__); - return 1; - } +static int read_user_input(std::string & user_input) { + // Use unique_ptr with free as the deleter + std::unique_ptr buffer(nullptr, &free); - // initialize the context - llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = n_ctx; - ctx_params.n_batch = n_ctx; + size_t buffer_size = 0; + char * raw_buffer = nullptr; - llama_context * ctx = llama_new_context_with_model(model, ctx_params); - if (!ctx) { - fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); - return 1; - } + // Use getline to dynamically allocate the buffer and get input + const ssize_t line_size = getline(&raw_buffer, &buffer_size, stdin); - // initialize the sampler - llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params()); - llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1)); - llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f)); - llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); + // Transfer ownership to unique_ptr + buffer.reset(raw_buffer); - // helper function to evaluate a prompt and generate a response - auto generate = [&](const std::string & prompt) { - std::string response; - - // tokenize the prompt - const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); - std::vector prompt_tokens(n_prompt_tokens); - if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) { - GGML_ABORT("failed to tokenize the prompt\n"); + if (line_size > 0) { + // Remove the trailing newline character if present + if (buffer.get()[line_size - 1] == '\n') { + buffer.get()[line_size - 1] = '\0'; } - // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); - llama_token new_token_id; - while (true) { - // check if we have enough space in the context to evaluate this batch - int n_ctx = llama_n_ctx(ctx); - int n_ctx_used = llama_get_kv_cache_used_cells(ctx); - if (n_ctx_used + batch.n_tokens > n_ctx) { - printf("\033[0m\n"); - fprintf(stderr, "context size exceeded\n"); - exit(0); - } + user_input = std::string(buffer.get()); - if (llama_decode(ctx, batch)) { - GGML_ABORT("failed to decode\n"); - } - - // sample the next token - new_token_id = llama_sampler_sample(smpl, ctx, -1); + return 0; // Success + } - // is it an end of generation? - if (llama_token_is_eog(model, new_token_id)) { - break; - } + user_input.clear(); - // convert the token to a string, print it and add it to the response - char buf[256]; - int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true); - if (n < 0) { - GGML_ABORT("failed to convert token to piece\n"); - } - std::string piece(buf, n); - printf("%s", piece.c_str()); - fflush(stdout); - response += piece; + return 1; // Indicate an error or empty input +} - // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1); - } +// Function to generate a response based on the prompt +static int generate_response(llama_model * model, llama_sampler * sampler, llama_context * context, + const std::string & prompt, std::string & response) { + // Set response color + printf("\033[33m"); + if (generate(model, sampler, context, prompt, response)) { + fprintf(stderr, "failed to generate response\n"); + return 1; + } - return response; - }; + // End response with color reset and newline + printf("\n\033[0m"); + return 0; +} - std::vector messages; - std::vector formatted(llama_n_ctx(ctx)); +// The main chat loop where user inputs are processed and responses generated. +static int chat_loop(llama_model * model, llama_sampler * sampler, llama_context * context, + std::vector & messages) { + std::vector> owned_content; + std::vector formatted(llama_n_ctx(context)); int prev_len = 0; + while (true) { - // get user input + // Print prompt for user input printf("\033[32m> \033[0m"); std::string user; - std::getline(std::cin, user); - - if (user.empty()) { + if (read_user_input(user)) { break; } - // add the user input to the message list and format it - messages.push_back({"user", strdup(user.c_str())}); - int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); - if (new_len > (int)formatted.size()) { - formatted.resize(new_len); - new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); - } + add_message("user", user, messages, owned_content); + int new_len = apply_chat_template(model, messages, formatted, true); if (new_len < 0) { fprintf(stderr, "failed to apply the chat template\n"); return 1; } - // remove previous messages to obtain the prompt to generate the response std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len); + std::string response; + if (generate_response(model, sampler, context, prompt, response)) { + return 1; + } - // generate a response - printf("\033[33m"); - std::string response = generate(prompt); - printf("\n\033[0m"); - - // add the response to the messages - messages.push_back({"assistant", strdup(response.c_str())}); - prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0); + add_message("assistant", response, messages, owned_content); + prev_len = apply_chat_template(model, messages, formatted, false); if (prev_len < 0) { fprintf(stderr, "failed to apply the chat template\n"); return 1; } } - // free resources - for (auto & msg : messages) { - free(const_cast(msg.content)); + return 0; +} + +static void log_callback(const enum ggml_log_level level, const char * text, void *) { + if (level == GGML_LOG_LEVEL_ERROR) { + fprintf(stderr, "%s", text); + } +} + +// Initializes the model and returns a unique pointer to it. +static std::unique_ptr initialize_model(const std::string & model_path, + int ngl) { + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = ngl; + + auto model = std::unique_ptr( + llama_load_model_from_file(model_path.c_str(), model_params), llama_free_model); + if (!model) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + } + + return model; +} + +// Initializes the context with the specified parameters. +static std::unique_ptr initialize_context(llama_model * model, int n_ctx) { + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = n_ctx; + ctx_params.n_batch = n_ctx; + + auto context = std::unique_ptr( + llama_new_context_with_model(model, ctx_params), llama_free); + if (!context) { + fprintf(stderr, "%s: error: failed to create the llama_context\n", __func__); + } + + return context; +} + +// Initializes and configures the sampler. +static std::unique_ptr initialize_sampler() { + auto sampler = std::unique_ptr( + llama_sampler_chain_init(llama_sampler_chain_default_params()), llama_sampler_free); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1)); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f)); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); + + return sampler; +} + +int main(int argc, const char ** argv) { + std::string model_path; + int ngl = 99; + int n_ctx = 2048; + if (parse_arguments(argc, argv, model_path, n_ctx, ngl)) { + return 1; + } + + llama_log_set(log_callback, nullptr); + auto model = initialize_model(model_path, ngl); + if (!model) { + return 1; + } + + auto context = initialize_context(model.get(), n_ctx); + if (!context) { + return 1; + } + + auto sampler = initialize_sampler(); + std::vector messages; + if (chat_loop(model.get(), sampler.get(), context.get(), messages)) { + return 1; } - llama_sampler_free(smpl); - llama_free(ctx); - llama_free_model(model); return 0; }