Skip to content

Commit

Permalink
feat: add --socket-addr CLI option
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Liu <[email protected]>
  • Loading branch information
apepkuss committed Sep 23, 2024
1 parent 862a419 commit 8ff2a00
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
13 changes: 13 additions & 0 deletions src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ pub(crate) async fn chat_handler(
proxy_request(state.client, req, chat_url).await
}

pub(crate) async fn audio_handler(
State(state): State<AppState>,
req: Request<Body>,
) -> Result<Response<Body>, StatusCode> {
println!("In audio_handler");

let audio_url = state.audio_urls.read().unwrap().next();

proxy_request(state.client, req, audio_url).await
}

pub(crate) async fn image_handler(
State(state): State<AppState>,
req: Request<Body>,
Expand Down Expand Up @@ -70,6 +81,7 @@ pub(crate) async fn add_url_handler(

let url_type = match url_type.as_str() {
"chat" => UrlType::Chat,
"audio" => UrlType::Audio,
"image" => UrlType::Image,
_ => return Err(StatusCode::BAD_REQUEST),
};
Expand All @@ -91,6 +103,7 @@ pub(crate) async fn remove_url_handler(

let url_type = match url_type.as_str() {
"chat" => UrlType::Chat,
"audio" => UrlType::Audio,
"image" => UrlType::Image,
_ => return Err(StatusCode::BAD_REQUEST),
};
Expand Down
42 changes: 30 additions & 12 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@ mod utils;
use anyhow::Result;
use async_trait::async_trait;
use axum::{http::Uri, routing::post, Router};
use clap::Parser;
use clap::{ArgGroup, Parser};
use error::ServerError;
use handler::*;
use hyper::{client::HttpConnector, Client};
use std::fmt;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::sync::RwLock;
use std::{
fmt,
net::SocketAddr,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, RwLock,
},
};
use tokio::net::TcpListener;
use utils::LogLevel;

Expand All @@ -26,9 +30,13 @@ const DEFAULT_PORT: &str = "8080";

#[derive(Debug, Parser)]
#[command(name = "LlamaEdge Gateway", version = env!("CARGO_PKG_VERSION"), author = env!("CARGO_PKG_AUTHORS"), about = "LlamaEdge Gateway")]
#[command(group = ArgGroup::new("socket_address_group").multiple(false).args(&["socket_addr", "port"]))]
struct Cli {
/// Socket address of Llama-Gateway instance. For example, `0.0.0.0:8080`.
#[arg(long, default_value = None, value_parser = clap::value_parser!(SocketAddr), group = "socket_address_group")]
socket_addr: Option<SocketAddr>,
/// Socket address of LlamaEdge API Server instance
#[arg(long, default_value = DEFAULT_PORT, value_parser = clap::value_parser!(u16))]
#[arg(long, default_value = DEFAULT_PORT, value_parser = clap::value_parser!(u16), group = "socket_address_group")]
port: u16,
}

Expand Down Expand Up @@ -62,16 +70,22 @@ async fn main() -> Result<(), ServerError> {
// Build our application with routes
let app = Router::new()
.route("/v1/chat/completions", post(chat_handler))
.route("/v1/image/generation", post(image_handler))
.route("/v1/audio/transcriptions", post(audio_handler))
.route("/v1/audio/translations", post(audio_handler))
.route("/v1/images/generations", post(image_handler))
.route("/admin/register/:type", post(add_url_handler))
.route("/admin/unregister/:type", post(remove_url_handler))
.with_state(app_state);

// Run it
let addr = format!("127.0.0.1:{}", cli.port);
let tcp_listener = TcpListener::bind(&addr).await.unwrap();
// socket address
let addr = match cli.socket_addr {
Some(addr) => addr,
None => SocketAddr::from(([0, 0, 0, 0], cli.port)),
};
let tcp_listener = TcpListener::bind(addr).await.unwrap();
info!(target: "stdout", "Listening on {}", addr);

// run
match axum::Server::from_tcp(tcp_listener.into_std().unwrap())
.unwrap()
.serve(app.into_make_service())
Expand Down Expand Up @@ -133,6 +147,7 @@ impl RoutingPolicy for Services {
struct AppState {
client: SharedClient,
chat_urls: Arc<RwLock<Services>>,
audio_urls: Arc<RwLock<Services>>,
image_urls: Arc<RwLock<Services>>,
}

Expand All @@ -141,22 +156,23 @@ impl AppState {
Self {
client,
chat_urls: Arc::new(RwLock::new(Services::default())),
audio_urls: Arc::new(RwLock::new(Services::default())),
image_urls: Arc::new(RwLock::new(Services::default())),
}
}

fn add_url(&self, url_type: UrlType, url: &Uri) {
match url_type {
UrlType::Chat => self.chat_urls.write().unwrap().push(url.clone()),
UrlType::Audio => self.audio_urls.write().unwrap().push(url.clone()),
UrlType::Image => self.image_urls.write().unwrap().push(url.clone()),
// UrlType::Chat => self.chat_urls.write().unwrap().push(url.clone()),
// UrlType::Image => self.image_urls.write().unwrap().push(url.clone()),
}
}

fn remove_url(&self, url_type: UrlType, url: &Uri) {
let services = match &url_type {
UrlType::Chat => &self.chat_urls,
UrlType::Audio => &self.audio_urls,
UrlType::Image => &self.image_urls,
};

Expand All @@ -174,13 +190,15 @@ impl AppState {

#[derive(Debug)]
enum UrlType {
Audio,
Chat,
Image,
}
impl fmt::Display for UrlType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UrlType::Chat => write!(f, "Chat"),
UrlType::Audio => write!(f, "Audio"),
UrlType::Image => write!(f, "Image"),
}
}
Expand Down

0 comments on commit 8ff2a00

Please sign in to comment.