Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attempt to add metal on mac #65

Merged
merged 9 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions llama-cpp-2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ anyhow = "1.0.80"
name = "grammar_bias"
harness = false

[[bench]]
name = "generate"
harness = false

[features]
cublas = ["llama-cpp-sys-2/cublas"]

Expand Down
61 changes: 61 additions & 0 deletions llama-cpp-2/benches/generate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use anyhow::Context;
use criterion::{Criterion, criterion_group, criterion_main};
use pprof::criterion::{Output, PProfProfiler};
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::{AddBos, LlamaModel};
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::token::data_array::LlamaTokenDataArray;

fn generate(c: &mut Criterion) {
let api = hf_hub::api::sync::ApiBuilder::new()
.with_progress(true)
.build()
.unwrap();
let file = api
.model("TheBloke/Llama-2-7B-Chat-GGUF".to_string())
.get("llama-2-7b-chat.Q4_K_M.gguf")
.unwrap();
let backend = LlamaBackend::init().unwrap();
let model_params = LlamaModelParams::default();
let model = LlamaModel::load_from_file(&backend, &file, &model_params).unwrap();
let mut ctx = model
.new_context(&backend, LlamaContextParams::default())
.unwrap();

c.bench_function("generate 50 tokens", |b| {
b.iter(|| {
let tokens_list = model.str_to_token("Hello, my name is", AddBos::Always).unwrap();
let mut n_ctx = tokens_list.len() as i32;
let mut batch = LlamaBatch::new(512, 1);
let last_index: i32 = (tokens_list.len() - 1) as i32;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
let is_last = i == last_index;
batch.add(token, i, &[0], is_last).unwrap();
}
ctx.decode(&mut batch).unwrap();

for _ in 0..50 {
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
let new_token_id = ctx.sample_token_greedy(candidates_p);
if new_token_id == model.token_eos() {
break;
}
batch.clear();
batch.add(new_token_id, n_ctx, &[0], true).unwrap();
n_ctx += 1;
ctx.decode(&mut batch).unwrap();
}
ctx.clear_kv_cache_seq(0, None, None)
});
});
}

criterion_group!(
name = benches;
config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
targets = generate
);
criterion_main!(benches);
1 change: 1 addition & 0 deletions llama-cpp-sys-2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ cc = { workspace = true }

[features]
cublas = []

