Skip to content

Commit

Permalink
jsonrpc: use ServerConfig and ClientConfig as the inner field in
Browse files Browse the repository at this point in the history
`ServerBuilder` and `ClientBuilder`
  • Loading branch information
hozan23 committed Jun 22, 2024
1 parent 0a2c0db commit 6c793e7
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 219 deletions.
131 changes: 27 additions & 104 deletions jsonrpc/src/client/builder.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,17 @@
use std::sync::{atomic::AtomicBool, Arc};
use std::sync::Arc;

use karyon_core::async_util::TaskGroup;
use karyon_net::{Conn, Endpoint, ToEndpoint};
#[cfg(feature = "tcp")]
use karyon_net::Endpoint;
use karyon_net::ToEndpoint;

#[cfg(feature = "tls")]
use karyon_net::{async_rustls::rustls, tls::ClientTlsConfig};

#[cfg(feature = "ws")]
use karyon_net::ws::ClientWsConfig;

#[cfg(all(feature = "ws", feature = "tls"))]
use karyon_net::ws::ClientWssConfig;

#[cfg(feature = "ws")]
use crate::codec::WsJsonCodec;
use karyon_net::async_rustls::rustls;

use crate::Result;
#[cfg(feature = "tcp")]
use crate::TcpConfig;
use crate::{Error, TcpConfig};

use crate::{codec::JsonCodec, Error, Result};

use super::{Client, MessageDispatcher, Subscriptions};
use super::{Client, ClientConfig};

const DEFAULT_TIMEOUT: u64 = 3000; // 3s

Expand All @@ -44,26 +35,22 @@ impl Client {
pub fn builder(endpoint: impl ToEndpoint) -> Result<ClientBuilder> {
let endpoint = endpoint.to_endpoint()?;
Ok(ClientBuilder {
endpoint,
timeout: Some(DEFAULT_TIMEOUT),
#[cfg(feature = "tcp")]
tcp_config: Default::default(),
#[cfg(feature = "tls")]
tls_config: None,
subscription_buffer_size: DEFAULT_MAX_SUBSCRIPTION_BUFFER_SIZE,
inner: ClientConfig {
endpoint,
timeout: Some(DEFAULT_TIMEOUT),
#[cfg(feature = "tcp")]
tcp_config: Default::default(),
#[cfg(feature = "tls")]
tls_config: None,
subscription_buffer_size: DEFAULT_MAX_SUBSCRIPTION_BUFFER_SIZE,
},
})
}
}

/// Builder for constructing an RPC [`Client`].
pub struct ClientBuilder {
endpoint: Endpoint,
#[cfg(feature = "tcp")]
tcp_config: TcpConfig,
#[cfg(feature = "tls")]
tls_config: Option<(rustls::ClientConfig, String)>,
timeout: Option<u64>,
subscription_buffer_size: usize,
inner: ClientConfig,
}

