Skip to content

Commit

Permalink
Rust: add a IPC/RPC client manager
Browse files Browse the repository at this point in the history
Signed-off-by: Tao He <[email protected]>
  • Loading branch information
sighingnow committed Aug 30, 2023
1 parent 12888eb commit c2d36ad
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 8 deletions.
49 changes: 49 additions & 0 deletions rust/vineyard/src/client/ipc_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use std::collections::HashMap;
use std::io;
use std::net::Shutdown;
use std::os::unix::net::UnixStream;
use std::sync::Arc;
use std::sync::Mutex;

use arrow_buffer::Buffer;
use parking_lot::ReentrantMutex;
Expand Down Expand Up @@ -194,6 +196,9 @@ impl Drop for IPCClient {
}
}

unsafe impl Send for IPCClient {}
unsafe impl Sync for IPCClient {}

impl Client for IPCClient {
fn disconnect(&mut self) {
if !self.connected() {
Expand Down Expand Up @@ -433,3 +438,47 @@ impl IPCClient {
return Ok(object);
}
}

pub struct IPCClientManager {}

impl IPCClientManager {
pub fn get_default() -> Result<Arc<Mutex<IPCClient>>> {
let default_ipc_socket = std::env::var(VINEYARD_IPC_SOCKET_KEY)?;
return IPCClientManager::get(default_ipc_socket);
}

pub fn get<S: Into<String>>(socket: S) -> Result<Arc<Mutex<IPCClient>>> {
let mut clients = IPCClientManager::get_clients().lock()?;
let socket = socket.into();
if let Some(client) = clients.get(&socket) {
return Ok(client.clone());
}
let client = Arc::new(Mutex::new(IPCClient::connect(&socket)?));
clients.insert(socket, client.clone());
return Ok(client);
}

pub fn close<S: Into<String>>(socket: S) -> Result<()> {
let mut clients = IPCClientManager::get_clients().lock()?;
let socket = socket.into();
if let Some(client) = clients.get(&socket) {
if Arc::strong_count(client) == 1 {
clients.remove(&socket);
}
return Ok(());
} else {
return Err(VineyardError::invalid(format!(
"Failed to close the client due to the unknown socket: {}",
socket
)));
}
}

fn get_clients() -> &'static Arc<Mutex<HashMap<String, Arc<Mutex<IPCClient>>>>> {
lazy_static! {
static ref CONNECTED_CLIENTS: Arc<Mutex<HashMap<String, Arc<Mutex<IPCClient>>>>> =
Arc::new(Mutex::new(HashMap::new()));
}
return &CONNECTED_CLIENTS;
}
}
9 changes: 9 additions & 0 deletions rust/vineyard/src/client/ipc_client_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,13 @@ mod tests {

return Ok(());
}

#[test]
fn test_ipc_client_manager() -> Result<()> {
let client = IPCClientManager::get_default()?;
let mut client = client.lock()?;
assert!(client.connected());

return Ok(());
}
}
76 changes: 68 additions & 8 deletions rust/vineyard/src/client/rpc_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;
use std::io;
use std::net::{Shutdown, TcpStream};
use std::sync::{Arc, Mutex};

use parking_lot::{ReentrantMutex, ReentrantMutexGuard};

Expand Down Expand Up @@ -142,18 +144,14 @@ impl Client for RPCClient {
}
}

unsafe impl Send for RPCClient {}
unsafe impl Sync for RPCClient {}

impl RPCClient {
#[allow(clippy::should_implement_trait)]
pub fn default() -> Result<RPCClient> {
let rpc_endpoint = std::env::var(VINEYARD_RPC_ENDPOINT_KEY)?;
let (host, port) = match rpc_endpoint.rfind(':') {
Some(idx) => (
&rpc_endpoint[..idx],
rpc_endpoint[idx + 1..].parse::<u16>()?,
),
None => (rpc_endpoint.as_str(), DEFAULT_RPC_PORT),
};
return RPCClient::connect(host, port);
return RPCClient::connect_with_endpoint(rpc_endpoint.as_str());
}

pub fn connect(host: &str, port: u16) -> Result<RPCClient> {
Expand All @@ -179,4 +177,66 @@ impl RPCClient {
lock: ReentrantMutex::new(()),
});
}

pub fn connect_with_endpoint(endpoint: &str) -> Result<RPCClient> {
let (host, port) = match endpoint.rfind(':') {
Some(idx) => (&endpoint[..idx], endpoint[idx + 1..].parse::<u16>()?),
None => (endpoint, DEFAULT_RPC_PORT),
};
return RPCClient::connect(host, port);
}
}

pub struct RPCClientManager {}

impl RPCClientManager {
pub fn get_default() -> Result<Arc<Mutex<RPCClient>>> {
let default_rpc_endpoint = std::env::var(VINEYARD_RPC_ENDPOINT_KEY)?;
return RPCClientManager::get_with_endpoint(default_rpc_endpoint);
}

pub fn get<S: Into<String>>(host: &str, port: u16) -> Result<Arc<Mutex<RPCClient>>> {
let endpoint = format!("{}:{}", host, port);
return RPCClientManager::get_with_endpoint(endpoint);
}

pub fn get_with_endpoint<S: Into<String>>(endpoint: S) -> Result<Arc<Mutex<RPCClient>>> {
let mut clients: std::sync::MutexGuard<'_, _> = RPCClientManager::get_clients().lock()?;
let endpoint: String = endpoint.into();
if let Some(client) = clients.get(endpoint.as_str()) {
return Ok(client.clone());
}
let client = Arc::new(Mutex::new(RPCClient::connect_with_endpoint(&endpoint)?));
clients.insert(endpoint, client.clone());
return Ok(client);
}

pub fn close<S: Into<String>>(host: &str, port: u16) -> Result<()> {
let endpoint = format!("{}:{}", host, port);
return RPCClientManager::close_with_endpoint(endpoint);
}

pub fn close_with_endpoint<S: Into<String>>(endpoint: S) -> Result<()> {
let mut clients = RPCClientManager::get_clients().lock()?;
let endpoint = endpoint.into();
if let Some(client) = clients.get(&endpoint) {
if Arc::strong_count(client) == 1 {
clients.remove(&endpoint);
}
return Ok(());
} else {
return Err(VineyardError::invalid(format!(
"Failed to close the client due to the unknown endpoint: {}",
endpoint
)));
}
}

fn get_clients() -> &'static Arc<Mutex<HashMap<String, Arc<Mutex<RPCClient>>>>> {
lazy_static! {
static ref CONNECTED_CLIENTS: Arc<Mutex<HashMap<String, Arc<Mutex<RPCClient>>>>> =
Arc::new(Mutex::new(HashMap::new()));
}
return &CONNECTED_CLIENTS;
}
}
1 change: 1 addition & 0 deletions rust/vineyard/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#![allow(clippy::nonminimal_bool)]
#![allow(clippy::not_unsafe_ptr_arg_deref)]
#![allow(clippy::redundant_field_names)]
#![allow(clippy::type_complexity)]
#![allow(clippy::unnecessary_cast)]
#![allow(clippy::vec_box)]
#![allow(incomplete_features)]
Expand Down

0 comments on commit c2d36ad

Please sign in to comment.