Skip to content

Commit

Permalink
improve codebase based on tower-http's efforts (David)
Browse files Browse the repository at this point in the history
  • Loading branch information
glendc committed Nov 20, 2023
1 parent 13b210c commit 347f34d
Show file tree
Hide file tree
Showing 16 changed files with 287 additions and 246 deletions.
4 changes: 3 additions & 1 deletion tower-async-http/src/auth/add_authorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ where

#[cfg(test)]
mod tests {
use std::convert::Infallible;

#[allow(unused_imports)]
use super::*;

Expand Down Expand Up @@ -225,7 +227,7 @@ mod tests {
let auth = request.headers().get(http::header::AUTHORIZATION).unwrap();
assert!(auth.is_sensitive());

Ok::<_, hyper::Error>(Response::new(Body::empty()))
Ok::<_, Infallible>(Response::new(Body::empty()))
});

let client = AddAuthorization::bearer(svc, "foo").as_sensitive(true);
Expand Down
3 changes: 1 addition & 2 deletions tower-async-http/src/classify/grpc_errors_as_failures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ impl GrpcCodeBitmask {
///
/// Responses are considered successful if
///
/// - `grpc-status` header value matches the defines success codes in [`GrpcErrorsAsFailures`] (only `Ok` by
/// default).
/// - `grpc-status` header value contains a successs value.
/// - `grpc-status` header is missing.
/// - `grpc-status` header value isn't a valid `String`.
/// - `grpc-status` header value can't parsed into an `i32`.
Expand Down
11 changes: 4 additions & 7 deletions tower-async-http/src/compression/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,10 @@ where
#[cfg(feature = "compression-zstd")]
BodyInnerProj::Zstd { inner } => inner.poll_frame(cx),
BodyInnerProj::Identity { inner } => match ready!(inner.poll_frame(cx)) {
Some(Ok(frame)) => match frame.into_data() {
Ok(mut buf) => {
let bytes = buf.copy_to_bytes(buf.remaining());
Poll::Ready(Some(Ok(Frame::data(bytes))))
}
Err(_) => Poll::Ready(None),
},
Some(Ok(frame)) => {
let frame = frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()));
Poll::Ready(Some(Ok(frame)))
}
Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
None => Poll::Ready(None),
},
Expand Down
25 changes: 7 additions & 18 deletions tower-async-http/src/compression/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,12 @@ impl CompressionLayer {
mod tests {
use super::*;

use crate::test_helpers::{Body, TowerHttpBodyExt};
use crate::test_helpers::Body;

use http::{header::ACCEPT_ENCODING, Request, Response};
use tokio::fs::File;
// for Body::data
use bytes::{Bytes, BytesMut};
use http_body_util::BodyExt;
use std::convert::Infallible;
use tokio::fs::File;
use tokio_util::io::ReaderStream;
use tower_async::{Service, ServiceBuilder};

Expand Down Expand Up @@ -167,13 +166,8 @@ mod tests {
assert_eq!(response.headers()["content-encoding"], "deflate");

// Read the body
let mut body = response.into_body();
let mut bytes = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk?;
bytes.extend_from_slice(&chunk[..]);
}
let bytes: Bytes = bytes.freeze();
let body = response.into_body();
let bytes = body.collect().await.unwrap().to_bytes();

let deflate_bytes_len = bytes.len();

Expand All @@ -197,13 +191,8 @@ mod tests {
assert_eq!(response.headers()["content-encoding"], "br");

// Read the body
let mut body = response.into_body();
let mut bytes = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk?;
bytes.extend_from_slice(&chunk[..]);
}
let bytes: Bytes = bytes.freeze();
let body = response.into_body();
let bytes = body.collect().await.unwrap().to_bytes();

let br_byte_length = bytes.len();

Expand Down
76 changes: 29 additions & 47 deletions tower-async-http/src/compression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,14 @@ mod tests {
use super::*;

use crate::compression::predicate::SizeAbove;
use crate::test_helpers::{Body, TowerHttpBodyExt};
use crate::test_helpers::{Body, WithTrailers};

use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder};
use bytes::BytesMut;
use flate2::read::GzDecoder;
use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_TYPE};
use hyper::{Error, Request, Response};
use http::{HeaderMap, HeaderName, Request, Response};
use http_body_util::BodyExt;
use std::convert::Infallible;
use std::io::Read;
use std::sync::{Arc, RwLock};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
Expand Down Expand Up @@ -127,13 +128,9 @@ mod tests {
let res = svc.call(req).await.unwrap();

// read the compressed body
let mut body = res.into_body();
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
let compressed_data = data.freeze().to_vec();
let collected = res.into_body().collect().await.unwrap();
let trailers = collected.trailers().cloned().unwrap();
let compressed_data = collected.to_bytes();

// decompress the body
// doing this with flate2 as that is much easier than async-compression and blocking during
Expand All @@ -143,6 +140,9 @@ mod tests {
decoder.read_to_string(&mut decompressed).unwrap();

assert_eq!(decompressed, "Hello, World!");

// trailers are maintained
assert_eq!(trailers["foo"], "bar");
}

#[tokio::test]
Expand All @@ -158,13 +158,8 @@ mod tests {
let res = svc.call(req).await.unwrap();

// read the compressed body
let mut body = res.into_body();
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
let compressed_data = data.freeze().to_vec();
let body = res.into_body();
let compressed_data = body.collect().await.unwrap().to_bytes();

// decompress the body
let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap();
Expand Down Expand Up @@ -215,12 +210,8 @@ mod tests {
);

// read the compressed body
let mut body = res.into_body();
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
let body = res.into_body();
let data = body.collect().await.unwrap().to_bytes();

// decompress the body
let data = {
Expand All @@ -237,8 +228,11 @@ mod tests {
assert_eq!(data, DATA.as_bytes());
}

async fn handle(_req: Request<Body>) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from("Hello, World!")))
async fn handle(_req: Request<Body>) -> Result<Response<WithTrailers<Body>>, Infallible> {
let mut trailers = HeaderMap::new();
trailers.insert(HeaderName::from_static("foo"), "bar".parse().unwrap());
let body = Body::from("Hello, World!").with_trailers(trailers);
Ok(Response::builder().body(body).unwrap())
}

#[tokio::test]
Expand All @@ -259,6 +253,7 @@ mod tests {
#[derive(Default, Clone)]
struct EveryOtherResponse(Arc<RwLock<u64>>);

#[allow(clippy::dbg_macro)]
impl Predicate for EveryOtherResponse {
fn should_compress<B>(&self, _: &http::Response<B>) -> bool
where
Expand All @@ -279,12 +274,8 @@ mod tests {
let res = svc.call(req).await.unwrap();

// read the uncompressed body
let mut body = res.into_body();
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
let body = res.into_body();
let data = body.collect().await.unwrap().to_bytes();
let still_uncompressed = String::from_utf8(data.to_vec()).unwrap();
assert_eq!(DATA, &still_uncompressed);

Expand All @@ -296,18 +287,14 @@ mod tests {
let res = svc.call(req).await.unwrap();

// read the compressed body
let mut body = res.into_body();
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
let body = res.into_body();
let data = body.collect().await.unwrap().to_bytes();
assert!(String::from_utf8(data.to_vec()).is_err());
}

#[tokio::test]
async fn doesnt_compress_images() {
async fn handle(_req: Request<Body>) -> Result<Response<Body>, Error> {
async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
let mut res = Response::new(Body::from(
"a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
));
Expand All @@ -332,7 +319,7 @@ mod tests {

#[tokio::test]
async fn does_compress_svg() {
async fn handle(_req: Request<Body>) -> Result<Response<Body>, Error> {
async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
let mut res = Response::new(Body::from(
"a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
));
Expand Down Expand Up @@ -377,13 +364,8 @@ mod tests {
let res = svc.call(req).await.unwrap();

// read the compressed body
let mut body = res.into_body();
let mut data = BytesMut::new();
while let Some(chunk) = body.data().await {
let chunk = chunk.unwrap();
data.extend_from_slice(&chunk[..]);
}
let compressed_data = data.freeze().to_vec();
let body = res.into_body();
let compressed_data = body.collect().await.unwrap().to_bytes();

// build the compressed body with the same quality level
let compressed_with_level = {
Expand All @@ -401,7 +383,7 @@ mod tests {
};

assert_eq!(
compressed_data.as_slice(),
compressed_data,
compressed_with_level.as_slice(),
"Compression level is not respected"
);
Expand Down
Loading

0 comments on commit 347f34d

Please sign in to comment.