diff --git a/src/error.rs b/src/error.rs index e7e8af6..0314f7d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -82,4 +82,6 @@ pub enum ServerError { ArgumentError(String), #[error("{0}")] Operation(String), + #[error("{0}")] + DatabaseError(String), } diff --git a/src/main.rs b/src/main.rs index 783ba0e..c57a5c7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,7 +20,7 @@ use llama_core::MetadataBuilder; use once_cell::sync::OnceCell; use serde::{Deserialize, Serialize}; use std::{collections::HashMap, net::SocketAddr, path::PathBuf}; -use utils::{is_valid_url, LogLevel}; +use utils::{is_valid_url, qdrant_up, LogLevel}; type Error = Box; @@ -212,11 +212,19 @@ async fn main() -> Result<(), ServerError> { // log { - error!(target: "server_config", "rag_prompt: {}", err_msg); + error!(target: "server_config", "qdrant_url: {}", err_msg); } return Err(ServerError::ArgumentError(err_msg)); } + if !qdrant_up(&cli.qdrant_url).await { + let err_msg = format!("[INFO] Qdrant not found at: {}", &cli.qdrant_url); + error!(target: "server_config", "qdrant_url: {}", err_msg); + + return Err(ServerError::DatabaseError(err_msg)); + } + + // log qdrant url info!(target: "server_config", "qdrant_url: {}", &cli.qdrant_url); // log qdrant collection name diff --git a/src/utils.rs b/src/utils.rs index 837da3f..b68e0d8 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,4 @@ +use hyper::Client; use serde::{Deserialize, Serialize}; use url::Url; @@ -5,6 +6,16 @@ pub(crate) fn is_valid_url(url: &str) -> bool { Url::parse(url).is_ok() } +//TODO: check json title field to check whether running service is really qdrant +pub(crate) async fn qdrant_up(url: &str) -> bool { + let client = Client::new(); + + match client.get(url.parse().unwrap()).await { + Ok(res) => res.status().is_success(), + Err(_) => false, + } +} + pub(crate) fn gen_chat_id() -> String { format!("chatcmpl-{}", uuid::Uuid::new_v4()) }