diff --git a/examples/simple-source/Makefile b/examples/simple-source/Makefile index b7769e8..80ba511 100644 --- a/examples/simple-source/Makefile +++ b/examples/simple-source/Makefile @@ -10,10 +10,9 @@ update: .PHONY: image image: update - cd ../../ && docker build \ + cd ../../ && docker buildx build \ -f ${DOCKER_FILE_PATH} \ - -t ${IMAGE_REGISTRY} . - @if [ "$(PUSH)" = "true" ]; then docker push ${IMAGE_REGISTRY}; fi + -t ${IMAGE_REGISTRY} . --platform linux/amd64,linux/arm64 --push .PHONY: clean clean: diff --git a/examples/simple-source/src/main.rs b/examples/simple-source/src/main.rs index 9127211..8e78970 100644 --- a/examples/simple-source/src/main.rs +++ b/examples/simple-source/src/main.rs @@ -2,14 +2,14 @@ #[tokio::main] async fn main() -> Result<(), Box> { - let source_handle = simple_source::SimpleSource::new("Hello World!".to_string()); + let source_handle = simple_source::SimpleSource::new(); numaflow::source::Server::new(source_handle).start().await } pub(crate) mod simple_source { - use std::{collections::HashSet, sync::RwLock}; - use chrono::Utc; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::{collections::HashSet, sync::RwLock}; use tokio::sync::mpsc::Sender; use numaflow::source::{Message, Offset, SourceReadRequest, Sourcer}; @@ -18,15 +18,15 @@ pub(crate) mod simple_source { /// or Atomics to provide concurrent access. Numaflow actually does not require concurrent access but we are forced to do this because the SDK /// does not provide a mutable reference as explained in [`Sourcer`] pub(crate) struct SimpleSource { - payload: String, yet_to_ack: RwLock>, + counter: AtomicUsize, } impl SimpleSource { - pub(crate) fn new(payload: String) -> Self { + pub(crate) fn new() -> Self { Self { - payload, yet_to_ack: RwLock::new(HashSet::new()), + counter: AtomicUsize::new(0), } } } @@ -42,9 +42,10 @@ pub(crate) mod simple_source { let mut message_offsets = Vec::with_capacity(request.count); for i in 0..request.count { let offset = format!("{}-{}", event_time.timestamp_nanos_opt().unwrap(), i); + let payload = self.counter.fetch_add(1, Ordering::Relaxed).to_string(); transmitter .send(Message { - value: format!("{}-{}", self.payload, event_time).into_bytes(), + value: payload.into_bytes(), event_time, offset: Offset { offset: offset.clone().into_bytes(), diff --git a/proto/source.proto b/proto/source.proto index dcaf253..8878ac6 100644 --- a/proto/source.proto +++ b/proto/source.proto @@ -7,9 +7,10 @@ package source.v1; service Source { // Read returns a stream of datum responses. - // The size of the returned ReadResponse is less than or equal to the num_records specified in each ReadRequest. - // If the request timeout is reached on the server side, the returned ReadResponse will contain all the datum that have been read (which could be an empty list). + // The size of the returned responses is less than or equal to the num_records specified in each ReadRequest. + // If the request timeout is reached on the server side, the returned responses will contain all the datum that have been read (which could be an empty list). // The server will continue to read and respond to subsequent ReadRequests until the client closes the stream. + // Once it has sent all the datum, the server will send a ReadResponse with the end of transmission flag set to true. rpc ReadFn(stream ReadRequest) returns (stream ReadResponse); // AckFn acknowledges a stream of datum offsets. @@ -17,7 +18,8 @@ service Source { // The caller (numa) expects the AckFn to be successful, and it does not expect any errors. // If there are some irrecoverable errors when the callee (UDSource) is processing the AckFn request, // then it is best to crash because there are no other retry mechanisms possible. - rpc AckFn(stream AckRequest) returns (AckResponse); + // Clients sends n requests and expects n responses. + rpc AckFn(stream AckRequest) returns (stream AckResponse); // PendingFn returns the number of pending records at the user defined source. rpc PendingFn(google.protobuf.Empty) returns (PendingResponse); @@ -29,6 +31,14 @@ service Source { rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); } +/* + * Handshake message between client and server to indicate the start of transmission. + */ +message Handshake { + // Required field indicating the start of transmission. + bool sot = 1; +} + /* * ReadRequest is the request for reading datum stream from user defined source. */ @@ -43,6 +53,7 @@ message ReadRequest { } // Required field indicating the request. Request request = 1; + optional Handshake handshake = 2; } /* @@ -82,14 +93,15 @@ message ReadResponse { // End of transmission flag. bool eot = 1; Code code = 2; - Error error = 3; + optional Error error = 3; optional string msg = 4; } // Required field holding the result. Result result = 1; // Status of the response. Holds the end of transmission flag and the status code. - // Status status = 2; + // Handshake message between client and server to indicate the start of transmission. + optional Handshake handshake = 3; } /* @@ -103,6 +115,7 @@ message AckRequest { } // Required field holding the request. The list will be ordered and will have the same order as the original Read response. Request request = 1; + optional Handshake handshake = 2; } /* @@ -122,6 +135,8 @@ message AckResponse { } // Required field holding the result. Result result = 1; + // Handshake message between client and server to indicate the start of transmission. + optional Handshake handshake = 2; } /* @@ -170,4 +185,4 @@ message Offset { // It is useful for sources that have multiple partitions. e.g. Kafka. // If the partition_id is not specified, it is assumed that the source has a single partition. int32 partition_id = 2; -} +} \ No newline at end of file diff --git a/src/source.rs b/src/source.rs index 94cc841..a36ca96 100644 --- a/src/source.rs +++ b/src/source.rs @@ -107,6 +107,7 @@ where headers: Default::default(), }), status: None, + handshake: None, })) .await .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; @@ -119,9 +120,10 @@ where status: Some(proto::read_response::Status { eot: true, code: 0, - error: 0, + error: None, msg: None, }), + handshake: None, })) .await .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; @@ -140,7 +142,7 @@ where let (stx, srx) = mpsc::channel::(DEFAULT_CHANNEL_SIZE); // spawn the rx side so that when the handler is invoked, we can stream the handler's read data - // to the gprc response stream. + // to the grpc response stream. let grpc_writer_handle: JoinHandle> = tokio::spawn(async move { Self::write_a_batch(grpc_resp_tx, srx).await }); @@ -172,7 +174,6 @@ where T: Sourcer + Send + Sync + 'static, { type ReadFnStream = ReceiverStream>; - async fn read_fn( &self, request: Request>, @@ -189,6 +190,28 @@ where let cln_token = self.cancellation_token.clone(); + // do the handshake first to let the client know that we are ready to receive read requests. + let handshake_request = sr + .message() + .await + .map_err(|e| Status::internal(format!("handshake failed {}", e)))? + .ok_or_else(|| Status::internal("stream closed before handshake"))?; + + if let Some(handshake) = handshake_request.handshake { + grpc_tx + .send(Ok(ReadResponse { + result: None, + status: None, + handshake: Some(handshake), + })) + .await + .map_err(|e| { + Status::internal(format!("failed to send handshake response {}", e)) + })?; + } else { + return Err(Status::invalid_argument("Handshake not present")); + } + // this is the top-level stream consumer and this task will only exit when stream is closed (which // will happen when server and client are shutting down). let grpc_read_handle: JoinHandle> = tokio::spawn(async move { @@ -242,31 +265,98 @@ where Ok(Response::new(ReceiverStream::new(rx))) } + type AckFnStream = ReceiverStream>; + async fn ack_fn( &self, request: Request>, - ) -> Result, Status> { + ) -> Result, Status> { let mut ack_stream = request.into_inner(); - while let Some(ack_request) = ack_stream.message().await? { - // the request is not there send back status as invalid argument - let Some(request) = ack_request.request else { - return Err(Status::invalid_argument("request is empty")); - }; + let (ack_tx, ack_rx) = mpsc::channel::>(DEFAULT_CHANNEL_SIZE); - let Some(offset) = request.offset else { - return Err(Status::invalid_argument("offset is not present")); - }; + let handler_fn = Arc::clone(&self.handler); - self.handler - .ack(Offset { - offset: offset.clone().offset, - partition_id: offset.partition_id, - }) - .await; + // do the handshake first to let the client know that we are ready to receive ack requests. + let handshake_request = ack_stream + .message() + .await + .map_err(|e| Status::internal(format!("handshake failed {}", e)))? + .ok_or_else(|| Status::internal("stream closed before handshake"))?; + + let ack_resp_tx = ack_tx.clone(); + if let Some(handshake) = handshake_request.handshake { + ack_resp_tx + .send(Ok(AckResponse { + result: None, + handshake: Some(handshake), + })) + .await + .map_err(|e| { + Status::internal(format!("failed to send handshake response {}", e)) + })?; + } else { + return Err(Status::invalid_argument("Handshake not present")); } - Ok(Response::new(AckResponse { - result: Some(proto::ack_response::Result { success: Some(()) }), - })) + + let cln_token = self.cancellation_token.clone(); + let grpc_read_handle: JoinHandle> = tokio::spawn(async move { + loop { + tokio::select! { + _ = cln_token.cancelled() => { + info!("Cancellation token triggered, shutting down"); + break; + } + ack_request = ack_stream.message() => { + let ack_request = ack_request + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? + .ok_or_else(|| SourceError(ErrorKind::InternalError("Stream closed".to_string())))?; + + let request = ack_request.request + .ok_or_else(|| SourceError(ErrorKind::InternalError("Invalid request, request is empty".to_string())))?; + + let offset = request.offset + .ok_or_else(|| SourceError(ErrorKind::InternalError("Invalid request, offset is empty".to_string())))?; + + handler_fn + .ack(Offset { + offset: offset.offset, + partition_id: offset.partition_id, + }) + .await; + + // the return of handler_fn implicitly means that the ack is successful; hence + // we are able to send success. There is no path for failure. + ack_resp_tx + .send(Ok(AckResponse { + result: Some(proto::ack_response::Result { success: Some(()) }), + handshake: None, + })) + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; + } + } + } + Ok(()) + }); + + let shutdown_tx = self.shutdown_tx.clone(); + tokio::spawn(async move { + if let Err(e) = grpc_read_handle.await { + error!("shutting down the gRPC ack channel, {}", e); + ack_tx + .send(Err(Status::internal(e.to_string()))) + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string()))) + .expect("writing error to grpc response channel should never fail"); + + shutdown_tx + .send(()) + .await + .expect("write to shutdown channel should never fail"); + } + }); + + Ok(Response::new(ReceiverStream::new(ack_rx))) } async fn pending_fn(&self, _: Request<()>) -> Result, Status> { @@ -528,11 +618,8 @@ mod tests { tokio::time::sleep(Duration::from_millis(50)).await; - // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs - // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")? .connect_with_connector(service_fn(move |_: Uri| { - // https://rust-lang.github.io/async-book/03_async_await/01_chapter.html#async-lifetimes let sock_file = sock_file.clone(); async move { Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( @@ -546,11 +633,18 @@ mod tests { // Test read_fn with bidirectional streaming let (read_tx, read_rx) = mpsc::channel(4); + let handshake_request = proto::ReadRequest { + request: None, + handshake: Some(proto::Handshake { sot: true }), + }; + read_tx.send(handshake_request).await.unwrap(); + let read_request = proto::ReadRequest { request: Some(proto::read_request::Request { num_records: 5, timeout_in_ms: 1000, }), + handshake: None, }; read_tx.send(read_request).await.unwrap(); drop(read_tx); // Close the sender to indicate no more requests @@ -580,6 +674,11 @@ mod tests { // Test ack_fn with client-side streaming let (ack_tx, ack_rx) = mpsc::channel(10); + let ack_handshake_request = proto::AckRequest { + request: None, + handshake: Some(proto::Handshake { sot: true }), + }; + ack_tx.send(ack_handshake_request).await.unwrap(); for resp in response_values.iter() { let ack_request = proto::AckRequest { request: Some(proto::ack_request::Request { @@ -588,16 +687,24 @@ mod tests { partition_id: resp.offset.clone().unwrap().partition_id, }), }), + handshake: None, }; ack_tx.send(ack_request).await.unwrap(); } drop(ack_tx); // Close the sender to indicate no more requests - let ack_response = client + let mut ack_response = client .ack_fn(Request::new(ReceiverStream::new(ack_rx))) .await? .into_inner(); - assert!(ack_response.result.unwrap().success.is_some()); + + // first response will be the handshake response + let ack_handshake_response = ack_response.message().await?.unwrap(); + assert!(ack_handshake_response.handshake.unwrap().sot); + + for _ in 0..5 { + assert!(ack_response.message().await?.is_some()); + } let pending_after_ack = client.pending_fn(Request::new(())).await?.into_inner(); assert_eq!(pending_after_ack.result.unwrap().count, 0);