Skip to content

Commit

Permalink
feat(http): add multipart for server
Browse files Browse the repository at this point in the history
  • Loading branch information
StellarisW committed Oct 23, 2024
1 parent 427fd20 commit b2a3a3d
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 83 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ tower = "0.5"
tracing = "0.1"
tracing-subscriber = "0.3"
update-informer = "1"
url="2.5.2"
url = "2.5.2"
url_path = "0.1"
walkdir = "2"

Expand Down
1 change: 0 additions & 1 deletion volo-http/src/server/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,6 @@ where
parts: Parts,
body: B,
) -> Result<Self, Self::Rejection> {
// TODO: add limited body
let bytes = body
.collect()
.await
Expand Down
68 changes: 20 additions & 48 deletions volo-http/src/server/layer/body_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,12 @@ use crate::{
context::ServerContext, request::ServerRequest, response::ServerResponse, server::IntoResponse,
};

#[derive(Debug, Clone, Copy)]
pub(crate) enum BodyLimitKind {
#[allow(dead_code)]
Disable,
#[allow(dead_code)]
Block(usize),
}

/// [`Layer`] for limiting body size
///
/// Get the body size by the priority:
///
/// 1. [`http::header::CONTENT_LENGTH`]
///
/// 2. [`http_body::Body::size_hint()`]
///
/// See [`BodyLimitLayer::max`] for more details.
/// See [`BodyLimitLayer::new`] for more details.
#[derive(Clone)]
pub struct BodyLimitLayer {
kind: BodyLimitKind,
limit: usize,
}

impl BodyLimitLayer {
Expand All @@ -48,22 +34,10 @@ impl BodyLimitLayer {
///
/// let router: Router = Router::new()
/// .route("/", post(handler))
/// .layer(BodyLimitLayer::max(1024)); // limit body size to 1KB
/// .layer(BodyLimitLayer::new(1024)); // limit body size to 1KB
/// ```
pub fn max(body_limit: usize) -> Self {
Self {
kind: BodyLimitKind::Block(body_limit),
}
}

/// Create a new [`BodyLimitLayer`] with `body_limit` disabled.
///
/// It's unnecessary to use this method, because the `body_limit` is disabled by default.
#[allow(dead_code)]
fn disable() -> Self {
Self {
kind: BodyLimitKind::Disable,
}
pub fn new(body_limit: usize) -> Self {
Self { limit: body_limit }
}
}

Expand All @@ -73,7 +47,7 @@ impl<S> Layer<S> for BodyLimitLayer {
fn layer(self, inner: S) -> Self::Service {
BodyLimitService {
service: inner,
kind: self.kind,
limit: self.limit,
}
}
}
Expand All @@ -83,7 +57,7 @@ impl<S> Layer<S> for BodyLimitLayer {
/// See [`BodyLimitLayer`] for more details.
pub struct BodyLimitService<S> {
service: S,
kind: BodyLimitKind,
limit: usize,
}

impl<S> Service<ServerContext, ServerRequest> for BodyLimitService<S>
Expand All @@ -101,21 +75,19 @@ where
req: ServerRequest,
) -> Result<Self::Response, Self::Error> {
let (parts, body) = req.into_parts();
if let BodyLimitKind::Block(limit) = self.kind {
// get body size from content length
if let Some(size) = parts
.headers
.get(http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok().and_then(|s| s.parse::<usize>().ok()))
{
if size > limit {
return Ok(StatusCode::PAYLOAD_TOO_LARGE.into_response());
}
} else {
// get body size from stream
if body.size_hint().lower() > limit as u64 {
return Ok(StatusCode::PAYLOAD_TOO_LARGE.into_response());
}
// get body size from content length
if let Some(size) = parts
.headers
.get(http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok().and_then(|s| s.parse::<usize>().ok()))
{
if size > self.limit {
return Ok(StatusCode::PAYLOAD_TOO_LARGE.into_response());
}
} else {
// get body size from stream
if body.size_hint().lower() > self.limit as u64 {
return Ok(StatusCode::PAYLOAD_TOO_LARGE.into_response());
}
}

Expand Down
55 changes: 22 additions & 33 deletions volo-http/src/server/utils/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//!
//! # Example
//!
//! ```rust,no_run
//! ```rust
//! use http::StatusCode;
//! use volo_http::{
//! response::ServerResponse,
Expand Down Expand Up @@ -57,7 +57,7 @@ use crate::{
///
/// # Example
///
/// ```rust,no_run
/// ```rust
/// use http::StatusCode;
/// use volo_http::{
/// response::ServerResponse,
Expand All @@ -80,7 +80,7 @@ use crate::{
/// Since the body is unlimited, so it is recommended to use
/// [`BodyLimitLayer`](crate::server::layer::BodyLimitLayer) to limit the size of the body.
///
/// ```rust,no_run
/// ```rust
/// use http::StatusCode;
/// use volo_http::{
/// Router,
Expand All @@ -91,25 +91,25 @@ use crate::{
/// }
/// };
///
/// # async fn upload_handler(mut multipart: Multipart<'static>) -> Result<StatusCode, MultipartRejectionError> {
/// # async fn upload_handler(mut multipart: Multipart) -> Result<StatusCode, MultipartRejectionError> {
/// # Ok(StatusCode::OK)
/// # }
///
/// let app: Router<_>= Router::new()
/// .route("/",post(upload_handler))
/// .layer( BodyLimitLayer::max(1024));
/// .layer( BodyLimitLayer::new(1024));
/// ```
#[must_use]
pub struct Multipart<'r> {
inner: multer::Multipart<'r>,
pub struct Multipart {
inner: multer::Multipart<'static>,
}

impl<'r> Multipart<'r> {
impl Multipart {
/// Iterate over all [`Field`] in [`Multipart`]
///
/// # Example
///
/// ```rust,no_run
/// ```rust
/// # use volo_http::server::utils::multipart::Multipart;
/// # let mut multipart: Multipart;
/// // Extract each field from multipart by using while loop
Expand All @@ -120,12 +120,12 @@ impl<'r> Multipart<'r> {
/// }
/// # }
/// ```
pub async fn next_field(&mut self) -> Result<Option<Field<'r>>, MultipartRejectionError> {
pub async fn next_field(&mut self) -> Result<Option<Field<'static>>, MultipartRejectionError> {
Ok(self.inner.next_field().await?)
}
}

impl<'r> FromRequest<crate::body::Body> for Multipart<'r> {
impl FromRequest<crate::body::Body> for Multipart {
type Rejection = MultipartRejectionError;
async fn from_request(
_: &mut ServerContext,
Expand All @@ -136,8 +136,9 @@ impl<'r> FromRequest<crate::body::Body> for Multipart<'r> {
parts
.headers
.get(http::header::CONTENT_TYPE)
.map(|h| h.to_str().unwrap_or_default())
.unwrap_or_default(),
.ok_or(multer::Error::NoMultipart)?
.to_str()
.map_err(|_| multer::Error::NoBoundary)?,
)?;

let multipart = multer::Multipart::new(body.into_data_stream(), boundary);
Expand Down Expand Up @@ -175,20 +176,7 @@ fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
StatusCode::PAYLOAD_TOO_LARGE
}
multer::Error::StreamReadFailed(err) => {
if let Some(err) = err.downcast_ref::<multer::Error>() {
return status_code_from_multer_error(err);
}

if err
.downcast_ref::<http_body_util::LengthLimitError>()
.is_some()
{
return StatusCode::PAYLOAD_TOO_LARGE;
}

StatusCode::INTERNAL_SERVER_ERROR
}
multer::Error::StreamReadFailed(_) => StatusCode::INTERNAL_SERVER_ERROR,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
Expand All @@ -214,7 +202,7 @@ impl fmt::Display for MultipartRejectionError {

impl IntoResponse for MultipartRejectionError {
fn into_response(self) -> http::Response<crate::body::Body> {
(self.to_status_code(), self.to_string()).into_response()
self.to_status_code().into_response()
}
}

Expand Down Expand Up @@ -245,7 +233,7 @@ mod multipart_tests {
};

fn _test_compile() {
async fn handler(_: Multipart<'_>) {}
async fn handler(_: Multipart) {}
let app = test_helpers::to_service(handler);
let addr = Address::Ip(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
Expand All @@ -270,13 +258,14 @@ mod multipart_tests {

tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}

#[tokio::test]
async fn test_single_field_upload() {
const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
const FILE_NAME: &str = "index.html";
const CONTENT_TYPE: &str = "text/html; charset=utf-8";

async fn handler(mut multipart: Multipart<'static>) -> impl IntoResponse {
async fn handler(mut multipart: Multipart) -> impl IntoResponse {
let field = multipart.next_field().await.unwrap().unwrap();

assert_eq!(field.file_name().unwrap(), FILE_NAME);
Expand Down Expand Up @@ -322,7 +311,7 @@ mod multipart_tests {
const FILE_NAME1: &str = "index1.html";
const FILE_NAME2: &str = "index2.html";

async fn handler(mut multipart: Multipart<'static>) -> Result<(), MultipartRejectionError> {
async fn handler(mut multipart: Multipart) -> Result<(), MultipartRejectionError> {
while let Some(field) = multipart.next_field().await? {
match field.name() {
Some(FIELD_NAME1) => {
Expand Down Expand Up @@ -382,7 +371,7 @@ mod multipart_tests {

#[tokio::test]
async fn test_large_field_upload() {
async fn handler(mut multipart: Multipart<'static>) -> Result<(), MultipartRejectionError> {
async fn handler(mut multipart: Multipart) -> Result<(), MultipartRejectionError> {
while let Some(field) = multipart.next_field().await? {
field.bytes().await?;
}
Expand Down Expand Up @@ -410,7 +399,7 @@ mod multipart_tests {

let app: Router<_> = Router::new()
.route("/", post(handler))
.layer(BodyLimitLayer::max(1024));
.layer(BodyLimitLayer::new(1024));

run_handler(app, 8003).await;

Expand Down

0 comments on commit b2a3a3d

Please sign in to comment.