Skip to content

Commit

Permalink
Move monero-serai to simple-request
Browse files Browse the repository at this point in the history
Deduplicates code across the entire repo, letting us make improvements in a
single place.
  • Loading branch information
kayabaNerve committed Nov 6, 2023
1 parent b680bb5 commit 84a0bca
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 79 deletions.
3 changes: 1 addition & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions coins/monero/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ base58-monero = { version = "2", default-features = false, features = ["check"]

# Used for the provided HTTP RPC
digest_auth = { version = "0.3", default-features = false, optional = true }
# Deprecated here means to enable deprecated warnings, not to restore deprecated APIs
hyper = { version = "0.14", default-features = false, features = ["http1", "tcp", "client", "backports", "deprecated"], optional = true }
hyper-rustls = { version = "0.24", default-features = false, features = ["http1", "native-tokio"], optional = true }
simple-request = { path = "../../common/request", version = "0.1", default-features = false, optional = true }
tokio = { version = "1", default-features = false, optional = true }

[build-dependencies]
Expand Down Expand Up @@ -102,7 +100,7 @@ std = [
"base58-monero/std",
]

http-rpc = ["digest_auth", "hyper", "hyper-rustls", "tokio/time", "tokio/rt"]
http-rpc = ["digest_auth", "simple-request", "tokio"]
multisig = ["transcript", "frost", "dleq", "std"]
binaries = ["tokio/rt-multi-thread", "tokio/macros", "http-rpc"]
experimental = []
Expand Down
108 changes: 41 additions & 67 deletions coins/monero/src/rpc/http.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
use core::str::FromStr;
use std::io::Read;

use async_trait::async_trait;

use digest_auth::AuthContext;
use hyper::{
Uri, header::HeaderValue, Request, service::Service, client::connect::HttpConnector, Client,
use simple_request::{
hyper::{header::HeaderValue, Request},
Client,
};
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};

use crate::rpc::{RpcError, RpcConnection, Rpc};

#[derive(Clone, Debug)]
enum Authentication {
// If unauthenticated, reuse a single client
Unauthenticated(Client<HttpsConnector<HttpConnector>>),
Unauthenticated(Client),
// If authenticated, don't reuse clients so that each connection makes its own connection
// This ensures that if a nonce is requested, another caller doesn't make a request invalidating
// it
// We could acquire a mutex over the client, yet creating a new client is preferred for the
// possibility of parallelism
Authenticated(HttpsConnector<HttpConnector>, String, String),
Authenticated { username: String, password: String },
}

/// An HTTP(S) transport for the RPC.
Expand All @@ -37,9 +37,6 @@ impl HttpRpc {
/// A daemon requiring authentication can be used via including the username and password in the
/// URL.
pub fn new(mut url: String) -> Result<Rpc<HttpRpc>, RpcError> {
let https_builder =
HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1().build();

let authentication = if url.contains('@') {
// Parse out the username and password
let url_clone = url;
Expand All @@ -64,13 +61,12 @@ impl HttpRpc {
if split_userpass.len() > 2 {
Err(RpcError::ConnectionError("invalid amount of passwords".to_string()))?;
}
Authentication::Authenticated(
https_builder,
split_userpass[0].to_string(),
split_userpass.get(1).unwrap_or(&"").to_string(),
)
Authentication::Authenticated {
username: split_userpass[0].to_string(),
password: split_userpass.get(1).unwrap_or(&"").to_string(),
}
} else {
Authentication::Unauthenticated(Client::builder().build(https_builder))
Authentication::Unauthenticated(Client::with_connection_pool())
};

Ok(Rpc(HttpRpc { authentication, url }))
Expand All @@ -79,45 +75,26 @@ impl HttpRpc {

impl HttpRpc {
async fn inner_post(&self, route: &str, body: Vec<u8>) -> Result<Vec<u8>, RpcError> {
let request = |uri| {
Request::post(uri)
.header(hyper::header::HOST, {
let mut host = self.url.clone();
if let Some(protocol_pos) = host.find("://") {
host.drain(0 .. (protocol_pos + 3));
}
host
})
.body(body.clone().into())
.unwrap()
};
let request = |uri| Request::post(uri).body(body.clone().into()).unwrap();

let mut connection_task_handle = None;
let mut connection = None;
let response = match &self.authentication {
Authentication::Unauthenticated(client) => client
.request(request(self.url.clone() + "/" + route))
.await
.map_err(|e| RpcError::ConnectionError(e.to_string()))?,
Authentication::Authenticated(https_builder, user, pass) => {
let connection = https_builder
.clone()
.call(
self
.url
.parse()
.map_err(|e: <Uri as FromStr>::Err| RpcError::ConnectionError(e.to_string()))?,
)
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?,
Authentication::Authenticated { username, password } => {
// This Client will drop and replace its connection on error, when monero-serai requires
// a single socket for the lifetime of this function
// Since dropping the connection will raise an error, and this function aborts on any
// error, this is fine
let client = Client::without_connection_pool(self.url.clone())
.map_err(|_| RpcError::ConnectionError("invalid URL".to_string()))?;
let mut response = client
.request(request("/".to_string() + route))
.await
.map_err(|e| RpcError::ConnectionError(e.to_string()))?;
let (mut requester, connection) = hyper::client::conn::http1::handshake(connection)
.await
.map_err(|e| RpcError::ConnectionError(e.to_string()))?;
let connection_task = tokio::spawn(connection);
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?;

let mut response = requester
.send_request(request("/".to_string() + route))
.await
.map_err(|e| RpcError::ConnectionError(e.to_string()))?;
// Only provide authentication if this daemon actually expects it
if let Some(header) = response.headers().get("www-authenticate") {
let mut request = request("/".to_string() + route);
Expand All @@ -131,8 +108,8 @@ impl HttpRpc {
)
.map_err(|_| RpcError::InvalidNode("invalid digest-auth response"))?
.respond(&AuthContext::new_post::<_, _, _, &[u8]>(
user,
pass,
username,
password,
"/".to_string() + route,
None,
))
Expand All @@ -142,19 +119,16 @@ impl HttpRpc {
.unwrap(),
);

// Wait for the connection to be ready again
requester.ready().await.map_err(|e| RpcError::ConnectionError(e.to_string()))?;

// Make the request with the response challenge
response = requester
.send_request(request)
response = client
.request(request)
.await
.map_err(|e| RpcError::ConnectionError(e.to_string()))?;

// Also embed the requester so it's not dropped, causing the connection to close
connection_task_handle = Some((requester, connection_task.abort_handle()));
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?;
}

// Store the client so it's not dropped yet
connection = Some(client);

response
}
};
Expand All @@ -177,19 +151,19 @@ impl HttpRpc {
let mut body = response.into_body();
while res.len() < length {
let Some(data) = body.data().await else { break };
res.extend(data.map_err(|e| RpcError::ConnectionError(e.to_string()))?.as_ref());
res.extend(data.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?.as_ref());
}
*/

let res = hyper::body::to_bytes(response.into_body())
let mut res = Vec::with_capacity(128);
response
.body()
.await
.map_err(|e| RpcError::ConnectionError(e.to_string()))?
.to_vec();
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?
.read_to_end(&mut res)
.unwrap();

if let Some((_, connection_task)) = connection_task_handle {
// Clean up the connection task
connection_task.abort();
}
drop(connection);

Ok(res)
}
Expand All @@ -201,6 +175,6 @@ impl RpcConnection for HttpRpc {
// TODO: Make this timeout configurable
tokio::time::timeout(core::time::Duration::from_secs(30), self.inner_post(route, body))
.await
.map_err(|e| RpcError::ConnectionError(e.to_string()))?
.map_err(|e| RpcError::ConnectionError(format!("{e:?}")))?
}
}
98 changes: 92 additions & 6 deletions common/request/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
#![doc = include_str!("../README.md")]

use std::sync::Arc;

use tokio::sync::Mutex;

use hyper_rustls::{HttpsConnectorBuilder, HttpsConnector};
use hyper::{header::HeaderValue, client::HttpConnector};
use hyper::{
Uri,
header::HeaderValue,
body::Body,
service::Service,
client::{HttpConnector, conn::http1::SendRequest},
};
pub use hyper;

mod request;
Expand All @@ -14,12 +24,20 @@ pub use response::*;
#[derive(Debug)]
pub enum Error {
InvalidUri,
MissingHost,
InconsistentHost,
SslError,
Hyper(hyper::Error),
}

#[derive(Clone, Debug)]
enum Connection {
ConnectionPool(hyper::Client<HttpsConnector<HttpConnector>>),
Connection {
https_builder: HttpsConnector<HttpConnector>,
host: Uri,
connection: Arc<Mutex<Option<SendRequest<Body>>>>,
},
}

#[derive(Clone, Debug)]
Expand All @@ -38,22 +56,90 @@ impl Client {
}
}

/*
fn without_connection_pool() -> Client {}
*/
pub fn without_connection_pool(host: String) -> Result<Client, Error> {
Ok(Client {
connection: Connection::Connection {
https_builder: HttpsConnectorBuilder::new()
.with_native_roots()
.https_or_http()
.enable_http1()
.build(),
host: {
let uri: Uri = host.parse().map_err(|_| Error::InvalidUri)?;
if uri.host().is_none() {
Err(Error::MissingHost)?;
};
uri
},
connection: Arc::new(Mutex::new(None)),
},
})
}

pub async fn request<R: Into<Request>>(&self, request: R) -> Result<Response, Error> {
let request: Request = request.into();
let mut request = request.0;
if request.headers().get(hyper::header::HOST).is_none() {
let host = request.uri().host().ok_or(Error::InvalidUri)?.to_string();
if let Some(header_host) = request.headers().get(hyper::header::HOST) {
match &self.connection {
Connection::ConnectionPool(_) => {}
Connection::Connection { host, .. } => {
if header_host.to_str().map_err(|_| Error::InvalidUri)? != host.host().unwrap() {
Err(Error::InconsistentHost)?;
}
}
}
} else {
let host = match &self.connection {
Connection::ConnectionPool(_) => {
request.uri().host().ok_or(Error::MissingHost)?.to_string()
}
Connection::Connection { host, .. } => {
let host_str = host.host().unwrap();
if let Some(uri_host) = request.uri().host() {
if host_str != uri_host {
Err(Error::InconsistentHost)?;
}
}
host_str.to_string()
}
};
request
.headers_mut()
.insert(hyper::header::HOST, HeaderValue::from_str(&host).map_err(|_| Error::InvalidUri)?);
}

Ok(Response(match &self.connection {
Connection::ConnectionPool(client) => client.request(request).await.map_err(Error::Hyper)?,
Connection::Connection { https_builder, host, connection } => {
let mut connection_lock = connection.lock().await;

// If there's not a connection...
if connection_lock.is_none() {
let (requester, connection) = hyper::client::conn::http1::handshake(
https_builder.clone().call(host.clone()).await.map_err(|_| Error::SslError)?,
)
.await
.map_err(Error::Hyper)?;
// This will die when we drop the requester, so we don't need to track an AbortHandle for
// it
tokio::spawn(connection);
*connection_lock = Some(requester);
}

let connection = connection_lock.as_mut().unwrap();
let mut err = connection.ready().await.err();
if err.is_none() {
// Send the request
let res = connection.send_request(request).await;
if let Ok(res) = res {
return Ok(Response(res));
}
err = res.err();
}
// Since this connection has been put into an error state, drop it
*connection_lock = None;
Err(Error::Hyper(err.unwrap()))?
}
}))
}
}

0 comments on commit 84a0bca

Please sign in to comment.