From 7a31ad2ae1363b3dcf986ea1a40c4c2f0cf75095 Mon Sep 17 00:00:00 2001 From: kristjanpeterson Date: Wed, 8 Jan 2025 15:06:54 +0200 Subject: [PATCH 1/2] feat: add Galadriel API integration --- rig-core/examples/agent_with_galadriel.rs | 23 ++ rig-core/src/providers/galadriel.rs | 397 ++++++++++++++++++++++ rig-core/src/providers/mod.rs | 1 + 3 files changed, 421 insertions(+) create mode 100644 rig-core/examples/agent_with_galadriel.rs create mode 100644 rig-core/src/providers/galadriel.rs diff --git a/rig-core/examples/agent_with_galadriel.rs b/rig-core/examples/agent_with_galadriel.rs new file mode 100644 index 00000000..f2deb741 --- /dev/null +++ b/rig-core/examples/agent_with_galadriel.rs @@ -0,0 +1,23 @@ +use rig::{completion::Prompt, providers}; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Create Galadriel client + let client = providers::galadriel::Client::new( + &env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set"), + env::var("GALADRIEL_FINE_TUNE_API_KEY").ok().as_deref(), + ); + + // Create agent with a single context prompt + let comedian_agent = client + .agent("gpt-4o") + .preamble("You are a comedian here to entertain the user using humour and jokes.") + .build(); + + // Prompt the agent and print the response + let response = comedian_agent.prompt("Entertain me!").await?; + println!("{}", response); + + Ok(()) +} diff --git a/rig-core/src/providers/galadriel.rs b/rig-core/src/providers/galadriel.rs new file mode 100644 index 00000000..34690dc7 --- /dev/null +++ b/rig-core/src/providers/galadriel.rs @@ -0,0 +1,397 @@ +//! Galadriel API client and Rig integration +//! +//! # Example +//! ``` +//! use rig::providers::galadriel; +//! +//! let client = galadriel::Client::new("YOUR_API_KEY", None); +//! // to use a fine-tuned model +//! // let client = galadriel::Client::new("YOUR_API_KEY", "FINE_TUNE_API_KEY"); +//! +//! let gpt4o = client.completion_model(galadriel::GPT_4O); +//! ``` +use crate::{ + agent::AgentBuilder, + completion::{self, CompletionError, CompletionRequest}, + extractor::ExtractorBuilder, + json_utils, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +// ================================================================ +// Main Galadriel Client +// ================================================================ +const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified"; + +#[derive(Clone)] +pub struct Client { + base_url: String, + http_client: reqwest::Client, +} + +impl Client { + /// Create a new Galadriel client with the given API key and optional fine-tune API key. + pub fn new(api_key: &str, fine_tune_api_key: Option<&str>) -> Self { + Self::from_url_with_optional_key(api_key, GALADRIEL_API_BASE_URL, fine_tune_api_key) + } + + /// Create a new Galadriel client with the given API key, base API URL, and optional fine-tune API key. + pub fn from_url(api_key: &str, base_url: &str, fine_tune_api_key: Option<&str>) -> Self { + Self::from_url_with_optional_key(api_key, base_url, fine_tune_api_key) + } + + pub fn from_url_with_optional_key( + api_key: &str, + base_url: &str, + fine_tune_api_key: Option<&str>, + ) -> Self { + Self { + base_url: base_url.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Authorization", + format!("Bearer {}", api_key) + .parse() + .expect("Bearer token should parse"), + ); + if let Some(key) = fine_tune_api_key { + headers.insert( + "Fine-Tune-Authorization", + format!("Bearer {}", key) + .parse() + .expect("Bearer token should parse"), + ); + } + headers + }) + .build() + .expect("Galadriel reqwest client should build"), + } + } + + /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable, + /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable. + /// Panics if the `GALADRIEL_API_KEY` environment variable is not set. + pub fn from_env() -> Self { + let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set"); + let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok(); + Self::new(&api_key, fine_tune_api_key.as_deref()) + } + fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + self.http_client.post(url) + } + + /// Create a completion model with the given name. + /// + /// # Example + /// ``` + /// use rig::providers::galadriel::{Client, self}; + /// + /// // Initialize the Galadriel client + /// let galadriel = Client::new("your-galadriel-api-key", None); + /// + /// let gpt4 = galadriel.completion_model(galadriel::GPT_4); + /// ``` + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } + + /// Create an agent builder with the given completion model. + /// + /// # Example + /// ``` + /// use rig::providers::galadriel::{Client, self}; + /// + /// // Initialize the Galadriel client + /// let galadriel = Client::new("your-galadriel-api-key", None); + /// + /// let agent = galadriel.agent(galadriel::GPT_4) + /// .preamble("You are comedian AI with a mission to make people laugh.") + /// .temperature(0.0) + /// .build(); + /// ``` + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + /// Create an extractor builder with the given completion model. + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } +} + +#[derive(Debug, Deserialize)] +struct ApiErrorResponse { + message: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum ApiResponse { + Ok(T), + Err(ApiErrorResponse), +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Usage { + pub prompt_tokens: usize, + pub total_tokens: usize, +} + +impl std::fmt::Display for Usage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Prompt tokens: {} Total tokens: {}", + self.prompt_tokens, self.total_tokens + ) + } +} + +// ================================================================ +// Galadriel Completion API +// ================================================================ +/// `o1-preview` completion model +pub const O1_PREVIEW: &str = "o1-preview"; +/// `o1-preview-2024-09-12` completion model +pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12"; +/// `o1-mini completion model +pub const O1_MINI: &str = "o1-mini"; +/// `o1-mini-2024-09-12` completion model +pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12"; +/// `gpt-4o` completion model +pub const GPT_4O: &str = "gpt-4o"; +/// `gpt-4o-2024-05-13` completion model +pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13"; +/// `gpt-4-turbo` completion model +pub const GPT_4_TURBO: &str = "gpt-4-turbo"; +/// `gpt-4-turbo-2024-04-09` completion model +pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09"; +/// `gpt-4-turbo-preview` completion model +pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview"; +/// `gpt-4-0125-preview` completion model +pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview"; +/// `gpt-4-1106-preview` completion model +pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview"; +/// `gpt-4-vision-preview` completion model +pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview"; +/// `gpt-4-1106-vision-preview` completion model +pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview"; +/// `gpt-4` completion model +pub const GPT_4: &str = "gpt-4"; +/// `gpt-4-0613` completion model +pub const GPT_4_0613: &str = "gpt-4-0613"; +/// `gpt-4-32k` completion model +pub const GPT_4_32K: &str = "gpt-4-32k"; +/// `gpt-4-32k-0613` completion model +pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613"; +/// `gpt-3.5-turbo` completion model +pub const GPT_35_TURBO: &str = "gpt-3.5-turbo"; +/// `gpt-3.5-turbo-0125` completion model +pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125"; +/// `gpt-3.5-turbo-1106` completion model +pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106"; +/// `gpt-3.5-turbo-instruct` completion model +pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct"; + +#[derive(Debug, Deserialize)] +pub struct CompletionResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub system_fingerprint: Option, + pub choices: Vec, + pub usage: Option, +} + +impl From for CompletionError { + fn from(err: ApiErrorResponse) -> Self { + CompletionError::ProviderError(err.message) + } +} + +impl TryFrom for completion::CompletionResponse { + type Error = CompletionError; + + fn try_from(value: CompletionResponse) -> std::prelude::v1::Result { + match value.choices.as_slice() { + [Choice { + message: + Message { + tool_calls: Some(calls), + .. + }, + .. + }, ..] => { + let call = calls.first().ok_or(CompletionError::ResponseError( + "Tool selection is empty".into(), + ))?; + + Ok(completion::CompletionResponse { + choice: completion::ModelChoice::ToolCall( + call.function.name.clone(), + serde_json::from_str(&call.function.arguments)?, + ), + raw_response: value, + }) + } + [Choice { + message: + Message { + content: Some(content), + .. + }, + .. + }, ..] => Ok(completion::CompletionResponse { + choice: completion::ModelChoice::Message(content.to_string()), + raw_response: value, + }), + _ => Err(CompletionError::ResponseError( + "Response did not contain a message or tool call".into(), + )), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct Choice { + pub index: usize, + pub message: Message, + pub logprobs: Option, + pub finish_reason: String, +} + +#[derive(Debug, Deserialize)] +pub struct Message { + pub role: String, + pub content: Option, + pub tool_calls: Option>, +} + +#[derive(Debug, Deserialize)] +pub struct ToolCall { + pub id: String, + pub r#type: String, + pub function: Function, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ToolDefinition { + pub r#type: String, + pub function: completion::ToolDefinition, +} + +impl From for ToolDefinition { + fn from(tool: completion::ToolDefinition) -> Self { + Self { + r#type: "function".into(), + function: tool, + } + } +} + +#[derive(Debug, Deserialize)] +pub struct Function { + pub name: String, + pub arguments: String, +} + +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + /// Name of the model (e.g.: gpt-3.5-turbo-1106) + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + async fn completion( + &self, + mut completion_request: CompletionRequest, + ) -> Result, CompletionError> { + // Add preamble to chat history (if available) + let mut full_history = if let Some(preamble) = &completion_request.preamble { + vec![completion::Message { + role: "system".into(), + content: preamble.clone(), + }] + } else { + vec![] + }; + + // Extend existing chat history + full_history.append(&mut completion_request.chat_history); + + // Add context documents to chat history + let prompt_with_context = completion_request.prompt_with_context(); + + // Add context documents to chat history + full_history.push(completion::Message { + role: "user".into(), + content: prompt_with_context, + }); + + let request = if completion_request.tools.is_empty() { + json!({ + "model": self.model, + "messages": full_history, + "temperature": completion_request.temperature, + }) + } else { + json!({ + "model": self.model, + "messages": full_history, + "temperature": completion_request.temperature, + "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), + "tool_choice": "auto", + }) + }; + + let response = self + .client + .post("/chat/completions") + .json( + &if let Some(params) = completion_request.additional_params { + json_utils::merge(request, params) + } else { + request + }, + ) + .send() + .await?; + + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Ok(response) => { + tracing::info!(target: "rig", + "Galadriel completion token usage: {:?}", + response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) + ); + response.try_into() + } + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + } + } else { + Err(CompletionError::ProviderError(response.text().await?)) + } + } +} diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 23d4d181..201d14e5 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -42,6 +42,7 @@ //! be used with the Cohere provider client. pub mod anthropic; pub mod cohere; +pub mod galadriel; pub mod gemini; pub mod openai; pub mod perplexity; From 3325d72de5c49e363fc28666a7eacf1e045729e8 Mon Sep 17 00:00:00 2001 From: kristjanpeterson Date: Thu, 23 Jan 2025 09:50:24 +0100 Subject: [PATCH 2/2] Fix toolCall for the Galadriel LLM provider --- rig-core/src/providers/galadriel.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rig-core/src/providers/galadriel.rs b/rig-core/src/providers/galadriel.rs index 34690dc7..98898bc4 100644 --- a/rig-core/src/providers/galadriel.rs +++ b/rig-core/src/providers/galadriel.rs @@ -239,6 +239,7 @@ impl TryFrom for completion::CompletionResponse