-
Notifications
You must be signed in to change notification settings - Fork 9.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Avoid manual memory cleanups. Less memory leaks in the code now. Avoid printing multiple dots. Split code into smaller functions. Use C-style IO, rather than a mix of C++ streams and C style. No exception handling. Signed-off-by: Eric Curtin <[email protected]>
- Loading branch information
1 parent
1842922
commit 83988df
Showing
6 changed files
with
672 additions
and
138 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
set(TARGET llama-ramalama-core) | ||
add_executable(${TARGET} ramalama-core.cpp) | ||
install(TARGETS ${TARGET} RUNTIME) | ||
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) | ||
target_compile_features(${TARGET} PRIVATE cxx_std_11) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# llama.cpp/example/ramalama-core | ||
|
||
The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the chat template from the GGUF file. | ||
|
||
```bash | ||
./llama-ramalama-core -m Meta-Llama-3.1-8B-Instruct.gguf -c 2048 | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,356 @@ | ||
#include <climits> | ||
#include <cstdio> | ||
#include <cstring> | ||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "llama.h" | ||
|
||
// Add a message to `messages` and store its content in `owned_content` | ||
static void add_message(const std::string & role, const std::string & text, std::vector<llama_chat_message> & messages, | ||
std::vector<std::unique_ptr<char[]>> & owned_content) { | ||
auto content = std::unique_ptr<char[]>(new char[text.size() + 1]); | ||
std::strcpy(content.get(), text.c_str()); | ||
messages.push_back({role.c_str(), content.get()}); | ||
owned_content.push_back(std::move(content)); | ||
} | ||
|
||
// Function to apply the chat template and resize `formatted` if needed | ||
static int apply_chat_template(const llama_model * model, const std::vector<llama_chat_message> & messages, | ||
std::vector<char> & formatted, const bool append) { | ||
int result = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), append, formatted.data(), | ||
formatted.size()); | ||
if (result > static_cast<int>(formatted.size())) { | ||
formatted.resize(result); | ||
result = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), append, formatted.data(), | ||
formatted.size()); | ||
} | ||
|
||
return result; | ||
} | ||
|
||
// Function to tokenize the prompt | ||
static int tokenize_prompt(const llama_model * model, const std::string & prompt, | ||
std::vector<llama_token> & prompt_tokens) { | ||
const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); | ||
prompt_tokens.resize(n_prompt_tokens); | ||
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < | ||
0) { | ||
GGML_ABORT("failed to tokenize the prompt\n"); | ||
} | ||
|
||
return n_prompt_tokens; | ||
} | ||
|
||
// Check if we have enough space in the context to evaluate this batch | ||
static int check_context_size(const llama_context * ctx, const llama_batch & batch) { | ||
const int n_ctx = llama_n_ctx(ctx); | ||
const int n_ctx_used = llama_get_kv_cache_used_cells(ctx); | ||
if (n_ctx_used + batch.n_tokens > n_ctx) { | ||
printf("\033[0m\n"); | ||
fprintf(stderr, "context size exceeded\n"); | ||
return 1; | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
// convert the token to a string | ||
static int convert_token_to_string(const llama_model * model, const llama_token token_id, std::string & piece) { | ||
char buf[256]; | ||
int n = llama_token_to_piece(model, token_id, buf, sizeof(buf), 0, true); | ||
if (n < 0) { | ||
GGML_ABORT("failed to convert token to piece\n"); | ||
} | ||
|
||
piece = std::string(buf, n); | ||
return 0; | ||
} | ||
|
||
static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) { | ||
printf("%s", piece.c_str()); | ||
fflush(stdout); | ||
response += piece; | ||
} | ||
|
||
// helper function to evaluate a prompt and generate a response | ||
static int generate(const llama_model * model, llama_sampler * smpl, llama_context * ctx, const std::string & prompt, | ||
std::string & response) { | ||
std::vector<llama_token> prompt_tokens; | ||
const int n_prompt_tokens = tokenize_prompt(model, prompt, prompt_tokens); | ||
if (n_prompt_tokens < 0) { | ||
return 1; | ||
} | ||
|
||
// prepare a batch for the prompt | ||
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); | ||
llama_token new_token_id; | ||
while (true) { | ||
check_context_size(ctx, batch); | ||
if (llama_decode(ctx, batch)) { | ||
GGML_ABORT("failed to decode\n"); | ||
} | ||
|
||
// sample the next token, check is it an end of generation? | ||
new_token_id = llama_sampler_sample(smpl, ctx, -1); | ||
if (llama_token_is_eog(model, new_token_id)) { | ||
break; | ||
} | ||
|
||
std::string piece; | ||
if (convert_token_to_string(model, new_token_id, piece)) { | ||
return 1; | ||
} | ||
|
||
print_word_and_concatenate_to_response(piece, response); | ||
|
||
// prepare the next batch with the sampled token | ||
batch = llama_batch_get_one(&new_token_id, 1); | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
static void print_usage(int, const char ** argv) { | ||
printf("\nexample usage:\n"); | ||
printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]); | ||
printf("\n"); | ||
} | ||
|
||
static int parse_int_arg(const char * arg, int & value) { | ||
char * end; | ||
long val = std::strtol(arg, &end, 10); | ||
if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) { | ||
value = static_cast<int>(val); | ||
return 0; | ||
} | ||
|
||
return 1; | ||
} | ||
|
||
static int handle_model_path(const int argc, const char ** argv, int & i, std::string & model_path) { | ||
if (i + 1 < argc) { | ||
model_path = argv[++i]; | ||
return 0; | ||
} | ||
|
||
print_usage(argc, argv); | ||
return 1; | ||
} | ||
|
||
static int handle_n_ctx(const int argc, const char ** argv, int & i, int & n_ctx) { | ||
if (i + 1 < argc) { | ||
if (parse_int_arg(argv[++i], n_ctx)) { | ||
return 0; | ||
} else { | ||
fprintf(stderr, "error: invalid value for -c: %s\n", argv[i]); | ||
print_usage(argc, argv); | ||
} | ||
} else { | ||
print_usage(argc, argv); | ||
} | ||
|
||
return 1; | ||
} | ||
|
||
static int handle_ngl(const int argc, const char ** argv, int & i, int & ngl) { | ||
if (i + 1 < argc) { | ||
if (parse_int_arg(argv[++i], ngl)) { | ||
return 0; | ||
} else { | ||
fprintf(stderr, "error: invalid value for -ngl: %s\n", argv[i]); | ||
print_usage(argc, argv); | ||
} | ||
} else { | ||
print_usage(argc, argv); | ||
} | ||
|
||
return 1; | ||
} | ||
|
||
static int parse_arguments(const int argc, const char ** argv, std::string & model_path, int & n_ctx, int & ngl) { | ||
for (int i = 1; i < argc; ++i) { | ||
if (strcmp(argv[i], "-m") == 0) { | ||
if (handle_model_path(argc, argv, i, model_path)) { | ||
return 1; | ||
} | ||
} else if (strcmp(argv[i], "-c") == 0) { | ||
if (handle_n_ctx(argc, argv, i, n_ctx)) { | ||
return 1; | ||
} | ||
} else if (strcmp(argv[i], "-ngl") == 0) { | ||
if (handle_ngl(argc, argv, i, ngl)) { | ||
return 1; | ||
} | ||
} else { | ||
print_usage(argc, argv); | ||
return 1; | ||
} | ||
} | ||
|
||
if (model_path.empty()) { | ||
print_usage(argc, argv); | ||
return 1; | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
static int read_user_input(std::string & user_input) { | ||
// Use unique_ptr with free as the deleter | ||
std::unique_ptr<char, decltype(&free)> buffer(nullptr, &free); | ||
|
||
size_t buffer_size = 0; | ||
char * raw_buffer = nullptr; | ||
|
||
// Use getline to dynamically allocate the buffer and get input | ||
const ssize_t line_size = getline(&raw_buffer, &buffer_size, stdin); | ||
|
||
// Transfer ownership to unique_ptr | ||
buffer.reset(raw_buffer); | ||
|
||
if (line_size > 0) { | ||
// Remove the trailing newline character if present | ||
if (buffer.get()[line_size - 1] == '\n') { | ||
buffer.get()[line_size - 1] = '\0'; | ||
} | ||
|
||
user_input = std::string(buffer.get()); | ||
|
||
return 0; // Success | ||
} | ||
|
||
user_input.clear(); | ||
|
||
return 1; // Indicate an error or empty input | ||
} | ||
|
||
// Function to generate a response based on the prompt | ||
static int generate_response(llama_model * model, llama_sampler * sampler, llama_context * context, | ||
const std::string & prompt, std::string & response) { | ||
// Set response color | ||
printf("\033[33m"); | ||
if (generate(model, sampler, context, prompt, response)) { | ||
fprintf(stderr, "failed to generate response\n"); | ||
return 1; | ||
} | ||
|
||
// End response with color reset and newline | ||
printf("\n\033[0m"); | ||
return 0; | ||
} | ||
|
||
// The main chat loop where user inputs are processed and responses generated. | ||
static int chat_loop(llama_model * model, llama_sampler * sampler, llama_context * context, | ||
std::vector<llama_chat_message> & messages) { | ||
std::vector<std::unique_ptr<char[]>> owned_content; | ||
std::vector<char> formatted(llama_n_ctx(context)); | ||
int prev_len = 0; | ||
|
||
while (true) { | ||
// Print prompt for user input | ||
printf("\033[32m> \033[0m"); | ||
std::string user; | ||
if (read_user_input(user)) { | ||
break; | ||
} | ||
|
||
add_message("user", user, messages, owned_content); | ||
int new_len = apply_chat_template(model, messages, formatted, true); | ||
if (new_len < 0) { | ||
fprintf(stderr, "failed to apply the chat template\n"); | ||
return 1; | ||
} | ||
|
||
std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len); | ||
std::string response; | ||
if (generate_response(model, sampler, context, prompt, response)) { | ||
return 1; | ||
} | ||
|
||
add_message("assistant", response, messages, owned_content); | ||
prev_len = apply_chat_template(model, messages, formatted, false); | ||
if (prev_len < 0) { | ||
fprintf(stderr, "failed to apply the chat template\n"); | ||
return 1; | ||
} | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
static void log_callback(const enum ggml_log_level level, const char * text, void *) { | ||
if (level == GGML_LOG_LEVEL_ERROR) { | ||
fprintf(stderr, "%s", text); | ||
} | ||
} | ||
|
||
// Initializes the model and returns a unique pointer to it. | ||
static std::unique_ptr<llama_model, decltype(&llama_free_model)> initialize_model(const std::string & model_path, | ||
int ngl) { | ||
llama_model_params model_params = llama_model_default_params(); | ||
model_params.n_gpu_layers = ngl; | ||
|
||
auto model = std::unique_ptr<llama_model, decltype(&llama_free_model)>( | ||
llama_load_model_from_file(model_path.c_str(), model_params), llama_free_model); | ||
if (!model) { | ||
fprintf(stderr, "%s: error: unable to load model\n", __func__); | ||
} | ||
|
||
return model; | ||
} | ||
|
||
// Initializes the context with the specified parameters. | ||
static std::unique_ptr<llama_context, decltype(&llama_free)> initialize_context(llama_model * model, int n_ctx) { | ||
llama_context_params ctx_params = llama_context_default_params(); | ||
ctx_params.n_ctx = n_ctx; | ||
ctx_params.n_batch = n_ctx; | ||
|
||
auto context = std::unique_ptr<llama_context, decltype(&llama_free)>( | ||
llama_new_context_with_model(model, ctx_params), llama_free); | ||
if (!context) { | ||
fprintf(stderr, "%s: error: failed to create the llama_context\n", __func__); | ||
} | ||
|
||
return context; | ||
} | ||
|
||
// Initializes and configures the sampler. | ||
static std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)> initialize_sampler() { | ||
auto sampler = std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>( | ||
llama_sampler_chain_init(llama_sampler_chain_default_params()), llama_sampler_free); | ||
llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1)); | ||
llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f)); | ||
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); | ||
|
||
return sampler; | ||
} | ||
|
||
int main(int argc, const char ** argv) { | ||
std::string model_path; | ||
int ngl = 99; | ||
int n_ctx = 2048; | ||
if (parse_arguments(argc, argv, model_path, n_ctx, ngl)) { | ||
return 1; | ||
} | ||
|
||
llama_log_set(log_callback, nullptr); | ||
auto model = initialize_model(model_path, ngl); | ||
if (!model) { | ||
return 1; | ||
} | ||
|
||
auto context = initialize_context(model.get(), n_ctx); | ||
if (!context) { | ||
return 1; | ||
} | ||
|
||
auto sampler = initialize_sampler(); | ||
std::vector<llama_chat_message> messages; | ||
if (chat_loop(model.get(), sampler.get(), context.get(), messages)) { | ||
return 1; | ||
} | ||
|
||
return 0; | ||
} |
Oops, something went wrong.