Skip to content

Commit

Permalink
updated type for RopeScalingType + fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusDunn committed Feb 5, 2024
1 parent e94e54b commit c631133
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 38 deletions.
43 changes: 24 additions & 19 deletions llama-cpp-2/examples/simple.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
//! This is an translation of simple.cpp in llama.cpp using llama-cpp-2.
#![allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]

use std::io::Write;
use std::num::NonZeroU32;
use std::path::PathBuf;
use std::time::Duration;
use anyhow::{bail, Context, Result};
use clap::Parser;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::params::LlamaModelParams;
use anyhow::{bail, Context, Result};
use llama_cpp_2::ggml_time_us;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::AddBos;

use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use std::io::Write;
use std::num::NonZeroU32;
use std::path::PathBuf;
use std::time::Duration;

#[derive(clap::Parser)]
struct Args {
Expand All @@ -30,7 +29,6 @@ struct Args {
disable_gpu: bool,
}


fn main() -> Result<()> {
let params = Args::parse();

Expand Down Expand Up @@ -60,12 +58,14 @@ fn main() -> Result<()> {
.with_n_ctx(NonZeroU32::new(2048))
.with_seed(1234);

let mut ctx = model.new_context(&backend, ctx_params)
let mut ctx = model
.new_context(&backend, ctx_params)
.with_context(|| "unable to create the llama_context")?;

// tokenize the prompt

let tokens_list = model.str_to_token(&params.prompt, AddBos::Always)
let tokens_list = model
.str_to_token(&params.prompt, AddBos::Always)
.with_context(|| format!("failed to tokenize {}", params.prompt))?;

let n_cxt = ctx.n_ctx() as i32;
Expand All @@ -75,8 +75,10 @@ fn main() -> Result<()> {

// make sure the KV cache is big enough to hold all the prompt and generated tokens
if n_kv_req > n_cxt {
bail!("n_kv_req > n_ctx, the required kv cache size is not big enough
either reduce n_len or increase n_ctx")
bail!(
"n_kv_req > n_ctx, the required kv cache size is not big enough
either reduce n_len or increase n_ctx"
)
}

// print the prompt token-by-token
Expand Down Expand Up @@ -137,7 +139,6 @@ either reduce n_len or increase n_ctx")
ctx.decode(&mut batch).with_context(|| "failed to eval")?;

n_decode += 1;

}

eprintln!("\n");
Expand All @@ -146,10 +147,14 @@ either reduce n_len or increase n_ctx")

let duration = Duration::from_micros((t_main_end - t_main_start) as u64);

eprintln!("decoded {} tokens in {:.2} s, speed {:.2} t/s\n", n_decode, duration.as_secs_f32(), n_decode as f32 / duration.as_secs_f32());
eprintln!(
"decoded {} tokens in {:.2} s, speed {:.2} t/s\n",
n_decode,
duration.as_secs_f32(),
n_decode as f32 / duration.as_secs_f32()
);

println!("{}", ctx.timings());

Ok(())

}
}
8 changes: 4 additions & 4 deletions llama-cpp-2/src/context/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ pub enum RopeScalingType {

/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
/// the value is not recognized.
impl From<i8> for RopeScalingType {
fn from(value: i8) -> Self {
impl From<i32> for RopeScalingType {
fn from(value: i32) -> Self {
match value {
0 => Self::None,
1 => Self::Linear,
Expand All @@ -31,7 +31,7 @@ impl From<i8> for RopeScalingType {
}

/// Create a `c_int` from a `RopeScalingType`.
impl From<RopeScalingType> for i8 {
impl From<RopeScalingType> for i32 {
fn from(value: RopeScalingType) -> Self {
match value {
RopeScalingType::None => 0,
Expand Down Expand Up @@ -172,7 +172,7 @@ impl LlamaContextParams {
/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
/// ```
pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
self.context_params.rope_scaling_type = i8::from(rope_scaling_type);
self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
self
}

Expand Down
17 changes: 12 additions & 5 deletions llama-cpp-2/src/llama_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ impl LlamaBatch {
seq_ids: &[i32],
logits: bool,
) -> Result<(), BatchAddError> {
if self.allocated < usize::try_from(self.n_tokens() + 1).expect("cannot fit n_tokens into a usize") {
return Err(BatchAddError::InsufficientSpace(self.allocated))
if self.allocated
< usize::try_from(self.n_tokens() + 1).expect("cannot fit n_tokens into a usize")
{
return Err(BatchAddError::InsufficientSpace(self.allocated));
}
let offset = self.llama_batch.n_tokens;
let offset_usize = usize::try_from(offset).expect("cannot fit n_tokens into a usize");
Expand All @@ -55,8 +57,10 @@ impl LlamaBatch {
// batch.pos [batch.n_tokens] = pos,
self.llama_batch.pos.add(offset_usize).write(pos);
// batch.n_seq_id[batch.n_tokens] = seq_ids.size();
self.llama_batch.n_seq_id.add(offset_usize).write(llama_seq_id::try_from(seq_ids.len())
.expect("cannot fit seq_ids.len() into a llama_seq_id"));
self.llama_batch.n_seq_id.add(offset_usize).write(
llama_seq_id::try_from(seq_ids.len())
.expect("cannot fit seq_ids.len() into a llama_seq_id"),
);
// for (size_t i = 0; i < seq_ids.size(); ++i) {
// batch.seq_id[batch.n_tokens][i] = seq_ids[i];
// }
Expand All @@ -65,7 +69,10 @@ impl LlamaBatch {
tmp.add(i).write(*seq_id);
}
// batch.logits [batch.n_tokens] = logits;
self.llama_batch.logits.add(offset_usize).write(i8::from(logits));
self.llama_batch
.logits
.add(offset_usize)
.write(i8::from(logits));
}

if logits {
Expand Down
4 changes: 1 addition & 3 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl LlamaModel {
) -> Result<Vec<LlamaToken>, StringToTokenError> {
let add_bos = match add_bos {
AddBos::Always => true,
AddBos::Never => false
AddBos::Never => false,
};

let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
Expand All @@ -136,8 +136,6 @@ impl LlamaModel {
let buffer_capacity =
c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");



let size = unsafe {
llama_cpp_sys_2::llama_tokenize(
self.model.as_ptr(),
Expand Down
2 changes: 1 addition & 1 deletion llama-cpp-2/src/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod data_array;
#[repr(transparent)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[allow(clippy::module_name_repetitions)]
pub struct LlamaToken( pub llama_cpp_sys_2::llama_token);
pub struct LlamaToken(pub llama_cpp_sys_2::llama_token);

impl Display for LlamaToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down
17 changes: 11 additions & 6 deletions llama-cpp-sys-2/build.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
use std::env;
use std::path::PathBuf;
use std::path::Path;
use std::path::PathBuf;

fn main() {
println!("cargo:rerun-if-changed=llama.cpp");

let cublas_enabled = env::var("CARGO_FEATURE_CUBLAS").is_ok();

if !Path::new("llama.cpp/ggml.c").exists() {
panic!("llama.cpp seems to not be populated, try running `git submodule update --init --recursive` to init.")
panic!("llama.cpp seems to not be populated, try running `git submodule update --init --recursive` to init.")
}

let mut ggml = cc::Build::new();
let mut ggml_cuda = if cublas_enabled { Some(cc::Build::new()) } else { None };
let mut ggml_cuda = if cublas_enabled {
Some(cc::Build::new())
} else {
None
};
let mut llama_cpp = cc::Build::new();

ggml.cpp(false);
llama_cpp.cpp(true);

// https://github.com/ggerganov/llama.cpp/blob/a836c8f534ab789b02da149fbdaf7735500bff74/Makefile#L364-L368
if let Some(ggml_cuda) = &mut ggml_cuda {
for lib in ["cuda", "cublas", "culibos", "cudart", "cublasLt", "pthread", "dl", "rt"] {
for lib in [
"cuda", "cublas", "culibos", "cudart", "cublasLt", "pthread", "dl", "rt",
] {
println!("cargo:rustc-link-lib={}", lib);
}

Expand Down Expand Up @@ -66,8 +72,7 @@ fn main() {
ggml.define("_GNU_SOURCE", None);
}

ggml
.std("c17")
ggml.std("c17")
.file("llama.cpp/ggml.c")
.file("llama.cpp/ggml-alloc.c")
.file("llama.cpp/ggml-backend.c")
Expand Down

0 comments on commit c631133

Please sign in to comment.