Skip to content

Commit

Permalink
Cleanup chat template API
Browse files Browse the repository at this point in the history
1. Make the template not an optional for apply_chat_template. This
   ensures you don't accidentally use the chatml template.
2. Improve performance for the expected case of using get_chat_template
   by returning a new LlamaChatTemplate struct that internally stores
   the string as a CString. Unless you try to explicitly create a copy or
   print, there's no extra copy into a Rust string that's created.
   Similarly, get_chat_template -> apply_chat_template no longer copies
   the template string.
3. Improve documentation including documentating what the add_ass
   parameter does and suggestions on what values you probably want to
   use. Additionally, I've made get_chat_template and
   apply_chat_template docs refer to one another to make it easier to
   discover how to use this.
  • Loading branch information
vlovich committed Feb 13, 2025
1 parent 04e9e31 commit 7a29ac4
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 35 deletions.
9 changes: 6 additions & 3 deletions llama-cpp-2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ pub enum LLamaCppError {
/// There was an error while getting the chat template from a model.
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum ChatTemplateError {
/// the buffer was too small.
#[error("The buffer was too small. However, a buffer size of {0} would be just large enough.")]
BuffSizeError(usize),
/// gguf has no chat template
#[error("the model has no meta val - returned code {0}")]
MissingTemplate(i32),
Expand All @@ -80,6 +77,12 @@ pub enum ChatTemplateError {
Utf8Error(#[from] std::str::Utf8Error),
}

enum InternalChatTemplateError {
Permanent(ChatTemplateError),
/// the buffer was too small.
RetryWithLargerBuffer(usize),
}

/// Failed to Load context
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum LlamaContextLoadError {
Expand Down
150 changes: 118 additions & 32 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! A safe wrapper around `llama_model`.
use std::ffi::{c_char, CString};
use std::ffi::{c_char, CStr, CString};
use std::num::NonZeroU16;
use std::os::raw::c_int;
use std::path::Path;
use std::ptr::NonNull;
use std::str::{FromStr, Utf8Error};

use crate::context::params::LlamaContextParams;
use crate::context::LlamaContext;
Expand All @@ -12,8 +13,9 @@ use crate::model::params::LlamaModelParams;
use crate::token::LlamaToken;
use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
use crate::{
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
ApplyChatTemplateError, ChatTemplateError, InternalChatTemplateError, LlamaContextLoadError,
LlamaLoraAdapterInitError, LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError,
TokenToStringError,
};

pub mod params;
Expand All @@ -34,6 +36,42 @@ pub struct LlamaLoraAdapter {
pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
}

/// A performance-friendly wrapper around [LlamaModel::get_chat_template] which is then
/// fed into [LlamaModel::apply_chat_template] to convert a list of messages into an LLM
/// prompt. Internally the template is stored as a CString to avoid round-trip conversions
/// within the FFI.
#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
pub struct LlamaChatTemplate(CString);

impl LlamaChatTemplate {
/// Create a new template from a string. This can either be the name of a llama.cpp [chat template](https://github.com/ggerganov/llama.cpp/blob/8a8c4ceb6050bd9392609114ca56ae6d26f5b8f5/src/llama-chat.cpp#L27-L61)
/// like "chatml" or "llama3" or an actual Jinja template for llama.cpp to interpret.
pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
Ok(Self(CString::from_str(template)?))
}

/// Accesses the template as a c string reference.
pub fn as_c_str(&self) -> &CStr {
&self.0
}

/// Attempts to convert the CString into a Rust str reference.
pub fn to_str(&self) -> Result<&str, Utf8Error> {
self.0.to_str()
}

/// Convenience method to create an owned String.
pub fn to_string(&self) -> Result<String, Utf8Error> {
self.to_str().map(str::to_string)
}
}

impl std::fmt::Debug for LlamaChatTemplate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}

/// A Safe wrapper around `llama_chat_message`
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct LlamaChatMessage {
Expand Down Expand Up @@ -408,41 +446,84 @@ impl LlamaModel {
unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
}

/// Get chat template from model.
///
/// # Errors
///
/// * If the model has no chat template
/// * If the chat template is not a valid [`CString`].
#[allow(clippy::missing_panics_doc)] // we statically know this will not panic as
pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
fn get_chat_template_impl(
&self,
capacity: usize,
) -> Result<LlamaChatTemplate, InternalChatTemplateError> {
// longest known template is about 1200 bytes from llama.cpp
let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
let chat_ptr = chat_temp.into_raw();
let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
// TODO: Once MaybeUninit support is better, this can be converted to use that instead of dummy initializing such a large array.
let mut chat_temp = vec![b'*' as u8; capacity];
let chat_name =
CStr::from_bytes_with_nul(b"tokenizer.chat_template\0").expect("should have null byte");

let ret = unsafe {
llama_cpp_sys_2::llama_model_meta_val_str(
self.model.as_ptr(),
chat_name.as_ptr(),
chat_ptr,
buf_size,
chat_temp.as_mut_ptr() as *mut c_char,
chat_temp.len(),
)
};

if ret < 0 {
return Err(ChatTemplateError::MissingTemplate(ret));
return Err(InternalChatTemplateError::Permanent(
ChatTemplateError::MissingTemplate(ret),
));
}

let template_c = unsafe { CString::from_raw(chat_ptr) };
let template = template_c.to_str()?;
let returned_len = ret as usize;

let ret: usize = ret.try_into().unwrap();
if template.len() < ret {
return Err(ChatTemplateError::BuffSizeError(ret + 1));
if ret as usize >= capacity {
// >= is important because if the returned length is equal to capacity, it means we're missing a trailing null
// since the returned length doesn't count the trailing null.
return Err(InternalChatTemplateError::RetryWithLargerBuffer(
returned_len,
));
}

Ok(template.to_owned())
assert_eq!(
chat_temp.get(returned_len),
Some(&0),
"should end with null byte"
);

chat_temp.resize(returned_len + 1, 0);

Ok(LlamaChatTemplate(unsafe {
CString::from_vec_with_nul_unchecked(chat_temp)
}))
}

/// Get chat template from model. If this fails, you may either want to fail to chat or pick the
/// specific shortcode that llama.cpp supports templates it has baked-in directly into its codebase
/// as fallbacks when the model doesn't contain. NOTE: If you don't specify a chat template, then
/// it uses chatml by default which is unlikely to actually be the correct template for your model
/// and you'll get weird results back.
///
/// You supply this into [Self::apply_chat_template] to get back a string with the appropriate template
/// substitution applied to convert a list of messages into a prompt the LLM can use to complete
/// the chat.
///
/// # Errors
///
/// * If the model has no chat template
/// * If the chat template is not a valid [`CString`].
#[allow(clippy::missing_panics_doc)] // we statically know this will not panic as
pub fn get_chat_template(&self) -> Result<LlamaChatTemplate, ChatTemplateError> {
// Typical chat templates are quite small. Let's start with a small allocation likely to succeed.
// Ideally the performance of this would be negligible but uninitialized arrays in Rust are currently
// still not well supported so we end up initializing the chat template buffer twice. One idea might
// be to use a very small value here that will likely fail (like 0 or 1) and then use that to initialize.
// Not sure which approach is the most optimal but in practice this should work well.
match self.get_chat_template_impl(200) {
Ok(t) => Ok(t),
Err(InternalChatTemplateError::Permanent(e)) => Err(e),
Err(InternalChatTemplateError::RetryWithLargerBuffer(actual_len)) => match self.get_chat_template_impl(actual_len + 1) {
Ok(t) => Ok(t),
Err(InternalChatTemplateError::Permanent(e)) => Err(e),
Err(InternalChatTemplateError::RetryWithLargerBuffer(unexpected_len)) => panic!("Was told that the template length was {actual_len} but now it's {unexpected_len}"),
}
}
}

/// Loads a model from a file.
Expand Down Expand Up @@ -526,15 +607,25 @@ impl LlamaModel {
/// Apply the models chat template to some messages.
/// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
///
/// `tmpl` of None means to use the default template provided by llama.cpp for the model
/// Unlike the llama.cpp apply_chat_template which just randomly uses the ChatML template when given
/// a null pointer for the template, this requires an explicit template to be specified. If you want to
/// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template
/// string.
///
/// Use [Self::get_chat_template] to retrieve the template baked into the model (this is the preferred
/// mechanism as using the wrong chat template can result in really unexpected responses from the LLM).
///
/// You probably want to set `add_ass` to true so that the generated template string ends with a the
/// opening tag of the assistant. If you fail to leave a hanging chat tag, the model will likely generate
/// one into the output and the output may also have unexpected output aside from that.
///
/// # Errors
/// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
#[tracing::instrument(skip_all)]
pub fn apply_chat_template(
&self,
tmpl: Option<String>,
chat: Vec<LlamaChatMessage>,
tmpl: &LlamaChatTemplate,
chat: &[LlamaChatMessage],
add_ass: bool,
) -> Result<String, ApplyChatTemplateError> {
// Buffer is twice the length of messages per their recommendation
Expand All @@ -552,12 +643,7 @@ impl LlamaModel {
})
.collect();

// Set the tmpl pointer
let tmpl = tmpl.map(CString::new);
let tmpl_ptr = match &tmpl {
Some(str) => str.as_ref().map_err(Clone::clone)?.as_ptr(),
None => std::ptr::null(),
};
let tmpl_ptr = tmpl.0.as_ptr();

let res = unsafe {
llama_cpp_sys_2::llama_chat_apply_template(
Expand Down

0 comments on commit 7a29ac4

Please sign in to comment.