From dcbb26834153b84853d9757e25395d92a1314d4a Mon Sep 17 00:00:00 2001 From: Yashash H L Date: Thu, 12 Sep 2024 17:55:57 +0530 Subject: [PATCH] bidirectional source Signed-off-by: Yashash H L --- examples/simple-source/src/main.rs | 8 +- src/source.rs | 324 +++++++++++++++-------------- 2 files changed, 172 insertions(+), 160 deletions(-) diff --git a/examples/simple-source/src/main.rs b/examples/simple-source/src/main.rs index e1328bd..9127211 100644 --- a/examples/simple-source/src/main.rs +++ b/examples/simple-source/src/main.rs @@ -60,11 +60,9 @@ pub(crate) mod simple_source { self.yet_to_ack.write().unwrap().extend(message_offsets) } - async fn ack(&self, offsets: Vec) { - for offset in offsets { - let x = &String::from_utf8(offset.offset).unwrap(); - self.yet_to_ack.write().unwrap().remove(x); - } + async fn ack(&self, offset: Offset) { + let x = &String::from_utf8(offset.offset).unwrap(); + self.yet_to_ack.write().unwrap().remove(x); } async fn pending(&self) -> usize { diff --git a/src/source.rs b/src/source.rs index c4ce627..9196c37 100644 --- a/src/source.rs +++ b/src/source.rs @@ -29,8 +29,8 @@ pub mod proto { struct SourceService { handler: Arc, - _shutdown_tx: Sender<()>, - _cancellation_token: CancellationToken, + shutdown_tx: Sender<()>, + cancellation_token: CancellationToken, } // FIXME: remove async_trait @@ -98,78 +98,99 @@ where let handler_fn = Arc::clone(&self.handler); + let grpc_tx = tx.clone(); + let cln_token = self.cancellation_token.clone(); let grpc_read_handle: JoinHandle> = tokio::spawn(async move { - while let Some(read_request) = sr - .message() - .await - .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? - { - // tx,rx pair for sending data over to user-defined source - let (stx, mut srx) = mpsc::channel::(DEFAULT_CHANNEL_SIZE); - - let Some(request) = read_request.request else { - panic!("request cannot be empty"); - }; - - let grpc_resp_tx = tx.clone(); - // start the ud-source rx asynchronously and start populating the gRPC response, so it can be streamed to the gRPC client (numaflow). - let grpc_writer_handle: JoinHandle> = tokio::spawn(async move { - while let Some(resp) = srx.recv().await { - grpc_resp_tx - .send(Ok(proto::ReadResponse { - result: Some(proto::read_response::Result { - payload: resp.value, - offset: Some(proto::Offset { - offset: resp.offset.offset, - partition_id: resp.offset.partition_id, + loop { + tokio::select! { + read_request = sr.message() => { + let read_request = read_request + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? + .ok_or_else(|| SourceError(ErrorKind::InternalError("Stream closed".to_string())))?; + + // tx,rx pair for sending data over to user-defined source + let (stx, mut srx) = mpsc::channel::(DEFAULT_CHANNEL_SIZE); + + let Some(request) = read_request.request else { + panic!("request cannot be empty"); + }; + + let grpc_resp_tx = grpc_tx.clone(); + // start the ud-source rx asynchronously and start populating the gRPC response, so it can be streamed to the gRPC client (numaflow). + let grpc_writer_handle: JoinHandle> = tokio::spawn(async move { + while let Some(resp) = srx.recv().await { + grpc_resp_tx + .send(Ok(proto::ReadResponse { + result: Some(proto::read_response::Result { + payload: resp.value, + offset: Some(proto::Offset { + offset: resp.offset.offset, + partition_id: resp.offset.partition_id, + }), + event_time: prost_timestamp_from_utc(resp.event_time), + keys: resp.keys, + headers: Default::default(), + }), + status: None, + })) + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; + } + + grpc_resp_tx + .send(Ok(proto::ReadResponse { + result: None, + status: Some(proto::read_response::Status { + eot: true, + code: 0, + error: 0, + msg: None, }), - event_time: prost_timestamp_from_utc(resp.event_time), - keys: resp.keys, - headers: Default::default(), - }), - status: None, - })) + })) + .await + .map_err(|e| Error::SourceError(ErrorKind::InternalError(e.to_string())))?; + + Ok(()) + }); + + handler_fn + .read( + SourceReadRequest { + count: request.num_records as usize, + timeout: Duration::from_millis(request.timeout_in_ms as u64), + }, + stx, + ) + .await; + + grpc_writer_handle .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; } - - grpc_resp_tx - .send(Ok(proto::ReadResponse { - result: None, - status: Some(proto::read_response::Status { - eot: true, - code: 0, - error: 0, - msg: None, - }), - })) - .await - .map_err(|e| Error::SourceError(ErrorKind::InternalError(e.to_string())))?; - - Ok(()) - }); - - handler_fn - .read( - SourceReadRequest { - count: request.num_records as usize, - timeout: Duration::from_millis(request.timeout_in_ms as u64), - }, - stx, - ) - .await; - - grpc_writer_handle - .await - .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? - .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))?; + _ = cln_token.cancelled() => { + eprintln!("Cancellation token triggered, shutting down"); + break; + } + } } Ok(()) }); - // we want to start streaming to the server as soon as possible + let shutdown_tx = self.shutdown_tx.clone(); tokio::spawn(async move { - // user-defined source read handler + // wait for grpc read handle, if there are any errors write to the grpc response channel + if let Err(e) = grpc_read_handle.await { + tx.send(Err(Status::internal(e.to_string()))) + .await + .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string()))) + .expect("writing error to grpc response channel should never fail"); + + shutdown_tx + .send(()) + .await + .expect("write to shutdown channel should never fail"); + } }); Ok(Response::new(ReceiverStream::new(rx))) @@ -179,18 +200,26 @@ where &self, request: Request>, ) -> Result, Status> { - let mut acks = request.into_inner(); - while let Some(ack_request) = acks.message().await? { - let offset = ack_request.request.unwrap().offset; + let mut ack_stream = request.into_inner(); + while let Some(ack_request) = ack_stream.message().await? { + // the request is not there send back status as invalid argument + let Some(request) = ack_request.request else { + return Err(Status::invalid_argument("request is empty")); + }; + + let Some(offset) = request.offset else { + return Err(Status::invalid_argument("offset is not present")); + }; + self.handler .ack(Offset { - offset: offset.clone().unwrap().offset, - partition_id: offset.unwrap().partition_id, + offset: offset.clone().offset, + partition_id: offset.partition_id, }) .await; } Ok(Response::new(AckResponse { - result: Some(proto::ack_response::Result { success: None }), + result: Some(proto::ack_response::Result { success: Some(()) }), })) } @@ -312,8 +341,8 @@ impl Server { let source_service = SourceService { handler: Arc::new(handler), - _shutdown_tx: internal_shutdown_tx, - _cancellation_token: cln_token.clone(), + shutdown_tx: internal_shutdown_tx, + cancellation_token: cln_token.clone(), }; let source_svc = proto::source_server::SourceServer::new(source_service) @@ -354,18 +383,19 @@ impl Drop for Server { #[cfg(test)] mod tests { use super::{proto, Message, Offset, SourceReadRequest}; + use crate::source; use chrono::Utc; use std::collections::{HashMap, HashSet}; + use std::error::Error; + use std::time::Duration; use std::vec; - use std::{error::Error, time::Duration}; - - use crate::source; use tempfile::TempDir; use tokio::net::UnixStream; use tokio::sync::mpsc::Sender; - use tokio::sync::oneshot; - use tokio_stream::StreamExt; + use tokio::sync::{mpsc, oneshot}; + use tokio_stream::wrappers::ReceiverStream; use tonic::transport::Uri; + use tonic::Request; use tower::service_fn; use uuid::Uuid; @@ -413,13 +443,11 @@ mod tests { self.yet_to_ack.write().unwrap().extend(message_offsets) } - async fn ack(&self, offsets: Vec) { - for offset in offsets { - self.yet_to_ack - .write() - .unwrap() - .remove(&String::from_utf8(offset.offset).unwrap()); - } + async fn ack(&self, offset: Offset) { + self.yet_to_ack + .write() + .unwrap() + .remove(&String::from_utf8(offset.offset).unwrap()); } async fn pending(&self) -> usize { @@ -469,83 +497,69 @@ mod tests { .await?; let mut client = proto::source_client::SourceClient::new(channel); - let request = tonic::Request::new(proto::ReadRequest { + + // Test read_fn with bidirectional streaming + let (read_tx, read_rx) = mpsc::channel(4); + let read_request = proto::ReadRequest { request: Some(proto::read_request::Request { num_records: 5, - timeout_in_ms: 500, + timeout_in_ms: 1000, }), - }); + }; + read_tx.send(read_request).await.unwrap(); + drop(read_tx); // Close the sender to indicate no more requests - let resp = client.read_fn(request).await?; - let resp = resp.into_inner(); - let result: Vec = resp - .map(|item| item.unwrap().result.unwrap()) - .collect() - .await; - let response_values: Vec = result - .iter() - .map(|item| { - usize::from_le_bytes( - item.payload - .clone() - .try_into() - .expect("expected Vec length to be 8"), - ) - }) - .collect(); - assert_eq!(response_values, vec![8, 8, 8, 8, 8]); - - let pending_before_ack = client - .pending_fn(tonic::Request::new(())) - .await - .unwrap() - .into_inner(); - assert_eq!( - pending_before_ack.result.unwrap().count, - 5, - "Expected pending messages to be 5 before ACK" - ); - - let offsets_to_ack: Vec = result - .iter() - .map(|item| item.clone().offset.unwrap()) - .collect(); - let ack_request = tonic::Request::new(proto::AckRequest { - request: Some(proto::ack_request::Request { - offsets: offsets_to_ack, - }), - }); - let resp = client.ack_fn(ack_request).await.unwrap().into_inner(); - assert!( - resp.result.unwrap().success.is_some(), - "Expected acknowledgement request to be successful" - ); - - let pending_before_ack = client - .pending_fn(tonic::Request::new(())) - .await - .unwrap() + let mut response_stream = client + .read_fn(Request::new(ReceiverStream::new(read_rx))) + .await? .into_inner(); - assert_eq!( - pending_before_ack.result.unwrap().count, - 0, - "Expected pending messages to be 0 after ACK" - ); - - let partitions = client - .partitions_fn(tonic::Request::new(())) - .await - .unwrap() + let mut response_values = Vec::new(); + + while let Some(response) = response_stream.message().await? { + if let Some(status) = response.status { + if status.eot { + break; + } + } + + if let Some(result) = response.result { + response_values.push(result); + } + } + assert_eq!(response_values.len(), 5); + + // Test pending_fn + let pending_before_ack = client.pending_fn(Request::new(())).await?.into_inner(); + assert_eq!(pending_before_ack.result.unwrap().count, 5); + + // Test ack_fn with client-side streaming + let (ack_tx, ack_rx) = mpsc::channel(10); + for resp in response_values.iter() { + let ack_request = proto::AckRequest { + request: Some(proto::ack_request::Request { + offset: Some(proto::Offset { + offset: resp.offset.clone().unwrap().offset, + partition_id: resp.offset.clone().unwrap().partition_id, + }), + }), + }; + ack_tx.send(ack_request).await.unwrap(); + } + drop(ack_tx); // Close the sender to indicate no more requests + + let ack_response = client + .ack_fn(Request::new(ReceiverStream::new(ack_rx))) + .await? .into_inner(); - assert_eq!( - partitions.result.unwrap().partitions, - vec![2], - "Expected number of partitions to be 2" - ); - - shutdown_tx - .send(()) - .expect("Sending shutdown signal to gRPC server"); + assert!(ack_response.result.unwrap().success.is_some()); + + let pending_after_ack = client.pending_fn(Request::new(())).await?.into_inner(); + assert_eq!(pending_after_ack.result.unwrap().count, 0); + + let partitions = client.partitions_fn(Request::new(())).await?.into_inner(); + assert_eq!(partitions.result.unwrap().partitions, vec![2]); + + shutdown_tx.send(()).unwrap(); tokio::time::sleep(Duration::from_millis(50)).await; assert!(task.is_finished(), "gRPC server is still running"); Ok(())