From b2a3a3de429cc4bfd48d1e60035cc220d7b8ea05 Mon Sep 17 00:00:00 2001 From: StellarisW Date: Wed, 23 Oct 2024 17:49:08 +0800 Subject: [PATCH] feat(http): add multipart for server --- Cargo.toml | 2 +- volo-http/src/server/extract.rs | 1 - volo-http/src/server/layer/body_limit.rs | 68 +++++++----------------- volo-http/src/server/utils/multipart.rs | 55 ++++++++----------- 4 files changed, 43 insertions(+), 83 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 36f3e518..2dce3fd7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -123,7 +123,7 @@ tower = "0.5" tracing = "0.1" tracing-subscriber = "0.3" update-informer = "1" -url="2.5.2" +url = "2.5.2" url_path = "0.1" walkdir = "2" diff --git a/volo-http/src/server/extract.rs b/volo-http/src/server/extract.rs index 5dc62a7e..e054939a 100644 --- a/volo-http/src/server/extract.rs +++ b/volo-http/src/server/extract.rs @@ -408,7 +408,6 @@ where parts: Parts, body: B, ) -> Result { - // TODO: add limited body let bytes = body .collect() .await diff --git a/volo-http/src/server/layer/body_limit.rs b/volo-http/src/server/layer/body_limit.rs index e34d110d..ab327204 100644 --- a/volo-http/src/server/layer/body_limit.rs +++ b/volo-http/src/server/layer/body_limit.rs @@ -6,26 +6,12 @@ use crate::{ context::ServerContext, request::ServerRequest, response::ServerResponse, server::IntoResponse, }; -#[derive(Debug, Clone, Copy)] -pub(crate) enum BodyLimitKind { - #[allow(dead_code)] - Disable, - #[allow(dead_code)] - Block(usize), -} - /// [`Layer`] for limiting body size /// -/// Get the body size by the priority: -/// -/// 1. [`http::header::CONTENT_LENGTH`] -/// -/// 2. [`http_body::Body::size_hint()`] -/// -/// See [`BodyLimitLayer::max`] for more details. +/// See [`BodyLimitLayer::new`] for more details. #[derive(Clone)] pub struct BodyLimitLayer { - kind: BodyLimitKind, + limit: usize, } impl BodyLimitLayer { @@ -48,22 +34,10 @@ impl BodyLimitLayer { /// /// let router: Router = Router::new() /// .route("/", post(handler)) - /// .layer(BodyLimitLayer::max(1024)); // limit body size to 1KB + /// .layer(BodyLimitLayer::new(1024)); // limit body size to 1KB /// ``` - pub fn max(body_limit: usize) -> Self { - Self { - kind: BodyLimitKind::Block(body_limit), - } - } - - /// Create a new [`BodyLimitLayer`] with `body_limit` disabled. - /// - /// It's unnecessary to use this method, because the `body_limit` is disabled by default. - #[allow(dead_code)] - fn disable() -> Self { - Self { - kind: BodyLimitKind::Disable, - } + pub fn new(body_limit: usize) -> Self { + Self { limit: body_limit } } } @@ -73,7 +47,7 @@ impl Layer for BodyLimitLayer { fn layer(self, inner: S) -> Self::Service { BodyLimitService { service: inner, - kind: self.kind, + limit: self.limit, } } } @@ -83,7 +57,7 @@ impl Layer for BodyLimitLayer { /// See [`BodyLimitLayer`] for more details. pub struct BodyLimitService { service: S, - kind: BodyLimitKind, + limit: usize, } impl Service for BodyLimitService @@ -101,21 +75,19 @@ where req: ServerRequest, ) -> Result { let (parts, body) = req.into_parts(); - if let BodyLimitKind::Block(limit) = self.kind { - // get body size from content length - if let Some(size) = parts - .headers - .get(http::header::CONTENT_LENGTH) - .and_then(|v| v.to_str().ok().and_then(|s| s.parse::().ok())) - { - if size > limit { - return Ok(StatusCode::PAYLOAD_TOO_LARGE.into_response()); - } - } else { - // get body size from stream - if body.size_hint().lower() > limit as u64 { - return Ok(StatusCode::PAYLOAD_TOO_LARGE.into_response()); - } + // get body size from content length + if let Some(size) = parts + .headers + .get(http::header::CONTENT_LENGTH) + .and_then(|v| v.to_str().ok().and_then(|s| s.parse::().ok())) + { + if size > self.limit { + return Ok(StatusCode::PAYLOAD_TOO_LARGE.into_response()); + } + } else { + // get body size from stream + if body.size_hint().lower() > self.limit as u64 { + return Ok(StatusCode::PAYLOAD_TOO_LARGE.into_response()); } } diff --git a/volo-http/src/server/utils/multipart.rs b/volo-http/src/server/utils/multipart.rs index d861f36e..f9da2d5c 100644 --- a/volo-http/src/server/utils/multipart.rs +++ b/volo-http/src/server/utils/multipart.rs @@ -5,7 +5,7 @@ //! //! # Example //! -//! ```rust,no_run +//! ```rust //! use http::StatusCode; //! use volo_http::{ //! response::ServerResponse, @@ -57,7 +57,7 @@ use crate::{ /// /// # Example /// -/// ```rust,no_run +/// ```rust /// use http::StatusCode; /// use volo_http::{ /// response::ServerResponse, @@ -80,7 +80,7 @@ use crate::{ /// Since the body is unlimited, so it is recommended to use /// [`BodyLimitLayer`](crate::server::layer::BodyLimitLayer) to limit the size of the body. /// -/// ```rust,no_run +/// ```rust /// use http::StatusCode; /// use volo_http::{ /// Router, @@ -91,25 +91,25 @@ use crate::{ /// } /// }; /// -/// # async fn upload_handler(mut multipart: Multipart<'static>) -> Result { +/// # async fn upload_handler(mut multipart: Multipart) -> Result { /// # Ok(StatusCode::OK) /// # } /// /// let app: Router<_>= Router::new() /// .route("/",post(upload_handler)) -/// .layer( BodyLimitLayer::max(1024)); +/// .layer( BodyLimitLayer::new(1024)); /// ``` #[must_use] -pub struct Multipart<'r> { - inner: multer::Multipart<'r>, +pub struct Multipart { + inner: multer::Multipart<'static>, } -impl<'r> Multipart<'r> { +impl Multipart { /// Iterate over all [`Field`] in [`Multipart`] /// /// # Example /// - /// ```rust,no_run + /// ```rust /// # use volo_http::server::utils::multipart::Multipart; /// # let mut multipart: Multipart; /// // Extract each field from multipart by using while loop @@ -120,12 +120,12 @@ impl<'r> Multipart<'r> { /// } /// # } /// ``` - pub async fn next_field(&mut self) -> Result>, MultipartRejectionError> { + pub async fn next_field(&mut self) -> Result>, MultipartRejectionError> { Ok(self.inner.next_field().await?) } } -impl<'r> FromRequest for Multipart<'r> { +impl FromRequest for Multipart { type Rejection = MultipartRejectionError; async fn from_request( _: &mut ServerContext, @@ -136,8 +136,9 @@ impl<'r> FromRequest for Multipart<'r> { parts .headers .get(http::header::CONTENT_TYPE) - .map(|h| h.to_str().unwrap_or_default()) - .unwrap_or_default(), + .ok_or(multer::Error::NoMultipart)? + .to_str() + .map_err(|_| multer::Error::NoBoundary)?, )?; let multipart = multer::Multipart::new(body.into_data_stream(), boundary); @@ -175,20 +176,7 @@ fn status_code_from_multer_error(err: &multer::Error) -> StatusCode { multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => { StatusCode::PAYLOAD_TOO_LARGE } - multer::Error::StreamReadFailed(err) => { - if let Some(err) = err.downcast_ref::() { - return status_code_from_multer_error(err); - } - - if err - .downcast_ref::() - .is_some() - { - return StatusCode::PAYLOAD_TOO_LARGE; - } - - StatusCode::INTERNAL_SERVER_ERROR - } + multer::Error::StreamReadFailed(_) => StatusCode::INTERNAL_SERVER_ERROR, _ => StatusCode::INTERNAL_SERVER_ERROR, } } @@ -214,7 +202,7 @@ impl fmt::Display for MultipartRejectionError { impl IntoResponse for MultipartRejectionError { fn into_response(self) -> http::Response { - (self.to_status_code(), self.to_string()).into_response() + self.to_status_code().into_response() } } @@ -245,7 +233,7 @@ mod multipart_tests { }; fn _test_compile() { - async fn handler(_: Multipart<'_>) {} + async fn handler(_: Multipart) {} let app = test_helpers::to_service(handler); let addr = Address::Ip(SocketAddr::new( IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), @@ -270,13 +258,14 @@ mod multipart_tests { tokio::time::sleep(std::time::Duration::from_secs(1)).await; } + #[tokio::test] async fn test_single_field_upload() { const BYTES: &[u8] = "🦀".as_bytes(); const FILE_NAME: &str = "index.html"; const CONTENT_TYPE: &str = "text/html; charset=utf-8"; - async fn handler(mut multipart: Multipart<'static>) -> impl IntoResponse { + async fn handler(mut multipart: Multipart) -> impl IntoResponse { let field = multipart.next_field().await.unwrap().unwrap(); assert_eq!(field.file_name().unwrap(), FILE_NAME); @@ -322,7 +311,7 @@ mod multipart_tests { const FILE_NAME1: &str = "index1.html"; const FILE_NAME2: &str = "index2.html"; - async fn handler(mut multipart: Multipart<'static>) -> Result<(), MultipartRejectionError> { + async fn handler(mut multipart: Multipart) -> Result<(), MultipartRejectionError> { while let Some(field) = multipart.next_field().await? { match field.name() { Some(FIELD_NAME1) => { @@ -382,7 +371,7 @@ mod multipart_tests { #[tokio::test] async fn test_large_field_upload() { - async fn handler(mut multipart: Multipart<'static>) -> Result<(), MultipartRejectionError> { + async fn handler(mut multipart: Multipart) -> Result<(), MultipartRejectionError> { while let Some(field) = multipart.next_field().await? { field.bytes().await?; } @@ -410,7 +399,7 @@ mod multipart_tests { let app: Router<_> = Router::new() .route("/", post(handler)) - .layer(BodyLimitLayer::max(1024)); + .layer(BodyLimitLayer::new(1024)); run_handler(app, 8003).await;