Skip to content

Commit

Permalink
refactor: extract proxy logic
Browse files Browse the repository at this point in the history
  • Loading branch information
omjadas committed Sep 1, 2021
1 parent 77d996d commit 1294ed7
Show file tree
Hide file tree
Showing 2 changed files with 279 additions and 311 deletions.
327 changes: 16 additions & 311 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,22 @@
mod certificate_authority;
mod error;
mod proxy;
mod rewind;

use futures::{sink::SinkExt, stream::StreamExt};
use http::uri::PathAndQuery;
use hyper::{
client::HttpConnector,
server::conn::Http,
service::{make_service_fn, service_fn},
upgrade::Upgraded,
Body, Client, Method, Request, Response, Server,
Body, Client, Request, Response, Server,
};
use hyper_proxy::{Proxy as UpstreamProxy, ProxyConnector};
use hyper_rustls::HttpsConnector;
use log::*;
use rewind::Rewind;
use proxy::Proxy;
use rustls::ClientConfig;
use std::{convert::Infallible, future::Future, net::SocketAddr, sync::Arc};
use tokio::io::AsyncReadExt;
use tokio_rustls::TlsAcceptor;
use tokio_tungstenite::{connect_async, tungstenite::Message, WebSocketStream};
use std::{convert::Infallible, future::Future, net::SocketAddr};
use tokio_tungstenite::tungstenite::Message;

pub(crate) use rewind::Rewind;

pub use certificate_authority::CertificateAuthority;
pub use error::Error;
Expand Down Expand Up @@ -109,22 +105,6 @@ where
pub upstream_proxy: Option<UpstreamProxy>,
}

#[derive(Clone)]
struct ProxyState<R1, R2, W1, W2>
where
R1: RequestHandler,
R2: ResponseHandler,
W1: MessageHandler,
W2: MessageHandler,
{
pub ca: CertificateAuthority,
pub client: MaybeProxyClient,
pub request_handler: R1,
pub response_handler: R2,
pub incoming_message_handler: W1,
pub outgoing_message_handler: W2,
}

/// Attempts to start a proxy server using the provided configuration options.
///
/// This will fail if the proxy server is unable to be started.
Expand Down Expand Up @@ -158,17 +138,15 @@ where
let outgoing_message_handler = outgoing_message_handler.clone();
async move {
Ok::<_, Infallible>(service_fn(move |req| {
proxy(
ProxyState {
ca: ca.clone(),
client: client.clone(),
request_handler: request_handler.clone(),
response_handler: response_handler.clone(),
incoming_message_handler: incoming_message_handler.clone(),
outgoing_message_handler: outgoing_message_handler.clone(),
},
req,
)
Proxy {
ca: ca.clone(),
client: client.clone(),
request_handler: request_handler.clone(),
response_handler: response_handler.clone(),
incoming_message_handler: incoming_message_handler.clone(),
outgoing_message_handler: outgoing_message_handler.clone(),
}
.proxy(req)
}))
}
});
Expand Down Expand Up @@ -216,276 +194,3 @@ fn gen_client(upstream_proxy: Option<UpstreamProxy>) -> MaybeProxyClient {
)
}
}

async fn proxy<R1, R2, W1, W2>(
state: ProxyState<R1, R2, W1, W2>,
req: Request<Body>,
) -> Result<Response<Body>, hyper::Error>
where
R1: RequestHandler,
R2: ResponseHandler,
W1: MessageHandler,
W2: MessageHandler,
{
if req.method() == Method::CONNECT {
process_connect(state, req).await
} else {
process_request(state, req).await
}
}

async fn process_request<R1, R2, W1, W2>(
mut state: ProxyState<R1, R2, W1, W2>,
req: Request<Body>,
) -> Result<Response<Body>, hyper::Error>
where
R1: RequestHandler,
R2: ResponseHandler,
W1: MessageHandler,
W2: MessageHandler,
{
let req = match (state.request_handler)(req) {
RequestOrResponse::Request(req) => req,
RequestOrResponse::Response(res) => return Ok(res),
};

if hyper_tungstenite::is_upgrade_request(&req) {
let scheme =
if req.uri().scheme().unwrap_or(&http::uri::Scheme::HTTP) == &http::uri::Scheme::HTTP {
"ws"
} else {
"wss"
};

let uri = http::uri::Builder::new()
.scheme(scheme)
.authority(
req.uri()
.authority()
.expect("Authority not included in request")
.to_owned(),
)
.path_and_query(
req.uri()
.path_and_query()
.unwrap_or(&PathAndQuery::from_static("/"))
.to_owned(),
)
.build()
.expect("Failed to build URI for websocket connection");

let (res, websocket) =
hyper_tungstenite::upgrade(req, None).expect("Request has missing headers");

tokio::spawn(async move {
let server_socket = websocket
.await
.unwrap_or_else(|_| panic!("Failed to upgrade websocket connection for {}", uri));
handle_websocket(state, server_socket, &uri).await;
});

return Ok(res);
}

let res = match state.client {
MaybeProxyClient::Proxy(client) => client.request(req).await?,
MaybeProxyClient::Https(client) => client.request(req).await?,
};

Ok((state.response_handler)(res))
}

