From 867ea8c798abec3075d6bce48a81c402f42c04a8 Mon Sep 17 00:00:00 2001 From: Sidhant Kohli Date: Sat, 10 Aug 2024 09:26:42 -0700 Subject: [PATCH] chore: update Serverinfo to use struct (#71) Signed-off-by: Sidhant Kohli --- src/batchmap.rs | 5 +-- src/map.rs | 6 +-- src/reduce.rs | 2 +- src/shared.rs | 88 ++++++++++++++++++++++++++++++------------ src/sideinput.rs | 2 +- src/sink.rs | 2 +- src/source.rs | 2 +- src/sourcetransform.rs | 2 +- 8 files changed, 73 insertions(+), 36 deletions(-) diff --git a/src/batchmap.rs b/src/batchmap.rs index d82f22a..43f43e2 100644 --- a/src/batchmap.rs +++ b/src/batchmap.rs @@ -471,10 +471,9 @@ impl crate::batchmap::Server { where T: BatchMapper + Send + Sync + 'static, { - let mut info = shared::default_info_file(); + let mut info = shared::ServerInfo::default(); // update the info json metadata field, and add the map mode - info["metadata"][shared::MAP_MODE_KEY] = - serde_json::Value::String(shared::BATCH_MAP.to_string()); + info.set_metadata(shared::MAP_MODE_KEY, shared::BATCH_MAP); let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file, info)?; let handler = self.svc.take().unwrap(); diff --git a/src/map.rs b/src/map.rs index 5e43c56..3499828 100644 --- a/src/map.rs +++ b/src/map.rs @@ -308,10 +308,10 @@ impl Server { where T: Mapper + Send + Sync + 'static, { - let mut info = shared::default_info_file(); + let mut info = shared::ServerInfo::default(); // update the info json metadata field, and add the map mode key value pair - info["metadata"][shared::MAP_MODE_KEY] = - serde_json::Value::String(shared::UNARY_MAP.to_string()); + info.set_metadata(shared::MAP_MODE_KEY, shared::UNARY_MAP); + let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file, info)?; let handler = self.svc.take().unwrap(); diff --git a/src/reduce.rs b/src/reduce.rs index 7e51c51..938e25c 100644 --- a/src/reduce.rs +++ b/src/reduce.rs @@ -820,7 +820,7 @@ impl Server { let listener = shared::create_listener_stream( &self.sock_addr, &self.server_info_file, - shared::default_info_file(), + shared::ServerInfo::default(), )?; let creator = self.creator.take().unwrap(); let (internal_shutdown_tx, internal_shutdown_rx) = channel(1); diff --git a/src/shared.rs b/src/shared.rs index 4f03ae6..8fab51c 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -4,6 +4,7 @@ use std::{collections::HashMap, io}; use chrono::{DateTime, TimeZone, Timelike, Utc}; use prost_types::Timestamp; +use serde::{Deserialize, Serialize}; use tokio::net::UnixListener; use tokio::signal; use tokio::sync::{mpsc, oneshot}; @@ -16,35 +17,73 @@ pub(crate) const STREAM_MAP: &str = "stream-map"; pub(crate) const BATCH_MAP: &str = "batch-map"; const MINIMUM_NUMAFLOW_VERSION: &str = "1.2.0-rc4"; -// default_info_file is a function to get a default server info json -// file content. This is used to write the server info file. -// This function is used in the write_info_file function. -// This function is not exposed to the user. -pub fn default_info_file() -> serde_json::Value { - let metadata: HashMap = HashMap::new(); - serde_json::json!({ - "protocol": "uds", - "language": "rust", - "version": "0.0.1", - "metadata": metadata, - "minimum-numaflow-version": MINIMUM_NUMAFLOW_VERSION, - }) +// ServerInfo structure to store server-related information +#[derive(Serialize, Deserialize, Debug)] +pub(crate) struct ServerInfo { + #[serde(default)] + protocol: String, + #[serde(default)] + language: String, + #[serde(default)] + minimum_numaflow_version: String, + #[serde(default)] + version: String, + #[serde(default)] + metadata: Option>, // Metadata is optional +} +impl ServerInfo { + // default_info_file is a function to get a default server info json + // file content. This is used to write the server info file. + // This function is used in the write_info_file function. + // This function is not exposed to the user. + pub fn default() -> Self { + let metadata: HashMap = HashMap::new(); + // Return the default server info json content + // Create a ServerInfo object with default values + ServerInfo { + protocol: "uds".to_string(), + language: "rust".to_string(), + minimum_numaflow_version: MINIMUM_NUMAFLOW_VERSION.to_string(), + version: "0.0.1".to_string(), + metadata: Option::from(metadata), + } + } + + // Check if the struct is empty + pub fn is_empty(&self) -> bool { + self.protocol.is_empty() + && self.language.is_empty() + && self.minimum_numaflow_version.is_empty() + && self.version.is_empty() + && self.metadata.is_none() + } + + // Set metadata key-value pair + pub fn set_metadata(&mut self, key: &str, value: &str) { + if let Some(metadata) = &mut self.metadata { + metadata.insert(key.to_string(), value.to_string()); + } else { + let mut metadata = HashMap::new(); + metadata.insert(key.to_string(), value.to_string()); + self.metadata = Some(metadata); + } + } } + // #[tracing::instrument(skip(path), fields(path = ?path.as_ref()))] #[tracing::instrument(fields(path = ? path.as_ref()))] -fn write_info_file(path: impl AsRef, mut server_info: serde_json::Value) -> io::Result<()> { +fn write_info_file(path: impl AsRef, mut server_info: ServerInfo) -> io::Result<()> { let parent = path.as_ref().parent().unwrap(); fs::create_dir_all(parent)?; // TODO: make port-number and CPU meta-data configurable, e.g., ("CPU_LIMIT", "1") - - // if server_info object is not provided, use the default one - if server_info.is_null() { - server_info = default_info_file(); + // If the server_info is empty, set it to the default + if server_info.is_empty() { + server_info = ServerInfo::default(); } - // Convert to a string of JSON and print it out - let content = format!("{}U+005C__END__", server_info); + let serialized = serde_json::to_string(&server_info).unwrap(); + let content = format!("{}U+005C__END__", serialized); info!(content, "Writing to file"); fs::write(path, content) } @@ -52,7 +91,7 @@ fn write_info_file(path: impl AsRef, mut server_info: serde_json::Value) - pub(crate) fn create_listener_stream( socket_file: impl AsRef, server_info_file: impl AsRef, - server_info: serde_json::Value, + server_info: ServerInfo, ) -> Result> { write_info_file(server_info_file, server_info) .map_err(|e| format!("writing info file: {e:?}"))?; @@ -170,10 +209,9 @@ mod tests { let temp_file = NamedTempFile::new()?; // Get a default server info file content - // let server_info = default_info_file(); - let mut info = default_info_file(); + let mut info = ServerInfo::default(); // update the info json metadata field, and add the map mode key value pair - info["metadata"][MAP_MODE_KEY] = serde_json::Value::String(BATCH_MAP.to_string()); + info.set_metadata(MAP_MODE_KEY, BATCH_MAP); // Call write_info_file with the path of the temporary file write_info_file(temp_file.path(), info)?; @@ -188,7 +226,7 @@ mod tests { assert!(contents.contains(r#""language":"rust""#)); assert!(contents.contains(r#""version":"0.0.1""#)); assert!(contents.contains(r#""metadata":{"MAP_MODE":"batch-map"}"#)); - assert!(contents.contains(r#""minimum-numaflow-version":"1.2.0-rc4""#)); + assert!(contents.contains(r#""minimum_numaflow_version":"1.2.0-rc4""#)); Ok(()) } diff --git a/src/sideinput.rs b/src/sideinput.rs index 3c82706..3ea9bbf 100644 --- a/src/sideinput.rs +++ b/src/sideinput.rs @@ -197,7 +197,7 @@ impl Server { let listener = shared::create_listener_stream( &self.sock_addr, &self.server_info_file, - shared::default_info_file(), + shared::ServerInfo::default(), )?; let handler = self.svc.take().unwrap(); let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); diff --git a/src/sink.rs b/src/sink.rs index 2e5eaea..2cda073 100644 --- a/src/sink.rs +++ b/src/sink.rs @@ -326,7 +326,7 @@ impl Server { let listener = shared::create_listener_stream( &self.sock_addr, &self.server_info_file, - shared::default_info_file(), + shared::ServerInfo::default(), )?; let handler = self.svc.take().unwrap(); let cln_token = CancellationToken::new(); diff --git a/src/source.rs b/src/source.rs index fa54be0..b19d1ec 100644 --- a/src/source.rs +++ b/src/source.rs @@ -267,7 +267,7 @@ impl Server { let listener = shared::create_listener_stream( &self.sock_addr, &self.server_info_file, - shared::default_info_file(), + shared::ServerInfo::default(), )?; let handler = self.svc.take().unwrap(); let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); diff --git a/src/sourcetransform.rs b/src/sourcetransform.rs index 1372586..25f06c7 100644 --- a/src/sourcetransform.rs +++ b/src/sourcetransform.rs @@ -337,7 +337,7 @@ impl Server { let listener = shared::create_listener_stream( &self.sock_addr, &self.server_info_file, - shared::default_info_file(), + shared::ServerInfo::default(), )?; let handler = self.svc.take().unwrap(); let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1);