Skip to content

Commit

Permalink
remove unsafe code
Browse files Browse the repository at this point in the history
Signed-off-by: karthik2804 <[email protected]>
  • Loading branch information
karthik2804 committed Sep 16, 2024
1 parent 4206f10 commit 3352e7e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 21 deletions.
10 changes: 6 additions & 4 deletions crates/llm-local/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod token_output_stream;

use anyhow::Context;
use bert::{BertModel, Config};
use candle::DType;
use candle::{safetensors::load_buffer, DType};
use candle_nn::VarBuilder;
use spin_common::ui::quoted_path;
use spin_core::async_trait;
Expand Down Expand Up @@ -152,6 +152,7 @@ impl LocalLlmEngine {
let model = match arch {
InferencingModelArch::Llama => Arc::new(
llama::LlamaModels::new(&model_dir)
.await
.map_err(|e| wasi_llm::Error::RuntimeError(e.to_string()))?,
),
};
Expand Down Expand Up @@ -326,10 +327,11 @@ fn load_tokenizer(tokenizer_file: &Path) -> anyhow::Result<tokenizers::Tokenizer
}

fn load_model(model_file: &Path) -> anyhow::Result<BertModel> {
let device = &candle::Device::Cpu;
// TODO: Check if there is a safe way to load the model from the file
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &candle::Device::Cpu)?
};
let data = std::fs::read(model_file)?;
let tensors = load_buffer(&data, device)?;
let vb = VarBuilder::from_tensors(tensors, DType::F32, device);
let model = BertModel::load(vb, &Config::default()).context("error loading bert model")?;
Ok(model)
}
50 changes: 33 additions & 17 deletions crates/llm-local/src/llama.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use crate::{token_output_stream, CachedInferencingModel};
use anyhow::{anyhow, Context, Result};
use candle::{utils, Device, Tensor};
use anyhow::{anyhow, bail, Result};
use candle::{safetensors::load_buffer, utils, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
generation::{LogitsProcessor, Sampling},
models::llama::{self, Cache, Config, Llama, LlamaConfig},
};
use rand::{RngCore, SeedableRng};
use serde::Deserialize;
use spin_core::async_trait;
use spin_world::v2::llm::{self as wasi_llm, InferencingUsage};
use std::{collections::HashMap, fs, path::Path, sync::Arc};
Expand Down Expand Up @@ -38,7 +37,7 @@ pub(crate) struct LlamaModels {
}

impl LlamaModels {
pub fn new(model_dir: &Path) -> Result<Self> {
pub async fn new(model_dir: &Path) -> Result<Self> {
let tokenizer_path = model_dir.join(TOKENIZER_FILENAME);
let config_path = model_dir.join(CONFIG_FILENAME);

Expand All @@ -54,8 +53,16 @@ impl LlamaModels {

let safetensor_files = load_safetensors(model_dir, MODEL_SAFETENSORS_INDEX_FILE)?;

// TODO: Check if there is a safe way to load the model from the file
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&safetensor_files, dtype, &device)? };
let mut tensor_map: HashMap<String, Tensor> = HashMap::new();

for file in safetensor_files {
let data = fs::read(file)?;
let tensors = load_buffer(&data, &device)?;
for (k, v) in tensors {
tensor_map.insert(k, v);
}
}
let vb = VarBuilder::from_tensors(tensor_map, dtype, &device);
let model = Llama::load(vb, &config)?;

Ok(Self {
Expand Down Expand Up @@ -155,6 +162,7 @@ impl CachedInferencingModel for LlamaModels {
}
// Decode the token and add it to the output.
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{}", t);
output_text.push_str(&t);
}
}
Expand All @@ -174,20 +182,28 @@ impl CachedInferencingModel for LlamaModels {
}
}

#[derive(Deserialize)]
struct SafeTensorsJson {
weight_map: HashMap<String, String>,
}

/// Loads a list of SafeTensors file paths from a given model directory and
/// path to the model index JSON file.
fn load_safetensors(model_dir: &Path, json_file: &str) -> Result<Vec<std::path::PathBuf>> {
let json_file = model_dir.join(json_file);
let json_file = std::fs::File::open(&json_file)
.with_context(|| format!("Error while opening {json_file:?}"))?;

let json: SafeTensorsJson = serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
let mut safetensors_files = Vec::new();
safetensors_files.extend(json.weight_map.values().map(|v| model_dir.join(v)));
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => bail!("weight map in {json_file:?} is not a map"),
};

let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| model_dir.join(v))
.collect::<Vec<_>>();
Ok(safetensors_files)
}

0 comments on commit 3352e7e

Please sign in to comment.