diff --git a/llama-cpp-2/src/context.rs b/llama-cpp-2/src/context.rs index 8946da2b..10f2d7eb 100644 --- a/llama-cpp-2/src/context.rs +++ b/llama-cpp-2/src/context.rs @@ -52,13 +52,13 @@ impl<'model> LlamaContext<'model> { } } - /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to n_ubatch. + /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to [`Self::n_ubatch`]. #[must_use] pub fn n_batch(&self) -> u32 { unsafe { llama_cpp_sys_2::llama_n_batch(self.context.as_ptr()) } } - /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to n_batch. + /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to [`Self::n_batch`]. #[must_use] pub fn n_ubatch(&self) -> u32 { unsafe { llama_cpp_sys_2::llama_n_ubatch(self.context.as_ptr()) } @@ -318,7 +318,7 @@ impl<'model> LlamaContext<'model> { scale: f32, ) -> Result<(), LlamaLoraAdapterSetError> { let err_code = unsafe { - llama_cpp_sys_2::llama_lora_adapter_set( + llama_cpp_sys_2::llama_set_adapter_lora( self.context.as_ptr(), adapter.lora_adapter.as_ptr(), scale, @@ -342,7 +342,7 @@ impl<'model> LlamaContext<'model> { adapter: &mut LlamaLoraAdapter, ) -> Result<(), LlamaLoraAdapterRemoveError> { let err_code = unsafe { - llama_cpp_sys_2::llama_lora_adapter_remove( + llama_cpp_sys_2::llama_rm_adapter_lora( self.context.as_ptr(), adapter.lora_adapter.as_ptr(), ) diff --git a/llama-cpp-2/src/context/kv_cache.rs b/llama-cpp-2/src/context/kv_cache.rs index d5a8ed65..d90a6b8a 100644 --- a/llama-cpp-2/src/context/kv_cache.rs +++ b/llama-cpp-2/src/context/kv_cache.rs @@ -6,6 +6,7 @@ use std::num::{NonZeroU8, TryFromIntError}; /// Errors that can occur when attempting to prepare values for the kv cache #[derive(Debug, Eq, PartialEq, thiserror::Error)] +#[allow(clippy::module_name_repetitions)] pub enum KvCacheConversionError { /// Sequence id conversion to i32 failed #[error("Provided sequence id is too large for a i32")] @@ -33,15 +34,16 @@ impl LlamaContext<'_> { /// Copy the cache from one sequence to another. /// /// # Returns - /// A `Result` indicating whether the operation was successful. If the either position exceeds - /// the maximum i32 value, no copy is attempted and an `Err` is returned. + /// A `Result` indicating whether the operation was successful. /// /// # Parameters - /// /// * `src` - The sequence id to copy the cache from. /// * `dest` - The sequence id to copy the cache to. /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`. /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`. + /// + /// # Errors + /// If either position exceeds [`i32::MAX`]. pub fn copy_kv_cache_seq( &mut self, src: i32, @@ -51,10 +53,10 @@ impl LlamaContext<'_> { ) -> Result<(), KvCacheConversionError> { let p0 = p0 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P0TooLarge(e))?; + .map_err(KvCacheConversionError::P0TooLarge)?; let p1 = p1 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P1TooLarge(e))?; + .map_err(KvCacheConversionError::P1TooLarge)?; unsafe { llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1); } @@ -69,10 +71,12 @@ impl LlamaContext<'_> { /// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned. /// /// # Parameters - /// /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`. /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`. + /// + /// # Errors + /// If the sequence id or either position exceeds [`i32::MAX`]. pub fn clear_kv_cache_seq( &mut self, src: Option, @@ -81,13 +85,13 @@ impl LlamaContext<'_> { ) -> Result { let src = src .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::SeqIdTooLarge(e))?; + .map_err(KvCacheConversionError::SeqIdTooLarge)?; let p0 = p0 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P0TooLarge(e))?; + .map_err(KvCacheConversionError::P0TooLarge)?; let p1 = p1 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P1TooLarge(e))?; + .map_err(KvCacheConversionError::P1TooLarge)?; Ok(unsafe { llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1) }) } @@ -118,8 +122,7 @@ impl LlamaContext<'_> { /// - explicitly with [`Self::kv_cache_update`] /// /// # Returns - /// A `Result` indicating whether the operation was successful. If either position - /// exceeds the maximum i32 value, no update is attempted and an `Err` is returned. + /// A `Result` indicating whether the operation was successful. /// /// # Parameters /// @@ -127,6 +130,9 @@ impl LlamaContext<'_> { /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`. /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`. /// * `delta` - The relative position to add to the tokens + /// + /// # Errors + /// If either position exceeds [`i32::MAX`]. pub fn kv_cache_seq_add( &mut self, seq_id: i32, @@ -136,10 +142,10 @@ impl LlamaContext<'_> { ) -> Result<(), KvCacheConversionError> { let p0 = p0 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P0TooLarge(e))?; + .map_err(KvCacheConversionError::P0TooLarge)?; let p1 = p1 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P1TooLarge(e))?; + .map_err(KvCacheConversionError::P1TooLarge)?; unsafe { llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta); } @@ -152,8 +158,7 @@ impl LlamaContext<'_> { /// - explicitly with [`Self::kv_cache_update`] /// /// # Returns - /// A `Result` indicating whether the operation was successful. If either position - /// exceeds the maximum i32 value, no update is attempted and an `Err` is returned. + /// A `Result` indicating whether the operation was successful. /// /// # Parameters /// @@ -161,6 +166,9 @@ impl LlamaContext<'_> { /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`. /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`. /// * `d` - The factor to divide the positions by + /// + /// # Errors + /// If either position exceeds [`i32::MAX`]. pub fn kv_cache_seq_div( &mut self, seq_id: i32, @@ -170,10 +178,10 @@ impl LlamaContext<'_> { ) -> Result<(), KvCacheConversionError> { let p0 = p0 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P0TooLarge(e))?; + .map_err(KvCacheConversionError::P0TooLarge)?; let p1 = p1 .map_or(Ok(-1), i32::try_from) - .map_err(|e| KvCacheConversionError::P1TooLarge(e))?; + .map_err(KvCacheConversionError::P1TooLarge)?; let d = c_int::from(d.get()); unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) } Ok(()) @@ -239,7 +247,7 @@ pub struct KVCacheView<'a> { view: llama_cpp_sys_2::llama_kv_cache_view, } -impl<'a> KVCacheView<'a> { +impl KVCacheView<'_> { /// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) pub fn update(&mut self) { unsafe { @@ -314,7 +322,7 @@ impl<'a> KVCacheView<'a> { } } -impl<'a> Drop for KVCacheView<'a> { +impl Drop for KVCacheView<'_> { fn drop(&mut self) { unsafe { llama_cpp_sys_2::llama_kv_cache_view_free(&mut self.view); diff --git a/llama-cpp-2/src/llama_batch.rs b/llama-cpp-2/src/llama_batch.rs index 153f5d52..b96588c7 100644 --- a/llama-cpp-2/src/llama_batch.rs +++ b/llama-cpp-2/src/llama_batch.rs @@ -10,6 +10,7 @@ pub struct LlamaBatch { allocated: usize, /// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed. pub(crate) initialized_logits: Vec, + #[allow(clippy::doc_markdown)] /// The llama_cpp batch. always initialize by `llama_cpp_sys_2::llama_batch_init(allocated, , )` pub(crate) llama_batch: llama_batch, } @@ -20,7 +21,7 @@ pub enum BatchAddError { /// There was not enough space in the batch to add the token. #[error("Insufficient Space of {0}")] InsufficientSpace(usize), - /// Empty buffer is provided for get_one + /// Empty buffer is provided for [`LlamaBatch::get_one`] #[error("Empty buffer")] EmptyBuffer, } @@ -152,22 +153,35 @@ impl LlamaBatch { } } - /// llama_batch_get_one - /// Return batch for single sequence of tokens starting at pos_0 + /// ``llama_batch_get_one`` + /// Return batch for single sequence of tokens /// /// NOTE: this is a helper function to facilitate transition to the new batch API /// + /// # Errors + /// If the provided token buffer is empty. + /// + /// # Panics + /// If the number of tokens in ``tokens`` exceeds [`i32::MAX`]. pub fn get_one(tokens: &[LlamaToken]) -> Result { if tokens.is_empty() { return Err(BatchAddError::EmptyBuffer); } let batch = unsafe { let ptr = tokens.as_ptr() as *mut i32; - llama_cpp_sys_2::llama_batch_get_one(ptr, tokens.len() as i32) + llama_cpp_sys_2::llama_batch_get_one( + ptr, + tokens + .len() + .try_into() + .expect("number of tokens exceeds i32::MAX"), + ) }; let batch = Self { allocated: 0, - initialized_logits: vec![(tokens.len() - 1) as i32], + initialized_logits: vec![(tokens.len() - 1) + .try_into() + .expect("number of tokens exceeds i32::MAX + 1")], llama_batch: batch, }; Ok(batch) diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 127ff6b7..85927ec6 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -31,7 +31,7 @@ pub struct LlamaModel { #[repr(transparent)] #[allow(clippy::module_name_repetitions)] pub struct LlamaLoraAdapter { - pub(crate) lora_adapter: NonNull, + pub(crate) lora_adapter: NonNull, } /// A Safe wrapper around `llama_chat_message` @@ -43,6 +43,9 @@ pub struct LlamaChatMessage { impl LlamaChatMessage { /// Create a new `LlamaChatMessage` + /// + /// # Errors + /// If either of ``role`` or ``content`` contain null bytes. pub fn new(role: String, content: String) -> Result { Ok(Self { role: CString::new(role)?, @@ -74,6 +77,10 @@ unsafe impl Send for LlamaModel {} unsafe impl Sync for LlamaModel {} impl LlamaModel { + pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab { + unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) } + } + /// get the number of tokens the model was trained on /// /// # Panics @@ -99,28 +106,28 @@ impl LlamaModel { /// Get the beginning of stream token. #[must_use] pub fn token_bos(&self) -> LlamaToken { - let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.model.as_ptr()) }; + let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) }; LlamaToken(token) } /// Get the end of stream token. #[must_use] pub fn token_eos(&self) -> LlamaToken { - let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.model.as_ptr()) }; + let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) }; LlamaToken(token) } /// Get the newline token. #[must_use] pub fn token_nl(&self) -> LlamaToken { - let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.model.as_ptr()) }; + let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) }; LlamaToken(token) } /// Check if a token represents the end of generation (end of turn, end of sequence, etc.) #[must_use] pub fn is_eog_token(&self, token: LlamaToken) -> bool { - unsafe { llama_cpp_sys_2::llama_token_is_eog(self.model.as_ptr(), token.0) } + unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) } } /// Get the decoder start token. @@ -148,17 +155,24 @@ impl LlamaModel { /// Convert single token to bytes. /// /// # Errors - /// /// See [`TokenToStringError`] for more information. + /// + /// # Panics + /// If a [`TokenToStringError::InsufficientBufferSpace`] error returned by + /// [`Self::token_to_bytes_with_size`] contains a positive nonzero value. This should never + /// happen. pub fn token_to_bytes( &self, token: LlamaToken, special: Special, ) -> Result, TokenToStringError> { match self.token_to_bytes_with_size(token, 8, special, None) { - Err(TokenToStringError::InsufficientBufferSpace(i)) => { - self.token_to_bytes_with_size(token, -i as usize, special, None) - } + Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_bytes_with_size( + token, + (-i).try_into().expect("Error buffer size is positive"), + special, + None, + ), x => x, } } @@ -225,10 +239,10 @@ impl LlamaModel { let size = unsafe { llama_cpp_sys_2::llama_tokenize( - self.model.as_ptr(), + self.vocab_ptr(), c_string.as_ptr(), c_int::try_from(c_string.as_bytes().len())?, - buffer.as_mut_ptr() as *mut llama_cpp_sys_2::llama_token, + buffer.as_mut_ptr().cast::(), buffer_capacity, add_bos, true, @@ -241,10 +255,10 @@ impl LlamaModel { buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger ")); unsafe { llama_cpp_sys_2::llama_tokenize( - self.model.as_ptr(), + self.vocab_ptr(), c_string.as_ptr(), c_int::try_from(c_string.as_bytes().len())?, - buffer.as_mut_ptr() as *mut llama_cpp_sys_2::llama_token, + buffer.as_mut_ptr().cast::(), -size, add_bos, true, @@ -268,7 +282,7 @@ impl LlamaModel { /// If the token type is not known to this library. #[must_use] pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs { - let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.model.as_ptr(), id) }; + let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) }; LlamaTokenAttrs::try_from(token_type).expect("token type is valid") } @@ -319,18 +333,16 @@ impl LlamaModel { lstrip: Option, ) -> Result, TokenToStringError> { if token == self.token_nl() { - return Ok(String::from("\n").into_bytes()); + return Ok(b"\n".to_vec()); } // unsure what to do with this in the face of the 'special' arg + attr changes let attrs = self.token_attr(token); - if attrs.contains(LlamaTokenAttr::Control) - && (token == self.token_bos() || token == self.token_eos()) - { - return Ok(Vec::new()); - } else if attrs.is_empty() + if attrs.is_empty() || attrs .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused) + || attrs.contains(LlamaTokenAttr::Control) + && (token == self.token_bos() || token == self.token_eos()) { return Ok(Vec::new()); } @@ -347,7 +359,7 @@ impl LlamaModel { let lstrip = lstrip.map_or(0, |it| i32::from(it.get())); let size = unsafe { llama_cpp_sys_2::llama_token_to_piece( - self.model.as_ptr(), + self.vocab_ptr(), token.0, buf, len, @@ -374,7 +386,7 @@ impl LlamaModel { /// without issue. #[must_use] pub fn n_vocab(&self) -> i32 { - unsafe { llama_cpp_sys_2::llama_n_vocab(self.model.as_ptr()) } + unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) } } /// The type of vocab the model was trained on. @@ -384,7 +396,8 @@ impl LlamaModel { /// If llama-cpp emits a vocab type that is not known to this library. #[must_use] pub fn vocab_type(&self) -> VocabType { - let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.model.as_ptr()) }; + // llama_cpp_sys_2::llama_model_get_vocab + let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) }; VocabType::try_from(vocab_type).expect("invalid vocab type") } @@ -479,7 +492,7 @@ impl LlamaModel { let cstr = CString::new(path)?; let adapter = - unsafe { llama_cpp_sys_2::llama_lora_adapter_init(self.model.as_ptr(), cstr.as_ptr()) }; + unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) }; let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?; @@ -548,33 +561,31 @@ impl LlamaModel { let res = unsafe { llama_cpp_sys_2::llama_chat_apply_template( - self.model.as_ptr(), tmpl_ptr, chat.as_ptr(), chat.len(), add_ass, buff.as_mut_ptr().cast::(), - buff.len() as i32, + buff.len().try_into().expect("Buffer size exceeds i32::MAX"), ) }; - if res > buff.len() as i32 { - buff.resize(res as usize, 0); + if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") { + buff.resize(res.try_into().expect("res is negative"), 0); let res = unsafe { llama_cpp_sys_2::llama_chat_apply_template( - self.model.as_ptr(), tmpl_ptr, chat.as_ptr(), chat.len(), add_ass, buff.as_mut_ptr().cast::(), - buff.len() as i32, + buff.len().try_into().expect("Buffer size exceeds i32::MAX"), ) }; - assert_eq!(res, buff.len() as i32); + assert_eq!(Ok(res), buff.len().try_into()); } - buff.truncate(res as usize); + buff.truncate(res.try_into().expect("res is negative")); Ok(String::from_utf8(buff)?) } } diff --git a/llama-cpp-2/src/model/params/kv_overrides.rs b/llama-cpp-2/src/model/params/kv_overrides.rs index 8bbcbdd4..b17516a1 100644 --- a/llama-cpp-2/src/model/params/kv_overrides.rs +++ b/llama-cpp-2/src/model/params/kv_overrides.rs @@ -104,7 +104,7 @@ pub struct KvOverrideValueIterator<'a> { current: usize, } -impl<'a> Iterator for KvOverrideValueIterator<'a> { +impl Iterator for KvOverrideValueIterator<'_> { type Item = (CString, ParamOverrideValue); fn next(&mut self) -> Option { diff --git a/llama-cpp-2/src/sampling.rs b/llama-cpp-2/src/sampling.rs index 88c4ee5d..b3a2cf4f 100644 --- a/llama-cpp-2/src/sampling.rs +++ b/llama-cpp-2/src/sampling.rs @@ -21,14 +21,13 @@ impl Debug for LlamaSampler { } // this is needed for the dry sampler to typecheck on android -// ...because what is normally an i8, is an u8 +// ...because what is normally an i8, is an u8 #[cfg(target_os = "android")] type CChar = u8; #[cfg(not(target_os = "android"))] type CChar = i8; - impl LlamaSampler { /// Sample and accept a token from the idx-th output of the last evaluation #[must_use] @@ -129,6 +128,7 @@ impl LlamaSampler { Self::chain(samplers, false) } + #[allow(clippy::doc_markdown)] /// Updates the logits l_i' = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original /// value, the rest are set to -inf /// @@ -238,7 +238,7 @@ impl LlamaSampler { let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_grammar( - model.model.as_ptr(), + model.vocab_ptr(), grammar_str.as_ptr(), grammar_root.as_ptr(), ) @@ -264,14 +264,18 @@ impl LlamaSampler { ) -> Self { let seq_breakers: Vec = seq_breakers .into_iter() - .map(|s| CString::new(s.as_ref()).unwrap()) + .map(|s| CString::new(s.as_ref()).expect("A sequence breaker contains null bytes")) .collect(); let mut seq_breaker_pointers: Vec<*const CChar> = seq_breakers.iter().map(|s| s.as_ptr()).collect(); let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_dry( - model.model.as_ptr(), + model.vocab_ptr(), + model + .n_ctx_train() + .try_into() + .expect("n_ctx_train exceeds i32::MAX"), multiplier, base, allowed_length, @@ -286,74 +290,29 @@ impl LlamaSampler { /// Penalizes tokens for being present in the context. /// /// Parameters: - /// - ``n_vocab``: [`LlamaModel::n_vocab`] - /// - ``special_eos)id``: [`LlamaModel::token_eos`] - /// - ``linefeed_id``: [`LlamaModel::token_nl`] /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size) /// - ``penalty_repeat``: 1.0 = disabled /// - ``penalty_freq``: 0.0 = disabled /// - ``penalty_present``: 0.0 = disabled - /// - ``penalize_nl``: consider newlines as a repeatable token - /// - ``ignore_eos``: ignore the end-of-sequence token #[allow(clippy::too_many_arguments)] #[must_use] pub fn penalties( - n_vocab: i32, - special_eos_id: i32, - linefeed_id: i32, penalty_last_n: i32, penalty_repeat: f32, penalty_freq: f32, penalty_present: f32, - penalize_nl: bool, - ignore_eos: bool, ) -> Self { let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_penalties( - n_vocab, - special_eos_id, - linefeed_id, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, - penalize_nl, - ignore_eos, ) }; Self { sampler } } - /// Same as [`Self::penalties`], but with `n_vocab`, `special_eos_id`, and `linefeed_id` - /// initialized from `model`, `penalize_nl = false`, and `ignore_eos = true`. - /// - /// Parameters: - /// - ``model``: The model's tokenizer to use to initialize the sampler. - /// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size) - /// - ``penalty_repeat``: 1.0 = disabled - /// - ``penalty_freq``: 0.0 = disabled - /// - ``penalty_present``: 0.0 = disabled - #[must_use] - pub fn penalties_simple( - model: &LlamaModel, - penalty_last_n: i32, - penalty_repeat: f32, - penalty_freq: f32, - penalty_present: f32, - ) -> Self { - Self::penalties( - model.n_vocab(), - model.token_eos().0, - model.token_nl().0, - penalty_last_n, - penalty_repeat, - penalty_freq, - penalty_present, - false, - true, - ) - } - /// Mirostat 1.0 algorithm described in the paper . Uses tokens instead of words. /// /// # Parameters: diff --git a/llama-cpp-2/src/token/data_array.rs b/llama-cpp-2/src/token/data_array.rs index 7f583064..448864b9 100644 --- a/llama-cpp-2/src/token/data_array.rs +++ b/llama-cpp-2/src/token/data_array.rs @@ -141,7 +141,7 @@ impl LlamaTokenDataArray { /// # Panics /// If the internal llama.cpp sampler fails to select a token. pub fn sample_token(&mut self, seed: u32) -> LlamaToken { - self.apply_sampler(&mut LlamaSampler::dist(seed)); + self.apply_sampler(&LlamaSampler::dist(seed)); self.selected_token() .expect("Dist sampler failed to select a token!") } @@ -151,7 +151,7 @@ impl LlamaTokenDataArray { /// # Panics /// If the internal llama.cpp sampler fails to select a token. pub fn sample_token_greedy(&mut self) -> LlamaToken { - self.apply_sampler(&mut LlamaSampler::greedy()); + self.apply_sampler(&LlamaSampler::greedy()); self.selected_token() .expect("Greedy sampler failed to select a token!") } diff --git a/llama-cpp-sys-2/build.rs b/llama-cpp-sys-2/build.rs index cdec57e1..7fff6bba 100644 --- a/llama-cpp-sys-2/build.rs +++ b/llama-cpp-sys-2/build.rs @@ -58,12 +58,10 @@ fn extract_lib_names(out_dir: &Path, build_shared_libs: bool) -> Vec { } else { "*.a" } + } else if build_shared_libs { + "*.so" } else { - if build_shared_libs { - "*.so" - } else { - "*.a" - } + "*.a" }; let libs_dir = out_dir.join("lib*"); let pattern = libs_dir.join(lib_pattern); @@ -294,21 +292,14 @@ fn main() { assert_ne!(llama_libs.len(), 0); for lib in llama_libs { - debug_log!( - "LINK {}", - format!("cargo:rustc-link-lib={}={}", llama_libs_kind, lib) - ); - println!( - "{}", - format!("cargo:rustc-link-lib={}={}", llama_libs_kind, lib) - ); + let link = format!("cargo:rustc-link-lib={}={}", llama_libs_kind, lib); + debug_log!("LINK {link}",); + println!("{link}",); } // OpenMP - if cfg!(feature = "openmp") { - if target.contains("gnu") { - println!("cargo:rustc-link-lib=gomp"); - } + if cfg!(feature = "openmp") && target.contains("gnu") { + println!("cargo:rustc-link-lib=gomp"); } // Windows debug diff --git a/llama-cpp-sys-2/llama.cpp b/llama-cpp-sys-2/llama.cpp index 64ed2091..6171c9d2 160000 --- a/llama-cpp-sys-2/llama.cpp +++ b/llama-cpp-sys-2/llama.cpp @@ -1 +1 @@ -Subproject commit 64ed2091b24b2f9747148fdf49a34ed5938762c3 +Subproject commit 6171c9d25820ccf676b243c172868819d882848f