From 7a29ac49b05d06c72176d914a1fa8794b1a80543 Mon Sep 17 00:00:00 2001 From: Vitali Lovich Date: Thu, 13 Feb 2025 12:02:08 -0800 Subject: [PATCH] Cleanup chat template API 1. Make the template not an optional for apply_chat_template. This ensures you don't accidentally use the chatml template. 2. Improve performance for the expected case of using get_chat_template by returning a new LlamaChatTemplate struct that internally stores the string as a CString. Unless you try to explicitly create a copy or print, there's no extra copy into a Rust string that's created. Similarly, get_chat_template -> apply_chat_template no longer copies the template string. 3. Improve documentation including documentating what the add_ass parameter does and suggestions on what values you probably want to use. Additionally, I've made get_chat_template and apply_chat_template docs refer to one another to make it easier to discover how to use this. --- llama-cpp-2/src/lib.rs | 9 ++- llama-cpp-2/src/model.rs | 150 ++++++++++++++++++++++++++++++--------- 2 files changed, 124 insertions(+), 35 deletions(-) diff --git a/llama-cpp-2/src/lib.rs b/llama-cpp-2/src/lib.rs index 61de5a65..3d79337f 100644 --- a/llama-cpp-2/src/lib.rs +++ b/llama-cpp-2/src/lib.rs @@ -69,9 +69,6 @@ pub enum LLamaCppError { /// There was an error while getting the chat template from a model. #[derive(Debug, Eq, PartialEq, thiserror::Error)] pub enum ChatTemplateError { - /// the buffer was too small. - #[error("The buffer was too small. However, a buffer size of {0} would be just large enough.")] - BuffSizeError(usize), /// gguf has no chat template #[error("the model has no meta val - returned code {0}")] MissingTemplate(i32), @@ -80,6 +77,12 @@ pub enum ChatTemplateError { Utf8Error(#[from] std::str::Utf8Error), } +enum InternalChatTemplateError { + Permanent(ChatTemplateError), + /// the buffer was too small. + RetryWithLargerBuffer(usize), +} + /// Failed to Load context #[derive(Debug, Eq, PartialEq, thiserror::Error)] pub enum LlamaContextLoadError { diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 3dc02ee9..8b19c4bb 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -1,9 +1,10 @@ //! A safe wrapper around `llama_model`. -use std::ffi::{c_char, CString}; +use std::ffi::{c_char, CStr, CString}; use std::num::NonZeroU16; use std::os::raw::c_int; use std::path::Path; use std::ptr::NonNull; +use std::str::{FromStr, Utf8Error}; use crate::context::params::LlamaContextParams; use crate::context::LlamaContext; @@ -12,8 +13,9 @@ use crate::model::params::LlamaModelParams; use crate::token::LlamaToken; use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs}; use crate::{ - ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError, - LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError, TokenToStringError, + ApplyChatTemplateError, ChatTemplateError, InternalChatTemplateError, LlamaContextLoadError, + LlamaLoraAdapterInitError, LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError, + TokenToStringError, }; pub mod params; @@ -34,6 +36,42 @@ pub struct LlamaLoraAdapter { pub(crate) lora_adapter: NonNull, } +/// A performance-friendly wrapper around [LlamaModel::get_chat_template] which is then +/// fed into [LlamaModel::apply_chat_template] to convert a list of messages into an LLM +/// prompt. Internally the template is stored as a CString to avoid round-trip conversions +/// within the FFI. +#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)] +pub struct LlamaChatTemplate(CString); + +impl LlamaChatTemplate { + /// Create a new template from a string. This can either be the name of a llama.cpp [chat template](https://github.com/ggerganov/llama.cpp/blob/8a8c4ceb6050bd9392609114ca56ae6d26f5b8f5/src/llama-chat.cpp#L27-L61) + /// like "chatml" or "llama3" or an actual Jinja template for llama.cpp to interpret. + pub fn new(template: &str) -> Result { + Ok(Self(CString::from_str(template)?)) + } + + /// Accesses the template as a c string reference. + pub fn as_c_str(&self) -> &CStr { + &self.0 + } + + /// Attempts to convert the CString into a Rust str reference. + pub fn to_str(&self) -> Result<&str, Utf8Error> { + self.0.to_str() + } + + /// Convenience method to create an owned String. + pub fn to_string(&self) -> Result { + self.to_str().map(str::to_string) + } +} + +impl std::fmt::Debug for LlamaChatTemplate { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + /// A Safe wrapper around `llama_chat_message` #[derive(Debug, Eq, PartialEq, Clone)] pub struct LlamaChatMessage { @@ -408,41 +446,84 @@ impl LlamaModel { unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) } } - /// Get chat template from model. - /// - /// # Errors - /// - /// * If the model has no chat template - /// * If the chat template is not a valid [`CString`]. - #[allow(clippy::missing_panics_doc)] // we statically know this will not panic as - pub fn get_chat_template(&self, buf_size: usize) -> Result { + fn get_chat_template_impl( + &self, + capacity: usize, + ) -> Result { // longest known template is about 1200 bytes from llama.cpp - let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null"); - let chat_ptr = chat_temp.into_raw(); - let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes"); + // TODO: Once MaybeUninit support is better, this can be converted to use that instead of dummy initializing such a large array. + let mut chat_temp = vec![b'*' as u8; capacity]; + let chat_name = + CStr::from_bytes_with_nul(b"tokenizer.chat_template\0").expect("should have null byte"); let ret = unsafe { llama_cpp_sys_2::llama_model_meta_val_str( self.model.as_ptr(), chat_name.as_ptr(), - chat_ptr, - buf_size, + chat_temp.as_mut_ptr() as *mut c_char, + chat_temp.len(), ) }; if ret < 0 { - return Err(ChatTemplateError::MissingTemplate(ret)); + return Err(InternalChatTemplateError::Permanent( + ChatTemplateError::MissingTemplate(ret), + )); } - let template_c = unsafe { CString::from_raw(chat_ptr) }; - let template = template_c.to_str()?; + let returned_len = ret as usize; - let ret: usize = ret.try_into().unwrap(); - if template.len() < ret { - return Err(ChatTemplateError::BuffSizeError(ret + 1)); + if ret as usize >= capacity { + // >= is important because if the returned length is equal to capacity, it means we're missing a trailing null + // since the returned length doesn't count the trailing null. + return Err(InternalChatTemplateError::RetryWithLargerBuffer( + returned_len, + )); } - Ok(template.to_owned()) + assert_eq!( + chat_temp.get(returned_len), + Some(&0), + "should end with null byte" + ); + + chat_temp.resize(returned_len + 1, 0); + + Ok(LlamaChatTemplate(unsafe { + CString::from_vec_with_nul_unchecked(chat_temp) + })) + } + + /// Get chat template from model. If this fails, you may either want to fail to chat or pick the + /// specific shortcode that llama.cpp supports templates it has baked-in directly into its codebase + /// as fallbacks when the model doesn't contain. NOTE: If you don't specify a chat template, then + /// it uses chatml by default which is unlikely to actually be the correct template for your model + /// and you'll get weird results back. + /// + /// You supply this into [Self::apply_chat_template] to get back a string with the appropriate template + /// substitution applied to convert a list of messages into a prompt the LLM can use to complete + /// the chat. + /// + /// # Errors + /// + /// * If the model has no chat template + /// * If the chat template is not a valid [`CString`]. + #[allow(clippy::missing_panics_doc)] // we statically know this will not panic as + pub fn get_chat_template(&self) -> Result { + // Typical chat templates are quite small. Let's start with a small allocation likely to succeed. + // Ideally the performance of this would be negligible but uninitialized arrays in Rust are currently + // still not well supported so we end up initializing the chat template buffer twice. One idea might + // be to use a very small value here that will likely fail (like 0 or 1) and then use that to initialize. + // Not sure which approach is the most optimal but in practice this should work well. + match self.get_chat_template_impl(200) { + Ok(t) => Ok(t), + Err(InternalChatTemplateError::Permanent(e)) => Err(e), + Err(InternalChatTemplateError::RetryWithLargerBuffer(actual_len)) => match self.get_chat_template_impl(actual_len + 1) { + Ok(t) => Ok(t), + Err(InternalChatTemplateError::Permanent(e)) => Err(e), + Err(InternalChatTemplateError::RetryWithLargerBuffer(unexpected_len)) => panic!("Was told that the template length was {actual_len} but now it's {unexpected_len}"), + } + } } /// Loads a model from a file. @@ -526,15 +607,25 @@ impl LlamaModel { /// Apply the models chat template to some messages. /// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template /// - /// `tmpl` of None means to use the default template provided by llama.cpp for the model + /// Unlike the llama.cpp apply_chat_template which just randomly uses the ChatML template when given + /// a null pointer for the template, this requires an explicit template to be specified. If you want to + /// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template + /// string. + /// + /// Use [Self::get_chat_template] to retrieve the template baked into the model (this is the preferred + /// mechanism as using the wrong chat template can result in really unexpected responses from the LLM). + /// + /// You probably want to set `add_ass` to true so that the generated template string ends with a the + /// opening tag of the assistant. If you fail to leave a hanging chat tag, the model will likely generate + /// one into the output and the output may also have unexpected output aside from that. /// /// # Errors /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information. #[tracing::instrument(skip_all)] pub fn apply_chat_template( &self, - tmpl: Option, - chat: Vec, + tmpl: &LlamaChatTemplate, + chat: &[LlamaChatMessage], add_ass: bool, ) -> Result { // Buffer is twice the length of messages per their recommendation @@ -552,12 +643,7 @@ impl LlamaModel { }) .collect(); - // Set the tmpl pointer - let tmpl = tmpl.map(CString::new); - let tmpl_ptr = match &tmpl { - Some(str) => str.as_ref().map_err(Clone::clone)?.as_ptr(), - None => std::ptr::null(), - }; + let tmpl_ptr = tmpl.0.as_ptr(); let res = unsafe { llama_cpp_sys_2::llama_chat_apply_template(