diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index 8946da2b..94e08b10 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -318,7 +318,7 @@ impl<'model> LlamaContext<'model> { scale: f32, ) -> Result<(), LlamaLoraAdapterSetError> { let err_code = unsafe { - llama_cpp_sys_2::llama_lora_adapter_set( + llama_cpp_sys_2::llama_set_adapter_lora( self.context.as_ptr(), adapter.lora_adapter.as_ptr(), scale, @@ -342,7 +342,7 @@ impl<'model> LlamaContext<'model> { adapter: &mut LlamaLoraAdapter, ) -> Result<(), LlamaLoraAdapterRemoveError> { let err_code = unsafe { - llama_cpp_sys_2::llama_lora_adapter_remove( + llama_cpp_sys_2::llama_rm_adapter_lora( self.context.as_ptr(), adapter.lora_adapter.as_ptr(), ) diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 127ff6b7..deefaf0d 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -31,7 +31,7 @@ pub struct LlamaModel { #[repr(transparent)] #[allow(clippy::module_name_repetitions)] pub struct LlamaLoraAdapter { - pub(crate) lora_adapter: NonNull, + pub(crate) lora_adapter: NonNull, } /// A Safe wrapper around `llama_chat_message` @@ -74,6 +74,10 @@ unsafe impl Send for LlamaModel {} unsafe impl Sync for LlamaModel {} impl LlamaModel { + pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab { + unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) } + } + /// get the number of tokens the model was trained on /// /// # Panics @@ -99,28 +103,28 @@ impl LlamaModel { /// Get the beginning of stream token. #[must_use] pub fn token_bos(&self) -> LlamaToken { - let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.model.as_ptr()) }; + let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) }; LlamaToken(token) } /// Get the end of stream token. #[must_use] pub fn token_eos(&self) -> LlamaToken { - let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.model.as_ptr()) }; + let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) }; LlamaToken(token) } /// Get the newline token. #[must_use] pub fn token_nl(&self) -> LlamaToken { - let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.model.as_ptr()) }; + let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) }; LlamaToken(token) } /// Check if a token represents the end of generation (end of turn, end of sequence, etc.) #[must_use] pub fn is_eog_token(&self, token: LlamaToken) -> bool { - unsafe { llama_cpp_sys_2::llama_token_is_eog(self.model.as_ptr(), token.0) } + unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) } } /// Get the decoder start token. @@ -225,7 +229,7 @@ impl LlamaModel { let size = unsafe { llama_cpp_sys_2::llama_tokenize( - self.model.as_ptr(), + self.vocab_ptr(), c_string.as_ptr(), c_int::try_from(c_string.as_bytes().len())?, buffer.as_mut_ptr() as *mut llama_cpp_sys_2::llama_token, @@ -241,7 +245,7 @@ impl LlamaModel { buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger ")); unsafe { llama_cpp_sys_2::llama_tokenize( - self.model.as_ptr(), + self.vocab_ptr(), c_string.as_ptr(), c_int::try_from(c_string.as_bytes().len())?, buffer.as_mut_ptr() as *mut llama_cpp_sys_2::llama_token, @@ -268,7 +272,7 @@ impl LlamaModel { /// If the token type is not known to this library. #[must_use] pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs { - let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.model.as_ptr(), id) }; + let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) }; LlamaTokenAttrs::try_from(token_type).expect("token type is valid") } @@ -347,7 +351,7 @@ impl LlamaModel { let lstrip = lstrip.map_or(0, |it| i32::from(it.get())); let size = unsafe { llama_cpp_sys_2::llama_token_to_piece( - self.model.as_ptr(), + self.vocab_ptr(), token.0, buf, len, @@ -374,7 +378,7 @@ impl LlamaModel { /// without issue. #[must_use] pub fn n_vocab(&self) -> i32 { - unsafe { llama_cpp_sys_2::llama_n_vocab(self.model.as_ptr()) } + unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) } } /// The type of vocab the model was trained on. @@ -384,7 +388,8 @@ impl LlamaModel { /// If llama-cpp emits a vocab type that is not known to this library. #[must_use] pub fn vocab_type(&self) -> VocabType { - let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.model.as_ptr()) }; + // llama_cpp_sys_2::llama_model_get_vocab + let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) }; VocabType::try_from(vocab_type).expect("invalid vocab type") } @@ -479,7 +484,7 @@ impl LlamaModel { let cstr = CString::new(path)?; let adapter = - unsafe { llama_cpp_sys_2::llama_lora_adapter_init(self.model.as_ptr(), cstr.as_ptr()) }; + unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) }; let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?; @@ -548,7 +553,6 @@ impl LlamaModel { let res = unsafe { llama_cpp_sys_2::llama_chat_apply_template( - self.model.as_ptr(), tmpl_ptr, chat.as_ptr(), chat.len(), @@ -563,7 +567,6 @@ impl LlamaModel { let res = unsafe { llama_cpp_sys_2::llama_chat_apply_template( - self.model.as_ptr(), tmpl_ptr, chat.as_ptr(), chat.len(), diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index 88c4ee5d..d33f92c8 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -238,7 +238,7 @@ impl LlamaSampler { let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_grammar( - model.model.as_ptr(), + model.vocab_ptr(), grammar_str.as_ptr(), grammar_root.as_ptr(), ) @@ -264,14 +264,15 @@ impl LlamaSampler { ) -> Self { let seq_breakers: Vec = seq_breakers .into_iter() - .map(|s| CString::new(s.as_ref()).unwrap()) + .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes")) .collect(); let mut seq_breaker_pointers: Vec<*const CChar> = seq_breakers.iter().map(|s| s.as_ptr()).collect(); let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_dry( - model.model.as_ptr(), + model.vocab_ptr(), + model.n_ctx_train().try_into().expect("n_ctx_train is greater than two billion"), multiplier, base, allowed_length, @@ -286,74 +287,29 @@ impl LlamaSampler { /// Penalizes tokens for being present in the context. /// /// Parameters: - /// - ``n_vocab``: [`LlamaModel::n_vocab`] - /// - ``special_eos)id``: [`LlamaModel::token_eos`] - /// - ``linefeed_id``: [`LlamaModel::token_nl`] /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size) /// - ``penalty_repeat``: 1.0 = disabled /// - ``penalty_freq``: 0.0 = disabled /// - ``penalty_present``: 0.0 = disabled - /// - ``penalize_nl``: consider newlines as a repeatable token - /// - ``ignore_eos``: ignore the end-of-sequence token #[allow(clippy::too_many_arguments)] #[must_use] pub fn penalties( - n_vocab: i32, - special_eos_id: i32, - linefeed_id: i32, penalty_last_n: i32, penalty_repeat: f32, penalty_freq: f32, penalty_present: f32, - penalize_nl: bool, - ignore_eos: bool, ) -> Self { let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_penalties( - n_vocab, - special_eos_id, - linefeed_id, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, - penalize_nl, - ignore_eos, ) }; Self { sampler } } - /// Same as [`Self::penalties`], but with `n_vocab`, `special_eos_id`, and `linefeed_id` - /// initialized from `model`, `penalize_nl = false`, and `ignore_eos = true`. - /// - /// Parameters: - /// - ``model``: The model's tokenizer to use to initialize the sampler. - /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size) - /// - ``penalty_repeat``: 1.0 = disabled - /// - ``penalty_freq``: 0.0 = disabled - /// - ``penalty_present``: 0.0 = disabled - #[must_use] - pub fn penalties_simple( - model: &LlamaModel, - penalty_last_n: i32, - penalty_repeat: f32, - penalty_freq: f32, - penalty_present: f32, - ) -> Self { - Self::penalties( - model.n_vocab(), - model.token_eos().0, - model.token_nl().0, - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - false, - true, - ) - } - /// Mirostat 1.0 algorithm described in the paper . Uses tokens instead of words. /// /// # Parameters: diff --git a/llama-cpp-sys-2/llama.cpp b/llama-cpp-sys-2/llama.cpp index 64ed2091..6171c9d2 160000 --- a/llama-cpp-sys-2/llama.cpp +++ b/llama-cpp-sys-2/llama.cpp @@ -1 +1 @@ -Subproject commit 64ed2091b24b2f9747148fdf49a34ed5938762c3 +Subproject commit 6171c9d25820ccf676b243c172868819d882848f