Skip to content

Commit

Permalink
Merge pull request #650 from vlovich/fix-chat-template
Browse files Browse the repository at this point in the history
Cleanup chat template API
  • Loading branch information
MarcusDunn authored Feb 14, 2025
2 parents 5c8e81b + 72c1255 commit 00f4163
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 42 deletions.
8 changes: 2 additions & 6 deletions examples/simple/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ use anyhow::{anyhow, bail, Context, Result};
use clap::Parser;
use hf_hub::api::sync::ApiBuilder;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::{ggml_time_us, send_logs_to_tracing, LogOptions};
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::{AddBos, Special};
use llama_cpp_2::sampling::LlamaSampler;
use llama_cpp_2::{ggml_time_us, send_logs_to_tracing, LogOptions};

use std::ffi::CString;
use std::io::Write;
Expand Down Expand Up @@ -67,11 +67,7 @@ struct Args {
help = "size of the prompt context (default: loaded from themodel)"
)]
ctx_size: Option<NonZeroU32>,
#[arg(
short = 'v',
long,
help = "enable verbose llama.cpp logs",
)]
#[arg(short = 'v', long, help = "enable verbose llama.cpp logs")]
verbose: bool,
}

Expand Down
9 changes: 6 additions & 3 deletions llama-cpp-2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ pub enum LLamaCppError {
/// There was an error while getting the chat template from a model.
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum ChatTemplateError {
/// the buffer was too small.
#[error("The buffer was too small. However, a buffer size of {0} would be just large enough.")]
BuffSizeError(usize),
/// gguf has no chat template
#[error("the model has no meta val - returned code {0}")]
MissingTemplate(i32),
Expand All @@ -80,6 +77,12 @@ pub enum ChatTemplateError {
Utf8Error(#[from] std::str::Utf8Error),
}

enum InternalChatTemplateError {
Permanent(ChatTemplateError),
/// the buffer was too small.
RetryWithLargerBuffer(usize),
}

/// Failed to Load context
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum LlamaContextLoadError {
Expand Down
3 changes: 2 additions & 1 deletion llama-cpp-2/src/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ impl State {
} else {
let level = self
.previous_level
.load(std::sync::atomic::Ordering::Acquire) as llama_cpp_sys_2::ggml_log_level;
.load(std::sync::atomic::Ordering::Acquire)
as llama_cpp_sys_2::ggml_log_level;
tracing::warn!(
inferred_level = level,
text = text,
Expand Down
150 changes: 118 additions & 32 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! A safe wrapper around `llama_model`.
use std::ffi::{c_char, CString};
use std::ffi::{c_char, CStr, CString};
use std::num::NonZeroU16;
use std::os::raw::c_int;
use std::path::Path;
use std::ptr::NonNull;
use std::str::{FromStr, Utf8Error};

use crate::context::params::LlamaContextParams;
use crate::context::LlamaContext;
Expand All @@ -12,8 +13,9 @@ use crate::model::params::LlamaModelParams;
use crate::token::LlamaToken;
use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
use crate::{
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
ApplyChatTemplateError, ChatTemplateError, InternalChatTemplateError, LlamaContextLoadError,
LlamaLoraAdapterInitError, LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError,
TokenToStringError,
};

pub mod params;
Expand All @@ -34,6 +36,42 @@ pub struct LlamaLoraAdapter {
pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
}

/// A performance-friendly wrapper around [LlamaModel::get_chat_template] which is then
/// fed into [LlamaModel::apply_chat_template] to convert a list of messages into an LLM
/// prompt. Internally the template is stored as a CString to avoid round-trip conversions
/// within the FFI.
#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
pub struct LlamaChatTemplate(CString);

impl LlamaChatTemplate {
/// Create a new template from a string. This can either be the name of a llama.cpp [chat template](https://github.com/ggerganov/llama.cpp/blob/8a8c4ceb6050bd9392609114ca56ae6d26f5b8f5/src/llama-chat.cpp#L27-L61)
/// like "chatml" or "llama3" or an actual Jinja template for llama.cpp to interpret.
pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
Ok(Self(CString::from_str(template)?))
}

/// Accesses the template as a c string reference.
pub fn as_c_str(&self) -> &CStr {
&self.0
}

/// Attempts to convert the CString into a Rust str reference.
pub fn to_str(&self) -> Result<&str, Utf8Error> {
self.0.to_str()
}

/// Convenience method to create an owned String.
pub fn to_string(&self) -> Result<String, Utf8Error> {
self.to_str().map(str::to_string)
}
}

impl std::fmt::Debug for LlamaChatTemplate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}

/// A Safe wrapper around `llama_chat_message`
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct LlamaChatMessage {
Expand Down Expand Up @@ -408,41 +446,84 @@ impl LlamaModel {
unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
}

/// Get chat template from model.
///
/// # Errors
///
/// * If the model has no chat template
/// * If the chat template is not a valid [`CString`].
#[allow(clippy::missing_panics_doc)] // we statically know this will not panic as
pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
fn get_chat_template_impl(
&self,
capacity: usize,
) -> Result<LlamaChatTemplate, InternalChatTemplateError> {
// longest known template is about 1200 bytes from llama.cpp
let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
let chat_ptr = chat_temp.into_raw();
let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
// TODO: Once MaybeUninit support is better, this can be converted to use that instead of dummy initializing such a large array.
let mut chat_temp = vec![b'*' as u8; capacity];
let chat_name =
CStr::from_bytes_with_nul(b"tokenizer.chat_template\0").expect("should have null byte");

let ret = unsafe {
llama_cpp_sys_2::llama_model_meta_val_str(
self.model.as_ptr(),
chat_name.as_ptr(),
chat_ptr,
buf_size,
chat_temp.as_mut_ptr() as *mut c_char,
chat_temp.len(),
)
};

if ret < 0 {
return Err(ChatTemplateError::MissingTemplate(ret));
return Err(InternalChatTemplateError::Permanent(
ChatTemplateError::MissingTemplate(ret),
));
}

let template_c = unsafe { CString::from_raw(chat_ptr) };
let template = template_c.to_str()?;
let returned_len = ret as usize;

let ret: usize = ret.try_into().unwrap();
if template.len() < ret {
return Err(ChatTemplateError::BuffSizeError(ret + 1));
if ret as usize >= capacity {
// >= is important because if the returned length is equal to capacity, it means we're missing a trailing null
// since the returned length doesn't count the trailing null.
return Err(InternalChatTemplateError::RetryWithLargerBuffer(
returned_len,
));
}

Ok(template.to_owned())
assert_eq!(
chat_temp.get(returned_len),
Some(&0),
"should end with null byte"
);

chat_temp.resize(returned_len + 1, 0);

Ok(LlamaChatTemplate(unsafe {
CString::from_vec_with_nul_unchecked(chat_temp)
}))
}

/// Get chat template from model. If this fails, you may either want to fail to chat or pick the
/// specific shortcode that llama.cpp supports templates it has baked-in directly into its codebase
/// as fallbacks when the model doesn't contain. NOTE: If you don't specify a chat template, then
/// it uses chatml by default which is unlikely to actually be the correct template for your model
/// and you'll get weird results back.
///
/// You supply this into [Self::apply_chat_template] to get back a string with the appropriate template
/// substitution applied to convert a list of messages into a prompt the LLM can use to complete
/// the chat.
///
/// # Errors
///
/// * If the model has no chat template
/// * If the chat template is not a valid [`CString`].
#[allow(clippy::missing_panics_doc)] // we statically know this will not panic as
pub fn get_chat_template(&self) -> Result<LlamaChatTemplate, ChatTemplateError> {
// Typical chat templates are quite small. Let's start with a small allocation likely to succeed.
// Ideally the performance of this would be negligible but uninitialized arrays in Rust are currently
// still not well supported so we end up initializing the chat template buffer twice. One idea might
// be to use a very small value here that will likely fail (like 0 or 1) and then use that to initialize.
// Not sure which approach is the most optimal but in practice this should work well.
match self.get_chat_template_impl(200) {
Ok(t) => Ok(t),
Err(InternalChatTemplateError::Permanent(e)) => Err(e),
Err(InternalChatTemplateError::RetryWithLargerBuffer(actual_len)) => match self.get_chat_template_impl(actual_len + 1) {
Ok(t) => Ok(t),
Err(InternalChatTemplateError::Permanent(e)) => Err(e),
Err(InternalChatTemplateError::RetryWithLargerBuffer(unexpected_len)) => panic!("Was told that the template length was {actual_len} but now it's {unexpected_len}"),
}
}
}

/// Loads a model from a file.
Expand Down Expand Up @@ -526,15 +607,25 @@ impl LlamaModel {
/// Apply the models chat template to some messages.
/// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
///
/// `tmpl` of None means to use the default template provided by llama.cpp for the model
/// Unlike the llama.cpp apply_chat_template which just randomly uses the ChatML template when given
/// a null pointer for the template, this requires an explicit template to be specified. If you want to
/// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template
/// string.
///
/// Use [Self::get_chat_template] to retrieve the template baked into the model (this is the preferred
/// mechanism as using the wrong chat template can result in really unexpected responses from the LLM).
///
/// You probably want to set `add_ass` to true so that the generated template string ends with a the
/// opening tag of the assistant. If you fail to leave a hanging chat tag, the model will likely generate
/// one into the output and the output may also have unexpected output aside from that.
///
/// # Errors
/// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
#[tracing::instrument(skip_all)]
pub fn apply_chat_template(
&self,
tmpl: Option<String>,
chat: Vec<LlamaChatMessage>,
tmpl: &LlamaChatTemplate,
chat: &[LlamaChatMessage],
add_ass: bool,
) -> Result<String, ApplyChatTemplateError> {
// Buffer is twice the length of messages per their recommendation
Expand All @@ -552,12 +643,7 @@ impl LlamaModel {
})
.collect();

// Set the tmpl pointer
let tmpl = tmpl.map(CString::new);
let tmpl_ptr = match &tmpl {
Some(str) => str.as_ref().map_err(Clone::clone)?.as_ptr(),
None => std::ptr::null(),
};
let tmpl_ptr = tmpl.0.as_ptr();

let res = unsafe {
llama_cpp_sys_2::llama_chat_apply_template(
Expand Down

0 comments on commit 00f4163

Please sign in to comment.