From 2e6bf15ca23061900a03ba6a62264a8bebf95895 Mon Sep 17 00:00:00 2001 From: karthik2804 Date: Thu, 22 Aug 2024 16:12:07 +0200 Subject: [PATCH] Replace rustformers/llm with candle Signed-off-by: karthik2804 use arc instead of box Signed-off-by: karthik2804 send response back to guest Signed-off-by: karthik2804 resolve todos Signed-off-by: karthik2804 some more fixes Signed-off-by: karthik2804 remove Box::clone on cache Signed-off-by: karthik2804 pin candle crate versions Signed-off-by: karthik2804 --- Cargo.lock | 1083 ++++++++++--------- crates/llm-local/Cargo.toml | 16 +- crates/llm-local/src/bert.rs | 2 +- crates/llm-local/src/lib.rs | 289 ++--- crates/llm-local/src/llama.rs | 168 +++ crates/llm-local/src/token_output_stream.rs | 85 ++ crates/llm-local/src/utils.rs | 25 + 7 files changed, 947 insertions(+), 721 deletions(-) create mode 100644 crates/llm-local/src/llama.rs create mode 100644 crates/llm-local/src/token_output_stream.rs create mode 100644 crates/llm-local/src/utils.rs diff --git a/Cargo.lock b/Cargo.lock index a4ceb7fd4..a89c0ec9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -30,22 +30,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8" dependencies = [ "cfg-if", - "cipher 0.3.0", + "cipher", "cpufeatures", "opaque-debug", ] -[[package]] -name = "aes" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" -dependencies = [ - "cfg-if", - "cipher 0.4.4", - "cpufeatures", -] - [[package]] name = "ahash" version = "0.8.11" @@ -58,15 +47,6 @@ dependencies = [ "zerocopy", ] -[[package]] -name = "aho-corasick" -version = "0.7.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" -dependencies = [ - "memchr", -] - [[package]] name = "aho-corasick" version = "1.1.3" @@ -180,6 +160,9 @@ name = "arbitrary" version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" +dependencies = [ + "derive_arbitrary", +] [[package]] name = "arc-swap" @@ -394,7 +377,7 @@ checksum = "30c5ef0ede93efbf733c1a727f3b6b5a1060bbedd5600183e66f6e4be4af0ec5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -461,7 +444,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -491,7 +474,7 @@ checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -715,9 +698,35 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.58", + "syn 2.0.75", ] +[[package]] +name = "bindgen_cuda" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f8489af5b7d17a81bffe37e0f4d6e1e4de87c87329d05447f22c35d95a1227d" +dependencies = [ + "glob", + "num_cpus", + "rayon", +] + +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bitflags" version = "1.3.2" @@ -748,6 +757,12 @@ dependencies = [ "digest", ] +[[package]] +name = "block" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" + [[package]] name = "block-buffer" version = "0.10.4" @@ -764,7 +779,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2cb03d1bed155d89dce0f845b7899b18a9a163e148fd004e1c28421a783e2d8e" dependencies = [ "block-padding", - "cipher 0.3.0", + "cipher", ] [[package]] @@ -820,63 +835,34 @@ name = "bytemuck" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" - -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - -[[package]] -name = "bytes" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" dependencies = [ - "serde 1.0.197", + "bytemuck_derive", ] [[package]] -name = "bzip2" -version = "0.4.4" +name = "bytemuck_derive" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26" dependencies = [ - "bzip2-sys", - "libc", + "proc-macro2", + "quote", + "syn 2.0.75", ] [[package]] -name = "bzip2-sys" -version = "0.1.11+1.0.8" +name = "byteorder" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" -dependencies = [ - "cc", - "libc", - "pkg-config", -] +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] -name = "cached-path" -version = "0.6.1" +name = "bytes" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "097968e38f1319207f057d0f4d76452e4f4f847a5de61c5215379f297fa034f3" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" dependencies = [ - "flate2", - "fs2", - "glob", - "indicatif 0.16.2", - "log", - "rand 0.8.5", - "reqwest 0.11.27", "serde 1.0.197", - "serde_json", - "sha2", - "tar", - "tempfile", - "thiserror", - "zip", ] [[package]] @@ -890,154 +876,79 @@ dependencies = [ [[package]] name = "candle-core" -version = "0.1.0" -source = "git+https://github.com/huggingface/candle?rev=b80348d22f8f0dadb6cc4101bde031d5de69a9a5#b80348d22f8f0dadb6cc4101bde031d5de69a9a5" +version = "0.6.1" +source = "git+https://github.com/huggingface/candle?rev=e3261216b157a7305c18ccdd766b6e2a41afe483#e3261216b157a7305c18ccdd766b6e2a41afe483" dependencies = [ "byteorder", - "candle-gemm", + "candle-kernels", + "candle-metal-kernels", + "cudarc", + "gemm", "half", - "memmap2 0.7.1", + "memmap2 0.9.4", + "metal", "num-traits 0.2.18", "num_cpus", "rand 0.8.5", - "safetensors", + "rand_distr", + "rayon", + "safetensors 0.4.4", "thiserror", + "yoke", "zip", ] [[package]] -name = "candle-gemm" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b726a1f6cdd7ff080e95e3d91694701b1e04a58acd198e4a78c39428b2274e" -dependencies = [ - "candle-gemm-c32", - "candle-gemm-c64", - "candle-gemm-common", - "candle-gemm-f16", - "candle-gemm-f32", - "candle-gemm-f64", - "dyn-stack", - "lazy_static 1.4.0", - "num-complex", - "num-traits 0.2.18", - "paste", - "raw-cpuid", - "rayon", - "seq-macro", -] - -[[package]] -name = "candle-gemm-c32" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "661470663389f0c99fd8449e620bfae630a662739f830a323eda4dcf80888843" -dependencies = [ - "candle-gemm-common", - "dyn-stack", - "lazy_static 1.4.0", - "num-complex", - "num-traits 0.2.18", - "paste", - "raw-cpuid", - "rayon", - "seq-macro", -] - -[[package]] -name = "candle-gemm-c64" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a111ddf61db562854a6d2ff4dfe1e8a84066431b7bc68d3afae4bf60874fda0" +name = "candle-kernels" +version = "0.6.1" +source = "git+https://github.com/huggingface/candle?rev=e3261216b157a7305c18ccdd766b6e2a41afe483#e3261216b157a7305c18ccdd766b6e2a41afe483" dependencies = [ - "candle-gemm-common", - "dyn-stack", - "lazy_static 1.4.0", - "num-complex", - "num-traits 0.2.18", - "paste", - "raw-cpuid", - "rayon", - "seq-macro", + "bindgen_cuda", ] [[package]] -name = "candle-gemm-common" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a6dd93783ead7eeef14361667ea32014dc6f716a2fc956b075fe78729e10dd5" +name = "candle-metal-kernels" +version = "0.6.1" +source = "git+https://github.com/huggingface/candle?rev=e3261216b157a7305c18ccdd766b6e2a41afe483#e3261216b157a7305c18ccdd766b6e2a41afe483" dependencies = [ - "dyn-stack", - "lazy_static 1.4.0", - "num-complex", - "num-traits 0.2.18", - "paste", - "raw-cpuid", - "rayon", - "seq-macro", + "metal", + "once_cell", + "thiserror", + "tracing", ] [[package]] -name = "candle-gemm-f16" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b76499bf4b858cacc526c5c8f948bc7152774247dce8568f174b743ab1363fa4" +name = "candle-nn" +version = "0.6.1" +source = "git+https://github.com/huggingface/candle?rev=e3261216b157a7305c18ccdd766b6e2a41afe483#e3261216b157a7305c18ccdd766b6e2a41afe483" dependencies = [ - "candle-gemm-common", - "candle-gemm-f32", - "dyn-stack", + "candle-core", + "candle-metal-kernels", "half", - "lazy_static 1.4.0", - "num-complex", + "metal", "num-traits 0.2.18", - "paste", - "raw-cpuid", "rayon", - "seq-macro", -] - -[[package]] -name = "candle-gemm-f32" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bec152e7d36339d3785e0d746d75ee94a4e92968fbb12ddcc91b536b938d016" -dependencies = [ - "candle-gemm-common", - "dyn-stack", - "lazy_static 1.4.0", - "num-complex", - "num-traits 0.2.18", - "paste", - "raw-cpuid", - "rayon", - "seq-macro", + "safetensors 0.4.4", + "serde 1.0.197", + "thiserror", ] [[package]] -name = "candle-gemm-f64" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00f59ac68a5521e2ff71431bb7f1b22126ff0b60c5e66599b1f4676433da6e69" +name = "candle-transformers" +version = "0.6.1" +source = "git+https://github.com/huggingface/candle?rev=e3261216b157a7305c18ccdd766b6e2a41afe483#e3261216b157a7305c18ccdd766b6e2a41afe483" dependencies = [ - "candle-gemm-common", - "dyn-stack", - "lazy_static 1.4.0", - "num-complex", + "byteorder", + "candle-core", + "candle-nn", + "fancy-regex", "num-traits 0.2.18", - "paste", - "raw-cpuid", + "rand 0.8.5", "rayon", - "seq-macro", -] - -[[package]] -name = "candle-nn" -version = "0.1.0" -source = "git+https://github.com/huggingface/candle?rev=b80348d22f8f0dadb6cc4101bde031d5de69a9a5#b80348d22f8f0dadb6cc4101bde031d5de69a9a5" -dependencies = [ - "candle-core", - "safetensors", - "thiserror", + "serde 1.0.197", + "serde_json", + "serde_plain", + "tracing", ] [[package]] @@ -1242,16 +1153,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "cipher" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" -dependencies = [ - "crypto-common", - "inout", -] - [[package]] name = "clang-sys" version = "1.7.0" @@ -1324,7 +1225,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -1489,12 +1390,6 @@ version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "373e9fafaa20882876db20562275ff58d50e0caa2590077fe7ce7bef90211d0d" -[[package]] -name = "constant_time_eq" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" - [[package]] name = "core-foundation" version = "0.9.4" @@ -1511,6 +1406,17 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +[[package]] +name = "core-graphics-types" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "libc", +] + [[package]] name = "cpp_demangle" version = "0.4.3" @@ -1767,6 +1673,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "cudarc" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384" +dependencies = [ + "half", + "libloading", +] + [[package]] name = "darling" version = "0.14.4" @@ -1812,7 +1728,7 @@ dependencies = [ "proc-macro2", "quote", "strsim 0.11.1", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -1834,7 +1750,7 @@ checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" dependencies = [ "darling_core 0.20.9", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -1878,6 +1794,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_arbitrary" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.75", +] + [[package]] name = "derive_builder" version = "0.11.2" @@ -1889,11 +1816,11 @@ dependencies = [ [[package]] name = "derive_builder" -version = "0.12.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" +checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7" dependencies = [ - "derive_builder_macro 0.12.0", + "derive_builder_macro 0.20.0", ] [[package]] @@ -1910,14 +1837,14 @@ dependencies = [ [[package]] name = "derive_builder_core" -version = "0.12.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" +checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d" dependencies = [ - "darling 0.14.4", + "darling 0.20.9", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.75", ] [[package]] @@ -1932,12 +1859,12 @@ dependencies = [ [[package]] name = "derive_builder_macro" -version = "0.12.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" +checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" dependencies = [ - "derive_builder_core 0.12.0", - "syn 1.0.109", + "derive_builder_core 0.20.0", + "syn 2.0.75", ] [[package]] @@ -2057,6 +1984,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.75", +] + [[package]] name = "dkregistry" version = "0.5.1-alpha.0" @@ -2115,9 +2053,9 @@ checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" [[package]] name = "dyn-stack" -version = "0.9.1" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fe7f8d7bcc523381d3c437b82cf74805de3931de0da69309ae0fe1bdf7a256e" +checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b" dependencies = [ "bytemuck", "reborrow", @@ -2184,6 +2122,18 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "enum-as-inner" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ffccbb6966c05b32ef8fbac435df276c4ae4d3dc55a8cd0eb9745e6c12f546a" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "syn 2.0.75", +] + [[package]] name = "enumflags2" version = "0.7.9" @@ -2202,7 +2152,7 @@ checksum = "5c785274071b1b420972453b306eeca06acf4633829db4223b58a2a8c5953bc4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -2320,7 +2270,7 @@ dependencies = [ "prettyplease", "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -2341,6 +2291,17 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" +[[package]] +name = "fancy-regex" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" +dependencies = [ + "bit-set", + "regex-automata 0.4.6", + "regex-syntax 0.8.3", +] + [[package]] name = "fastrand" version = "1.9.0" @@ -2455,7 +2416,28 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" dependencies = [ - "foreign-types-shared", + "foreign-types-shared 0.1.1", +] + +[[package]] +name = "foreign-types" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" +dependencies = [ + "foreign-types-macros", + "foreign-types-shared 0.3.1", +] + +[[package]] +name = "foreign-types-macros" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.75", ] [[package]] @@ -2464,6 +2446,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +[[package]] +name = "foreign-types-shared" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -2493,16 +2481,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "fs2" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "fs_extra" version = "1.3.0" @@ -2612,7 +2590,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -2667,6 +2645,124 @@ dependencies = [ "serde_json", ] +[[package]] +name = "gemm" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ab24cc62135b40090e31a76a9b2766a501979f3070fa27f689c27ec04377d32" +dependencies = [ + "dyn-stack", + "gemm-c32", + "gemm-c64", + "gemm-common", + "gemm-f16", + "gemm-f32", + "gemm-f64", + "num-complex", + "num-traits 0.2.18", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c32" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9c030d0b983d1e34a546b86e08f600c11696fde16199f971cd46c12e67512c0" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits 0.2.18", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c64" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbb5f2e79fefb9693d18e1066a557b4546cd334b226beadc68b11a8f9431852a" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits 0.2.18", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-common" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8" +dependencies = [ + "bytemuck", + "dyn-stack", + "half", + "num-complex", + "num-traits 0.2.18", + "once_cell", + "paste", + "pulp", + "raw-cpuid", + "rayon", + "seq-macro", + "sysctl", +] + +[[package]] +name = "gemm-f16" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ca4c06b9b11952071d317604acb332e924e817bd891bec8dfb494168c7cedd4" +dependencies = [ + "dyn-stack", + "gemm-common", + "gemm-f32", + "half", + "num-complex", + "num-traits 0.2.18", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f32" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9a69f51aaefbd9cf12d18faf273d3e982d9d711f60775645ed5c8047b4ae113" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits 0.2.18", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-f64" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa397a48544fadf0b81ec8741e5c0fba0043008113f71f2034def1935645d2b0" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits 0.2.18", + "paste", + "raw-cpuid", + "seq-macro", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -2702,24 +2798,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "ggml" -version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?rev=2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663#2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663" -dependencies = [ - "ggml-sys", - "memmap2 0.5.10", - "thiserror", -] - -[[package]] -name = "ggml-sys" -version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?rev=2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663#2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663" -dependencies = [ - "cc", -] - [[package]] name = "gimli" version = "0.28.1" @@ -2969,7 +3047,7 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57da3b9b5b85bd66f31093f8c408b90a74431672542466497dcbdfdc02034be1" dependencies = [ - "aho-corasick 1.1.3", + "aho-corasick", "bstr", "log", "regex-automata 0.4.6", @@ -3039,10 +3117,11 @@ dependencies = [ [[package]] name = "half" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5eceaaeec696539ddaf7b333340f1af35a5aa87ae3e4f3ead0532f72affab2e" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" dependencies = [ + "bytemuck", "cfg-if", "crunchy", "num-traits 0.2.18", @@ -3524,30 +3603,6 @@ dependencies = [ "serde 1.0.197", ] -[[package]] -name = "indicatif" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7baab56125e25686df467fe470785512329883aab42696d661247aca2a2896e4" -dependencies = [ - "console", - "lazy_static 1.4.0", - "number_prefix 0.3.0", - "regex", -] - -[[package]] -name = "indicatif" -version = "0.16.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d207dc617c7a380ab07ff572a6e52fa202a2a8f355860ac9c38e23f8196be1b" -dependencies = [ - "console", - "lazy_static 1.4.0", - "number_prefix 0.4.0", - "regex", -] - [[package]] name = "indicatif" version = "0.17.8" @@ -3556,7 +3611,7 @@ checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" dependencies = [ "console", "instant", - "number_prefix 0.4.0", + "number_prefix", "portable-atomic", "unicode-width", ] @@ -3584,16 +3639,7 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" dependencies = [ - "libc", -] - -[[package]] -name = "inout" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" -dependencies = [ - "generic-array", + "libc", ] [[package]] @@ -3649,24 +3695,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "itertools" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f56a2d0bc861f9165be4eb3442afd3c236d8a98afd426f65d92324ae1091a484" -dependencies = [ - "either", -] - -[[package]] -name = "itertools" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.10.5" @@ -3896,12 +3924,12 @@ dependencies = [ [[package]] name = "libloading" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" +checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.48.5", ] [[package]] @@ -4098,58 +4126,6 @@ dependencies = [ "unicode-segmentation", ] -[[package]] -name = "llm" -version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?rev=2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663#2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663" -dependencies = [ - "llm-base", - "llm-llama", - "serde 1.0.197", - "tracing", -] - -[[package]] -name = "llm-base" -version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?rev=2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663#2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663" -dependencies = [ - "bytemuck", - "ggml", - "half", - "llm-samplers", - "memmap2 0.5.10", - "partial_sort", - "rand 0.8.5", - "regex", - "serde 1.0.197", - "serde_bytes", - "thiserror", - "tokenizers", - "tracing", -] - -[[package]] -name = "llm-llama" -version = "0.2.0-dev" -source = "git+https://github.com/rustformers/llm?rev=2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663#2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663" -dependencies = [ - "llm-base", - "tracing", -] - -[[package]] -name = "llm-samplers" -version = "0.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7553f60d113c9cdc6a5402456a31cd9a273bef79f6f16d8a4f7b4bedf5f754b2" -dependencies = [ - "anyhow", - "num-traits 0.2.18", - "rand 0.8.5", - "thiserror", -] - [[package]] name = "lock_api" version = "0.4.11" @@ -4198,7 +4174,7 @@ dependencies = [ "proc-macro2", "quote", "regex-syntax 0.6.29", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -4213,7 +4189,7 @@ dependencies = [ "proc-macro2", "quote", "regex-syntax 0.8.3", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -4263,9 +4239,9 @@ dependencies = [ [[package]] name = "macro_rules_attribute" -version = "0.1.3" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf0c9b980bf4f3a37fd7b1c066941dd1b1d0152ce6ee6e8fe8c49b9f6810d862" +checksum = "8a82271f7bc033d84bbca59a3ce3e4159938cb08a9c3aebbe54d215131518a13" dependencies = [ "macro_rules_attribute-proc_macro", "paste", @@ -4273,9 +4249,18 @@ dependencies = [ [[package]] name = "macro_rules_attribute-proc_macro" -version = "0.1.3" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568" + +[[package]] +name = "malloc_buf" +version = "0.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d" +checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb" +dependencies = [ + "libc", +] [[package]] name = "matchers" @@ -4334,11 +4319,12 @@ dependencies = [ [[package]] name = "memmap2" -version = "0.7.1" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" +checksum = "fe751422e4a8caa417e13c3ea66452215d7d63e19e604f4980461212f3ae1322" dependencies = [ "libc", + "stable_deref_trait", ] [[package]] @@ -4368,6 +4354,21 @@ dependencies = [ "autocfg", ] +[[package]] +name = "metal" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25" +dependencies = [ + "bitflags 2.5.0", + "block", + "core-graphics-types", + "foreign-types 0.5.0", + "log", + "objc", + "paste", +] + [[package]] name = "miette" version = "5.10.0" @@ -4400,7 +4401,7 @@ checksum = "49e7bc1560b95a3c4a25d03de42fe76ca718ab92d1a22a55b9b4cf67b3ae635c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -4411,7 +4412,7 @@ checksum = "dcf09caffaac8068c346b6df2a7fc27a177fd20b39421a39ce0a211bde679a6c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -4471,9 +4472,9 @@ dependencies = [ [[package]] name = "monostate" -version = "0.1.11" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "878c2a1f1c70e5724fa28f101ca787b6a7e8ad5c5e4ae4ca3b0fa4a419fa9075" +checksum = "0d208407d7552cd041d8cdb69a1bc3303e029c598738177a3d87082004dc0e1e" dependencies = [ "monostate-impl", "serde 1.0.197", @@ -4481,13 +4482,13 @@ dependencies = [ [[package]] name = "monostate-impl" -version = "0.1.11" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f686d68a09079e63b1d2c64aa305095887ce50565f00a922ebfaeeee0d9ba6ce" +checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -4718,6 +4719,7 @@ version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" dependencies = [ + "bytemuck", "num-traits 0.2.18", ] @@ -4787,6 +4789,27 @@ dependencies = [ "libc", ] +[[package]] +name = "num_enum" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179" +dependencies = [ + "num_enum_derive", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.75", +] + [[package]] name = "num_threads" version = "0.1.7" @@ -4796,12 +4819,6 @@ dependencies = [ "libc", ] -[[package]] -name = "number_prefix" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a" - [[package]] name = "number_prefix" version = "0.4.0" @@ -4827,6 +4844,25 @@ dependencies = [ "url", ] +[[package]] +name = "objc" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" +dependencies = [ + "malloc_buf", + "objc_exception", +] + +[[package]] +name = "objc_exception" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad970fb455818ad6cba4c122ad012fae53ae8b4795f86378bce65e4f6bab2ca4" +dependencies = [ + "cc", +] + [[package]] name = "object" version = "0.32.2" @@ -4967,7 +5003,7 @@ checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" dependencies = [ "bitflags 2.5.0", "cfg-if", - "foreign-types", + "foreign-types 0.3.2", "libc", "once_cell", "openssl-macros", @@ -4982,7 +5018,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -5180,23 +5216,6 @@ dependencies = [ "windows-targets 0.48.5", ] -[[package]] -name = "partial_sort" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7924d1d0ad836f665c9065e26d016c673ece3993f30d340068b16f282afc1156" - -[[package]] -name = "password-hash" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" -dependencies = [ - "base64ct", - "rand_core 0.6.4", - "subtle", -] - [[package]] name = "paste" version = "1.0.14" @@ -5264,18 +5283,6 @@ dependencies = [ "serde 1.0.197", ] -[[package]] -name = "pbkdf2" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" -dependencies = [ - "digest", - "hmac", - "password-hash", - "sha2", -] - [[package]] name = "pem" version = "3.0.3" @@ -5332,7 +5339,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -5412,7 +5419,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -5563,7 +5570,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7" dependencies = [ "proc-macro2", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -5617,9 +5624,9 @@ checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" [[package]] name = "proc-macro2" -version = "1.0.79" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" dependencies = [ "unicode-ident", ] @@ -5685,7 +5692,7 @@ dependencies = [ "prost", "prost-types", "regex", - "syn 2.0.58", + "syn 2.0.75", "tempfile", ] @@ -5699,7 +5706,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -5776,6 +5783,18 @@ dependencies = [ "tint", ] +[[package]] +name = "pulp" +version = "0.18.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0a01a0dc67cf4558d279f0c25b0962bd08fc6dec0137699eae304103e882fe6" +dependencies = [ + "bytemuck", + "libm", + "num-complex", + "reborrow", +] + [[package]] name = "quote" version = "1.0.35" @@ -5896,12 +5915,12 @@ dependencies = [ [[package]] name = "rayon-cond" -version = "0.1.0" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd1259362c9065e5ea39a789ef40b1e3fd934c94beb7b5ab3ac6629d3b5e7cb7" +checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" dependencies = [ "either", - "itertools 0.8.2", + "itertools 0.11.0", "rayon", ] @@ -6029,7 +6048,7 @@ version = "1.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ - "aho-corasick 1.1.3", + "aho-corasick", "memchr", "regex-automata 0.4.6", "regex-syntax 0.8.3", @@ -6050,7 +6069,7 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" dependencies = [ - "aho-corasick 1.1.3", + "aho-corasick", "memchr", "regex-syntax 0.8.3", ] @@ -6061,12 +6080,6 @@ version = "0.6.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" -[[package]] -name = "regex-syntax" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" - [[package]] name = "regex-syntax" version = "0.8.3" @@ -6332,7 +6345,7 @@ dependencies = [ "regex", "serde_urlencoded", "syn 1.0.109", - "synstructure", + "synstructure 0.12.6", ] [[package]] @@ -6486,6 +6499,16 @@ dependencies = [ "serde_json", ] +[[package]] +name = "safetensors" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7725d4d98fa515472f43a6e2bbf956c48e06b89bb50593a040e5945160214450" +dependencies = [ + "serde 1.0.197", + "serde_json", +] + [[package]] name = "same-file" version = "1.0.6" @@ -6566,7 +6589,7 @@ version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5da1a5ad4d28c03536f82f77d9f36603f5e37d8869ac98f0a750d5b5686d8d95" dependencies = [ - "aes 0.7.5", + "aes", "block-modes", "futures-util", "generic-array", @@ -6654,15 +6677,6 @@ dependencies = [ "serde 1.0.197", ] -[[package]] -name = "serde_bytes" -version = "0.11.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b8497c313fd43ab992087548117643f6fcd935cbf36f176ffda0aacf9591734" -dependencies = [ - "serde 1.0.197", -] - [[package]] name = "serde_derive" version = "1.0.197" @@ -6671,7 +6685,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -6685,11 +6699,12 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.115" +version = "1.0.125" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd" +checksum = "83c8e735a073ccf5be70aa8066aa984eaf2fa000db6c8d0100ae605b366d31ed" dependencies = [ "itoa", + "memchr", "ryu", "serde 1.0.197", ] @@ -6704,6 +6719,15 @@ dependencies = [ "serde 1.0.197", ] +[[package]] +name = "serde_plain" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1fc6db65a611022b23a0dec6975d63fb80a302cb3388835ff02c097258d50" +dependencies = [ + "serde 1.0.197", +] + [[package]] name = "serde_qs" version = "0.8.5" @@ -6723,7 +6747,7 @@ checksum = "6c64451ba24fc7a6a2d60fc75dd9c83c90903b19028d4eff35e88fc1e86564e9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -6774,7 +6798,7 @@ dependencies = [ "darling 0.20.9", "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -7097,7 +7121,7 @@ dependencies = [ "http-body-util", "hyper 1.4.1", "hyper-util", - "indicatif 0.17.8", + "indicatif", "is-terminal", "itertools 0.11.0", "lazy_static 1.4.0", @@ -7500,7 +7524,7 @@ dependencies = [ "expander", "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -7598,13 +7622,14 @@ dependencies = [ "anyhow", "candle-core", "candle-nn", + "candle-transformers", "chrono", - "llm", "lru 0.9.0", "num_cpus", "rand 0.8.5", - "safetensors", + "safetensors 0.3.3", "serde 1.0.197", + "serde_json", "spin-common", "spin-core", "spin-world", @@ -8133,9 +8158,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.58" +version = "2.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687" +checksum = "f6af063034fc1935ede7be0122941bafa9bacb949334d090b77ca98b5817c7d9" dependencies = [ "proc-macro2", "quote", @@ -8169,6 +8194,31 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.75", +] + +[[package]] +name = "sysctl" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea" +dependencies = [ + "bitflags 2.5.0", + "byteorder", + "enum-as-inner", + "libc", + "thiserror", + "walkdir", +] + [[package]] name = "system-configuration" version = "0.5.1" @@ -8304,7 +8354,7 @@ version = "0.0.0" dependencies = [ "heck 0.4.1", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -8377,7 +8427,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -8472,19 +8522,16 @@ dependencies = [ [[package]] name = "tokenizers" -version = "0.13.4" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aea68938177975ab09da68552b720eac941779ff386baceaf77e0f5f9cea645f" +checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd" dependencies = [ - "aho-corasick 0.7.20", - "cached-path", - "clap 4.5.4", - "derive_builder 0.12.0", - "dirs 4.0.0", + "aho-corasick", + "derive_builder 0.20.0", "esaxx-rs", "getrandom 0.2.12", - "indicatif 0.15.0", - "itertools 0.9.0", + "indicatif", + "itertools 0.12.1", "lazy_static 1.4.0", "log", "macro_rules_attribute", @@ -8495,8 +8542,7 @@ dependencies = [ "rayon", "rayon-cond", "regex", - "regex-syntax 0.7.5", - "reqwest 0.11.27", + "regex-syntax 0.8.3", "serde 1.0.197", "serde_json", "spm_precompiled", @@ -8542,7 +8588,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -8824,7 +8870,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -8910,7 +8956,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" dependencies = [ "cfg-if", - "rand 0.8.5", + "rand 0.7.3", "static_assertions", ] @@ -9373,7 +9419,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", "wasm-bindgen-shared", ] @@ -9407,7 +9453,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -9783,7 +9829,7 @@ dependencies = [ "anyhow", "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", "wasmtime-component-util", "wasmtime-wit-bindgen", "wit-parser 0.209.1", @@ -9910,7 +9956,7 @@ checksum = "de5a9bc4f44ceeb168e9e8e3be4e0b4beb9095b468479663a9e24c667e36826f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", ] [[package]] @@ -10180,7 +10226,7 @@ dependencies = [ "proc-macro2", "quote", "shellexpand 2.1.2", - "syn 2.0.58", + "syn 2.0.75", "witx", ] @@ -10192,7 +10238,7 @@ checksum = "cc26129a8aea20b62c961d1b9ab4a3c3b56b10042ed85d004f8678af0f21ba6e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", "wiggle-generate", ] @@ -10711,6 +10757,30 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "yoke" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +dependencies = [ + "serde 1.0.197", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.75", + "synstructure 0.13.1", +] + [[package]] name = "zbus" version = "3.15.2" @@ -10794,7 +10864,28 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.58", + "syn 2.0.75", +] + +[[package]] +name = "zerofrom" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.75", + "synstructure 0.13.1", ] [[package]] @@ -10805,31 +10896,17 @@ checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" [[package]] name = "zip" -version = "0.6.6" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +checksum = "9cc23c04387f4da0374be4533ad1208cbb091d5c11d070dfef13676ad6497164" dependencies = [ - "aes 0.8.4", - "byteorder", - "bzip2", - "constant_time_eq", + "arbitrary", "crc32fast", "crossbeam-utils", - "flate2", - "hmac", - "pbkdf2", - "sha1 0.10.6", - "time", - "zstd 0.11.2+zstd.1.5.2", -] - -[[package]] -name = "zstd" -version = "0.11.2+zstd.1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" -dependencies = [ - "zstd-safe 5.0.2+zstd.1.5.2", + "displaydoc", + "indexmap 2.2.6", + "num_enum", + "thiserror", ] [[package]] @@ -10850,16 +10927,6 @@ dependencies = [ "zstd-safe 7.1.0", ] -[[package]] -name = "zstd-safe" -version = "5.0.2+zstd.1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" -dependencies = [ - "libc", - "zstd-sys", -] - [[package]] name = "zstd-safe" version = "6.0.6" diff --git a/crates/llm-local/Cargo.toml b/crates/llm-local/Cargo.toml index 9e559d76a..f908fcf63 100644 --- a/crates/llm-local/Cargo.toml +++ b/crates/llm-local/Cargo.toml @@ -6,30 +6,28 @@ edition = { workspace = true } [dependencies] anyhow = "1.0" -candle = { git = "https://github.com/huggingface/candle", rev = "b80348d22f8f0dadb6cc4101bde031d5de69a9a5", package = "candle-core" } -candle-nn = { git = "https://github.com/huggingface/candle", rev = "b80348d22f8f0dadb6cc4101bde031d5de69a9a5" } +candle = { git = "https://github.com/huggingface/candle", rev = "e3261216b157a7305c18ccdd766b6e2a41afe483", package = "candle-core" } +candle-nn = { git = "https://github.com/huggingface/candle", rev = "e3261216b157a7305c18ccdd766b6e2a41afe483" } +candle-transformers = { git = "https://github.com/huggingface/candle", rev = "e3261216b157a7305c18ccdd766b6e2a41afe483" } chrono = "0.4.26" -llm = { git = "https://github.com/rustformers/llm", rev = "2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663", features = [ - "tokenizers-remote", - "llama", -], default-features = false } lru = "0.9.0" num_cpus = "1" rand = "0.8.5" safetensors = "0.3.3" serde = { version = "1.0.150", features = ["derive"] } +serde_json = "1.0.125" spin-common = { path = "../common" } spin-core = { path = "../core" } spin-world = { path = "../world" } terminal = { path = "../terminal" } -tokenizers = "0.13.4" +tokenizers = "0.19.1" tokio = { version = "1.32.0", features = ["macros", "sync"] } tracing = { workspace = true } [features] default = [] -metal = ["llm/metal"] -cublas = ["llm/cublas"] +metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"] +cublas = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] [lints] workspace = true diff --git a/crates/llm-local/src/bert.rs b/crates/llm-local/src/bert.rs index 3ecb2ac28..55f991250 100644 --- a/crates/llm-local/src/bert.rs +++ b/crates/llm-local/src/bert.rs @@ -4,7 +4,7 @@ /// /// TODO: Remove this file when a new release of Candle makes it obsolete. use anyhow::{bail, Result}; -use candle::{DType, Tensor}; +use candle::{DType, Module, Tensor}; use candle_nn::{Embedding, VarBuilder}; use serde::Deserialize; diff --git a/crates/llm-local/src/lib.rs b/crates/llm-local/src/lib.rs index e4db26193..fbe4bde2c 100644 --- a/crates/llm-local/src/lib.rs +++ b/crates/llm-local/src/lib.rs @@ -1,35 +1,59 @@ mod bert; +mod llama; +mod token_output_stream; +mod utils; use anyhow::Context; use bert::{BertModel, Config}; use candle::DType; use candle_nn::VarBuilder; -use llm::{ - InferenceFeedback, InferenceParameters, InferenceResponse, InferenceSessionConfig, Model, - ModelArchitecture, ModelKVMemoryType, ModelParameters, -}; -use rand::SeedableRng; use spin_common::ui::quoted_path; +use spin_core::async_trait; use spin_world::v2::llm::{self as wasi_llm}; use std::{ - collections::hash_map::Entry, - collections::HashMap, - convert::Infallible, + collections::{hash_map::Entry, HashMap}, path::{Path, PathBuf}, - sync::{Arc, Mutex}, + str::FromStr, + sync::Arc, }; use tokenizers::PaddingParams; const MODEL_ALL_MINILM_L6_V2: &str = "all-minilm-l6-v2"; +type ModelName = String; #[derive(Clone)] pub struct LocalLlmEngine { registry: PathBuf, - use_gpu: bool, - inferencing_models: HashMap<(String, bool), Arc>, + _use_gpu: bool, + inferencing_models: HashMap>, embeddings_models: HashMap>, } +#[derive(Debug)] +enum InferencingModelArch { + Llama, +} + +impl FromStr for InferencingModelArch { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "llama" => Ok(InferencingModelArch::Llama), + _ => Err(()), + } + } +} + +#[async_trait] +trait CachedInferencingModel: Send + Sync { + async fn infer( + &self, + prompt: String, + params: wasi_llm::InferencingParams, + ) -> anyhow::Result; +} + impl LocalLlmEngine { pub async fn infer( &mut self, @@ -37,57 +61,13 @@ impl LocalLlmEngine { prompt: String, params: wasi_llm::InferencingParams, ) -> Result { + // return self.inference(model).await; let model = self.inferencing_model(model).await?; - let cfg = InferenceSessionConfig { - memory_k_type: ModelKVMemoryType::Float16, - memory_v_type: ModelKVMemoryType::Float16, - n_batch: 8, - n_threads: num_cpus::get(), - }; - - let mut session = Model::start_session(model.as_ref(), cfg); - let inference_params = InferenceParameters { - sampler: generate_sampler(params), - }; - let mut rng = rand::rngs::StdRng::from_entropy(); - let mut text = String::new(); - #[cfg(debug_assertions)] - { - terminal::warn!( - "\ - This is a debug build - running inference might be prohibitively slow\n\ - You may want to consider switching to the release build" - ) - } - let res = session.infer::( - model.as_ref(), - &mut rng, - &llm::InferenceRequest { - prompt: prompt.as_str().into(), - parameters: &inference_params, - play_back_previous_tokens: false, - maximum_token_count: Some(params.max_tokens as usize), - }, - &mut Default::default(), - |r| { - match r { - InferenceResponse::InferredToken(t) => text.push_str(&t), - InferenceResponse::EotToken => return Ok(InferenceFeedback::Halt), - _ => {} - }; - Ok(InferenceFeedback::Continue) - }, - ); - let stats = res.map_err(|e| { - wasi_llm::Error::RuntimeError(format!("Error occurred during inferencing: {e}")) - })?; - let usage = wasi_llm::InferencingUsage { - prompt_token_count: stats.prompt_tokens as u32, - generated_token_count: (stats.predict_tokens - stats.prompt_tokens) as u32, - }; - let response = wasi_llm::InferencingResult { text, usage }; - Ok(response) + model + .infer(prompt, params) + .await + .map_err(|e| wasi_llm::Error::RuntimeError(e.to_string())) } pub async fn generate_embeddings( @@ -103,10 +83,10 @@ impl LocalLlmEngine { } impl LocalLlmEngine { - pub fn new(registry: PathBuf, use_gpu: bool) -> Self { + pub fn new(registry: PathBuf, _use_gpu: bool) -> Self { Self { registry, - use_gpu, + _use_gpu, inferencing_models: Default::default(), embeddings_models: Default::default(), } @@ -164,73 +144,36 @@ impl LocalLlmEngine { async fn inferencing_model( &mut self, model: wasi_llm::InferencingModel, - ) -> Result, wasi_llm::Error> { - let use_gpu = self.use_gpu; - let progress_fn = |_| {}; - let model = match self.inferencing_models.entry((model.clone(), use_gpu)) { + ) -> Result, wasi_llm::Error> { + // let use_gpu = self.use_gpu; + + let model = match self.inferencing_models.entry(model.clone()) { Entry::Occupied(o) => o.get().clone(), - Entry::Vacant(v) => v - .insert({ - let (path, arch) = if let Some(arch) = well_known_inferencing_model_arch(&model) { - let model_binary = self.registry.join(&model); - if model_binary.exists() { - (model_binary, arch.to_owned()) - } else { - walk_registry_for_model(&self.registry, model).await? - } - } else { - walk_registry_for_model(&self.registry, model).await? - }; - if !self.registry.exists() { - return Err(wasi_llm::Error::RuntimeError( - format!("The directory expected to house the inferencing model '{}' does not exist.", self.registry.display()) - )); - } - if !path.exists() { - return Err(wasi_llm::Error::RuntimeError( - format!("The inferencing model file '{}' does not exist.", path.display()) - )); - } - tokio::task::spawn_blocking(move || { - let params = ModelParameters { - prefer_mmap: true, - context_size: 2048, - lora_adapters: None, - use_gpu, - gpu_layers: None, - rope_overrides: None, - n_gqa: None, - }; - let model = llm::load_dynamic( - Some(arch), - &path, - llm::TokenizerSource::Embedded, - params, - progress_fn, - ) - .map_err(|e| { - wasi_llm::Error::RuntimeError(format!( - "Failed to load model from model registry: {e}" - )) - })?; - Ok(Arc::from(model)) - }) - .await - .map_err(|_| { - wasi_llm::Error::RuntimeError("Error loading inferencing model".into()) - })?? - }) - .clone(), + Entry::Vacant(v) => { + let (model_dir, arch) = + walk_registry_for_model(&self.registry, model.clone()).await?; + let model = match arch { + InferencingModelArch::Llama => Arc::new( + llama::LlamaModels::new(&model_dir) + .map_err(|e| wasi_llm::Error::RuntimeError(e.to_string()))?, + ), + }; + + v.insert(model.clone()); + + model + } }; Ok(model) } } -/// Get the model binary and arch from walking the registry file structure +/// Walks the registry file structure and returns the directory the model is +/// present along with its architecture async fn walk_registry_for_model( registry_path: &Path, model: String, -) -> Result<(PathBuf, ModelArchitecture), wasi_llm::Error> { +) -> Result<(PathBuf, InferencingModelArch), wasi_llm::Error> { let mut arch_dirs = tokio::fs::read_dir(registry_path).await.map_err(|e| { wasi_llm::Error::RuntimeError(format!( "Could not read model registry directory '{}': {e}", @@ -256,17 +199,31 @@ async fn walk_registry_for_model( { continue; } - let mut model_files = tokio::fs::read_dir(arch_dir.path()).await.map_err(|e| { + let mut model_dirs = tokio::fs::read_dir(arch_dir.path()).await.map_err(|e| { wasi_llm::Error::RuntimeError(format!( "Error reading architecture directory in model registry: {e}" )) })?; - while let Some(model_file) = model_files.next_entry().await.map_err(|e| { + while let Some(model_dir) = model_dirs.next_entry().await.map_err(|e| { wasi_llm::Error::RuntimeError(format!( - "Error reading model file in model registry: {e}" + "Error reading model folder in model registry: {e}" )) })? { - if model_file + // Models need to be a directory. So ignore any files. + if model_dir + .file_type() + .await + .map_err(|e| { + wasi_llm::Error::RuntimeError(format!( + "Could not read file type of '{}' dir: {e}", + model_dir.path().display() + )) + })? + .is_file() + { + continue; + } + if model_dir .file_name() .to_str() .map(|m| m == model) @@ -278,7 +235,7 @@ async fn walk_registry_for_model( .ok_or(wasi_llm::Error::ModelNotSupported)? .parse() .map_err(|_| wasi_llm::Error::ModelNotSupported)?; - result = Some((model_file.path(), arch)); + result = Some((model_dir.path(), arch)); break 'outer; } } @@ -291,15 +248,6 @@ async fn walk_registry_for_model( }) } -fn well_known_inferencing_model_arch( - model: &wasi_llm::InferencingModel, -) -> Option { - match model.as_str() { - "llama2-chat" | "code_llama" => Some(ModelArchitecture::Llama), - _ => None, - } -} - async fn generate_embeddings( data: Vec, model: Arc<(tokenizers::Tokenizer, BertModel)>, @@ -381,75 +329,10 @@ fn load_tokenizer(tokenizer_file: &Path) -> anyhow::Result anyhow::Result { - let buffer = std::fs::read(model_file) - .with_context(|| format!("Failed to read model file {}", quoted_path(model_file)))?; - let weights = safetensors::SafeTensors::deserialize(&buffer)?; - let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &candle::Device::Cpu); + // TODO: Check if there is a safe way to load the model from the file + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &candle::Device::Cpu)? + }; let model = BertModel::load(vb, &Config::default()).context("error loading bert model")?; Ok(model) } - -// Sampling options for picking the next token in the sequence. -// We start with a default sampler, then add the inference parameters supplied by the request. -fn generate_sampler( - params: wasi_llm::InferencingParams, -) -> Arc>> { - let mut result = llm::samplers::ConfiguredSamplers { - // We are *not* using the default implementation for ConfiguredSamplers here - // because the builder already sets values for parameters, which we cannot replace. - builder: llm::samplers::llm_samplers::configure::SamplerChainBuilder::default(), - ..Default::default() - }; - - result.builder += ( - "temperature".into(), - llm::samplers::llm_samplers::configure::SamplerSlot::new_single( - move || { - Box::new( - llm::samplers::llm_samplers::samplers::SampleTemperature::default() - .temperature(params.temperature), - ) - }, - Option::::None, - ), - ); - result.builder += ( - "topp".into(), - llm::samplers::llm_samplers::configure::SamplerSlot::new_single( - move || { - Box::new( - llm::samplers::llm_samplers::samplers::SampleTopP::default().p(params.top_p), - ) - }, - Option::::None, - ), - ); - result.builder += ( - "topk".into(), - llm::samplers::llm_samplers::configure::SamplerSlot::new_single( - move || { - Box::new( - llm::samplers::llm_samplers::samplers::SampleTopK::default() - .k(params.top_k as usize), - ) - }, - Option::::None, - ), - ); - result.builder += ( - "repetition".into(), - llm::samplers::llm_samplers::configure::SamplerSlot::new_chain( - move || { - Box::new( - llm::samplers::llm_samplers::samplers::SampleRepetition::default() - .penalty(params.repeat_penalty) - .last_n(params.repeat_penalty_last_n_token_count as usize), - ) - }, - [], - ), - ); - - result.ensure_default_slots(); - Arc::new(Mutex::new(result.builder.into_chain())) -} diff --git a/crates/llm-local/src/llama.rs b/crates/llm-local/src/llama.rs new file mode 100644 index 000000000..84f218c8e --- /dev/null +++ b/crates/llm-local/src/llama.rs @@ -0,0 +1,168 @@ +use crate::{token_output_stream, utils::load_safetensors, CachedInferencingModel}; +use anyhow::{anyhow, Result}; +use candle::{utils, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::{ + generation::{LogitsProcessor, Sampling}, + models::llama::{self, Cache, Config, Llama, LlamaConfig}, +}; +use rand::{RngCore, SeedableRng}; +use spin_core::async_trait; +use spin_world::v2::llm::{self as wasi_llm, InferencingUsage}; +use std::{fs, path::PathBuf, sync::Arc}; +use tokenizers::Tokenizer; + +const TOKENIZER_FILENAME: &str = "tokenizer.json"; +const CONFIG_FILENAME: &str = "config.json"; +const EOS_TOKEN: &str = ""; +const MODEL_SAFETENSORS_INDEX: &str = "model.safetensors.index.json"; + +pub fn auto_device() -> Result { + if utils::cuda_is_available() { + Ok(Device::new_cuda(0)?) + } else if utils::metal_is_available() { + Ok(Device::new_metal(0)?) + } else { + Ok(Device::Cpu) + } +} + +#[derive(Clone)] +pub(crate) struct LlamaModels { + model: Arc, + config: Config, + cache: Cache, + tokenizer: Tokenizer, + device: Device, +} + +impl LlamaModels { + pub fn new(model_dir: &PathBuf) -> Result { + let tokenizer_path = model_dir.join(TOKENIZER_FILENAME); + let config_path = model_dir.join(CONFIG_FILENAME); + + let dtype = candle::DType::F16; + let device = auto_device()?; + + let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| anyhow!(e.to_string()))?; + let config: LlamaConfig = serde_json::from_slice(&fs::read(config_path)?)?; + + // TODO: flash attention is supposed to minimize memory read and writes - Do we want to turn it on + let config = config.into_config(false); + let cache = llama::Cache::new(true, dtype, &config, &device)?; + + let safetensor_files = load_safetensors(&model_dir, MODEL_SAFETENSORS_INDEX)?; + + // TODO: Check if there is a safe way to load the model from the file + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&safetensor_files, dtype, &device)? }; + let model = Llama::load(vb, &config)?; + + Ok(Self { + model: Arc::new(model), + config, + cache, + tokenizer, + device, + }) + } +} + +#[async_trait] +impl CachedInferencingModel for LlamaModels { + async fn infer( + &self, + prompt: String, + params: wasi_llm::InferencingParams, + ) -> anyhow::Result { + let model = Arc::clone(&self.model); + let config = &self.config; + let tokenizer = self.tokenizer.clone(); + let mut cache = self.cache.clone(); + let eos_token_id = config.clone().eos_token_id.or_else(|| { + tokenizer + .token_to_id(EOS_TOKEN) + .map(llama::LlamaEosToks::Single) + }); + + let mut tokens = tokenizer + .encode(prompt, true) + .map_err(|e| anyhow!(e.to_string()))? + .get_ids() + .to_vec(); + let mut tokenizer = token_output_stream::TokenOutputStream::new(tokenizer); + let mut rng = rand::rngs::StdRng::from_entropy(); + + let mut logits_processor = { + let temperature = params.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + Sampling::TopKThenTopP { + k: params.top_k as usize, + p: params.top_p as f64, + temperature: params.temperature as f64, + } + }; + LogitsProcessor::from_sampling(rng.next_u64(), sampling) + }; + + let mut index_pos = 0; + let mut tokens_generated = 0; + let mut output_text: String = String::default(); + + for index in 0..params.max_tokens { + let (context_size, context_index) = if self.cache.use_kv_cache && index > 0 { + (1, index_pos) + } else { + (tokens.len(), 0) + }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = model.forward(&input, context_index, &mut cache)?; + let logits = logits.squeeze(0)?; + let logits = if params.repeat_penalty == 1. { + logits + } else { + let start_at = tokens + .len() + .saturating_sub(params.repeat_penalty_last_n_token_count as usize); + candle_transformers::utils::apply_repeat_penalty( + &logits, + params.repeat_penalty, + &tokens[start_at..], + )? + }; + index_pos += ctxt.len(); + + let next_token = logits_processor.sample(&logits)?; + tokens_generated += 1; + tokens.push(next_token); + + match eos_token_id { + Some(llama::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => { + break; + } + Some(llama::LlamaEosToks::Multiple(ref eos_ids)) + if eos_ids.contains(&next_token) => + { + break; + } + _ => (), + } + if let Some(t) = tokenizer.next_token(next_token)? { + output_text.push_str(&t); + } + } + if let Some(rest) = tokenizer.decode_rest()? { + output_text.push_str(&rest); + } + + Ok(wasi_llm::InferencingResult { + text: output_text, + usage: InferencingUsage { + prompt_token_count: tokens.len() as u32, + generated_token_count: tokens_generated, + }, + }) + } +} diff --git a/crates/llm-local/src/token_output_stream.rs b/crates/llm-local/src/token_output_stream.rs new file mode 100644 index 000000000..34af97ca5 --- /dev/null +++ b/crates/llm-local/src/token_output_stream.rs @@ -0,0 +1,85 @@ +// Implementation for TokenOutputStream Code is borrow from +// https://github.com/huggingface/candle/blob/main/candle-examples/src/token_output_stream.rs +// (Commit SHA 4fd00b890036ef67391a9cc03f896247d0a75711) +pub struct TokenOutputStream { + tokenizer: tokenizers::Tokenizer, + tokens: Vec, + prev_index: usize, + current_index: usize, +} + +impl TokenOutputStream { + pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { + Self { + tokenizer, + tokens: Vec::new(), + prev_index: 0, + current_index: 0, + } + } + + pub fn _into_inner(self) -> tokenizers::Tokenizer { + self.tokenizer + } + + fn decode(&self, tokens: &[u32]) -> anyhow::Result { + match self.tokenizer.decode(tokens, true) { + Ok(str) => Ok(str), + Err(err) => anyhow::bail!("cannot decode: {err}"), + } + } + + // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 + pub fn next_token(&mut self, token: u32) -> anyhow::Result> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + self.tokens.push(token); + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { + let text = text.split_at(prev_text.len()); + self.prev_index = self.current_index; + self.current_index = self.tokens.len(); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_rest(&self) -> anyhow::Result> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() { + let text = text.split_at(prev_text.len()); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn _decode_all(&self) -> anyhow::Result { + self.decode(&self.tokens) + } + + pub fn _get_token(&self, token_s: &str) -> Option { + self.tokenizer.get_vocab(true).get(token_s).copied() + } + + pub fn _tokenizer(&self) -> &tokenizers::Tokenizer { + &self.tokenizer + } + + pub fn _clear(&mut self) { + self.tokens.clear(); + self.prev_index = 0; + self.current_index = 0; + } +} diff --git a/crates/llm-local/src/utils.rs b/crates/llm-local/src/utils.rs new file mode 100644 index 000000000..d50d10ec8 --- /dev/null +++ b/crates/llm-local/src/utils.rs @@ -0,0 +1,25 @@ +use candle::Result; +use std::path::Path; + +pub fn load_safetensors(model_dir: &Path, json_file: &str) -> Result> { + let json_file = model_dir.join(json_file); + let json_file = std::fs::File::open(json_file)?; + let json: serde_json::Value = + serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; + let weight_map = match json.get("weight_map") { + None => candle::bail!("no weight map in {json_file:?}"), + Some(serde_json::Value::Object(map)) => map, + Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), + }; + let mut safetensors_files = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_files.insert(file.to_string()); + } + } + let safetensors_files = safetensors_files + .iter() + .map(|v| model_dir.join(v)) + .collect::>(); + Ok(safetensors_files) +}