diff --git a/Cargo.toml b/Cargo.toml index 1deabe1..3ffdb43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ tokio = { version = "1.39.3", features = [ "macros", "rt-multi-thread", "tracing", + "signal" ] } tonic = "0.12.1" prost = "0.13.1" diff --git a/src/server.rs b/src/server.rs index d5bef88..ac2679d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -10,14 +10,19 @@ use redis_protocol::resp2::{ types::{OwnedFrame, Resp2Frame}, }; use std::net::SocketAddr; -use std::net::TcpListener; +// use std::net::TcpListener; use std::path::PathBuf; use std::sync::mpsc; use std::sync::Arc; use std::{ - io::{Read, Write}, + // io::{Read, Write}, sync::RwLock, }; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::net::TcpStream; +use tokio::signal; +use tokio::sync::Notify; use tokio::task::JoinSet; pub struct Server { @@ -30,12 +35,12 @@ impl Server { Server { config, store } } - async fn handle_client_connection(&self, mut stream: std::net::TcpStream) { + async fn handle_client_connection(&self, mut stream: TcpStream) { let mut buffer = [0u8; 4096]; loop { // 读取客户端发送的数据 - let n = match stream.read(&mut buffer) { + let n = match stream.read(&mut buffer).await { Ok(size) => size, Err(e) => { error!("Failed to read from stream: {}", e); @@ -53,11 +58,11 @@ impl Server { Ok(resp) => { let mut buf = vec![0; resp.encode_len()]; encode(&mut buf, &resp).unwrap(); - stream.write_all(&buf).unwrap(); + stream.write_all(&buf).await.unwrap(); } Err(e) => { let error_message = format!("-ERR {}\r\n", e); - stream.write_all(error_message.as_bytes()).unwrap(); + stream.write_all(error_message.as_bytes()).await.unwrap(); } }, Ok(None) => { @@ -851,33 +856,41 @@ impl Server { } } - pub fn server_start(self: Arc) -> anyhow::Result<()> { + pub async fn server_start(self: Arc, notify: Arc) -> anyhow::Result<()> { let addr = self.config.get_addr()?; - let listener = TcpListener::bind(addr)?; - + let listener = TcpListener::bind(addr).await?; println!("Listening on {}", addr); - for stream in listener.incoming() { - match stream { - Ok(stream) => { - debug!("New connection: {}", stream.peer_addr().unwrap()); - let server_clone = Arc::clone(&self); - - tokio::spawn(async move { - server_clone.handle_client_connection(stream).await; - }); - } - Err(e) => { - error!("Error: {}", e); + loop { + tokio::select! { + accept_result = listener.accept() => { + match accept_result { + Ok((stream, _)) => { + debug!("New connection: {}", stream.peer_addr().unwrap()); + let server_clone = Arc::clone(&self); + tokio::spawn(async move { + server_clone.handle_client_connection(stream).await; + }); + }, + Err(e) => { + error!("Failed to accept connection: {}", e); + } + } + }, + _ = notify.notified() => { + println!("Shutdown signal received. Stopping server..."); + break; } } } - Ok(()) } } pub async fn start_server(option: &Option) -> anyhow::Result<()> { + // 创建一个 Notify 对象,用于通知所有任务停止 + let notify = Arc::new(Notify::new()); + let conf = if let Some(file) = option { config::Config::try_from(file.as_path())? } else { @@ -903,22 +916,31 @@ pub async fn start_server(option: &Option) -> anyhow::Result<()> { // server.server_start(); // }); - join_set.spawn(async { + let notify_clone = Arc::clone(¬ify); + join_set.spawn(async move { // 使用 block_in_place 处理阻塞操作 - tokio::task::block_in_place(|| { - server.server_start().unwrap(); - }); + server.server_start(notify_clone).await.unwrap(); }); // 添加 gRPC 服务器任务(如果配置存在) if let Some(grpc_config) = the_config.get_grpc() { let store_clone = Arc::clone(&store); let addr = grpc_config.get_addr()?; + let notify_clone = Arc::clone(¬ify); join_set.spawn(async move { - run_grpc_server(addr, store_clone).await.unwrap(); + run_grpc_server(addr, store_clone, notify_clone) + .await + .unwrap(); }); } + // 监听 Ctrl+C 信号 + signal::ctrl_c().await.expect("Failed to listen for Ctrl+C"); + info!("Received Ctrl+C, shutting down..."); + + // 通知所有任务停止 + notify.notify_waiters(); + // 等待所有任务完成 while let Some(Ok(_)) = join_set.join_next().await {} @@ -929,12 +951,16 @@ pub async fn start_server(option: &Option) -> anyhow::Result<()> { async fn run_grpc_server( addr: SocketAddr, store: Arc>, + notify: Arc, ) -> anyhow::Result<()> { println!("gRPC Server Listening on {:?}", addr); tonic::transport::Server::builder() .add_service(StoreServer::new(StoreImpl::new(store))) - .serve(addr) + .serve_with_shutdown(addr, async { + notify.notified().await; + }) + // .serve(addr) .await?; Ok(())