Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update llama.cpp dependency to the latest version; Fix all clippy lints. #622

Merged
merged 2 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions llama-cpp-2/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ impl<'model> LlamaContext<'model> {
}
}

/// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to n_ubatch.
/// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to [`Self::n_ubatch`].
#[must_use]
pub fn n_batch(&self) -> u32 {
unsafe { llama_cpp_sys_2::llama_n_batch(self.context.as_ptr()) }
}

/// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to n_batch.
/// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to [`Self::n_batch`].
#[must_use]
pub fn n_ubatch(&self) -> u32 {
unsafe { llama_cpp_sys_2::llama_n_ubatch(self.context.as_ptr()) }
Expand Down Expand Up @@ -318,7 +318,7 @@ impl<'model> LlamaContext<'model> {
scale: f32,
) -> Result<(), LlamaLoraAdapterSetError> {
let err_code = unsafe {
llama_cpp_sys_2::llama_lora_adapter_set(
llama_cpp_sys_2::llama_set_adapter_lora(
self.context.as_ptr(),
adapter.lora_adapter.as_ptr(),
scale,
Expand All @@ -342,7 +342,7 @@ impl<'model> LlamaContext<'model> {
adapter: &mut LlamaLoraAdapter,
) -> Result<(), LlamaLoraAdapterRemoveError> {
let err_code = unsafe {
llama_cpp_sys_2::llama_lora_adapter_remove(
llama_cpp_sys_2::llama_rm_adapter_lora(
self.context.as_ptr(),
adapter.lora_adapter.as_ptr(),
)
Expand Down
46 changes: 27 additions & 19 deletions llama-cpp-2/src/context/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::num::{NonZeroU8, TryFromIntError};

/// Errors that can occur when attempting to prepare values for the kv cache
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
#[allow(clippy::module_name_repetitions)]
pub enum KvCacheConversionError {
/// Sequence id conversion to i32 failed
#[error("Provided sequence id is too large for a i32")]
Expand Down Expand Up @@ -33,15 +34,16 @@ impl LlamaContext<'_> {
/// Copy the cache from one sequence to another.
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If the either position exceeds
/// the maximum i32 value, no copy is attempted and an `Err` is returned.
/// A `Result` indicating whether the operation was successful.
///
/// # Parameters
///
/// * `src` - The sequence id to copy the cache from.
/// * `dest` - The sequence id to copy the cache to.
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
///
/// # Errors
/// If either position exceeds [`i32::MAX`].
pub fn copy_kv_cache_seq(
&mut self,
src: i32,
Expand All @@ -51,10 +53,10 @@ impl LlamaContext<'_> {
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
.map_err(KvCacheConversionError::P0TooLarge)?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
.map_err(KvCacheConversionError::P1TooLarge)?;
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_cp(self.context.as_ptr(), src, dest, p0, p1);
}
Expand All @@ -69,10 +71,12 @@ impl LlamaContext<'_> {
/// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
///
/// # Parameters
///
/// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
/// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
/// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
///
/// # Errors
/// If the sequence id or either position exceeds [`i32::MAX`].
pub fn clear_kv_cache_seq(
&mut self,
src: Option<u32>,
Expand All @@ -81,13 +85,13 @@ impl LlamaContext<'_> {
) -> Result<bool, KvCacheConversionError> {
let src = src
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::SeqIdTooLarge(e))?;
.map_err(KvCacheConversionError::SeqIdTooLarge)?;
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
.map_err(KvCacheConversionError::P0TooLarge)?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
.map_err(KvCacheConversionError::P1TooLarge)?;
Ok(unsafe { llama_cpp_sys_2::llama_kv_cache_seq_rm(self.context.as_ptr(), src, p0, p1) })
}

Expand Down Expand Up @@ -118,15 +122,17 @@ impl LlamaContext<'_> {
/// - explicitly with [`Self::kv_cache_update`]
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If either position
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
/// A `Result` indicating whether the operation was successful.
///
/// # Parameters
///
/// * `seq_id` - The sequence id to update
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
/// * `delta` - The relative position to add to the tokens
///
/// # Errors
/// If either position exceeds [`i32::MAX`].
pub fn kv_cache_seq_add(
&mut self,
seq_id: i32,
Expand All @@ -136,10 +142,10 @@ impl LlamaContext<'_> {
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
.map_err(KvCacheConversionError::P0TooLarge)?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
.map_err(KvCacheConversionError::P1TooLarge)?;
unsafe {
llama_cpp_sys_2::llama_kv_cache_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta);
}
Expand All @@ -152,15 +158,17 @@ impl LlamaContext<'_> {
/// - explicitly with [`Self::kv_cache_update`]
///
/// # Returns
/// A `Result` indicating whether the operation was successful. If either position
/// exceeds the maximum i32 value, no update is attempted and an `Err` is returned.
/// A `Result` indicating whether the operation was successful.
///
/// # Parameters
///
/// * `seq_id` - The sequence id to update
/// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
/// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
/// * `d` - The factor to divide the positions by
///
/// # Errors
/// If either position exceeds [`i32::MAX`].
pub fn kv_cache_seq_div(
&mut self,
seq_id: i32,
Expand All @@ -170,10 +178,10 @@ impl LlamaContext<'_> {
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P0TooLarge(e))?;
.map_err(KvCacheConversionError::P0TooLarge)?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(|e| KvCacheConversionError::P1TooLarge(e))?;
.map_err(KvCacheConversionError::P1TooLarge)?;
let d = c_int::from(d.get());
unsafe { llama_cpp_sys_2::llama_kv_cache_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
Ok(())
Expand Down Expand Up @@ -239,7 +247,7 @@ pub struct KVCacheView<'a> {
view: llama_cpp_sys_2::llama_kv_cache_view,
}

impl<'a> KVCacheView<'a> {
impl KVCacheView<'_> {
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
pub fn update(&mut self) {
unsafe {
Expand Down Expand Up @@ -314,7 +322,7 @@ impl<'a> KVCacheView<'a> {
}
}

impl<'a> Drop for KVCacheView<'a> {
impl Drop for KVCacheView<'_> {
fn drop(&mut self) {
unsafe {
llama_cpp_sys_2::llama_kv_cache_view_free(&mut self.view);
Expand Down
24 changes: 19 additions & 5 deletions llama-cpp-2/src/llama_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub struct LlamaBatch {
allocated: usize,
/// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed.
pub(crate) initialized_logits: Vec<i32>,
#[allow(clippy::doc_markdown)]
/// The llama_cpp batch. always initialize by `llama_cpp_sys_2::llama_batch_init(allocated, <unknown>, <unknown>)`
pub(crate) llama_batch: llama_batch,
}
Expand All @@ -20,7 +21,7 @@ pub enum BatchAddError {
/// There was not enough space in the batch to add the token.
#[error("Insufficient Space of {0}")]
InsufficientSpace(usize),
/// Empty buffer is provided for get_one
/// Empty buffer is provided for [`LlamaBatch::get_one`]
#[error("Empty buffer")]
EmptyBuffer,
}
Expand Down Expand Up @@ -152,22 +153,35 @@ impl LlamaBatch {
}
}

/// llama_batch_get_one
/// Return batch for single sequence of tokens starting at pos_0
/// ``llama_batch_get_one``
/// Return batch for single sequence of tokens
///
/// NOTE: this is a helper function to facilitate transition to the new batch API
///
/// # Errors
/// If the provided token buffer is empty.
///
/// # Panics
/// If the number of tokens in ``tokens`` exceeds [`i32::MAX`].
pub fn get_one(tokens: &[LlamaToken]) -> Result<Self, BatchAddError> {
if tokens.is_empty() {
return Err(BatchAddError::EmptyBuffer);
}
let batch = unsafe {
let ptr = tokens.as_ptr() as *mut i32;
llama_cpp_sys_2::llama_batch_get_one(ptr, tokens.len() as i32)
llama_cpp_sys_2::llama_batch_get_one(
ptr,
tokens
.len()
.try_into()
.expect("number of tokens exceeds i32::MAX"),
)
};
let batch = Self {
allocated: 0,
initialized_logits: vec![(tokens.len() - 1) as i32],
initialized_logits: vec![(tokens.len() - 1)
.try_into()
.expect("number of tokens exceeds i32::MAX + 1")],
llama_batch: batch,
};
Ok(batch)
Expand Down
Loading
Loading