diff --git a/Cargo.lock b/Cargo.lock index d074e385..744cc44b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1236,6 +1236,23 @@ dependencies = [ "tokio-rustls 0.24.1", ] +[[package]] +name = "hyper-rustls" +version = "0.27.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +dependencies = [ + "futures-util", + "http 1.1.0", + "hyper 1.4.1", + "hyper-util", + "rustls 0.23.13", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.0", + "tower-service", +] + [[package]] name = "hyper-timeout" version = "0.5.1" @@ -1249,6 +1266,22 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.4.1", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.8" @@ -1673,6 +1706,23 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "multer" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83e87776546dc87511aa5ee218730c92b666d7264ab6ed41f9d215af9cd5224b" +dependencies = [ + "bytes", + "encoding_rs", + "futures-util", + "http 1.1.0", + "httparse", + "memchr", + "mime", + "spin", + "version_check", +] + [[package]] name = "mur3" version = "0.1.0" @@ -2420,7 +2470,7 @@ dependencies = [ "http 0.2.12", "http-body 0.4.6", "hyper 0.14.30", - "hyper-rustls", + "hyper-rustls 0.24.2", "ipnet", "js-sys", "log", @@ -2434,7 +2484,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper 0.1.2", - "system-configuration", + "system-configuration 0.5.1", "tokio", "tokio-rustls 0.24.1", "tower-service", @@ -2446,6 +2496,50 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest" +version = "0.12.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b" +dependencies = [ + "base64 0.22.1", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2 0.4.6", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.4.1", + "hyper-rustls 0.27.3", + "hyper-tls", + "hyper-util", + "ipnet", + "js-sys", + "log", + "mime", + "mime_guess", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls-pemfile 2.1.3", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 1.0.1", + "system-configuration 0.6.1", + "tokio", + "tokio-native-tls", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "windows-registry", +] + [[package]] name = "resolv-conf" version = "0.7.0" @@ -2928,6 +3022,9 @@ name = "sync_wrapper" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +dependencies = [ + "futures-core", +] [[package]] name = "sysinfo" @@ -2951,7 +3048,18 @@ checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", "core-foundation", - "system-configuration-sys", + "system-configuration-sys 0.5.0", +] + +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "system-configuration-sys 0.6.0", ] [[package]] @@ -2964,6 +3072,16 @@ dependencies = [ "libc", ] +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tempfile" version = "3.12.0" @@ -3138,6 +3256,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +dependencies = [ + "rustls 0.23.13", + "rustls-pki-types", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.16" @@ -3486,7 +3615,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2f8811797a24ff123db3c6e1087aa42551d03d772b3724be421ad063da1f5f3f" dependencies = [ "directories", - "reqwest", + "reqwest 0.11.27", "semver", "serde", "serde_json", @@ -3721,9 +3850,11 @@ dependencies = [ "mime", "mime_guess", "motore", + "multer", "parking_lot 0.12.3", "paste", "pin-project", + "reqwest 0.12.8", "scopeguard", "serde", "serde_urlencoded", @@ -3738,6 +3869,7 @@ dependencies = [ "tokio-util", "tracing", "tungstenite", + "url", "volo", ] @@ -3966,7 +4098,7 @@ checksum = "d2ed2439a290666cd67ecce2b0ffaad89c2a56b976b736e6ece670297897832d" dependencies = [ "windows-implement", "windows-interface", - "windows-result", + "windows-result 0.1.2", "windows-targets 0.52.6", ] @@ -3992,6 +4124,17 @@ dependencies = [ "syn 2.0.77", ] +[[package]] +name = "windows-registry" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +dependencies = [ + "windows-result 0.2.0", + "windows-strings", + "windows-targets 0.52.6", +] + [[package]] name = "windows-result" version = "0.1.2" @@ -4001,6 +4144,25 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result 0.2.0", + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/Cargo.toml b/Cargo.toml index 6886bb77..23b1b4fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,6 +80,7 @@ mime = "0.3" mime_guess = { version = "2", default-features = false } mockall = "0.13" mockall_double = "0.3" +multer = "3.1" mur3 = "0.1" nix = "0.29" nom = "7" @@ -96,6 +97,7 @@ proc-macro2 = "1" quote = "1" rand = "0.8" regex = "1" +reqwest = "0.12" run_script = "0.10" rustc-hash = { version = "2", features = ["rand"] } same-file = "1" @@ -121,6 +123,7 @@ tower = "0.5" tracing = "0.1" tracing-subscriber = "0.3" update-informer = "1" +url = "2.5" url_path = "0.1" walkdir = "2" diff --git a/volo-http/Cargo.toml b/volo-http/Cargo.toml index 825262ab..9ed50dfa 100644 --- a/volo-http/Cargo.toml +++ b/volo-http/Cargo.toml @@ -60,6 +60,7 @@ tokio-util = { workspace = true, features = ["io"] } tracing.workspace = true # =====optional===== +multer = { workspace = true, optional = true } # server optional matchit = { workspace = true, optional = true } @@ -84,19 +85,22 @@ sonic-rs = { workspace = true, optional = true } async-stream.workspace = true libc.workspace = true serde = { workspace = true, features = ["derive"] } +reqwest = { workspace = true, features = ["multipart"] } tokio-test.workspace = true +url.workspace = true [features] default = [] default_client = ["client", "json"] -default_server = ["server", "query", "form", "json"] +default_server = ["server", "query", "form", "json", "multipart"] -full = ["client", "server", "rustls", "cookie", "query", "form", "json", "tls", "ws"] +full = ["client", "server", "rustls", "cookie", "query", "form", "json", "multipart", "tls", "ws"] client = ["hyper/client", "hyper/http1"] # client core server = ["hyper/server", "hyper/http1", "dep:matchit"] # server core +multipart = ["dep:multer"] ws = ["dep:tungstenite", "dep:tokio-tungstenite"] tls = ["rustls"] diff --git a/volo-http/src/server/layer.rs b/volo-http/src/server/layer.rs deleted file mode 100644 index 8341995a..00000000 --- a/volo-http/src/server/layer.rs +++ /dev/null @@ -1,335 +0,0 @@ -//! Collections of some useful [`Layer`]s. -//! -//! See [`FilterLayer`] and [`TimeoutLayer`] for more details. - -use std::{marker::PhantomData, time::Duration}; - -use motore::{layer::Layer, service::Service}; - -use super::{handler::HandlerWithoutRequest, IntoResponse}; -use crate::{context::ServerContext, request::ServerRequest, response::ServerResponse}; - -/// [`Layer`] for filtering requests -/// -/// See [`FilterLayer::new`] for more details. -#[derive(Clone)] -pub struct FilterLayer { - handler: H, - _marker: PhantomData<(R, T)>, -} - -impl FilterLayer { - /// Create a new [`FilterLayer`] - /// - /// The `handler` is an async function with some params that implement - /// [`FromContext`](crate::server::extract::FromContext), and returns - /// `Result<(), impl IntoResponse>`. - /// - /// If the handler returns `Ok(())`, the request will proceed. However, if the handler returns - /// `Err` with an object that implements [`IntoResponse`], the request will be rejected with - /// the returned object as the response. - /// - /// # Examples - /// - /// ``` - /// use http::{method::Method, status::StatusCode}; - /// use volo_http::server::{ - /// layer::FilterLayer, - /// route::{get, Router}, - /// }; - /// - /// async fn reject_post(method: Method) -> Result<(), StatusCode> { - /// if method == Method::POST { - /// Err(StatusCode::METHOD_NOT_ALLOWED) - /// } else { - /// Ok(()) - /// } - /// } - /// - /// async fn handler() -> &'static str { - /// "Hello, World" - /// } - /// - /// let router: Router = Router::new() - /// .route("/", get(handler)) - /// .layer(FilterLayer::new(reject_post)); - /// ``` - pub fn new(h: H) -> Self { - Self { - handler: h, - _marker: PhantomData, - } - } -} - -impl Layer for FilterLayer -where - S: Send + Sync + 'static, - H: Clone + Send + Sync + 'static, - T: Sync, -{ - type Service = Filter; - - fn layer(self, inner: S) -> Self::Service { - Filter { - service: inner, - handler: self.handler, - _marker: PhantomData, - } - } -} - -/// [`FilterLayer`] generated [`Service`] -/// -/// See [`FilterLayer`] for more details. -#[derive(Clone)] -pub struct Filter { - service: S, - handler: H, - _marker: PhantomData<(R, T)>, -} - -impl Service> for Filter -where - S: Service> + Send + Sync + 'static, - S::Response: IntoResponse, - S::Error: IntoResponse, - B: Send, - H: HandlerWithoutRequest> + Clone + Send + Sync + 'static, - R: IntoResponse + Send + Sync, - T: Sync, -{ - type Response = ServerResponse; - type Error = S::Error; - - async fn call( - &self, - cx: &mut ServerContext, - req: ServerRequest, - ) -> Result { - let (mut parts, body) = req.into_parts(); - let res = self.handler.clone().handle(cx, &mut parts).await; - let req = ServerRequest::from_parts(parts, body); - match res { - // do not filter it, call the service - Ok(Ok(())) => self - .service - .call(cx, req) - .await - .map(IntoResponse::into_response), - // filter it and return the specified response - Ok(Err(res)) => Ok(res.into_response()), - // something wrong while extracting - Err(rej) => { - tracing::warn!("[Volo-HTTP] FilterLayer: something wrong while extracting"); - Ok(rej.into_response()) - } - } - } -} - -/// [`Layer`] for setting timeout to the request -/// -/// See [`TimeoutLayer::new`] for more details. -#[derive(Clone)] -pub struct TimeoutLayer { - duration: Duration, - handler: H, -} - -impl TimeoutLayer { - /// Create a new [`TimeoutLayer`] with given [`Duration`] and handler. - /// - /// The handler should be a sync function with [`&ServerContext`](ServerContext) as parameter, - /// and return anything that implement [`IntoResponse`]. - /// - /// # Examples - /// - /// ``` - /// use std::time::Duration; - /// - /// use http::status::StatusCode; - /// use volo_http::{ - /// context::ServerContext, - /// server::{ - /// layer::TimeoutLayer, - /// route::{get, Router}, - /// }, - /// }; - /// - /// async fn index() -> &'static str { - /// "Hello, World" - /// } - /// - /// fn timeout_handler(_: &ServerContext) -> StatusCode { - /// StatusCode::REQUEST_TIMEOUT - /// } - /// - /// let router: Router = Router::new() - /// .route("/", get(index)) - /// .layer(TimeoutLayer::new(Duration::from_secs(1), timeout_handler)); - /// ``` - pub fn new(duration: Duration, handler: H) -> Self { - Self { duration, handler } - } -} - -impl Layer for TimeoutLayer -where - S: Send + Sync + 'static, -{ - type Service = Timeout; - - fn layer(self, inner: S) -> Self::Service { - Timeout { - service: inner, - duration: self.duration, - handler: self.handler, - } - } -} - -trait TimeoutHandler<'r> { - fn call(self, cx: &'r ServerContext) -> ServerResponse; -} - -impl<'r, F, R> TimeoutHandler<'r> for F -where - F: FnOnce(&'r ServerContext) -> R + 'r, - R: IntoResponse + 'r, -{ - fn call(self, cx: &'r ServerContext) -> ServerResponse { - self(cx).into_response() - } -} - -/// [`TimeoutLayer`] generated [`Service`] -/// -/// See [`TimeoutLayer`] for more details. -#[derive(Clone)] -pub struct Timeout { - service: S, - duration: Duration, - handler: H, -} - -impl Service> for Timeout -where - S: Service> + Send + Sync + 'static, - S::Response: IntoResponse, - S::Error: IntoResponse, - B: Send, - H: for<'r> TimeoutHandler<'r> + Clone + Sync, -{ - type Response = ServerResponse; - type Error = S::Error; - - async fn call( - &self, - cx: &mut ServerContext, - req: ServerRequest, - ) -> Result { - let fut_service = self.service.call(cx, req); - let fut_timeout = tokio::time::sleep(self.duration); - - tokio::select! { - resp = fut_service => resp.map(IntoResponse::into_response), - _ = fut_timeout => { - Ok((self.handler.clone()).call(cx)) - }, - } - } -} - -#[cfg(test)] -mod layer_tests { - use http::{method::Method, status::StatusCode}; - use motore::{layer::Layer, service::Service}; - - use crate::{ - body::BodyConversion, - context::ServerContext, - server::{ - route::{any, get, Route}, - test_helpers::empty_cx, - }, - utils::test_helpers::simple_req, - }; - - #[tokio::test] - async fn test_filter_layer() { - use crate::server::layer::FilterLayer; - - async fn reject_post(method: Method) -> Result<(), StatusCode> { - if method == Method::POST { - Err(StatusCode::METHOD_NOT_ALLOWED) - } else { - Ok(()) - } - } - - async fn handler() -> &'static str { - "Hello, World" - } - - let filter_layer = FilterLayer::new(reject_post); - let route: Route<&str> = Route::new(any(handler)); - let service = filter_layer.layer(route); - - let mut cx = empty_cx(); - - // Test case 1: not filter - let req = simple_req(Method::GET, "/", ""); - let resp = service.call(&mut cx, req).await.unwrap(); - assert_eq!( - resp.into_body().into_string().await.unwrap(), - "Hello, World" - ); - - // Test case 2: filter - let req = simple_req(Method::POST, "/", ""); - let resp = service.call(&mut cx, req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); - } - - #[tokio::test] - async fn test_timeout_layer() { - use std::time::Duration; - - use crate::server::layer::TimeoutLayer; - - async fn index_handler() -> &'static str { - "Hello, World" - } - - async fn index_timeout_handler() -> &'static str { - tokio::time::sleep(Duration::from_secs_f64(1.5)).await; - "Hello, World" - } - - fn timeout_handler(_: &ServerContext) -> StatusCode { - StatusCode::REQUEST_TIMEOUT - } - - let timeout_layer = TimeoutLayer::new(Duration::from_secs(1), timeout_handler); - - let mut cx = empty_cx(); - - // Test case 1: timeout - let route: Route<&str> = Route::new(get(index_timeout_handler)); - let service = timeout_layer.clone().layer(route); - let req = simple_req(Method::GET, "/", ""); - let resp = service.call(&mut cx, req).await.unwrap(); - assert_eq!(resp.status(), StatusCode::REQUEST_TIMEOUT); - - // Test case 2: not timeout - let route: Route<&str> = Route::new(get(index_handler)); - let service = timeout_layer.clone().layer(route); - let req = simple_req(Method::GET, "/", ""); - let resp = service.call(&mut cx, req).await.unwrap(); - assert_eq!( - resp.into_body().into_string().await.unwrap(), - "Hello, World" - ); - } -} diff --git a/volo-http/src/server/layer/body_limit.rs b/volo-http/src/server/layer/body_limit.rs new file mode 100644 index 00000000..23381f1e --- /dev/null +++ b/volo-http/src/server/layer/body_limit.rs @@ -0,0 +1,117 @@ +use http::StatusCode; +use http_body::Body; +use motore::{layer::Layer, Service}; + +use crate::{ + context::ServerContext, request::ServerRequest, response::ServerResponse, server::IntoResponse, +}; + +/// [`Layer`] for limiting body size +/// +/// See [`BodyLimitLayer::new`] for more details. +#[derive(Clone)] +pub struct BodyLimitLayer { + limit: usize, +} + +impl BodyLimitLayer { + /// Create a new [`BodyLimitLayer`] with given `body_limit`. + /// + /// If the Body is larger than the `body_limit`, the request will be rejected. + pub fn new(body_limit: usize) -> Self { + Self { limit: body_limit } + } +} + +impl Layer for BodyLimitLayer { + type Service = BodyLimitService; + + fn layer(self, inner: S) -> Self::Service { + BodyLimitService { + service: inner, + limit: self.limit, + } + } +} + +/// [`BodyLimitLayer`] generated [`Service`] +/// +/// See [`BodyLimitLayer`] for more details. +pub struct BodyLimitService { + service: S, + limit: usize, +} + +impl Service> for BodyLimitService +where + S: Service> + Send + Sync + 'static, + S::Response: IntoResponse, + B: Body + Send, +{ + type Response = ServerResponse; + type Error = S::Error; + + async fn call( + &self, + cx: &mut ServerContext, + req: ServerRequest, + ) -> Result { + let (parts, body) = req.into_parts(); + // get body size from content length + if let Some(size) = parts + .headers + .get(http::header::CONTENT_LENGTH) + .and_then(|v| v.to_str().ok().and_then(|s| s.parse::().ok())) + { + if size > self.limit { + return Ok(StatusCode::PAYLOAD_TOO_LARGE.into_response()); + } + } else { + // get body size from stream + if body.size_hint().lower() > self.limit as u64 { + return Ok(StatusCode::PAYLOAD_TOO_LARGE.into_response()); + } + } + + let req = ServerRequest::from_parts(parts, body); + Ok(self.service.call(cx, req).await?.into_response()) + } +} + +#[cfg(test)] +mod tests { + use http::{Method, StatusCode}; + use motore::{layer::Layer, Service}; + + use crate::{ + server::{ + layer::BodyLimitLayer, + route::{any, Route}, + test_helpers::empty_cx, + }, + utils::test_helpers::simple_req, + }; + + #[tokio::test] + async fn test_body_limit() { + async fn handler() -> &'static str { + "Hello, World" + } + + let body_limit_layer = BodyLimitLayer::new(8); + let route: Route<_> = Route::new(any(handler)); + let service = body_limit_layer.layer(route); + + let mut cx = empty_cx(); + + // Test case 1: reject + let req = simple_req(Method::GET, "/", "111111111".to_string()); + let res = service.call(&mut cx, req).await.unwrap(); + assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); + + // Test case 2: not reject + let req = simple_req(Method::GET, "/", "1".to_string()); + let res = service.call(&mut cx, req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + } +} diff --git a/volo-http/src/server/layer/filter.rs b/volo-http/src/server/layer/filter.rs new file mode 100644 index 00000000..cb56797e --- /dev/null +++ b/volo-http/src/server/layer/filter.rs @@ -0,0 +1,180 @@ +use std::marker::PhantomData; + +use motore::{layer::Layer, Service}; + +use crate::{ + context::ServerContext, + request::ServerRequest, + response::ServerResponse, + server::{handler::HandlerWithoutRequest, IntoResponse}, +}; + +/// [`Layer`] for filtering requests +/// +/// See [`FilterLayer::new`] for more details. +#[derive(Clone)] +pub struct FilterLayer { + handler: H, + _marker: PhantomData<(R, T)>, +} + +impl FilterLayer { + /// Create a new [`FilterLayer`] + /// + /// The `handler` is an async function with some params that implement + /// [`FromContext`](crate::server::extract::FromContext), and returns + /// `Result<(), impl IntoResponse>`. + /// + /// If the handler returns `Ok(())`, the request will proceed. However, if the handler returns + /// `Err` with an object that implements [`IntoResponse`], the request will be rejected with + /// the returned object as the response. + /// + /// # Examples + /// + /// ``` + /// use http::{method::Method, status::StatusCode}; + /// use volo_http::server::{ + /// layer::FilterLayer, + /// route::{get, Router}, + /// }; + /// + /// async fn reject_post(method: Method) -> Result<(), StatusCode> { + /// if method == Method::POST { + /// Err(StatusCode::METHOD_NOT_ALLOWED) + /// } else { + /// Ok(()) + /// } + /// } + /// + /// async fn handler() -> &'static str { + /// "Hello, World" + /// } + /// + /// let router: Router = Router::new() + /// .route("/", get(handler)) + /// .layer(FilterLayer::new(reject_post)); + /// ``` + pub fn new(h: H) -> Self { + Self { + handler: h, + _marker: PhantomData, + } + } +} + +impl Layer for FilterLayer +where + S: Send + Sync + 'static, + H: Clone + Send + Sync + 'static, + T: Sync, +{ + type Service = Filter; + + fn layer(self, inner: S) -> Self::Service { + Filter { + service: inner, + handler: self.handler, + _marker: PhantomData, + } + } +} + +/// [`FilterLayer`] generated [`Service`] +/// +/// See [`FilterLayer`] for more details. +#[derive(Clone)] +pub struct Filter { + service: S, + handler: H, + _marker: PhantomData<(R, T)>, +} + +impl Service> for Filter +where + S: Service> + Send + Sync + 'static, + S::Response: IntoResponse, + S::Error: IntoResponse, + B: Send, + H: HandlerWithoutRequest> + Clone + Send + Sync + 'static, + R: IntoResponse + Send + Sync, + T: Sync, +{ + type Response = ServerResponse; + type Error = S::Error; + + async fn call( + &self, + cx: &mut ServerContext, + req: ServerRequest, + ) -> Result { + let (mut parts, body) = req.into_parts(); + let res = self.handler.clone().handle(cx, &mut parts).await; + let req = ServerRequest::from_parts(parts, body); + match res { + // do not filter it, call the service + Ok(Ok(())) => self + .service + .call(cx, req) + .await + .map(IntoResponse::into_response), + // filter it and return the specified response + Ok(Err(res)) => Ok(res.into_response()), + // something wrong while extracting + Err(rej) => { + tracing::warn!("[Volo-HTTP] FilterLayer: something wrong while extracting"); + Ok(rej.into_response()) + } + } + } +} + +#[cfg(test)] +mod filter_tests { + use http::{Method, StatusCode}; + use motore::{layer::Layer, Service}; + + use crate::{ + body::BodyConversion, + server::{ + route::{any, Route}, + test_helpers::empty_cx, + }, + utils::test_helpers::simple_req, + }; + + #[tokio::test] + async fn test_filter_layer() { + use crate::server::layer::FilterLayer; + + async fn reject_post(method: Method) -> Result<(), StatusCode> { + if method == Method::POST { + Err(StatusCode::METHOD_NOT_ALLOWED) + } else { + Ok(()) + } + } + + async fn handler() -> &'static str { + "Hello, World" + } + + let filter_layer = FilterLayer::new(reject_post); + let route: Route<&str> = Route::new(any(handler)); + let service = filter_layer.layer(route); + + let mut cx = empty_cx(); + + // Test case 1: not filter + let req = simple_req(Method::GET, "/", ""); + let resp = service.call(&mut cx, req).await.unwrap(); + assert_eq!( + resp.into_body().into_string().await.unwrap(), + "Hello, World" + ); + + // Test case 2: filter + let req = simple_req(Method::POST, "/", ""); + let resp = service.call(&mut cx, req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + } +} diff --git a/volo-http/src/server/layer/mod.rs b/volo-http/src/server/layer/mod.rs new file mode 100644 index 00000000..88c0cd8f --- /dev/null +++ b/volo-http/src/server/layer/mod.rs @@ -0,0 +1,11 @@ +//! Collections of some useful `Layer`s. +//! +//! See [`FilterLayer`] and [`TimeoutLayer`] for more details. + +pub(crate) mod body_limit; +mod filter; +mod timeout; + +pub use body_limit::BodyLimitLayer; +pub use filter::FilterLayer; +pub use timeout::TimeoutLayer; diff --git a/volo-http/src/server/layer/timeout.rs b/volo-http/src/server/layer/timeout.rs new file mode 100644 index 00000000..936206dc --- /dev/null +++ b/volo-http/src/server/layer/timeout.rs @@ -0,0 +1,177 @@ +use std::time::Duration; + +use motore::{layer::Layer, Service}; + +use crate::{ + context::ServerContext, request::ServerRequest, response::ServerResponse, server::IntoResponse, +}; + +/// [`Layer`] for setting timeout to the request +/// +/// See [`TimeoutLayer::new`] for more details. +#[derive(Clone)] +pub struct TimeoutLayer { + duration: Duration, + handler: H, +} + +impl TimeoutLayer { + /// Create a new [`TimeoutLayer`] with given [`Duration`] and handler. + /// + /// The handler should be a sync function with [`&ServerContext`](ServerContext) as parameter, + /// and return anything that implement [`IntoResponse`]. + /// + /// # Examples + /// + /// ``` + /// use std::time::Duration; + /// + /// use http::status::StatusCode; + /// use volo_http::{ + /// context::ServerContext, + /// server::{ + /// layer::TimeoutLayer, + /// route::{get, Router}, + /// }, + /// }; + /// + /// async fn index() -> &'static str { + /// "Hello, World" + /// } + /// + /// fn timeout_handler(_: &ServerContext) -> StatusCode { + /// StatusCode::REQUEST_TIMEOUT + /// } + /// + /// let router: Router = Router::new() + /// .route("/", get(index)) + /// .layer(TimeoutLayer::new(Duration::from_secs(1), timeout_handler)); + /// ``` + pub fn new(duration: Duration, handler: H) -> Self { + Self { duration, handler } + } +} + +impl Layer for TimeoutLayer +where + S: Send + Sync + 'static, +{ + type Service = Timeout; + + fn layer(self, inner: S) -> Self::Service { + Timeout { + service: inner, + duration: self.duration, + handler: self.handler, + } + } +} + +trait TimeoutHandler<'r> { + fn call(self, cx: &'r ServerContext) -> ServerResponse; +} + +impl<'r, F, R> TimeoutHandler<'r> for F +where + F: FnOnce(&'r ServerContext) -> R + 'r, + R: IntoResponse + 'r, +{ + fn call(self, cx: &'r ServerContext) -> ServerResponse { + self(cx).into_response() + } +} + +/// [`TimeoutLayer`] generated [`Service`] +/// +/// See [`TimeoutLayer`] for more details. +#[derive(Clone)] +pub struct Timeout { + service: S, + duration: Duration, + handler: H, +} + +impl Service> for Timeout +where + S: Service> + Send + Sync + 'static, + S::Response: IntoResponse, + S::Error: IntoResponse, + B: Send, + H: for<'r> TimeoutHandler<'r> + Clone + Sync, +{ + type Response = ServerResponse; + type Error = S::Error; + + async fn call( + &self, + cx: &mut ServerContext, + req: ServerRequest, + ) -> Result { + let fut_service = self.service.call(cx, req); + let fut_timeout = tokio::time::sleep(self.duration); + + tokio::select! { + resp = fut_service => resp.map(IntoResponse::into_response), + _ = fut_timeout => { + Ok((self.handler.clone()).call(cx)) + }, + } + } +} + +#[cfg(test)] +mod timeout_tests { + use http::{Method, StatusCode}; + use motore::{layer::Layer, Service}; + + use crate::{ + body::BodyConversion, + context::ServerContext, + server::{ + route::{get, Route}, + test_helpers::empty_cx, + }, + utils::test_helpers::simple_req, + }; + + #[tokio::test] + async fn test_timeout_layer() { + use std::time::Duration; + + use crate::server::layer::TimeoutLayer; + + async fn index_handler() -> &'static str { + "Hello, World" + } + + async fn index_timeout_handler() -> &'static str { + tokio::time::sleep(Duration::from_secs_f64(1.5)).await; + "Hello, World" + } + + fn timeout_handler(_: &ServerContext) -> StatusCode { + StatusCode::REQUEST_TIMEOUT + } + + let timeout_layer = TimeoutLayer::new(Duration::from_secs(1), timeout_handler); + + let mut cx = empty_cx(); + + // Test case 1: timeout + let route: Route<&str> = Route::new(get(index_timeout_handler)); + let service = timeout_layer.clone().layer(route); + let req = simple_req(Method::GET, "/", ""); + let resp = service.call(&mut cx, req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::REQUEST_TIMEOUT); + + // Test case 2: not timeout + let route: Route<&str> = Route::new(get(index_handler)); + let service = timeout_layer.clone().layer(route); + let req = simple_req(Method::GET, "/", ""); + let resp = service.call(&mut cx, req).await.unwrap(); + assert_eq!( + resp.into_body().into_string().await.unwrap(), + "Hello, World" + ); + } +} diff --git a/volo-http/src/server/utils/mod.rs b/volo-http/src/server/utils/mod.rs index a1a456d1..38f75ff8 100644 --- a/volo-http/src/server/utils/mod.rs +++ b/volo-http/src/server/utils/mod.rs @@ -5,5 +5,7 @@ mod serve_dir; pub use file_response::FileResponse; pub use serve_dir::ServeDir; +#[cfg(feature = "multipart")] +pub mod multipart; #[cfg(feature = "ws")] pub mod ws; diff --git a/volo-http/src/server/utils/multipart.rs b/volo-http/src/server/utils/multipart.rs new file mode 100644 index 00000000..f20c9de2 --- /dev/null +++ b/volo-http/src/server/utils/multipart.rs @@ -0,0 +1,365 @@ +//! Multipart implementation for server. +//! +//! This module provides utilities for extracting `multipart/form-data` formatted data from HTTP +//! requests. +//! +//! # Example +//! +//! ```rust +//! use http::StatusCode; +//! use volo_http::{ +//! response::ServerResponse, +//! server::{ +//! route::post, +//! utils::multipart::{Multipart, MultipartRejectionError}, +//! }, +//! Router, +//! }; +//! +//! async fn upload(mut multipart: Multipart) -> Result { +//! while let Some(field) = multipart.next_field().await? { +//! let name = field.name().unwrap().to_string(); +//! let value = field.bytes().await?; +//! +//! println!("The field {} has {} bytes", name, value.len()); +//! } +//! +//! Ok(StatusCode::OK) +//! } +//! +//! let app: Router = Router::new().route("/upload", post(upload)); +//! ``` +//! +//! See [`Multipart`] for more details. + +use std::{error::Error, fmt}; + +use http::{request::Parts, StatusCode}; +use http_body_util::BodyExt; +use multer::Field; + +use crate::{ + context::ServerContext, + server::{extract::FromRequest, IntoResponse}, +}; + +/// Extract a type from `multipart/form-data` HTTP requests. +/// +/// [`Multipart`] can be passed as an argument to a handler, which can be used to extract each +/// `multipart/form-data` field by calling [`Multipart::next_field`]. +/// +/// **Notice** +/// +/// Extracting `multipart/form-data` data will consume the body, hence [`Multipart`] must be the +/// last argument from the handler. +/// +/// # Example +/// +/// ```rust +/// use http::StatusCode; +/// use volo_http::{ +/// response::ServerResponse, +/// server::utils::multipart::{Multipart, MultipartRejectionError}, +/// }; +/// +/// async fn upload(mut multipart: Multipart) -> Result { +/// while let Some(field) = multipart.next_field().await? { +/// todo!() +/// } +/// +/// Ok(StatusCode::OK) +/// } +/// ``` +/// +/// # Body Limitation +/// +/// Since the body is unlimited, so it is recommended to use +/// [`BodyLimitLayer`](crate::server::layer::BodyLimitLayer) to limit the size of the body. +/// +/// ```rust +/// use http::StatusCode; +/// use volo_http::{ +/// server::{ +/// layer::BodyLimitLayer, +/// route::post, +/// utils::multipart::{Multipart, MultipartRejectionError}, +/// }, +/// Router, +/// }; +/// +/// async fn upload_handler( +/// mut multipart: Multipart, +/// ) -> Result { +/// Ok(StatusCode::OK) +/// } +/// +/// let app: Router<_> = Router::new() +/// .route("/", post(upload_handler)) +/// .layer(BodyLimitLayer::new(1024)); +/// ``` +#[must_use] +pub struct Multipart { + inner: multer::Multipart<'static>, +} + +impl Multipart { + /// Iterate over all [`Field`] in [`Multipart`] + /// + /// # Example + /// + /// ```rust + /// # use volo_http::server::utils::multipart::Multipart; + /// # let mut multipart: Multipart; + /// // Extract each field from multipart by using while loop + /// # async fn upload(mut multipart: Multipart) { + /// while let Some(field) = multipart.next_field().await.unwrap() { + /// let name = field.name().unwrap().to_string(); // Get field name + /// let data = field.bytes().await.unwrap(); // Get field data + /// } + /// # } + /// ``` + pub async fn next_field(&mut self) -> Result>, MultipartRejectionError> { + Ok(self.inner.next_field().await?) + } +} + +impl FromRequest for Multipart { + type Rejection = MultipartRejectionError; + async fn from_request( + _: &mut ServerContext, + parts: Parts, + body: crate::body::Body, + ) -> Result { + let boundary = multer::parse_boundary( + parts + .headers + .get(http::header::CONTENT_TYPE) + .ok_or(multer::Error::NoMultipart)? + .to_str() + .map_err(|_| multer::Error::NoBoundary)?, + )?; + + let multipart = multer::Multipart::new(body.into_data_stream(), boundary); + + Ok(Self { inner: multipart }) + } +} + +/// [`Error`]s while extracting [`Multipart`]. +/// +/// [`Error`]: Error +#[derive(Debug)] +pub struct MultipartRejectionError { + inner: multer::Error, +} + +impl From for MultipartRejectionError { + fn from(err: multer::Error) -> Self { + Self { inner: err } + } +} + +fn status_code_from_multer_error(err: &multer::Error) -> StatusCode { + match err { + multer::Error::UnknownField { .. } + | multer::Error::IncompleteFieldData { .. } + | multer::Error::IncompleteHeaders + | multer::Error::ReadHeaderFailed(..) + | multer::Error::DecodeHeaderName { .. } + | multer::Error::DecodeContentType(..) + | multer::Error::NoBoundary + | multer::Error::DecodeHeaderValue { .. } + | multer::Error::NoMultipart + | multer::Error::IncompleteStream => StatusCode::BAD_REQUEST, + multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => { + StatusCode::PAYLOAD_TOO_LARGE + } + multer::Error::StreamReadFailed(_) => StatusCode::INTERNAL_SERVER_ERROR, + _ => StatusCode::INTERNAL_SERVER_ERROR, + } +} + +impl MultipartRejectionError { + /// Convert the [`MultipartRejectionError`] into a [`http::StatusCode`]. + pub fn to_status_code(&self) -> http::StatusCode { + status_code_from_multer_error(&self.inner) + } +} + +impl Error for MultipartRejectionError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(&self.inner) + } +} + +impl fmt::Display for MultipartRejectionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + std::fmt::Display::fmt(&self.inner, f) + } +} + +impl IntoResponse for MultipartRejectionError { + fn into_response(self) -> http::Response { + self.to_status_code().into_response() + } +} + +#[cfg(test)] +mod multipart_tests { + use std::{ + convert::Infallible, + net::{IpAddr, Ipv4Addr, SocketAddr}, + }; + + use motore::Service; + use reqwest::multipart::Form; + use volo::net::Address; + + use crate::{ + context::ServerContext, + request::ServerRequest, + response::ServerResponse, + server::{ + test_helpers, + utils::multipart::{Multipart, MultipartRejectionError}, + IntoResponse, + }, + Server, + }; + + fn _test_compile() { + async fn handler(_: Multipart) {} + let app = test_helpers::to_service(handler); + let addr = Address::Ip(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + 25241, + )); + let _server = Server::new(app).run(addr); + } + + async fn run_handler(service: S, port: u16) + where + S: Service + + Send + + Sync + + 'static, + { + let addr = Address::Ip(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + port, + )); + + tokio::spawn(Server::new(service).run(addr)); + + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + } + + #[tokio::test] + async fn test_single_field_upload() { + const BYTES: &[u8] = "🦀".as_bytes(); + const FILE_NAME: &str = "index.html"; + const CONTENT_TYPE: &str = "text/html; charset=utf-8"; + + async fn handler(mut multipart: Multipart) -> impl IntoResponse { + let field = multipart.next_field().await.unwrap().unwrap(); + + assert_eq!(field.file_name().unwrap(), FILE_NAME); + assert_eq!(field.content_type().unwrap().as_ref(), CONTENT_TYPE); + assert_eq!(field.headers()["foo"], "bar"); + assert_eq!(field.bytes().await.unwrap(), BYTES); + + assert!(multipart.next_field().await.unwrap().is_none()); + } + + let form = Form::new().part( + "file", + reqwest::multipart::Part::bytes(BYTES) + .file_name(FILE_NAME) + .mime_str(CONTENT_TYPE) + .unwrap() + .headers(reqwest::header::HeaderMap::from_iter([( + reqwest::header::HeaderName::from_static("foo"), + reqwest::header::HeaderValue::from_static("bar"), + )])), + ); + + run_handler(test_helpers::to_service(handler), 25241).await; + + let url_str = format!("http://127.0.0.1:{}", 25241); + let url = url::Url::parse(url_str.as_str()).unwrap(); + + reqwest::Client::new() + .post(url) + .multipart(form) + .send() + .await + .unwrap(); + } + + #[tokio::test] + async fn test_multiple_field_upload() { + const BYTES: &[u8] = "🦀".as_bytes(); + const CONTENT_TYPE: &str = "text/html; charset=utf-8"; + + const FIELD_NAME1: &str = "file1"; + const FIELD_NAME2: &str = "file2"; + const FILE_NAME1: &str = "index1.html"; + const FILE_NAME2: &str = "index2.html"; + + async fn handler(mut multipart: Multipart) -> Result<(), MultipartRejectionError> { + while let Some(field) = multipart.next_field().await? { + match field.name() { + Some(FIELD_NAME1) => { + assert_eq!(field.file_name().unwrap(), FILE_NAME1); + assert_eq!(field.headers()["foo1"], "bar1"); + } + Some(FIELD_NAME2) => { + assert_eq!(field.file_name().unwrap(), FILE_NAME2); + assert_eq!(field.headers()["foo2"], "bar2"); + } + _ => unreachable!(), + } + assert_eq!(field.content_type().unwrap().as_ref(), CONTENT_TYPE); + assert_eq!(field.bytes().await?, BYTES); + } + + Ok(()) + } + + let form = Form::new() + .part( + FIELD_NAME1, + reqwest::multipart::Part::bytes(BYTES) + .file_name(FILE_NAME1) + .mime_str(CONTENT_TYPE) + .unwrap() + .headers(reqwest::header::HeaderMap::from_iter([( + reqwest::header::HeaderName::from_static("foo1"), + reqwest::header::HeaderValue::from_static("bar1"), + )])), + ) + .part( + FIELD_NAME2, + reqwest::multipart::Part::bytes(BYTES) + .file_name(FILE_NAME2) + .mime_str(CONTENT_TYPE) + .unwrap() + .headers(reqwest::header::HeaderMap::from_iter([( + reqwest::header::HeaderName::from_static("foo2"), + reqwest::header::HeaderValue::from_static("bar2"), + )])), + ); + + run_handler(test_helpers::to_service(handler), 25242).await; + + let url_str = format!("http://127.0.0.1:{}", 25242); + let url = url::Url::parse(url_str.as_str()).unwrap(); + + reqwest::Client::new() + .post(url.clone()) + .multipart(form) + .send() + .await + .unwrap(); + } +} diff --git a/volo-http/src/server/utils/ws.rs b/volo-http/src/server/utils/ws.rs index f4f584fa..a80fa7bf 100644 --- a/volo-http/src/server/utils/ws.rs +++ b/volo-http/src/server/utils/ws.rs @@ -71,7 +71,7 @@ use crate::{ const HEADERVALUE_UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); const HEADERVALUE_WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); -/// Handler request for establishing WebSocket connection. +/// Handle request for establishing WebSocket connection. /// /// [`WebSocketUpgrade`] can be passed as an argument to a handler, which will be called if the /// http connection making the request can be upgraded to a websocket connection.