Skip to content

Commit

Permalink
feat: provide decode_response function
Browse files Browse the repository at this point in the history
  • Loading branch information
omjadas committed Sep 15, 2021
1 parent b355656 commit 26cd405
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 3 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ categories = ["network-programming"]
exclude = [".github/"]

[dependencies]
async-compression = { version = "0.3", features = ["tokio", "brotli", "gzip", "zlib"] }
async-trait = "0.1"
bytes = "1"
chrono = "0.4"
Expand All @@ -31,6 +32,7 @@ thiserror = "1"
tokio = { version = "1", features = ["full"] }
tokio-rustls = "0.22"
tokio-tungstenite = { version = "0.15", features = ["rustls-tls"] }
tokio-util = { version = "0.6", features = ["io"] }
webpki-roots = "0.21"

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion examples/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::net::SocketAddr;
async fn shutdown_signal() {
tokio::signal::ctrl_c()
.await
.expect("failed to install CTRL+C signal handler");
.expect("Failed to install CTRL+C signal handler");
}

#[derive(Clone)]
Expand Down
2 changes: 1 addition & 1 deletion examples/noop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::net::SocketAddr;
async fn shutdown_signal() {
tokio::signal::ctrl_c()
.await
.expect("failed to install CTRL+C signal handler");
.expect("Failed to install CTRL+C signal handler");
}

#[derive(Clone)]
Expand Down
2 changes: 1 addition & 1 deletion examples/upstream_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::net::SocketAddr;
async fn shutdown_signal() {
tokio::signal::ctrl_c()
.await
.expect("failed to install CTRL+C signal handler");
.expect("Failed to install CTRL+C signal handler");
}

#[derive(Clone)]
Expand Down
105 changes: 105 additions & 0 deletions src/decoder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use crate::Error;
use async_compression::tokio::bufread::{BrotliDecoder, GzipDecoder, ZlibDecoder};
use bytes::Bytes;
use futures::{Stream, TryStreamExt};
use http::header::{CONTENT_ENCODING, CONTENT_LENGTH};
use hyper::{Body, Error as HyperError, Response};
use std::{
io,
io::Error as IoError,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncBufRead, AsyncRead, BufReader};
use tokio_util::io::{ReaderStream, StreamReader};

struct IoStream<T: Stream<Item = Result<Bytes, HyperError>> + Unpin>(T);

impl<T: Stream<Item = Result<Bytes, HyperError>> + Unpin> Stream for IoStream<T> {
type Item = Result<Bytes, IoError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match futures::ready!(Pin::new(&mut self.0).poll_next(cx)) {
Some(Ok(chunk)) => Poll::Ready(Some(Ok(chunk))),
Some(Err(err)) => {
Poll::Ready(Some(Err(IoError::new(io::ErrorKind::Other, err))))
}
None => Poll::Ready(None),
}
}
}

enum Decoder {
Body(Body),
Decoder(Box<dyn AsyncRead + Send + Unpin>),
}

impl Decoder {
pub fn decode(self, encoding: &str) -> Result<Self, Error> {
let reader: Box<dyn AsyncBufRead + Send + Unpin> = match self {
Decoder::Body(body) => Box::new(StreamReader::new(IoStream(body.into_stream()))),
Decoder::Decoder(decoder) => Box::new(BufReader::new(decoder)),
};

let decoder: Box<dyn AsyncRead + Send + Unpin> = match encoding {
"gzip" | "x-gzip" => Box::new(GzipDecoder::new(reader)),
"deflate" => Box::new(ZlibDecoder::new(reader)),
"br" => Box::new(BrotliDecoder::new(reader)),
_ => return Err(Error::Decode),
};

Ok(Decoder::Decoder(decoder))
}

pub fn into_inner(self) -> Result<Box<dyn AsyncRead + Send + Unpin>, Error> {
match self {
Decoder::Body(_) => Err(Error::Decode),
Decoder::Decoder(decoder) => Ok(decoder),
}
}
}

/// Decode the body of a response.
///
/// This will fail if either of the `content-encoding` or `content-length` headers are unable to be
/// parsed, or if one of the values specified in the `content-encoding` header is not supported.
pub fn decode_response(res: Response<Body>) -> Result<Response<Body>, Error> {
let (mut parts, body) = res.into_parts();
let mut encodings: Vec<String> = vec![];

for val in parts.headers.get_all(CONTENT_ENCODING) {
match val.to_str() {
Ok(val) => {
encodings.extend(val.split(',').map(|v| String::from(v.trim())));
}
Err(_) => return Err(Error::Decode),
}
}

parts.headers.remove(CONTENT_ENCODING);

if let Some(val) = parts.headers.remove(CONTENT_LENGTH) {
match val.to_str() {
Ok("0") => return Ok(Response::from_parts(parts, body)),
Err(_) => return Err(Error::Decode),
_ => (),
}
}

if encodings.is_empty() {
return Ok(Response::from_parts(parts, body));
}

let mut decoder = Decoder::Body(body);

while let Some(encoding) = encodings.pop() {
decoder = decoder.decode(&encoding)?;
}

Ok(Response::from_parts(
parts,
Body::wrap_stream(ReaderStream::new(
decoder.into_inner().expect("Should not be Err"),
)),
))
}
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ pub enum Error {
Tls(#[from] RcgenError),
#[error("network error")]
Network(#[from] hyper::Error),
#[error("unable to decode response body")]
Decode,
#[error("unknown error")]
Unknown,
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
//! - Modify websocket messages
mod certificate_authority;
mod decoder;
mod error;
mod proxy;
mod rewind;
Expand All @@ -26,6 +27,7 @@ pub(crate) use rewind::Rewind;

pub use async_trait;
pub use certificate_authority::CertificateAuthority;
pub use decoder::decode_response;
pub use error::Error;
pub use hyper;
pub use hyper_proxy;
Expand Down

0 comments on commit 26cd405

Please sign in to comment.