async fn process_connect<R1, R2, W1, W2>(
state: ProxyState<R1, R2, W1, W2>,
req: Request<Body>,
) -> Result<Response<Body>, hyper::Error>
where
R1: RequestHandler,
R2: ResponseHandler,
W1: MessageHandler,
W2: MessageHandler,
{
tokio::task::spawn(async move {
let authority = req
.uri()
.authority()
.expect("URI does not contain authority");
let server_config = Arc::new(state.ca.gen_server_config(authority).await);

match hyper::upgrade::on(req).await {
Ok(mut upgraded) => {
let mut buffer = [0; 4];
let bytes_read = upgraded
.read(&mut buffer)
.await
.expect("Failed to read from upgraded connection");

let upgraded = Rewind::new_buffered(
upgraded,
bytes::Bytes::copy_from_slice(buffer[..bytes_read].as_ref()),
);

if bytes_read == 4 && buffer == *b"GET " {
if let Err(e) = serve_websocket(state, upgraded).await {
error!("websocket connect error: {}", e);
}
} else {
let stream = TlsAcceptor::from(server_config)
.accept(upgraded)
.await
.expect("Failed to establish TLS connection with client");

if let Err(e) = serve_https(state, stream).await {
let e_string = e.to_string();
if !e_string.starts_with("error shutting down connection") {
error!("https connect error: {}", e);
}
}
}
}
Err(e) => error!("upgrade error: {}", e),
};
});

Ok(Response::new(Body::empty()))
}

async fn handle_websocket<R1, R2, W1, W2>(
ProxyState {
mut incoming_message_handler,
mut outgoing_message_handler,
..
}: ProxyState<R1, R2, W1, W2>,
server_socket: WebSocketStream<Upgraded>,
uri: &http::Uri,
) where
R1: RequestHandler,
R2: ResponseHandler,
W1: MessageHandler,
W2: MessageHandler,
{
let (client_socket, _) = connect_async(uri)
.await
.unwrap_or_else(|_| panic!("Failed to open websocket connection to {}", uri));

let (mut server_sink, mut server_stream) = server_socket.split();
let (mut client_sink, mut client_stream) = client_socket.split();

tokio::spawn(async move {
while let Some(message) = server_stream.next().await {
match message {
Ok(message) => {
let message = incoming_message_handler(message);
match client_sink.send(message).await {
Err(tungstenite::Error::ConnectionClosed) => (),
Err(e) => error!("websocket send error: {}", e),
_ => (),
}
}
Err(e) => error!("websocket message error: {}", e),
}
}
});

tokio::spawn(async move {
while let Some(message) = client_stream.next().await {
match message {
Ok(message) => {
let message = outgoing_message_handler(message);
match server_sink.send(message).await {
Err(tungstenite::Error::ConnectionClosed) => (),
Err(e) => error!("websocket send error: {}", e),
_ => (),
}
}
Err(e) => error!("websocket message error: {}", e),
}
}
});
}

async fn serve_websocket<R1, R2, W1, W2>(
state: ProxyState<R1, R2, W1, W2>,
stream: Rewind<Upgraded>,
) -> Result<(), hyper::Error>
where
R1: RequestHandler,
R2: ResponseHandler,
W1: MessageHandler,
W2: MessageHandler,
{
let service = service_fn(|req| {
let authority = req
.headers()
.get(http::header::HOST)
.expect("Host is a required header")
.to_str()
.expect("Failed to convert host to str");

let uri = http::uri::Builder::new()
.scheme(http::uri::Scheme::HTTP)
.authority(authority)
.path_and_query(
req.uri()
.path_and_query()
.unwrap_or(&PathAndQuery::from_static("/"))
.to_owned(),
)
.build()
.expect("Failed to build URI");

let (mut parts, body) = req.into_parts();
parts.uri = uri;
let req = Request::from_parts(parts, body);
process_request(state.clone(), req)
});

Http::new()
.serve_connection(stream, service)
.with_upgrades()
.await
}

async fn serve_https<R1, R2, W1, W2>(
state: ProxyState<R1, R2, W1, W2>,
stream: tokio_rustls::server::TlsStream<Rewind<Upgraded>>,
) -> Result<(), hyper::Error>
where
R1: RequestHandler,
R2: ResponseHandler,
W1: MessageHandler,
W2: MessageHandler,
{
let service = service_fn(|mut req| {
if req.version() == http::Version::HTTP_11 {
let authority = req
.headers()
.get(http::header::HOST)
.expect("Host is a required header")
.to_str()
.expect("Failed to convert host to str");

let uri = http::uri::Builder::new()
.scheme(http::uri::Scheme::HTTPS)
.authority(authority)
.path_and_query(
req.uri()
.path_and_query()
.unwrap_or(&PathAndQuery::from_static("/"))
.to_owned(),
)
.build()
.expect("Failed to build URI");

let (mut parts, body) = req.into_parts();
parts.uri = uri;
req = Request::from_parts(parts, body)
};

process_request(state.clone(), req)
});
Http::new()
.serve_connection(stream, service)
.with_upgrades()
.await
}
Loading

0 comments on commit 1294ed7

Please sign in to comment.