Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
Signed-off-by: karthik2804 <[email protected]>
  • Loading branch information
karthik2804 committed Sep 17, 2024
1 parent 3db74cc commit d4c0a2e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
13 changes: 6 additions & 7 deletions crates/llm-local/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type ModelName = String;
#[derive(Clone)]
pub struct LocalLlmEngine {
registry: PathBuf,
inferencing_models: HashMap<ModelName, Arc<dyn CachedInferencingModel>>,
inferencing_models: HashMap<ModelName, Arc<dyn InferencingModel>>,
embeddings_models: HashMap<String, Arc<(tokenizers::Tokenizer, BertModel)>>,
}

Expand All @@ -43,11 +43,11 @@ impl FromStr for InferencingModelArch {
}
}

/// `CachedInferencingModel` implies that the model is prepared and cached after
/// loading, allowing faster future requests by avoiding repeated file reads
/// and decoding. This trait does not specify anything about if the results are cached.
/// A model that is prepared and cached after loading.
///
/// This trait does not specify anything about if the results are cached.
#[async_trait]
trait CachedInferencingModel: Send + Sync {
trait InferencingModel: Send + Sync {
async fn infer(
&self,
prompt: String,
Expand Down Expand Up @@ -143,7 +143,7 @@ impl LocalLlmEngine {
async fn inferencing_model(
&mut self,
model: wasi_llm::InferencingModel,
) -> Result<Arc<dyn CachedInferencingModel>, wasi_llm::Error> {
) -> Result<Arc<dyn InferencingModel>, wasi_llm::Error> {
let model = match self.inferencing_models.entry(model.clone()) {
Entry::Occupied(o) => o.get().clone(),
Entry::Vacant(v) => {
Expand Down Expand Up @@ -328,7 +328,6 @@ fn load_tokenizer(tokenizer_file: &Path) -> anyhow::Result<tokenizers::Tokenizer

fn load_model(model_file: &Path) -> anyhow::Result<BertModel> {
let device = &candle::Device::Cpu;
// TODO: Check if there is a safe way to load the model from the file
let data = std::fs::read(model_file)?;
let tensors = load_buffer(&data, device)?;
let vb = VarBuilder::from_tensors(tensors, DType::F32, device);
Expand Down
9 changes: 5 additions & 4 deletions crates/llm-local/src/llama.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{token_output_stream, CachedInferencingModel};
use anyhow::{anyhow, bail, Result};
use crate::{token_output_stream, InferencingModel};
use anyhow::{anyhow, bail, Context, Result};
use candle::{safetensors::load_buffer, utils, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand Down Expand Up @@ -76,7 +76,7 @@ impl LlamaModels {
}

#[async_trait]
impl CachedInferencingModel for LlamaModels {
impl InferencingModel for LlamaModels {
async fn infer(
&self,
prompt: String,
Expand Down Expand Up @@ -185,7 +185,8 @@ impl CachedInferencingModel for LlamaModels {
/// path to the model index JSON file relative to the model folder.
fn load_safetensors(model_dir: &Path, json_file: &str) -> Result<Vec<std::path::PathBuf>> {
let json_file = model_dir.join(json_file);
let json_file = std::fs::File::open(json_file)?;
let json_file = std::fs::File::open(json_file)
.with_context(format!("Could not read model index file: {json_file:?}"))?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
let weight_map = match json.get("weight_map") {
Expand Down

0 comments on commit d4c0a2e

Please sign in to comment.