impl ClientBuilder {
Expand All @@ -82,7 +69,7 @@ impl ClientBuilder {
/// };
/// ```
pub fn set_timeout(mut self, timeout: u64) -> Self {
self.timeout = Some(timeout);
self.inner.timeout = Some(timeout);
self
}

Expand All @@ -106,7 +93,7 @@ impl ClientBuilder {
/// };
/// ```
pub fn set_max_subscription_buffer_size(mut self, size: usize) -> Self {
self.subscription_buffer_size = size;
self.inner.subscription_buffer_size = size;
self
}

Expand All @@ -128,12 +115,12 @@ impl ClientBuilder {
/// This function will return an error if the endpoint does not support TCP protocols.
#[cfg(feature = "tcp")]
pub fn tcp_config(mut self, config: TcpConfig) -> Result<Self> {
match self.endpoint {
match self.inner.endpoint {
Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => {
self.tcp_config = config;
self.inner.tcp_config = config;
Ok(self)
}
_ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
_ => Err(Error::UnsupportedProtocol(self.inner.endpoint.to_string())),
}
}

Expand All @@ -157,14 +144,14 @@ impl ClientBuilder {
/// This function will return an error if the endpoint does not support TLS protocols.
#[cfg(feature = "tls")]
pub fn tls_config(mut self, config: rustls::ClientConfig, dns_name: &str) -> Result<Self> {
match self.endpoint {
match self.inner.endpoint {
Endpoint::Tls(..) | Endpoint::Wss(..) => {
self.tls_config = Some((config, dns_name.to_string()));
self.inner.tls_config = Some((config, dns_name.to_string()));
Ok(self)
}
_ => Err(Error::UnsupportedProtocol(format!(
"Invalid tls config for endpoint: {}",
self.endpoint
self.inner.endpoint
))),
}
}
Expand All @@ -189,71 +176,7 @@ impl ClientBuilder {
///
/// ```
pub async fn build(self) -> Result<Arc<Client>> {
let conn: Conn<serde_json::Value> = match self.endpoint {
#[cfg(feature = "tcp")]
Endpoint::Tcp(..) => Box::new(
karyon_net::tcp::dial(&self.endpoint, self.tcp_config, JsonCodec {}).await?,
),
#[cfg(feature = "tls")]
Endpoint::Tls(..) => match self.tls_config {
Some((conf, dns_name)) => Box::new(
karyon_net::tls::dial(
&self.endpoint,
ClientTlsConfig {
dns_name,
client_config: conf,
tcp_config: self.tcp_config,
},
JsonCodec {},
)
.await?,
),
None => return Err(Error::TLSConfigRequired),
},
#[cfg(feature = "ws")]
Endpoint::Ws(..) => {
let config = ClientWsConfig {
tcp_config: self.tcp_config,
wss_config: None,
};
Box::new(karyon_net::ws::dial(&self.endpoint, config, WsJsonCodec {}).await?)
}
#[cfg(all(feature = "ws", feature = "tls"))]
Endpoint::Wss(..) => match self.tls_config {
Some((conf, dns_name)) => Box::new(
karyon_net::ws::dial(
&self.endpoint,
ClientWsConfig {
tcp_config: self.tcp_config,
wss_config: Some(ClientWssConfig {
dns_name,
client_config: conf,
}),
},
WsJsonCodec {},
)
.await?,
),
None => return Err(Error::TLSConfigRequired),
},
#[cfg(all(feature = "unix", target_family = "unix"))]
Endpoint::Unix(..) => Box::new(
karyon_net::unix::dial(&self.endpoint, Default::default(), JsonCodec {}).await?,
),
_ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
};

let send_chan = async_channel::bounded(10);

let client = Arc::new(Client {
timeout: self.timeout,
disconnect: AtomicBool::new(false),
send_chan,
message_dispatcher: MessageDispatcher::new(),
subscriptions: Subscriptions::new(self.subscription_buffer_size),
task_group: TaskGroup::new(),
});
client.start_background_loop(conn);
let client = Client::init(self.inner).await?;
Ok(client)
}
}
119 changes: 111 additions & 8 deletions jsonrpc/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,26 @@ use log::{debug, error};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::json;

#[cfg(feature = "tcp")]
use karyon_net::tcp::TcpConfig;
#[cfg(feature = "ws")]
use karyon_net::ws::ClientWsConfig;
#[cfg(all(feature = "ws", feature = "tls"))]
use karyon_net::ws::ClientWssConfig;
#[cfg(feature = "tls")]
use karyon_net::{async_rustls::rustls, tls::ClientTlsConfig};
use karyon_net::{Conn, Endpoint};

use karyon_core::{
async_util::{select, timeout, Either, TaskGroup, TaskResult},
util::random_32,
};
use karyon_net::Conn;

#[cfg(feature = "ws")]
use crate::codec::WsJsonCodec;

use crate::{
codec::JsonCodec,
message::{self, SubscriptionID},
Error, Result,
};
Expand All @@ -32,14 +45,24 @@ use subscriptions::Subscriptions;

type RequestID = u32;

struct ClientConfig {
endpoint: Endpoint,
#[cfg(feature = "tcp")]
tcp_config: TcpConfig,
#[cfg(feature = "tls")]
tls_config: Option<(rustls::ClientConfig, String)>,
timeout: Option<u64>,
subscription_buffer_size: usize,
}

/// Represents an RPC client
pub struct Client {
timeout: Option<u64>,
disconnect: AtomicBool,
message_dispatcher: MessageDispatcher,
task_group: TaskGroup,
send_chan: (Sender<serde_json::Value>, Receiver<serde_json::Value>),
subscriptions: Arc<Subscriptions>,
send_chan: (Sender<serde_json::Value>, Receiver<serde_json::Value>),
task_group: TaskGroup,
config: ClientConfig,
}

#[derive(Serialize, Deserialize)]
Expand Down Expand Up @@ -96,6 +119,11 @@ impl Client {
Ok(())
}

/// Disconnect the client
pub async fn stop(&self) {
self.task_group.cancel().await;
}

async fn send_request<T: Serialize + DeserializeOwned>(
&self,
method: &str,
Expand All @@ -116,7 +144,7 @@ impl Client {
let rx = self.message_dispatcher.register(id).await;

// Wait for the message dispatcher to send the response
let result = match self.timeout {
let result = match self.config.timeout {
Some(t) => timeout(Duration::from_millis(t), rx.recv()).await?,
None => rx.recv().await,
};
Expand Down Expand Up @@ -152,11 +180,86 @@ impl Client {
Ok(())
}

async fn init(config: ClientConfig) -> Result<Arc<Self>> {
let client = Arc::new(Client {
disconnect: AtomicBool::new(false),
subscriptions: Subscriptions::new(config.subscription_buffer_size),
send_chan: async_channel::bounded(10),
message_dispatcher: MessageDispatcher::new(),
task_group: TaskGroup::new(),
config,
});

let conn = client.connect().await?;
client.start_background_loop(conn);
Ok(client)
}

async fn connect(self: &Arc<Self>) -> Result<Conn<serde_json::Value>> {
let endpoint = self.config.endpoint.clone();
let conn: Conn<serde_json::Value> = match endpoint {
#[cfg(feature = "tcp")]
Endpoint::Tcp(..) => Box::new(
karyon_net::tcp::dial(&endpoint, self.config.tcp_config.clone(), JsonCodec {})
.await?,
),
#[cfg(feature = "tls")]
Endpoint::Tls(..) => match &self.config.tls_config {
Some((conf, dns_name)) => Box::new(
karyon_net::tls::dial(
&self.config.endpoint,
ClientTlsConfig {
dns_name: dns_name.to_string(),
client_config: conf.clone(),
tcp_config: self.config.tcp_config.clone(),
},
JsonCodec {},
)
.await?,
),
None => return Err(Error::TLSConfigRequired),
},
#[cfg(feature = "ws")]
Endpoint::Ws(..) => {
let config = ClientWsConfig {
tcp_config: self.config.tcp_config.clone(),
wss_config: None,
};
Box::new(karyon_net::ws::dial(&endpoint, config, WsJsonCodec {}).await?)
}
#[cfg(all(feature = "ws", feature = "tls"))]
Endpoint::Wss(..) => match &self.config.tls_config {
Some((conf, dns_name)) => Box::new(
karyon_net::ws::dial(
&endpoint,
ClientWsConfig {
tcp_config: self.config.tcp_config.clone(),
wss_config: Some(ClientWssConfig {
dns_name: dns_name.clone(),
client_config: conf.clone(),
}),
},
WsJsonCodec {},
)
.await?,
),
None => return Err(Error::TLSConfigRequired),
},
#[cfg(all(feature = "unix", target_family = "unix"))]
Endpoint::Unix(..) => {
Box::new(karyon_net::unix::dial(&endpoint, Default::default(), JsonCodec {}).await?)
}
_ => return Err(Error::UnsupportedProtocol(endpoint.to_string())),
};

Ok(conn)
}

fn start_background_loop(self: &Arc<Self>, conn: Conn<serde_json::Value>) {
let selfc = self.clone();
let on_complete = |result: TaskResult<Result<()>>| async move {
if let TaskResult::Completed(Err(err)) = result {
error!("Background loop stopped: {err}");
error!("Client stopped: {err}");
}
selfc.disconnect.store(true, Ordering::Relaxed);
selfc.subscriptions.clear().await;
Expand Down Expand Up @@ -201,8 +304,8 @@ impl Client {
self.subscriptions.notify(nt).await
}
},
Err(_) => {
error!("Receive unexpected msg: {msg}");
Err(err) => {
error!("Receive unexpected msg {msg}: {err}");
Err(Error::InvalidMsg("Unexpected msg"))
}
}
Expand Down
2 changes: 1 addition & 1 deletion jsonrpc/src/client/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl Subscription {
}

pub async fn recv(&self) -> Result<Value> {
self.rx.recv().await.map_err(Error::from)
self.rx.recv().await.map_err(|_| Error::SubscriptionClosed)
}

pub fn id(&self) -> SubscriptionID {
Expand Down
3 changes: 3 additions & 0 deletions jsonrpc/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ pub enum Error {
#[error("Subscription exceeds the maximum buffer size")]
SubscriptionBufferFull,

#[error("Subscription closed")]
SubscriptionClosed,

#[error("ClientDisconnected")]
ClientDisconnected,

Expand Down
Loading

0 comments on commit 6c793e7

Please sign in to comment.