diff --git a/proto/sourcetransform.proto b/proto/sourcetransform.proto index 18e045c..90a2a64 100644 --- a/proto/sourcetransform.proto +++ b/proto/sourcetransform.proto @@ -9,23 +9,38 @@ service SourceTransform { // SourceTransformFn applies a function to each request element. // In addition to map function, SourceTransformFn also supports assigning a new event time to response. // SourceTransformFn can be used only at source vertex by source data transformer. - rpc SourceTransformFn(SourceTransformRequest) returns (SourceTransformResponse); + rpc SourceTransformFn(stream SourceTransformRequest) returns (stream SourceTransformResponse); // IsReady is the heartbeat endpoint for gRPC. rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); } +/* + * Handshake message between client and server to indicate the start of transmission. + */ + message Handshake { + // Required field indicating the start of transmission. + bool sot = 1; +} + /** * SourceTransformerRequest represents a request element. */ message SourceTransformRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - map headers = 5; + message Request { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + map headers = 5; + // This ID is used to uniquely identify a transform request + string id = 6; + } + Request request = 1; + optional Handshake handshake = 2; } + /** * SourceTransformerResponse represents a response element. */ @@ -37,6 +52,10 @@ message SourceTransformResponse { repeated string tags = 4; } repeated Result results = 1; + // This ID is used to refer the responses to the request it corresponds to. + string id = 2; + // Handshake message between client and server to indicate the start of transmission. + optional Handshake handshake = 3; } /** diff --git a/src/sourcetransform.rs b/src/sourcetransform.rs index 25f06c7..9d9a6b4 100644 --- a/src/sourcetransform.rs +++ b/src/sourcetransform.rs @@ -1,18 +1,25 @@ -use crate::error::Error::SourceTransformerError; -use crate::error::ErrorKind::UserDefinedError; -use crate::shared::{self, prost_timestamp_from_utc}; -use chrono::{DateTime, Utc}; use std::collections::HashMap; use std::fs; use std::path::PathBuf; use std::sync::Arc; + +use chrono::{DateTime, Utc}; +use proto::SourceTransformResponse; use tokio::sync::{mpsc, oneshot}; +use tokio::task::JoinHandle; +use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; -use tonic::{async_trait, Request, Response, Status}; +use tonic::{async_trait, Request, Response, Status, Streaming}; +use tracing::{error, info, warn}; + +use crate::error::Error::{self, SourceTransformerError}; +use crate::error::ErrorKind; +use crate::shared::{self, prost_timestamp_from_utc, utc_from_timestamp}; const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/sourcetransform.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcetransformer-server-info"; +const DEFAULT_CHANNEL_SIZE: usize = 1000; const DROP: &str = "U+005C__DROP__"; @@ -113,7 +120,7 @@ impl Message { /// # Arguments /// /// * `event_time` - The `DateTime` that specifies when the event occurred. Event time is required because, even though a message is dropped, - /// it is still considered as being processed, hence the watermark should be updated accordingly using the provided event time. + /// it is still considered as being processed, hence the watermark should be updated accordingly using the provided event time. /// /// # Examples /// @@ -218,14 +225,14 @@ impl From for proto::source_transform_response::Result { } } -impl From for SourceTransformRequest { - fn from(value: proto::SourceTransformRequest) -> Self { - Self { - keys: value.keys, - value: value.value, - watermark: shared::utc_from_timestamp(value.watermark), - eventtime: shared::utc_from_timestamp(value.event_time), - headers: value.headers, +impl From for SourceTransformRequest { + fn from(request: proto::source_transform_request::Request) -> Self { + SourceTransformRequest { + keys: request.keys, + value: request.value, + watermark: utc_from_timestamp(request.watermark), + eventtime: utc_from_timestamp(request.event_time), + headers: request.headers, } } } @@ -235,37 +242,61 @@ impl proto::source_transform_server::SourceTransform for SourceTransformerSer where T: SourceTransformer + Send + Sync + 'static, { + type SourceTransformFnStream = ReceiverStream>; + async fn source_transform_fn( &self, - request: Request, - ) -> Result, Status> { - let request = request.into_inner(); + request: Request>, + ) -> Result, Status> { + let mut stream = request.into_inner(); let handler = Arc::clone(&self.handler); - let handle = tokio::spawn(async move { handler.transform(request.into()).await }); - let shutdown_tx = self.shutdown_tx.clone(); - let cancellation_token = self.cancellation_token.clone(); - - // Wait for the handler to finish processing the request. If the server is shutting down(token will be cancelled), - // then return an error. - tokio::select! { - result = handle => { - match result { - Ok(messages) => Ok(Response::new(proto::SourceTransformResponse { - results: messages.into_iter().map(|msg| msg.into()).collect(), - })), - Err(e) => { - tracing::error!("Error in source transform handler: {:?}", e); - // Send a shutdown signal to the server to do a graceful shutdown because there was - // a panic in the handler. - shutdown_tx.send(()).await.expect("Sending shutdown signal to gRPC server"); - Err(Status::internal(SourceTransformerError(UserDefinedError(e.to_string())).to_string())) - } - } - }, - _ = cancellation_token.cancelled() => { - Err(Status::internal(SourceTransformerError(UserDefinedError("Server is shutting down".to_string())).to_string())) - }, + + let (stream_response_tx, stream_response_rx) = + mpsc::channel::>(DEFAULT_CHANNEL_SIZE); + + // do the handshake first to let the client know that we are ready to receive transformation requests. + let handshake_request = stream + .message() + .await + .map_err(|e| Status::internal(format!("handshake failed {}", e)))? + .ok_or_else(|| Status::internal("stream closed before handshake"))?; + + if let Some(handshake) = handshake_request.handshake { + stream_response_tx + .send(Ok(SourceTransformResponse { + results: vec![], + id: "".to_string(), + handshake: Some(handshake), + })) + .await + .map_err(|e| { + Status::internal(format!("failed to send handshake response {}", e)) + })?; + } else { + return Err(Status::invalid_argument("Handshake not present")); } + + let (error_tx, error_rx) = mpsc::channel::(1); + + // Spawn a task to continuously receive messages from the client over the gRPC stream. + // For each message received from the stream, a new task is spawned to call the transform function and send the response back to the client + let handle: JoinHandle<()> = tokio::spawn(handle_stream_requests( + handler.clone(), + stream, + stream_response_tx.clone(), + error_tx.clone(), + self.cancellation_token.child_token(), + )); + + tokio::spawn(manage_gprc_stream( + handle, + self.cancellation_token.clone(), + stream_response_tx, + error_rx, + self.shutdown_tx.clone(), + )); + + Ok(Response::new(ReceiverStream::new(stream_response_rx))) } async fn is_ready(&self, _: Request<()>) -> Result, Status> { @@ -273,6 +304,169 @@ where } } +// shutdown the gRPC server on first error +async fn manage_gprc_stream( + request_handler: JoinHandle<()>, + token: CancellationToken, + stream_response_tx: mpsc::Sender>, + mut error_rx: mpsc::Receiver, + server_shutdown_tx: mpsc::Sender<()>, +) { + let err = tokio::select! { + _ = request_handler => { + token.cancel(); + return; + }, + err = error_rx.recv() => err, + }; + + token.cancel(); + let Some(err) = err else { + return; + }; + error!("Shutting down gRPC channel: {err:?}"); + stream_response_tx + .send(Err(Status::internal(err.to_string()))) + .await + .expect("Sending error message to gRPC response channel"); + server_shutdown_tx + .send(()) + .await + .expect("Writing to shutdown channel"); +} + +// Receives messages from the stream. +// For each message received from the stream, a new task is spawned to call the transform function and send the response back to the client +async fn handle_stream_requests( + handler: Arc, + mut stream: Streaming, + stream_response_tx: mpsc::Sender>, + error_tx: mpsc::Sender, + token: CancellationToken, +) where + T: SourceTransformer + Send + Sync + 'static, +{ + let mut stream_open = true; + while stream_open { + stream_open = tokio::select! { + transform_request = stream.message() => handle_request( + handler.clone(), + transform_request, + stream_response_tx.clone(), + error_tx.clone(), + token.clone(), + ).await, + _ = token.cancelled() => { + info!("Cancellation token is cancelled, shutting down"); + break; + } + } + } +} + +// The return boolean value indicates whether a task was created to handle the request. +// If the return value is false, either client sent an error gRPC status or the stream was closed. +async fn handle_request( + handler: Arc, + transform_request: Result, Status>, + stream_response_tx: mpsc::Sender>, + error_tx: mpsc::Sender, + token: CancellationToken, +) -> bool +where + T: SourceTransformer + Send + Sync + 'static, +{ + let transform_request = match transform_request { + Ok(None) => return false, + Ok(Some(val)) => val, + Err(val) => { + error!("Received gRPC error from sender: {val:?}"); + return false; + } + }; + tokio::spawn(run_transform( + handler, + transform_request, + stream_response_tx, + error_tx, + token, + )); + true +} + +// Calls the user implemented transform function on the request. +async fn run_transform( + handler: Arc, + transform_request: proto::SourceTransformRequest, + stream_response_tx: mpsc::Sender>, + error_tx: mpsc::Sender, + token: CancellationToken, +) where + T: SourceTransformer + Send + Sync + 'static, +{ + let Some(request) = transform_request.request else { + error_tx + .send(SourceTransformerError(ErrorKind::InternalError( + "Transform request can not be none".to_string(), + ))) + .await + .expect("Sending error on channel"); + return; + }; + + let message_id = request.id.clone(); + + // A new task is spawned to catch the panic + let udf_transform_task = tokio::spawn({ + let handler = handler.clone(); + let token = token.child_token(); + async move { + tokio::select! { + _ = token.cancelled() => None, + messages = handler.transform(request.into()) => Some(messages), + } + } + }); + + let messages = match udf_transform_task.await { + Ok(messages) => messages, + Err(e) => { + tracing::error!("Failed to run transform function: {e:?}"); + error_tx + .send(SourceTransformerError(ErrorKind::UserDefinedError( + "panic in transform UDF".to_string(), + ))) + .await + .expect("Sending error on channel"); + return; + } + }; + + let Some(messages) = messages else { + // CancellationToken is cancelled + return; + }; + + let send_response_result = stream_response_tx + .send(Ok(SourceTransformResponse { + results: messages.into_iter().map(|msg| msg.into()).collect(), + id: message_id, + handshake: None, + })) + .await; + + let Err(e) = send_response_result else { + return; + }; + + error_tx + .send(SourceTransformerError(ErrorKind::InternalError(format!( + "sending source transform response over gRPC stream: {e:?}" + )))) + .await + .expect("Sending error on channel"); +} + /// gRPC server to start a sourcetransform service #[derive(Debug)] pub struct Server { @@ -390,12 +584,15 @@ mod tests { use tempfile::TempDir; use tokio::net::UnixStream; - use tokio::sync::oneshot; + use tokio::sync::{mpsc, oneshot}; + use tokio_stream::wrappers::ReceiverStream; use tonic::transport::Uri; use tower::service_fn; - use crate::sourcetransform; - use crate::sourcetransform::proto::source_transform_client::SourceTransformClient; + use crate::sourcetransform::{ + self, + proto::{self, source_transform_client::SourceTransformClient}, + }; #[tokio::test] async fn source_transformer_server() -> Result<(), Box> { @@ -447,21 +644,59 @@ mod tests { .await?; let mut client = SourceTransformClient::new(channel); - let request = tonic::Request::new(sourcetransform::proto::SourceTransformRequest { - keys: vec!["first".into(), "second".into()], - value: "hello".into(), - watermark: Some(prost_types::Timestamp::default()), - event_time: Some(prost_types::Timestamp::default()), - headers: Default::default(), - }); - - let resp = client.source_transform_fn(request).await?; - let resp = resp.into_inner(); + + let (tx, rx) = mpsc::channel(2); + + let handshake_request = proto::SourceTransformRequest { + request: None, + handshake: Some(proto::Handshake { sot: true }), + }; + tx.send(handshake_request).await.unwrap(); + + let mut stream = tokio::time::timeout( + Duration::from_secs(2), + client.source_transform_fn(ReceiverStream::new(rx)), + ) + .await + .map_err(|_| "timeout while getting stream for source_transform_fn")?? + .into_inner(); + + let handshake_resp = stream.message().await?.unwrap(); + assert!( + handshake_resp.results.is_empty(), + "The handshake response should not contain any messages" + ); + assert!( + handshake_resp.id.is_empty(), + "The message id of the handshake response should be empty" + ); + assert!( + handshake_resp.handshake.is_some(), + "Not a valid response for handshake request" + ); + + let request = sourcetransform::proto::SourceTransformRequest { + request: Some(proto::source_transform_request::Request { + id: "1".to_string(), + keys: vec!["first".into(), "second".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + headers: Default::default(), + }), + handshake: None, + }; + + tx.send(request).await.unwrap(); + + let resp = stream.message().await?.unwrap(); assert_eq!(resp.results.len(), 1, "Expected single message from server"); let msg = &resp.results[0]; assert_eq!(msg.keys.first(), Some(&"first".to_owned())); assert_eq!(msg.value, "hello".as_bytes()); + drop(tx); + shutdown_tx .send(()) .expect("Sending shutdown signal to gRPC server"); @@ -515,21 +750,40 @@ mod tests { .await?; let mut client = SourceTransformClient::new(channel); - let request = tonic::Request::new(sourcetransform::proto::SourceTransformRequest { - keys: vec!["first".into(), "second".into()], - value: "hello".into(), - watermark: Some(prost_types::Timestamp::default()), - event_time: Some(prost_types::Timestamp::default()), - headers: Default::default(), - }); - - let resp = client.source_transform_fn(request).await; - assert!(resp.is_err(), "Expected error from server"); - - if let Err(e) = resp { - assert_eq!(e.code(), tonic::Code::Internal); - assert!(e.message().contains("User Defined Error")); - } + + let (tx, rx) = mpsc::channel(2); + let handshake_request = proto::SourceTransformRequest { + request: None, + handshake: Some(proto::Handshake { sot: true }), + }; + tx.send(handshake_request).await.unwrap(); + + let mut stream = tokio::time::timeout( + Duration::from_secs(2), + client.source_transform_fn(ReceiverStream::new(rx)), + ) + .await + .map_err(|_| "timeout while getting stream for source_transform_fn")?? + .into_inner(); + + let handshake_resp = stream.message().await?.unwrap(); + assert!( + handshake_resp.handshake.is_some(), + "Not a valid response for handshake request" + ); + + let request = proto::SourceTransformRequest { + request: Some(proto::source_transform_request::Request { + id: "1".to_string(), + keys: vec!["first".into(), "second".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + headers: Default::default(), + }), + handshake: None, + }; + tx.send(request).await.unwrap(); // server should shut down gracefully because there was a panic in the handler. for _ in 0..10 {