Skip to content

Commit

Permalink
added cb_eval & cb_eval_user_data to context_params.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusDunn committed Jan 19, 2024
1 parent 8ddc126 commit a5fd538
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
2 changes: 1 addition & 1 deletion llama-cpp-2/examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ fn main() -> Result<()> {
..LlamaContextParams::default()
};

let mut ctx = model.new_context(&backend, &ctx_params)
let mut ctx = model.new_context(&backend, ctx_params)
.with_context(|| "unable to create the llama_context")?;

// tokenize the prompt
Expand Down
14 changes: 12 additions & 2 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl From<RopeScalingType> for i8 {
}

/// A safe wrapper around `llama_context_params`.
#[derive(Debug, Clone, Copy, PartialEq)]
#[derive(Debug, PartialEq)]
#[allow(
missing_docs,
clippy::struct_excessive_bools,
Expand Down Expand Up @@ -71,6 +71,8 @@ pub struct LlamaContextParams {
pub logits_all: bool,
pub embedding: bool,
pub offload_kqv: bool,
pub cb_eval: llama_cpp_sys_2::ggml_backend_sched_eval_callback,
pub cb_eval_user_data: *mut std::ffi::c_void,
}

/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
Expand All @@ -97,6 +99,8 @@ impl From<llama_context_params> for LlamaContextParams {
n_threads_batch,
rope_freq_base,
rope_freq_scale,
cb_eval,
cb_eval_user_data,
type_k,
type_v,
mul_mat_q,
Expand Down Expand Up @@ -131,6 +135,8 @@ impl From<llama_context_params> for LlamaContextParams {
yarn_beta_slow,
yarn_orig_ctx,
offload_kqv,
cb_eval,
cb_eval_user_data,
}
}
}
Expand All @@ -157,6 +163,8 @@ impl From<LlamaContextParams> for llama_context_params {
yarn_beta_slow,
yarn_orig_ctx,
offload_kqv,
cb_eval,
cb_eval_user_data,
}: LlamaContextParams,
) -> Self {
llama_context_params {
Expand All @@ -179,6 +187,8 @@ impl From<LlamaContextParams> for llama_context_params {
yarn_beta_slow,
yarn_orig_ctx,
offload_kqv,
cb_eval,
cb_eval_user_data,
}
}
}
}
10 changes: 5 additions & 5 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,12 @@ impl LlamaModel {
/// # Errors
///
/// There is many ways this can fail. See [`LlamaContextLoadError`] for more information.
pub fn new_context<'a>(
&'a self,
pub fn new_context(
&self,
_: &LlamaBackend,
params: &LlamaContextParams,
) -> Result<LlamaContext<'a>, LlamaContextLoadError> {
let context_params = llama_context_params::from(*params);
params: LlamaContextParams,
) -> Result<LlamaContext, LlamaContextLoadError> {
let context_params = llama_context_params::from(params);
let context = unsafe {
llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
};
Expand Down

0 comments on commit a5fd538

Please sign in to comment.