Skip to content

Commit

Permalink
Use smart pointers in simple-chat
Browse files Browse the repository at this point in the history
Avoid manual memory cleanups. Less memory leaks in the code now.
Avoid printing multiple dots. Split code into smaller functions.
Use C-style IO, rather than a mix of C++ streams and C style. No
exception handling.

Signed-off-by: Eric Curtin <[email protected]>
  • Loading branch information
ericcurtin committed Nov 15, 2024
1 parent 1842922 commit 83988df
Show file tree
Hide file tree
Showing 6 changed files with 672 additions and 138 deletions.
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ BUILD_TARGETS = \
llama-server \
llama-simple \
llama-simple-chat \
llama-ramalama-core \
llama-speculative \
llama-tokenize \
llama-vdot \
Expand Down Expand Up @@ -1382,6 +1383,11 @@ llama-infill: examples/infill/infill.cpp \
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

llama-ramalama-core: examples/ramalama-core/ramalama-core.cpp \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

llama-simple: examples/simple/simple.cpp \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ else()
add_subdirectory(sycl)
endif()
add_subdirectory(save-load-state)
add_subdirectory(ramalama-core)
add_subdirectory(simple)
add_subdirectory(simple-chat)
add_subdirectory(speculative)
Expand Down
5 changes: 5 additions & 0 deletions examples/ramalama-core/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(TARGET llama-ramalama-core)
add_executable(${TARGET} ramalama-core.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
7 changes: 7 additions & 0 deletions examples/ramalama-core/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# llama.cpp/example/ramalama-core

The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the chat template from the GGUF file.

```bash
./llama-ramalama-core -m Meta-Llama-3.1-8B-Instruct.gguf -c 2048
...
356 changes: 356 additions & 0 deletions examples/ramalama-core/ramalama-core.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,356 @@
#include <climits>
#include <cstdio>
#include <cstring>
#include <memory>
#include <string>
#include <vector>

#include "llama.h"

// Add a message to `messages` and store its content in `owned_content`
static void add_message(const std::string & role, const std::string & text, std::vector<llama_chat_message> & messages,
std::vector<std::unique_ptr<char[]>> & owned_content) {
auto content = std::unique_ptr<char[]>(new char[text.size() + 1]);
std::strcpy(content.get(), text.c_str());
messages.push_back({role.c_str(), content.get()});
owned_content.push_back(std::move(content));
}

// Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(const llama_model * model, const std::vector<llama_chat_message> & messages,
std::vector<char> & formatted, const bool append) {
int result = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), append, formatted.data(),
formatted.size());
if (result > static_cast<int>(formatted.size())) {
formatted.resize(result);
result = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), append, formatted.data(),
formatted.size());
}

return result;
}

// Function to tokenize the prompt
static int tokenize_prompt(const llama_model * model, const std::string & prompt,
std::vector<llama_token> & prompt_tokens) {
const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) <
0) {
GGML_ABORT("failed to tokenize the prompt\n");
}

return n_prompt_tokens;
}

// Check if we have enough space in the context to evaluate this batch
static int check_context_size(const llama_context * ctx, const llama_batch & batch) {
const int n_ctx = llama_n_ctx(ctx);
const int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
if (n_ctx_used + batch.n_tokens > n_ctx) {
printf("\033[0m\n");
fprintf(stderr, "context size exceeded\n");
return 1;
}

return 0;
}

// convert the token to a string
static int convert_token_to_string(const llama_model * model, const llama_token token_id, std::string & piece) {
char buf[256];
int n = llama_token_to_piece(model, token_id, buf, sizeof(buf), 0, true);
if (n < 0) {
GGML_ABORT("failed to convert token to piece\n");
}

piece = std::string(buf, n);
return 0;
}

static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) {
printf("%s", piece.c_str());
fflush(stdout);
response += piece;
}

// helper function to evaluate a prompt and generate a response
static int generate(const llama_model * model, llama_sampler * smpl, llama_context * ctx, const std::string & prompt,
std::string & response) {
std::vector<llama_token> prompt_tokens;
const int n_prompt_tokens = tokenize_prompt(model, prompt, prompt_tokens);
if (n_prompt_tokens < 0) {
return 1;
}

// prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
llama_token new_token_id;
while (true) {
check_context_size(ctx, batch);
if (llama_decode(ctx, batch)) {
GGML_ABORT("failed to decode\n");
}

// sample the next token, check is it an end of generation?
new_token_id = llama_sampler_sample(smpl, ctx, -1);
if (llama_token_is_eog(model, new_token_id)) {
break;
}

std::string piece;
if (convert_token_to_string(model, new_token_id, piece)) {
return 1;
}

print_word_and_concatenate_to_response(piece, response);

// prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1);
}

return 0;
}

static void print_usage(int, const char ** argv) {
printf("\nexample usage:\n");
printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]);
printf("\n");
}

static int parse_int_arg(const char * arg, int & value) {
char * end;
long val = std::strtol(arg, &end, 10);
if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) {
value = static_cast<int>(val);
return 0;
}

return 1;
}

static int handle_model_path(const int argc, const char ** argv, int & i, std::string & model_path) {
if (i + 1 < argc) {
model_path = argv[++i];
return 0;
}

print_usage(argc, argv);
return 1;
}

static int handle_n_ctx(const int argc, const char ** argv, int & i, int & n_ctx) {
if (i + 1 < argc) {
if (parse_int_arg(argv[++i], n_ctx)) {
return 0;
} else {
fprintf(stderr, "error: invalid value for -c: %s\n", argv[i]);
print_usage(argc, argv);
}
} else {
print_usage(argc, argv);
}

return 1;
}

static int handle_ngl(const int argc, const char ** argv, int & i, int & ngl) {
if (i + 1 < argc) {
if (parse_int_arg(argv[++i], ngl)) {
return 0;
} else {
fprintf(stderr, "error: invalid value for -ngl: %s\n", argv[i]);
print_usage(argc, argv);
}
} else {
print_usage(argc, argv);
}

return 1;
}

static int parse_arguments(const int argc, const char ** argv, std::string & model_path, int & n_ctx, int & ngl) {
for (int i = 1; i < argc; ++i) {
if (strcmp(argv[i], "-m") == 0) {
if (handle_model_path(argc, argv, i, model_path)) {
return 1;
}
} else if (strcmp(argv[i], "-c") == 0) {
if (handle_n_ctx(argc, argv, i, n_ctx)) {
return 1;
}
} else if (strcmp(argv[i], "-ngl") == 0) {
if (handle_ngl(argc, argv, i, ngl)) {
return 1;
}
} else {
print_usage(argc, argv);
return 1;
}
}

if (model_path.empty()) {
print_usage(argc, argv);
return 1;
}

return 0;
}

static int read_user_input(std::string & user_input) {
// Use unique_ptr with free as the deleter
std::unique_ptr<char, decltype(&free)> buffer(nullptr, &free);

size_t buffer_size = 0;
char * raw_buffer = nullptr;

// Use getline to dynamically allocate the buffer and get input
const ssize_t line_size = getline(&raw_buffer, &buffer_size, stdin);

// Transfer ownership to unique_ptr
buffer.reset(raw_buffer);

if (line_size > 0) {
// Remove the trailing newline character if present
if (buffer.get()[line_size - 1] == '\n') {
buffer.get()[line_size - 1] = '\0';
}

user_input = std::string(buffer.get());

return 0; // Success
}

user_input.clear();

return 1; // Indicate an error or empty input
}

// Function to generate a response based on the prompt
static int generate_response(llama_model * model, llama_sampler * sampler, llama_context * context,
const std::string & prompt, std::string & response) {
// Set response color
printf("\033[33m");
if (generate(model, sampler, context, prompt, response)) {
fprintf(stderr, "failed to generate response\n");
return 1;
}

// End response with color reset and newline
printf("\n\033[0m");
return 0;
}

// The main chat loop where user inputs are processed and responses generated.
static int chat_loop(llama_model * model, llama_sampler * sampler, llama_context * context,
std::vector<llama_chat_message> & messages) {
std::vector<std::unique_ptr<char[]>> owned_content;
std::vector<char> formatted(llama_n_ctx(context));
int prev_len = 0;

while (true) {
// Print prompt for user input
printf("\033[32m> \033[0m");
std::string user;
if (read_user_input(user)) {
break;
}

add_message("user", user, messages, owned_content);
int new_len = apply_chat_template(model, messages, formatted, true);
if (new_len < 0) {
fprintf(stderr, "failed to apply the chat template\n");
return 1;
}

std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);
std::string response;
if (generate_response(model, sampler, context, prompt, response)) {
return 1;
}

add_message("assistant", response, messages, owned_content);
prev_len = apply_chat_template(model, messages, formatted, false);
if (prev_len < 0) {
fprintf(stderr, "failed to apply the chat template\n");
return 1;
}
}

return 0;
}

static void log_callback(const enum ggml_log_level level, const char * text, void *) {
if (level == GGML_LOG_LEVEL_ERROR) {
fprintf(stderr, "%s", text);
}
}

// Initializes the model and returns a unique pointer to it.
static std::unique_ptr<llama_model, decltype(&llama_free_model)> initialize_model(const std::string & model_path,
int ngl) {
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = ngl;

auto model = std::unique_ptr<llama_model, decltype(&llama_free_model)>(
llama_load_model_from_file(model_path.c_str(), model_params), llama_free_model);
if (!model) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
}

return model;
}

// Initializes the context with the specified parameters.
static std::unique_ptr<llama_context, decltype(&llama_free)> initialize_context(llama_model * model, int n_ctx) {
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_ctx;
ctx_params.n_batch = n_ctx;

auto context = std::unique_ptr<llama_context, decltype(&llama_free)>(
llama_new_context_with_model(model, ctx_params), llama_free);
if (!context) {
fprintf(stderr, "%s: error: failed to create the llama_context\n", __func__);
}

return context;
}

// Initializes and configures the sampler.
static std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)> initialize_sampler() {
auto sampler = std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>(
llama_sampler_chain_init(llama_sampler_chain_default_params()), llama_sampler_free);
llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f));
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED));

return sampler;
}

int main(int argc, const char ** argv) {
std::string model_path;
int ngl = 99;
int n_ctx = 2048;
if (parse_arguments(argc, argv, model_path, n_ctx, ngl)) {
return 1;
}

llama_log_set(log_callback, nullptr);
auto model = initialize_model(model_path, ngl);
if (!model) {
return 1;
}

auto context = initialize_context(model.get(), n_ctx);
if (!context) {
return 1;
}

auto sampler = initialize_sampler();
std::vector<llama_chat_message> messages;
if (chat_loop(model.get(), sampler.get(), context.get(), messages)) {
return 1;
}

return 0;
}
Loading

0 comments on commit 83988df

Please sign in to comment.