Skip to content

Commit

Permalink
added ability to configure rope
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusDunn committed Jan 24, 2024
1 parent fdac4ac commit 3a7c99d
Showing 1 changed file with 70 additions and 1 deletion.
71 changes: 70 additions & 1 deletion llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,21 @@ impl LlamaContextParams {
NonZeroU32::new(self.context_params.n_ctx)
}

/// Set the type of rope scaling.
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
/// let params = LlamaContextParams::default()
/// .with_rope_scaling_type(RopeScalingType::Linear);
/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
/// ```
pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
self.context_params.rope_scaling_type = i8::from(rope_scaling_type);
self
}

/// Get the type of rope scaling.
///
/// # Examples
Expand All @@ -143,6 +158,60 @@ impl LlamaContextParams {
pub fn rope_scaling_type(&self) -> RopeScalingType {
RopeScalingType::from(self.context_params.rope_scaling_type)
}

/// Set the rope frequency scale.
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default()
/// .with_rope_freq_base(0.5);
/// assert_eq!(params.rope_freq_base(), 0.5);
/// ```
pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
self.context_params.rope_freq_base = rope_freq_base;
self
}

/// Get the rope frequency base.
///
/// # Examples
///
/// ```rust
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
/// assert_eq!(params.rope_freq_base(), 0.0);
/// ```
pub fn rope_freq_base(&self) -> f32 {
self.context_params.rope_freq_base
}

/// Set the rope frequency scale.
///
/// # Examples
///
/// ```rust
/// use llama_cpp_2::context::params::LlamaContextParams;
/// let params = LlamaContextParams::default()
/// .with_rope_freq_scale(0.5);
/// assert_eq!(params.rope_freq_scale(), 0.5);
/// ```
pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
self.context_params.rope_freq_scale = rope_freq_scale;
self
}

/// Get the rope frequency scale.
///
/// # Examples
///
/// ```rust
/// let params = llama_cpp_2::context::params::LlamaContextParams::default();
/// assert_eq!(params.rope_freq_scale(), 0.0);
/// ```
pub fn rope_freq_scale(&self) -> f32 {
self.context_params.rope_freq_scale
}
}

/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
Expand All @@ -156,6 +225,6 @@ impl LlamaContextParams {
impl Default for LlamaContextParams {
fn default() -> Self {
let context_params = unsafe { llama_cpp_sys_2::llama_context_default_params() };
Self { context_params, }
Self { context_params }
}
}

0 comments on commit 3a7c99d

Please sign in to comment.