From 26cd40531497edb8e4af9b88f1d69f3765c0d1bc Mon Sep 17 00:00:00 2001 From: omjadas Date: Wed, 15 Sep 2021 19:41:14 +1000 Subject: [PATCH] feat: provide decode_response function --- Cargo.toml | 2 + examples/log.rs | 2 +- examples/noop.rs | 2 +- examples/upstream_proxy.rs | 2 +- src/decoder.rs | 105 +++++++++++++++++++++++++++++++++++++ src/error.rs | 2 + src/lib.rs | 2 + 7 files changed, 114 insertions(+), 3 deletions(-) create mode 100644 src/decoder.rs diff --git a/Cargo.toml b/Cargo.toml index 5e2c239..57b1a24 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ categories = ["network-programming"] exclude = [".github/"] [dependencies] +async-compression = { version = "0.3", features = ["tokio", "brotli", "gzip", "zlib"] } async-trait = "0.1" bytes = "1" chrono = "0.4" @@ -31,6 +32,7 @@ thiserror = "1" tokio = { version = "1", features = ["full"] } tokio-rustls = "0.22" tokio-tungstenite = { version = "0.15", features = ["rustls-tls"] } +tokio-util = { version = "0.6", features = ["io"] } webpki-roots = "0.21" [dev-dependencies] diff --git a/examples/log.rs b/examples/log.rs index ab12ad9..49f69b7 100644 --- a/examples/log.rs +++ b/examples/log.rs @@ -11,7 +11,7 @@ use std::net::SocketAddr; async fn shutdown_signal() { tokio::signal::ctrl_c() .await - .expect("failed to install CTRL+C signal handler"); + .expect("Failed to install CTRL+C signal handler"); } #[derive(Clone)] diff --git a/examples/noop.rs b/examples/noop.rs index f03839a..a2646c1 100644 --- a/examples/noop.rs +++ b/examples/noop.rs @@ -11,7 +11,7 @@ use std::net::SocketAddr; async fn shutdown_signal() { tokio::signal::ctrl_c() .await - .expect("failed to install CTRL+C signal handler"); + .expect("Failed to install CTRL+C signal handler"); } #[derive(Clone)] diff --git a/examples/upstream_proxy.rs b/examples/upstream_proxy.rs index c7898af..7724851 100644 --- a/examples/upstream_proxy.rs +++ b/examples/upstream_proxy.rs @@ -13,7 +13,7 @@ use std::net::SocketAddr; async fn shutdown_signal() { tokio::signal::ctrl_c() .await - .expect("failed to install CTRL+C signal handler"); + .expect("Failed to install CTRL+C signal handler"); } #[derive(Clone)] diff --git a/src/decoder.rs b/src/decoder.rs new file mode 100644 index 0000000..a90f55f --- /dev/null +++ b/src/decoder.rs @@ -0,0 +1,105 @@ +use crate::Error; +use async_compression::tokio::bufread::{BrotliDecoder, GzipDecoder, ZlibDecoder}; +use bytes::Bytes; +use futures::{Stream, TryStreamExt}; +use http::header::{CONTENT_ENCODING, CONTENT_LENGTH}; +use hyper::{Body, Error as HyperError, Response}; +use std::{ + io, + io::Error as IoError, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncBufRead, AsyncRead, BufReader}; +use tokio_util::io::{ReaderStream, StreamReader}; + +struct IoStream> + Unpin>(T); + +impl> + Unpin> Stream for IoStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match futures::ready!(Pin::new(&mut self.0).poll_next(cx)) { + Some(Ok(chunk)) => Poll::Ready(Some(Ok(chunk))), + Some(Err(err)) => { + Poll::Ready(Some(Err(IoError::new(io::ErrorKind::Other, err)))) + } + None => Poll::Ready(None), + } + } +} + +enum Decoder { + Body(Body), + Decoder(Box), +} + +impl Decoder { + pub fn decode(self, encoding: &str) -> Result { + let reader: Box = match self { + Decoder::Body(body) => Box::new(StreamReader::new(IoStream(body.into_stream()))), + Decoder::Decoder(decoder) => Box::new(BufReader::new(decoder)), + }; + + let decoder: Box = match encoding { + "gzip" | "x-gzip" => Box::new(GzipDecoder::new(reader)), + "deflate" => Box::new(ZlibDecoder::new(reader)), + "br" => Box::new(BrotliDecoder::new(reader)), + _ => return Err(Error::Decode), + }; + + Ok(Decoder::Decoder(decoder)) + } + + pub fn into_inner(self) -> Result, Error> { + match self { + Decoder::Body(_) => Err(Error::Decode), + Decoder::Decoder(decoder) => Ok(decoder), + } + } +} + +/// Decode the body of a response. +/// +/// This will fail if either of the `content-encoding` or `content-length` headers are unable to be +/// parsed, or if one of the values specified in the `content-encoding` header is not supported. +pub fn decode_response(res: Response) -> Result, Error> { + let (mut parts, body) = res.into_parts(); + let mut encodings: Vec = vec![]; + + for val in parts.headers.get_all(CONTENT_ENCODING) { + match val.to_str() { + Ok(val) => { + encodings.extend(val.split(',').map(|v| String::from(v.trim()))); + } + Err(_) => return Err(Error::Decode), + } + } + + parts.headers.remove(CONTENT_ENCODING); + + if let Some(val) = parts.headers.remove(CONTENT_LENGTH) { + match val.to_str() { + Ok("0") => return Ok(Response::from_parts(parts, body)), + Err(_) => return Err(Error::Decode), + _ => (), + } + } + + if encodings.is_empty() { + return Ok(Response::from_parts(parts, body)); + } + + let mut decoder = Decoder::Body(body); + + while let Some(encoding) = encodings.pop() { + decoder = decoder.decode(&encoding)?; + } + + Ok(Response::from_parts( + parts, + Body::wrap_stream(ReaderStream::new( + decoder.into_inner().expect("Should not be Err"), + )), + )) +} diff --git a/src/error.rs b/src/error.rs index 1d81bf2..c6b0cb5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -7,6 +7,8 @@ pub enum Error { Tls(#[from] RcgenError), #[error("network error")] Network(#[from] hyper::Error), + #[error("unable to decode response body")] + Decode, #[error("unknown error")] Unknown, } diff --git a/src/lib.rs b/src/lib.rs index c3aeff3..4709180 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ //! - Modify websocket messages mod certificate_authority; +mod decoder; mod error; mod proxy; mod rewind; @@ -26,6 +27,7 @@ pub(crate) use rewind::Rewind; pub use async_trait; pub use certificate_authority::CertificateAuthority; +pub use decoder::decode_response; pub use error::Error; pub use hyper; pub use hyper_proxy;