99 changes: 82 additions & 17 deletions llama-cpp-sys-2/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,23 @@ fn main() {

let cublas_enabled = env::var("CARGO_FEATURE_CUBLAS").is_ok();

let mut ggml_cuda = if cublas_enabled { Some(cc::Build::new()) } else { None };

if !Path::new("llama.cpp/ggml.c").exists() {
panic!("llama.cpp seems to not be populated, try running `git submodule update --init --recursive` to init.")
}

let mut ggml = cc::Build::new();
let mut ggml_cuda = if cublas_enabled {
Some(cc::Build::new())
} else {
None
};
let mut llama_cpp = cc::Build::new();

ggml.cpp(false);
llama_cpp.cpp(true);

// https://github.com/ggerganov/llama.cpp/blob/a836c8f534ab789b02da149fbdaf7735500bff74/Makefile#L364-L368
if let Some(ggml_cuda) = &mut ggml_cuda {
for lib in ["cuda", "cublas", "cudart", "cublasLt"] {
for lib in [
"cuda", "cublas", "culibos", "cudart", "cublasLt", "pthread", "dl", "rt",
] {
println!("cargo:rustc-link-lib={}", lib);
}
if !ggml_cuda.get_compiler().is_like_msvc() {
Expand All @@ -39,23 +38,27 @@ fn main() {
ggml_cuda
.flag_if_supported("-mfp16-format=ieee")
.flag_if_supported("-mno-unaligned-access");
ggml.flag_if_supported("-mfp16-format=ieee")
.flag_if_supported("-mno-unaligned-access");
llama_cpp
.flag_if_supported("-mfp16-format=ieee")
.flag_if_supported("-mno-unaligned-access");
ggml_cuda
.flag_if_supported("-mfp16-format=ieee")
ggml.flag_if_supported("-mfp16-format=ieee")
.flag_if_supported("-mno-unaligned-access");
}

ggml_cuda
.cuda(true)
.flag("-arch=all")
.file("llama.cpp/ggml-cuda.cu");
.file("llama.cpp/ggml-cuda.cu")
.include("llama.cpp");

if ggml_cuda.get_compiler().is_like_msvc() {
ggml_cuda.std("c++14");
} else {
ggml_cuda.std("c++17");
ggml_cuda
.flag("-std=c++11")
.std("c++11");
}

ggml.define("GGML_USE_CUBLAS", None);
Expand All @@ -65,22 +68,36 @@ fn main() {

// https://github.com/ggerganov/llama.cpp/blob/191221178f51b6e81122c5bda0fd79620e547d07/Makefile#L133-L141
if cfg!(target_os = "macos") {
assert!(!cublas_enabled, "CUBLAS is not supported on macOS");

println!("cargo:rustc-link-lib=framework=Metal");
println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=MetalPerformanceShaders");
println!("cargo:rustc-link-lib=framework=MetalKit");

llama_cpp.define("_DARWIN_C_SOURCE", None);

// https://github.com/ggerganov/llama.cpp/blob/3c0d25c4756742ebf15ad44700fabc0700c638bd/Makefile#L340-L343
llama_cpp.define("GGML_USE_METAL", None);
llama_cpp.define("GGML_USE_ACCELERATE", None);
llama_cpp.define("ACCELERATE_NEW_LAPACK", None);
llama_cpp.define("ACCELERATE_LAPACK_ILP64", None);
println!("cargo:rustc-link-arg=framework=Accelerate");

metal_hack(&mut ggml);
ggml.include("./llama.cpp/ggml-metal.h");
}

if cfg!(target_os = "dragonfly") {
llama_cpp.define("__BSD_VISIBLE", None);
}

if let Some(ggml_cuda) = ggml_cuda {
println!("compiling ggml-cuda");
ggml_cuda.compile("ggml-cuda");
}

if cfg!(target_os = "linux") {
ggml.define("_GNU_SOURCE", None);
}

ggml.std("c17")
ggml.std("c11")
.include("./llama.cpp")
.file("llama.cpp/ggml.c")
.file("llama.cpp/ggml-alloc.c")
.file("llama.cpp/ggml-backend.c")
Expand All @@ -89,14 +106,23 @@ fn main() {

llama_cpp
.define("_XOPEN_SOURCE", Some("600"))
.std("c++17")
.include("llama.cpp")
.std("c++11")
.file("llama.cpp/llama.cpp");

if let Some(ggml_cuda) = ggml_cuda {
println!("compiling ggml-cuda");
ggml_cuda.compile("ggml-cuda");
println!("compiled ggml-cuda");
}

println!("compiling ggml");
ggml.compile("ggml");
println!("compiled ggml");

println!("compiling llama");
llama_cpp.compile("llama");
println!("compiled llama");

let header = "llama.cpp/llama.h";

Expand All @@ -116,3 +142,42 @@ fn main() {
.write_to_file(out_path.join("bindings.rs"))
.expect("failed to write bindings to file");
}

// courtesy of https://github.com/rustformers/llm
fn metal_hack(build: &mut cc::Build) {
const GGML_METAL_METAL_PATH: &str = "llama.cpp/ggml-metal.metal";
const GGML_METAL_PATH: &str = "llama.cpp/ggml-metal.m";

let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR is not defined"));

let ggml_metal_path = {
let ggml_metal_metal = std::fs::read_to_string(GGML_METAL_METAL_PATH)
.expect("Could not read ggml-metal.metal")
.replace('\\', "\\\\")
.replace('\n', "\\n")
.replace('\r', "\\r")
.replace('\"', "\\\"");

let ggml_metal =
std::fs::read_to_string(GGML_METAL_PATH).expect("Could not read ggml-metal.m");

let needle = r#"NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];"#;
if !ggml_metal.contains(needle) {
panic!("ggml-metal.m does not contain the needle to be replaced; the patching logic needs to be reinvestigated. Contact a `llama-cpp-sys-2` developer!");
}

// Replace the runtime read of the file with a compile-time string
let ggml_metal = ggml_metal.replace(
needle,
&format!(r#"NSString * src = @"{ggml_metal_metal}";"#),
);

let patched_ggml_metal_path = out_dir.join("ggml-metal.m");
std::fs::write(&patched_ggml_metal_path, ggml_metal)
.expect("Could not write temporary patched ggml-metal.m");

patched_ggml_metal_path
};

build.file(ggml_metal_path);
}