From 8c52621ba46c2c0a6081998292b8b3fcc5416501 Mon Sep 17 00:00:00 2001 From: marcus Date: Sun, 21 Jan 2024 12:46:19 -0800 Subject: [PATCH] fixed up `LlamaContextParams` with new CB --- Cargo.lock | 4 +- llama-cpp-2/examples/simple.rs | 8 +- llama-cpp-2/src/context/params.rs | 227 +++++++++++++----------------- llama-cpp-2/src/model.rs | 11 +- 4 files changed, 105 insertions(+), 145 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index de2aa408..91f101ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -714,7 +714,7 @@ checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "llama-cpp-2" -version = "0.1.11" +version = "0.1.13" dependencies = [ "anyhow", "clap", @@ -728,7 +728,7 @@ dependencies = [ [[package]] name = "llama-cpp-sys-2" -version = "0.1.11" +version = "0.1.13" dependencies = [ "bindgen", "cc", diff --git a/llama-cpp-2/examples/simple.rs b/llama-cpp-2/examples/simple.rs index 21197659..c69bb0a8 100644 --- a/llama-cpp-2/examples/simple.rs +++ b/llama-cpp-2/examples/simple.rs @@ -56,11 +56,9 @@ fn main() -> Result<()> { .with_context(|| "unable to load model")?; // initialize the context - let ctx_params = LlamaContextParams { - seed: 1234, - n_ctx: NonZeroU32::new(2048), - ..LlamaContextParams::default() - }; + let ctx_params = LlamaContextParams::default() + .with_n_ctx(NonZeroU32::new(2048)) + .with_seed(1234); let mut ctx = model.new_context(&backend, ctx_params) .with_context(|| "unable to create the llama_context")?; diff --git a/llama-cpp-2/src/context/params.rs b/llama-cpp-2/src/context/params.rs index 2352c548..775b7a49 100644 --- a/llama-cpp-2/src/context/params.rs +++ b/llama-cpp-2/src/context/params.rs @@ -1,5 +1,5 @@ //! A safe wrapper around `llama_context_params`. -use llama_cpp_sys_2::{ggml_type, llama_context_params}; +use llama_cpp_sys_2; use std::fmt::Debug; use std::num::NonZeroU32; @@ -43,36 +43,102 @@ impl From for i8 { } /// A safe wrapper around `llama_context_params`. -#[derive(Debug, PartialEq)] +/// +/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods. +/// +/// # Examples +/// +/// ```rust +/// # use std::num::NonZeroU32; +/// use llama_cpp_2::context::params::LlamaContextParams; +/// +///let ctx_params = LlamaContextParams::default() +/// .with_n_ctx(NonZeroU32::new(2048)) +/// .with_seed(1234); +/// +/// assert_eq!(ctx_params.seed(), 1234); +/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048)); +/// ``` +#[derive(Debug, Clone)] #[allow( missing_docs, clippy::struct_excessive_bools, clippy::module_name_repetitions )] pub struct LlamaContextParams { - /// The random seed - pub seed: u32, - /// the number of tokens in the context - [`None`] if defined by the model. - pub n_ctx: Option, - pub n_batch: u32, - pub n_threads: u32, - pub n_threads_batch: u32, - pub rope_scaling_type: RopeScalingType, - pub rope_freq_base: f32, - pub rope_freq_scale: f32, - pub yarn_ext_factor: f32, - pub yarn_attn_factor: f32, - pub yarn_beta_fast: f32, - pub yarn_beta_slow: f32, - pub yarn_orig_ctx: u32, - pub type_k: ggml_type, - pub type_v: ggml_type, - pub mul_mat_q: bool, - pub logits_all: bool, - pub embedding: bool, - pub offload_kqv: bool, - pub cb_eval: llama_cpp_sys_2::ggml_backend_sched_eval_callback, - pub cb_eval_user_data: *mut std::ffi::c_void, + pub(crate) context_params: llama_cpp_sys_2::llama_context_params, +} + +impl LlamaContextParams { + /// Set the seed of the context + /// + /// # Examples + /// + /// ```rust + /// use llama_cpp_2::context::params::LlamaContextParams; + /// let params = LlamaContextParams::default(); + /// let params = params.with_seed(1234); + /// assert_eq!(params.seed(), 1234); + /// ``` + pub fn with_seed(mut self, seed: u32) -> Self { + self.context_params.seed = seed; + self + } + + /// Get the seed of the context + /// + /// # Examples + /// + /// ```rust + /// use llama_cpp_2::context::params::LlamaContextParams; + /// let params = LlamaContextParams::default() + /// .with_seed(1234); + /// assert_eq!(params.seed(), 1234); + /// ``` + pub fn seed(&self) -> u32 { + self.context_params.seed + } + + /// Set the side of the context + /// + /// # Examples + /// + /// ```rust + /// # use std::num::NonZeroU32; + /// use llama_cpp_2::context::params::LlamaContextParams; + /// let params = LlamaContextParams::default(); + /// let params = params.with_n_ctx(NonZeroU32::new(2048)); + /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048)); + /// ``` + pub fn with_n_ctx(mut self, n_ctx: Option) -> Self { + self.context_params.n_ctx = n_ctx.map_or(0, |n_ctx| n_ctx.get()); + self + } + + /// Get the size of the context. + /// + /// [`None`] if the context size is specified by the model and not the context. + /// + /// # Examples + /// + /// ```rust + /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); + /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512)); + pub fn n_ctx(&self) -> Option { + NonZeroU32::new(self.context_params.n_ctx) + } + + /// Get the type of rope scaling. + /// + /// # Examples + /// + /// ```rust + /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); + /// assert_eq!(params.rope_scaling_type(), llama_cpp_2::context::params::RopeScalingType::Unspecified); + /// ``` + pub fn rope_scaling_type(&self) -> RopeScalingType { + RopeScalingType::from(self.context_params.rope_scaling_type) + } } /// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`) @@ -80,115 +146,12 @@ pub struct LlamaContextParams { /// # use std::num::NonZeroU32; /// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType}; /// let params = LlamaContextParams::default(); -/// assert_eq!(params.n_ctx, NonZeroU32::new(512), "n_ctx should be 512"); -/// assert_eq!(params.rope_scaling_type, RopeScalingType::Unspecified); +/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512"); +/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified); /// ``` impl Default for LlamaContextParams { fn default() -> Self { - Self::from(unsafe { llama_cpp_sys_2::llama_context_default_params() }) - } -} - -impl From for LlamaContextParams { - fn from( - llama_context_params { - seed, - n_ctx, - n_batch, - n_threads, - n_threads_batch, - rope_freq_base, - rope_freq_scale, - cb_eval, - cb_eval_user_data, - type_k, - type_v, - mul_mat_q, - logits_all, - embedding, - rope_scaling_type, - yarn_ext_factor, - yarn_attn_factor, - yarn_beta_fast, - yarn_beta_slow, - yarn_orig_ctx, - offload_kqv, - }: llama_context_params, - ) -> Self { - Self { - seed, - n_ctx: NonZeroU32::new(n_ctx), - n_batch, - n_threads, - n_threads_batch, - rope_freq_base, - rope_freq_scale, - type_k, - type_v, - mul_mat_q, - logits_all, - embedding, - rope_scaling_type: RopeScalingType::from(rope_scaling_type), - yarn_ext_factor, - yarn_attn_factor, - yarn_beta_fast, - yarn_beta_slow, - yarn_orig_ctx, - offload_kqv, - cb_eval, - cb_eval_user_data, - } + let context_params = unsafe { llama_cpp_sys_2::llama_context_default_params() }; + Self { context_params, } } } - -impl From for llama_context_params { - fn from( - LlamaContextParams { - seed, - n_ctx, - n_batch, - n_threads, - n_threads_batch, - rope_freq_base, - rope_freq_scale, - type_k, - type_v, - mul_mat_q, - logits_all, - embedding, - rope_scaling_type, - yarn_ext_factor, - yarn_attn_factor, - yarn_beta_fast, - yarn_beta_slow, - yarn_orig_ctx, - offload_kqv, - cb_eval, - cb_eval_user_data, - }: LlamaContextParams, - ) -> Self { - llama_context_params { - seed, - n_ctx: n_ctx.map_or(0, NonZeroU32::get), - n_batch, - n_threads, - n_threads_batch, - rope_freq_base, - rope_freq_scale, - type_k, - type_v, - mul_mat_q, - logits_all, - embedding, - rope_scaling_type: i8::from(rope_scaling_type), - yarn_ext_factor, - yarn_attn_factor, - yarn_beta_fast, - yarn_beta_slow, - yarn_orig_ctx, - offload_kqv, - cb_eval, - cb_eval_user_data, - } - } -} \ No newline at end of file diff --git a/llama-cpp-2/src/model.rs b/llama-cpp-2/src/model.rs index 50584453..0e50c19d 100644 --- a/llama-cpp-2/src/model.rs +++ b/llama-cpp-2/src/model.rs @@ -6,7 +6,6 @@ use crate::model::params::LlamaModelParams; use crate::token::LlamaToken; use crate::token_type::LlamaTokenType; use crate::{LlamaContextLoadError, LlamaModelLoadError, StringToTokenError, TokenToStringError}; -use llama_cpp_sys_2::{llama_context_params, llama_token_get_type, llama_vocab_type}; use std::ffi::CString; use std::os::raw::c_int; use std::path::Path; @@ -184,7 +183,7 @@ impl LlamaModel { /// If the token type is not known to this library. #[must_use] pub fn token_type(&self, LlamaToken(id): LlamaToken) -> LlamaTokenType { - let token_type = unsafe { llama_token_get_type(self.model.as_ptr(), id) }; + let token_type = unsafe { llama_cpp_sys_2::llama_token_get_type(self.model.as_ptr(), id) }; LlamaTokenType::try_from(token_type).expect("token type is valid") } @@ -314,7 +313,7 @@ impl LlamaModel { _: &LlamaBackend, params: LlamaContextParams, ) -> Result { - let context_params = llama_context_params::from(params); + let context_params = params.context_params; let context = unsafe { llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params) }; @@ -345,13 +344,13 @@ pub enum VocabType { pub enum LlamaTokenTypeFromIntError { /// The value is not a valid `llama_token_type`. Contains the int value that was invalid. #[error("Unknown Value {0}")] - UnknownValue(llama_vocab_type), + UnknownValue(llama_cpp_sys_2::llama_vocab_type), } -impl TryFrom for VocabType { +impl TryFrom for VocabType { type Error = LlamaTokenTypeFromIntError; - fn try_from(value: llama_vocab_type) -> Result { + fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result { match value { llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE), llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),