Skip to content

Commit

Permalink
address some PR comments
Browse files Browse the repository at this point in the history
Signed-off-by: karthik2804 <[email protected]>
  • Loading branch information
karthik2804 committed Sep 16, 2024
1 parent 5abb8ca commit 08f1611
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 160 deletions.
144 changes: 72 additions & 72 deletions Cargo.lock

Large diffs are not rendered by default.

18 changes: 6 additions & 12 deletions crates/factor-llm/src/spin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ mod local {
/// The default engine creator for the LLM factor when used in the Spin CLI.
pub fn default_engine_creator(
state_dir: Option<PathBuf>,
use_gpu: bool,
) -> anyhow::Result<impl LlmEngineCreator + 'static> {
#[cfg(feature = "llm")]
let engine = {
Expand All @@ -53,11 +52,11 @@ pub fn default_engine_creator(
Some(ref dir) => dir.clone(),
None => std::env::current_dir().context("failed to get current working directory")?,
};
spin_llm_local::LocalLlmEngine::new(models_dir_parent.join("ai-models"), use_gpu)
spin_llm_local::LocalLlmEngine::new(models_dir_parent.join("ai-models"))
};
#[cfg(not(feature = "llm"))]
let engine = {
let _ = (state_dir, use_gpu);
let _ = (state_dir);
noop::NoopLlmEngine
};
let engine = Arc::new(Mutex::new(engine)) as Arc<Mutex<dyn LlmEngine>>;
Expand Down Expand Up @@ -91,15 +90,14 @@ impl LlmEngine for RemoteHttpLlmEngine {
pub fn runtime_config_from_toml(
table: &impl GetTomlValue,
state_dir: Option<PathBuf>,
use_gpu: bool,
) -> anyhow::Result<Option<RuntimeConfig>> {
let Some(value) = table.get("llm_compute") else {
return Ok(None);
};
let config: LlmCompute = value.clone().try_into()?;

Ok(Some(RuntimeConfig {
engine: config.into_engine(state_dir, use_gpu)?,
engine: config.into_engine(state_dir)?,
}))
}

Expand All @@ -111,19 +109,15 @@ pub enum LlmCompute {
}

impl LlmCompute {
fn into_engine(
self,
state_dir: Option<PathBuf>,
use_gpu: bool,
) -> anyhow::Result<Arc<Mutex<dyn LlmEngine>>> {
fn into_engine(self, state_dir: Option<PathBuf>) -> anyhow::Result<Arc<Mutex<dyn LlmEngine>>> {
let engine: Arc<Mutex<dyn LlmEngine>> = match self {
#[cfg(not(feature = "llm"))]
LlmCompute::Spin => {
let _ = (state_dir, use_gpu);
let _ = (state_dir);
Arc::new(Mutex::new(noop::NoopLlmEngine))
}
#[cfg(feature = "llm")]
LlmCompute::Spin => default_engine_creator(state_dir, use_gpu)?.create(),
LlmCompute::Spin => default_engine_creator(state_dir)?.create(),
LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new(
config.url,
config.auth_token,
Expand Down
8 changes: 1 addition & 7 deletions crates/llm-local/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
mod bert;
mod llama;
mod token_output_stream;
mod utils;

use anyhow::Context;
use bert::{BertModel, Config};
Expand All @@ -24,7 +23,6 @@ type ModelName = String;
#[derive(Clone)]
pub struct LocalLlmEngine {
registry: PathBuf,
_use_gpu: bool,
inferencing_models: HashMap<ModelName, Arc<dyn CachedInferencingModel>>,
embeddings_models: HashMap<String, Arc<(tokenizers::Tokenizer, BertModel)>>,
}
Expand Down Expand Up @@ -61,7 +59,6 @@ impl LocalLlmEngine {
prompt: String,
params: wasi_llm::InferencingParams,
) -> Result<wasi_llm::InferencingResult, wasi_llm::Error> {
// return self.inference(model).await;
let model = self.inferencing_model(model).await?;

model
Expand All @@ -83,10 +80,9 @@ impl LocalLlmEngine {
}

impl LocalLlmEngine {
pub fn new(registry: PathBuf, _use_gpu: bool) -> Self {
pub fn new(registry: PathBuf) -> Self {
Self {
registry,
_use_gpu,
inferencing_models: Default::default(),
embeddings_models: Default::default(),
}
Expand Down Expand Up @@ -145,8 +141,6 @@ impl LocalLlmEngine {
&mut self,
model: wasi_llm::InferencingModel,
) -> Result<Arc<dyn CachedInferencingModel>, wasi_llm::Error> {
// let use_gpu = self.use_gpu;

let model = match self.inferencing_models.entry(model.clone()) {
Entry::Occupied(o) => o.get().clone(),
Entry::Vacant(v) => {
Expand Down
23 changes: 20 additions & 3 deletions crates/llm-local/src/llama.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use crate::{token_output_stream, utils::load_safetensors, CachedInferencingModel};
use anyhow::{anyhow, Result};
use crate::{token_output_stream, CachedInferencingModel};
use anyhow::{anyhow, Context, Result};
use candle::{utils, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
generation::{LogitsProcessor, Sampling},
models::llama::{self, Cache, Config, Llama, LlamaConfig},
};
use rand::{RngCore, SeedableRng};
use serde::Deserialize;
use spin_core::async_trait;
use spin_world::v2::llm::{self as wasi_llm, InferencingUsage};
use std::{fs, path::Path, sync::Arc};
use std::{collections::HashMap, fs, path::Path, sync::Arc};
use tokenizers::Tokenizer;

const TOKENIZER_FILENAME: &str = "tokenizer.json";
Expand Down Expand Up @@ -166,3 +167,19 @@ impl CachedInferencingModel for LlamaModels {
})
}
}

#[derive(Deserialize)]
struct SafeTensorsJson {
weight_map: HashMap<String, String>,
}

fn load_safetensors(model_dir: &Path, json_file: &str) -> Result<Vec<std::path::PathBuf>> {
let json_file = model_dir.join(json_file);
let json_file = std::fs::File::open(&json_file)
.with_context(|| format!("Error while opening {json_file:?}"))?;

let json: SafeTensorsJson = serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
let mut safetensors_files = Vec::new();
safetensors_files.extend(json.weight_map.values().map(|v| model_dir.join(v)));
Ok(safetensors_files)
}
35 changes: 8 additions & 27 deletions crates/llm-local/src/token_output_stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Implementation for TokenOutputStream Code is borrow from
// https://github.com/huggingface/candle/blob/main/candle-examples/src/token_output_stream.rs
// (Commit SHA 4fd00b890036ef67391a9cc03f896247d0a75711)
/// Implementation for TokenOutputStream Code is borrowed from
/// https://github.com/huggingface/candle/blob/main/candle-examples/src/token_output_stream.rs
/// (Commit SHA 4fd00b890036ef67391a9cc03f896247d0a75711)
///
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
/// streaming way rather than having to wait for the full decoding.
pub struct TokenOutputStream {
tokenizer: tokenizers::Tokenizer,
tokens: Vec<u32>,
Expand All @@ -18,10 +21,6 @@ impl TokenOutputStream {
}
}

pub fn _into_inner(self) -> tokenizers::Tokenizer {
self.tokenizer
}

fn decode(&self, tokens: &[u32]) -> anyhow::Result<String> {
match self.tokenizer.decode(tokens, true) {
Ok(str) => Ok(str),
Expand All @@ -40,10 +39,10 @@ impl TokenOutputStream {
self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
let text = text.split_at(prev_text.len());
let (_, text) = text.split_at(prev_text.len());
self.prev_index = self.current_index;
self.current_index = self.tokens.len();
Ok(Some(text.1.to_string()))
Ok(Some(text.to_string()))
} else {
Ok(None)
}
Expand All @@ -64,22 +63,4 @@ impl TokenOutputStream {
Ok(None)
}
}

pub fn _decode_all(&self) -> anyhow::Result<String> {
self.decode(&self.tokens)
}

pub fn _get_token(&self, token_s: &str) -> Option<u32> {
self.tokenizer.get_vocab(true).get(token_s).copied()
}

pub fn _tokenizer(&self) -> &tokenizers::Tokenizer {
&self.tokenizer
}

pub fn _clear(&mut self) {
self.tokens.clear();
self.prev_index = 0;
self.current_index = 0;
}
}
25 changes: 0 additions & 25 deletions crates/llm-local/src/utils.rs

This file was deleted.

10 changes: 2 additions & 8 deletions crates/runtime-config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ where
local_app_dir: Option<PathBuf>,
provided_state_dir: UserProvidedPath,
provided_log_dir: UserProvidedPath,
use_gpu: bool,
) -> anyhow::Result<Self> {
let toml = match runtime_config_path {
Some(runtime_config_path) => {
Expand All @@ -119,14 +118,13 @@ where
let toml_resolver =
TomlResolver::new(&toml, local_app_dir, provided_state_dir, provided_log_dir);

Self::new(toml_resolver, runtime_config_path, use_gpu)
Self::new(toml_resolver, runtime_config_path)
}

/// Creates a new resolved runtime configuration from a TOML table.
pub fn new(
toml_resolver: TomlResolver<'_>,
runtime_config_path: Option<&Path>,
use_gpu: bool,
) -> anyhow::Result<Self> {
let runtime_config_dir = runtime_config_path
.and_then(Path::parent)
Expand All @@ -142,7 +140,6 @@ where
&key_value_config_resolver,
tls_resolver.as_ref(),
&sqlite_config_resolver,
use_gpu,
);
let runtime_config: T = source.try_into().map_err(Into::into)?;

Expand Down Expand Up @@ -275,7 +272,6 @@ pub struct TomlRuntimeConfigSource<'a, 'b> {
key_value: &'a key_value::RuntimeConfigResolver,
tls: Option<&'a SpinTlsRuntimeConfig>,
sqlite: &'a sqlite::RuntimeConfigResolver,
use_gpu: bool,
}

impl<'a, 'b> TomlRuntimeConfigSource<'a, 'b> {
Expand All @@ -284,14 +280,12 @@ impl<'a, 'b> TomlRuntimeConfigSource<'a, 'b> {
key_value: &'a key_value::RuntimeConfigResolver,
tls: Option<&'a SpinTlsRuntimeConfig>,
sqlite: &'a sqlite::RuntimeConfigResolver,
use_gpu: bool,
) -> Self {
Self {
toml: toml_resolver,
key_value,
tls,
sqlite,
use_gpu,
}
}
}
Expand Down Expand Up @@ -338,7 +332,7 @@ impl FactorRuntimeConfigSource<OutboundMysqlFactor> for TomlRuntimeConfigSource<

impl FactorRuntimeConfigSource<LlmFactor> for TomlRuntimeConfigSource<'_, '_> {
fn get_runtime_config(&mut self) -> anyhow::Result<Option<spin_factor_llm::RuntimeConfig>> {
llm::runtime_config_from_toml(&self.toml.table, self.toml.state_dir()?, self.use_gpu)
llm::runtime_config_from_toml(&self.toml.table, self.toml.state_dir()?)
}
}

Expand Down
4 changes: 0 additions & 4 deletions crates/runtime-factors/src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@ impl RuntimeFactorsBuilder for FactorsBuilder {
config: &FactorsConfig,
args: &Self::CliArgs,
) -> anyhow::Result<(Self::Factors, Self::RuntimeConfig)> {
// Hardcode `use_gpu` to true for now
let use_gpu = true;
let runtime_config = ResolvedRuntimeConfig::<TriggerFactorsRuntimeConfig>::from_file(
config.runtime_config_file.clone().as_deref(),
config.local_app_dir.clone().map(PathBuf::from),
config.state_dir.clone(),
config.log_dir.clone(),
use_gpu,
)?;

runtime_config.summarize(config.runtime_config_file.as_deref());
Expand All @@ -40,7 +37,6 @@ impl RuntimeFactorsBuilder for FactorsBuilder {
args.allow_transient_write,
runtime_config.key_value_resolver.clone(),
runtime_config.sqlite_resolver.clone(),
use_gpu,
)
.context("failed to create factors")?;
Ok((factors, runtime_config))
Expand Down
3 changes: 1 addition & 2 deletions crates/runtime-factors/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ impl TriggerFactors {
allow_transient_writes: bool,
default_key_value_label_resolver: impl spin_factor_key_value::DefaultLabelResolver + 'static,
default_sqlite_label_resolver: impl spin_factor_sqlite::DefaultLabelResolver + 'static,
use_gpu: bool,
) -> anyhow::Result<Self> {
Ok(Self {
wasi: wasi_factor(working_dir, allow_transient_writes),
Expand All @@ -56,7 +55,7 @@ impl TriggerFactors {
pg: OutboundPgFactor::new(),
mysql: OutboundMysqlFactor::new(),
llm: LlmFactor::new(
spin_factor_llm::spin::default_engine_creator(state_dir, use_gpu)
spin_factor_llm::spin::default_engine_creator(state_dir)
.context("failed to configure LLM factor")?,
),
})
Expand Down

0 comments on commit 08f1611

Please sign in to comment.