From d4c8494b0d5c12d0b4731fdba9ef31446992e6b8 Mon Sep 17 00:00:00 2001
From: Yashash H L <yashashhl25@gmail.com>
Date: Fri, 4 Oct 2024 15:26:51 +0530
Subject: [PATCH 1/2] feat: Use gRPC Bidirectional Streaming for Map

Signed-off-by: Yashash H L <yashashhl25@gmail.com>
---
 proto/map.proto        |  28 ++-
 src/batchmap.rs        |  25 +--
 src/map.rs             | 490 ++++++++++++++++++++++++++++++-----------
 src/shared.rs          |   3 +-
 src/sourcetransform.rs |  58 ++---
 5 files changed, 429 insertions(+), 175 deletions(-)

diff --git a/proto/map.proto b/proto/map.proto
index 07433dd..f3761d1 100644
--- a/proto/map.proto
+++ b/proto/map.proto
@@ -7,7 +7,7 @@ package map.v1;
 
 service Map {
   // MapFn applies a function to each map request element.
-  rpc MapFn(MapRequest) returns (MapResponse);
+  rpc MapFn(stream MapRequest) returns (stream MapResponse);
 
   // IsReady is the heartbeat endpoint for gRPC.
   rpc IsReady(google.protobuf.Empty) returns (ReadyResponse);
@@ -17,12 +17,25 @@ service Map {
  * MapRequest represents a request element.
  */
 message MapRequest {
-  repeated string keys = 1;
-  bytes value = 2;
-  google.protobuf.Timestamp event_time = 3;
-  google.protobuf.Timestamp watermark = 4;
-  map<string, string> headers = 5;
+  message Request {
+    repeated string keys = 1;
+    bytes value = 2;
+    google.protobuf.Timestamp event_time = 3;
+    google.protobuf.Timestamp watermark = 4;
+    map<string, string> headers = 5;
+  }
+  Request request = 1;
+  // This ID is used to uniquely identify a map request
+  string id = 2;
+  optional Handshake handshake = 3;
+}
 
+/*
+ * Handshake message between client and server to indicate the start of transmission.
+ */
+message Handshake {
+  // Required field indicating the start of transmission.
+  bool sot = 1;
 }
 
 /**
@@ -35,6 +48,9 @@ message MapResponse {
     repeated string tags = 3;
   }
   repeated Result results = 1;
+  // This ID is used to refer the responses to the request it corresponds to.
+  string id = 2;
+  optional Handshake handshake = 3;
 }
 
 /**
diff --git a/src/batchmap.rs b/src/batchmap.rs
index 7109fe9..9814a59 100644
--- a/src/batchmap.rs
+++ b/src/batchmap.rs
@@ -118,7 +118,7 @@ pub struct Message {
 }
 
 /// Represents a message that can be modified and forwarded.
-impl crate::batchmap::Message {
+impl Message {
     /// Creates a new message with the specified value.
     ///
     /// This constructor initializes the message with no keys, tags, or specific event time.
@@ -148,11 +148,11 @@ impl crate::batchmap::Message {
     /// use numaflow::batchmap::Message;
     /// let dropped_message = Message::message_to_drop();
     /// ```
-    pub fn message_to_drop() -> crate::batchmap::Message {
-        crate::batchmap::Message {
+    pub fn message_to_drop() -> Message {
+        Message {
             keys: None,
             value: vec![],
-            tags: Some(vec![crate::batchmap::DROP.to_string()]),
+            tags: Some(vec![DROP.to_string()]),
         }
     }
 
@@ -245,11 +245,8 @@ impl<T> BatchMap for BatchMapService<T>
 where
     T: BatchMapper + Send + Sync + 'static,
 {
-    async fn is_ready(
-        &self,
-        _: Request<()>,
-    ) -> Result<tonic::Response<proto::ReadyResponse>, Status> {
-        Ok(tonic::Response::new(proto::ReadyResponse { ready: true }))
+    async fn is_ready(&self, _: Request<()>) -> Result<Response<proto::ReadyResponse>, Status> {
+        Ok(Response::new(proto::ReadyResponse { ready: true }))
     }
 
     type BatchMapFnStream = ReceiverStream<Result<proto::BatchMapResponse, Status>>;
@@ -261,7 +258,7 @@ where
         let mut stream = request.into_inner();
 
         // Create a channel to send the messages to the user defined function.
-        let (tx, rx) = mpsc::channel::<Datum>(1);
+        let (tx, rx) = channel::<Datum>(1);
 
         // Create a channel to send the response back to the grpc client.
         let (grpc_response_tx, grpc_response_rx) =
@@ -418,9 +415,9 @@ pub struct Server<T> {
     server_info_file: PathBuf,
     svc: Option<T>,
 }
-impl<T> crate::batchmap::Server<T> {
+impl<T> Server<T> {
     pub fn new(batch_map_svc: T) -> Self {
-        crate::batchmap::Server {
+        Server {
             sock_addr: DEFAULT_SOCK_ADDR.into(),
             max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
             server_info_file: DEFAULT_SERVER_INFO_FILE.into(),
@@ -478,8 +475,8 @@ impl<T> crate::batchmap::Server<T> {
         let cln_token = CancellationToken::new();
 
         // Create a channel to send shutdown signal to the server to do graceful shutdown in case of non retryable errors.
-        let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1);
-        let map_svc = crate::batchmap::BatchMapService {
+        let (internal_shutdown_tx, internal_shutdown_rx) = channel(1);
+        let map_svc = BatchMapService {
             handler: Arc::new(handler),
             _shutdown_tx: internal_shutdown_tx,
             cancellation_token: cln_token.clone(),
diff --git a/src/map.rs b/src/map.rs
index eec2294..2ba1fca 100644
--- a/src/map.rs
+++ b/src/map.rs
@@ -1,5 +1,5 @@
-use crate::error::Error::MapError;
-use crate::error::ErrorKind::{InternalError, UserDefinedError};
+use crate::error::{Error, ErrorKind};
+use crate::map::proto::MapResponse;
 use crate::shared::{self, shutdown_signal, ContainerType};
 use chrono::{DateTime, Utc};
 use std::collections::HashMap;
@@ -7,9 +7,13 @@ use std::fs;
 use std::path::PathBuf;
 use std::sync::Arc;
 use tokio::sync::{mpsc, oneshot};
+use tokio::task::JoinHandle;
+use tokio_stream::wrappers::ReceiverStream;
 use tokio_util::sync::CancellationToken;
-use tonic::{async_trait, Request, Response, Status};
+use tonic::{async_trait, Request, Response, Status, Streaming};
+use tracing::{error, info};
 
+const DEFAULT_CHANNEL_SIZE: usize = 1000;
 const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
 const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/map.sock";
 const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/mapper-server-info";
@@ -60,53 +64,6 @@ pub trait Mapper {
     async fn map(&self, input: MapRequest) -> Vec<Message>;
 }
 
-#[async_trait]
-impl<T> proto::map_server::Map for MapService<T>
-where
-    T: Mapper + Send + Sync + 'static,
-{
-    async fn map_fn(
-        &self,
-        request: Request<proto::MapRequest>,
-    ) -> Result<Response<proto::MapResponse>, Status> {
-        let request = request.into_inner();
-        let handler = Arc::clone(&self.handler);
-        let handle = tokio::spawn(async move { handler.map(request.into()).await });
-        let shutdown_tx = self.shutdown_tx.clone();
-        let cancellation_token = self.cancellation_token.clone();
-
-        // Wait for the handler to finish processing the request. If the server is shutting down(token will be cancelled),
-        // then return an error.
-        tokio::select! {
-            result = handle => {
-                match result {
-                    Ok(result) => Ok(Response::new(proto::MapResponse {
-                        results: result.into_iter().map(|msg| msg.into()).collect(),
-                    })),
-                    Err(e) => {
-                        tracing::error!("Error in map handler: {:?}", e);
-                        // Send a shutdown signal to the server to do a graceful shutdown because there was
-                        // a panic in the handler.
-                        shutdown_tx
-                            .send(())
-                            .await
-                            .expect("Sending shutdown signal to gRPC server");
-                        Err(Status::internal(MapError(UserDefinedError(e.to_string())).to_string()))
-                    }
-                }
-            },
-
-            _ = cancellation_token.cancelled() => {
-                Err(Status::internal(MapError(InternalError("Server is shutting down".to_string())).to_string()))
-            },
-        }
-    }
-
-    async fn is_ready(&self, _: Request<()>) -> Result<Response<proto::ReadyResponse>, Status> {
-        Ok(Response::new(proto::ReadyResponse { ready: true }))
-    }
-}
-
 /// Message is the response struct from the [`Mapper::map`] .
 #[derive(Debug, PartialEq)]
 pub struct Message {
@@ -234,8 +191,8 @@ pub struct MapRequest {
     pub headers: HashMap<String, String>,
 }
 
-impl From<proto::MapRequest> for MapRequest {
-    fn from(value: proto::MapRequest) -> Self {
+impl From<proto::map_request::Request> for MapRequest {
+    fn from(value: proto::map_request::Request) -> Self {
         Self {
             keys: value.keys,
             value: value.value,
@@ -246,6 +203,235 @@ impl From<proto::MapRequest> for MapRequest {
     }
 }
 
+#[async_trait]
+impl<T> proto::map_server::Map for MapService<T>
+where
+    T: Mapper + Send + Sync + 'static,
+{
+    type MapFnStream = ReceiverStream<Result<MapResponse, Status>>;
+
+    async fn map_fn(
+        &self,
+        request: Request<Streaming<proto::MapRequest>>,
+    ) -> Result<Response<Self::MapFnStream>, Status> {
+        let mut stream = request.into_inner();
+        let handler = Arc::clone(&self.handler);
+
+        let (stream_response_tx, stream_response_rx) =
+            mpsc::channel::<Result<MapResponse, Status>>(DEFAULT_CHANNEL_SIZE);
+
+        // perform handshake
+        perform_handshake(&mut stream, &stream_response_tx).await?;
+
+        let (error_tx, error_rx) = mpsc::channel::<Error>(1);
+
+        // Spawn a task to handle incoming stream requests
+        let handle: JoinHandle<()> = tokio::spawn(handle_stream_requests(
+            handler.clone(),
+            stream,
+            stream_response_tx.clone(),
+            error_tx.clone(),
+            self.cancellation_token.child_token(),
+        ));
+
+        tokio::spawn(manage_grpc_stream(
+            handle,
+            self.cancellation_token.clone(),
+            stream_response_tx,
+            error_rx,
+            self.shutdown_tx.clone(),
+        ));
+
+        Ok(Response::new(ReceiverStream::new(stream_response_rx)))
+    }
+
+    async fn is_ready(&self, _: Request<()>) -> Result<Response<proto::ReadyResponse>, Status> {
+        Ok(Response::new(proto::ReadyResponse { ready: true }))
+    }
+}
+
+async fn handle_stream_requests<T>(
+    handler: Arc<T>,
+    mut stream: Streaming<proto::MapRequest>,
+    stream_response_tx: mpsc::Sender<Result<MapResponse, Status>>,
+    error_tx: mpsc::Sender<Error>,
+    token: CancellationToken,
+) where
+    T: Mapper + Send + Sync + 'static,
+{
+    let mut stream_open = true;
+    while stream_open {
+        stream_open = tokio::select! {
+            map_request = stream.message() => handle_request(
+                handler.clone(),
+                map_request,
+                stream_response_tx.clone(),
+                error_tx.clone(),
+                token.clone(),
+            ).await,
+            _ = token.cancelled() => {
+                info!("Cancellation token is cancelled, shutting down");
+                break;
+            }
+        }
+    }
+}
+
+async fn manage_grpc_stream(
+    request_handler: JoinHandle<()>,
+    token: CancellationToken,
+    stream_response_tx: mpsc::Sender<Result<MapResponse, Status>>,
+    mut error_rx: mpsc::Receiver<Error>,
+    server_shutdown_tx: mpsc::Sender<()>,
+) {
+    let err = tokio::select! {
+        _ = request_handler => {
+            token.cancel();
+            return;
+        },
+        err = error_rx.recv() => err,
+    };
+
+    token.cancel();
+    let Some(err) = err else {
+        return;
+    };
+    error!("Shutting down gRPC channel: {err:?}");
+    stream_response_tx
+        .send(Err(Status::internal(err.to_string())))
+        .await
+        .expect("Sending error message to gRPC response channel");
+    server_shutdown_tx
+        .send(())
+        .await
+        .expect("Writing to shutdown channel");
+}
+
+async fn handle_request<T>(
+    handler: Arc<T>,
+    map_request: Result<Option<proto::MapRequest>, Status>,
+    stream_response_tx: mpsc::Sender<Result<MapResponse, Status>>,
+    error_tx: mpsc::Sender<Error>,
+    token: CancellationToken,
+) -> bool
+where
+    T: Mapper + Send + Sync + 'static,
+{
+    let map_request = match map_request {
+        Ok(None) => return false,
+        Ok(Some(val)) => val,
+        Err(val) => {
+            error!("Received gRPC error from sender: {val:?}");
+            return false;
+        }
+    };
+    tokio::spawn(run_map(
+        handler,
+        map_request,
+        stream_response_tx,
+        error_tx,
+        token,
+    ));
+    true
+}
+
+async fn run_map<T>(
+    handler: Arc<T>,
+    map_request: proto::MapRequest,
+    stream_response_tx: mpsc::Sender<Result<MapResponse, Status>>,
+    error_tx: mpsc::Sender<Error>,
+    token: CancellationToken,
+) where
+    T: Mapper + Send + Sync + 'static,
+{
+    let Some(request) = map_request.request else {
+        error_tx
+            .send(Error::MapError(ErrorKind::InternalError(
+                "Request not present".to_string(),
+            )))
+            .await
+            .expect("Sending error on channel");
+        return;
+    };
+
+    let message_id = map_request.id.clone();
+
+    // A new task is spawned to catch the panic
+    let udf_map_task = tokio::spawn({
+        let handler = handler.clone();
+        let token = token.child_token();
+        async move {
+            tokio::select! {
+                _ = token.cancelled() => None,
+                messages = handler.map(request.into()) => Some(messages),
+            }
+        }
+    });
+
+    let messages = match udf_map_task.await {
+        Ok(messages) => messages,
+        Err(e) => {
+            error!("Failed to run map function: {e:?}");
+            error_tx
+                .send(Error::MapError(ErrorKind::InternalError(format!(
+                    "panicked: {e:?}"
+                ))))
+                .await
+                .expect("Sending error on channel");
+            return;
+        }
+    };
+
+    let Some(messages) = messages else {
+        // CancellationToken is cancelled
+        return;
+    };
+
+    let send_response_result = stream_response_tx
+        .send(Ok(MapResponse {
+            results: messages.into_iter().map(|msg| msg.into()).collect(),
+            id: message_id,
+            handshake: None,
+        }))
+        .await;
+
+    let Err(e) = send_response_result else {
+        return;
+    };
+
+    error_tx
+        .send(Error::MapError(ErrorKind::InternalError(format!(
+            "Failed to send response: {e:?}"
+        ))))
+        .await
+        .expect("Sending error on channel");
+}
+
+async fn perform_handshake(
+    stream: &mut Streaming<proto::MapRequest>,
+    stream_response_tx: &mpsc::Sender<Result<MapResponse, Status>>,
+) -> Result<(), Status> {
+    let handshake_request = stream
+        .message()
+        .await
+        .map_err(|e| Status::internal(format!("Handshake failed: {}", e)))?
+        .ok_or_else(|| Status::internal("Stream closed before handshake"))?;
+
+    if let Some(handshake) = handshake_request.handshake {
+        stream_response_tx
+            .send(Ok(MapResponse {
+                results: vec![],
+                id: "".to_string(),
+                handshake: Some(handshake),
+            }))
+            .await
+            .map_err(|e| Status::internal(format!("Failed to send handshake response: {}", e)))?;
+        Ok(())
+    } else {
+        Err(Status::invalid_argument("Handshake not present"))
+    }
+}
+
 /// gRPC server to start a map service
 #[derive(Debug)]
 pub struct Server<T> {
@@ -362,10 +548,11 @@ mod tests {
     use crate::map::proto::map_client::MapClient;
     use std::{error::Error, time::Duration};
 
+    use crate::map::proto;
     use tempfile::TempDir;
     use tokio::net::UnixStream;
-    use tokio::sync::oneshot;
-    use tokio::time::sleep;
+    use tokio::sync::{mpsc, oneshot};
+    use tokio_stream::wrappers::ReceiverStream;
     use tonic::transport::Uri;
     use tower::service_fn;
 
@@ -415,21 +602,51 @@ mod tests {
             .await?;
 
         let mut client = MapClient::new(channel);
-        let request = tonic::Request::new(map::proto::MapRequest {
-            keys: vec!["first".into(), "second".into()],
-            value: "hello".into(),
-            watermark: Some(prost_types::Timestamp::default()),
-            event_time: Some(prost_types::Timestamp::default()),
-            headers: Default::default(),
-        });
-
-        let resp = client.map_fn(request).await?;
-        let resp = resp.into_inner();
-        assert_eq!(resp.results.len(), 1, "Expected single message from server");
-        let msg = &resp.results[0];
+        let request = proto::MapRequest {
+            request: Some(proto::map_request::Request {
+                keys: vec!["first".into(), "second".into()],
+                value: "hello".into(),
+                watermark: Some(prost_types::Timestamp::default()),
+                event_time: Some(prost_types::Timestamp::default()),
+                headers: Default::default(),
+            }),
+            id: "".to_string(),
+            handshake: None,
+        };
+
+        let (tx, rx) = mpsc::channel(2);
+        let handshake_request = proto::MapRequest {
+            request: None,
+            id: "".to_string(),
+            handshake: Some(proto::Handshake { sot: true }),
+        };
+
+        tx.send(handshake_request).await?;
+        tx.send(request).await?;
+
+        let resp = client.map_fn(ReceiverStream::new(rx)).await?;
+        let mut resp = resp.into_inner();
+
+        let handshake_response = resp.message().await?;
+        assert!(handshake_response.is_some());
+
+        let handshake_response = handshake_response.unwrap();
+        assert!(handshake_response.handshake.is_some());
+
+        let actual_response = resp.message().await?;
+        assert!(actual_response.is_some());
+
+        let actual_response = actual_response.unwrap();
+        assert_eq!(
+            actual_response.results.len(),
+            1,
+            "Expected single message from server"
+        );
+        let msg = &actual_response.results[0];
         assert_eq!(msg.keys.first(), Some(&"first".to_owned()));
         assert_eq!(msg.value, "hello".as_bytes());
 
+        drop(tx);
         shutdown_tx
             .send(())
             .expect("Sending shutdown signal to gRPC server");
@@ -440,11 +657,11 @@ mod tests {
 
     #[tokio::test]
     async fn map_server_panic() -> Result<(), Box<dyn Error>> {
-        struct PanicCat;
+        struct PanicMapper;
         #[tonic::async_trait]
-        impl map::Mapper for PanicCat {
-            async fn map(&self, _input: map::MapRequest) -> Vec<map::Message> {
-                panic!("PanicCat panicking");
+        impl map::Mapper for PanicMapper {
+            async fn map(&self, _: map::MapRequest) -> Vec<map::Message> {
+                panic!("Panic in mapper");
             }
         }
 
@@ -452,7 +669,7 @@ mod tests {
         let sock_file = tmp_dir.path().join("map.sock");
         let server_info_file = tmp_dir.path().join("mapper-server-info");
 
-        let mut server = map::Server::new(PanicCat)
+        let mut server = map::Server::new(PanicMapper)
             .with_server_info_file(&server_info_file)
             .with_socket_file(&sock_file)
             .with_max_message_size(10240);
@@ -466,8 +683,10 @@ mod tests {
 
         tokio::time::sleep(Duration::from_millis(50)).await;
 
+        // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs
         let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")?
             .connect_with_connector(service_fn(move |_: Uri| {
+                // https://rust-lang.github.io/async-book/03_async_await/01_chapter.html#async-lifetimes
                 let sock_file = sock_file.clone();
                 async move {
                     Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new(
@@ -478,22 +697,41 @@ mod tests {
             .await?;
 
         let mut client = MapClient::new(channel);
-        let request = tonic::Request::new(map::proto::MapRequest {
-            keys: vec!["first".into(), "second".into()],
-            value: "hello".into(),
-            watermark: Some(prost_types::Timestamp::default()),
-            event_time: Some(prost_types::Timestamp::default()),
-            headers: Default::default(),
-        });
-
-        // server should return an error because of the panic.
-        let resp = client.map_fn(request).await;
-        assert!(resp.is_err(), "Expected error from server");
-
-        if let Err(e) = resp {
-            assert_eq!(e.code(), tonic::Code::Internal);
-            assert!(e.message().contains("User Defined Error"));
-        }
+
+        let (tx, rx) = mpsc::channel(2);
+        let handshake_request = proto::MapRequest {
+            request: None,
+            id: "".to_string(),
+            handshake: Some(proto::Handshake { sot: true }),
+        };
+        tx.send(handshake_request).await.unwrap();
+
+        let mut stream = tokio::time::timeout(
+            Duration::from_secs(2),
+            client.map_fn(ReceiverStream::new(rx)),
+        )
+        .await
+        .map_err(|_| "timeout while getting stream for map_fn")??
+        .into_inner();
+
+        let handshake_resp = stream.message().await?.unwrap();
+        assert!(
+            handshake_resp.handshake.is_some(),
+            "Not a valid response for handshake request"
+        );
+
+        let request = proto::MapRequest {
+            request: Some(proto::map_request::Request {
+                keys: vec!["three".into(), "four".into()],
+                value: "hello".into(),
+                watermark: Some(prost_types::Timestamp::default()),
+                event_time: Some(prost_types::Timestamp::default()),
+                headers: Default::default(),
+            }),
+            id: "".to_string(),
+            handshake: None,
+        };
+        tx.send(request).await.unwrap();
 
         // server should shut down gracefully because there was a panic in the handler.
         for _ in 0..10 {
@@ -511,17 +749,11 @@ mod tests {
     // should shut down gracefully.
     #[tokio::test]
     async fn panic_with_multiple_requests() -> Result<(), Box<dyn Error>> {
-        struct PanicCat;
+        struct PanicMapper;
         #[tonic::async_trait]
-        impl map::Mapper for PanicCat {
-            async fn map(&self, input: map::MapRequest) -> Vec<map::Message> {
-                if !input.keys.is_empty() && input.keys[0] == "key1" {
-                    sleep(Duration::from_millis(20)).await;
-                    panic!("Cat panicked");
-                }
-                // assume each request takes 100ms to process
-                sleep(Duration::from_millis(100)).await;
-                vec![]
+        impl map::Mapper for PanicMapper {
+            async fn map(&self, _: map::MapRequest) -> Vec<map::Message> {
+                panic!("Panic in mapper");
             }
         }
 
@@ -529,7 +761,7 @@ mod tests {
         let sock_file = tmp_dir.path().join("map.sock");
         let server_info_file = tmp_dir.path().join("mapper-server-info");
 
-        let mut server = map::Server::new(PanicCat)
+        let mut server = map::Server::new(PanicMapper)
             .with_server_info_file(&server_info_file)
             .with_socket_file(&sock_file)
             .with_max_message_size(10240);
@@ -543,6 +775,7 @@ mod tests {
 
         tokio::time::sleep(Duration::from_millis(50)).await;
 
+        // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs
         let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")?
             .connect_with_connector(service_fn(move |_: Uri| {
                 let sock_file = sock_file.clone();
@@ -556,47 +789,48 @@ mod tests {
 
         let mut client = MapClient::new(channel);
 
-        let mut client_one = client.clone();
-        tokio::spawn(async move {
-            let request = tonic::Request::new(map::proto::MapRequest {
-                keys: vec!["key2".into()],
+        let (tx, rx) = mpsc::channel(2);
+        let handshake_request = proto::MapRequest {
+            request: None,
+            id: "".to_string(),
+            handshake: Some(proto::Handshake { sot: true }),
+        };
+        tx.send(handshake_request).await.unwrap();
+
+        let mut stream = tokio::time::timeout(
+            Duration::from_secs(2),
+            client.map_fn(ReceiverStream::new(rx)),
+        )
+        .await
+        .map_err(|_| "timeout while getting stream for map_fn")??
+        .into_inner();
+
+        let handshake_resp = stream.message().await?.unwrap();
+        assert!(
+            handshake_resp.handshake.is_some(),
+            "Not a valid response for handshake request"
+        );
+
+        let request = proto::MapRequest {
+            request: Some(proto::map_request::Request {
+                keys: vec!["five".into(), "six".into()],
                 value: "hello".into(),
                 watermark: Some(prost_types::Timestamp::default()),
                 event_time: Some(prost_types::Timestamp::default()),
                 headers: Default::default(),
-            });
-
-            // panic is only for requests with key "key1", since we have graceful shutdown
-            // the request should get processed.
-            let resp = client_one.map_fn(request).await;
-            assert!(resp.is_ok(), "Expected ok from server");
-        });
-
-        let request = tonic::Request::new(map::proto::MapRequest {
-            keys: vec!["key1".into()],
-            value: "hello".into(),
-            watermark: Some(prost_types::Timestamp::default()),
-            event_time: Some(prost_types::Timestamp::default()),
-            headers: Default::default(),
-        });
-
-        // panic happens for the key1 request, so we should expect error on the client side.
-        let resp = client.map_fn(request).await;
-        assert!(resp.is_err(), "Expected error from server");
-
-        if let Err(e) = resp {
-            assert_eq!(e.code(), tonic::Code::Internal);
-            assert!(e.message().contains("User Defined Error"));
-        }
+            }),
+            id: "".to_string(),
+            handshake: None,
+        };
+        tx.send(request).await.unwrap();
 
-        // but since there is a panic, the server should shutdown.
+        // server should shut down gracefully because there was a panic in the handler.
         for _ in 0..10 {
             tokio::time::sleep(Duration::from_millis(10)).await;
             if task.is_finished() {
                 break;
             }
         }
-
         assert!(task.is_finished(), "gRPC server is still running");
         Ok(())
     }
diff --git a/src/shared.rs b/src/shared.rs
index 2754507..c5380e6 100644
--- a/src/shared.rs
+++ b/src/shared.rs
@@ -91,7 +91,7 @@ impl ServerInfo {
             minimum_numaflow_version: MINIMUM_NUMAFLOW_VERSION
                 .get(&container_type)
                 .map(|&version| version.to_string())
-                .unwrap_or_else(String::new),
+                .unwrap_or_default(),
             version: SDK_VERSION.to_string(),
             metadata: Option::from(metadata),
         }
@@ -138,6 +138,7 @@ pub(crate) fn prost_timestamp_from_utc(t: DateTime<Utc>) -> Option<Timestamp> {
 /// shuts downs the gRPC server. This happens in 2 cases
 /// 1. there has been an internal error (one of the tasks failed) and we need to shutdown
 /// 2. user is explicitly asking us to shutdown
+///
 /// Once the request for shutdown has be invoked, server will broadcast shutdown to all tasks
 /// through the cancellation-token.
 pub(crate) async fn shutdown_signal(
diff --git a/src/sourcetransform.rs b/src/sourcetransform.rs
index 9d8fe04..b3ad454 100644
--- a/src/sourcetransform.rs
+++ b/src/sourcetransform.rs
@@ -255,26 +255,7 @@ where
             mpsc::channel::<Result<SourceTransformResponse, Status>>(DEFAULT_CHANNEL_SIZE);
 
         // do the handshake first to let the client know that we are ready to receive transformation requests.
-        let handshake_request = stream
-            .message()
-            .await
-            .map_err(|e| Status::internal(format!("handshake failed {}", e)))?
-            .ok_or_else(|| Status::internal("stream closed before handshake"))?;
-
-        if let Some(handshake) = handshake_request.handshake {
-            stream_response_tx
-                .send(Ok(SourceTransformResponse {
-                    results: vec![],
-                    id: "".to_string(),
-                    handshake: Some(handshake),
-                }))
-                .await
-                .map_err(|e| {
-                    Status::internal(format!("failed to send handshake response {}", e))
-                })?;
-        } else {
-            return Err(Status::invalid_argument("Handshake not present"));
-        }
+        perform_handshake(&mut stream, &stream_response_tx).await?;
 
         let (error_tx, error_rx) = mpsc::channel::<Error>(1);
 
@@ -288,7 +269,7 @@ where
             self.cancellation_token.child_token(),
         ));
 
-        tokio::spawn(manage_gprc_stream(
+        tokio::spawn(manage_grpc_stream(
             handle,
             self.cancellation_token.clone(),
             stream_response_tx,
@@ -304,8 +285,33 @@ where
     }
 }
 
+async fn perform_handshake(
+    stream: &mut Streaming<proto::SourceTransformRequest>,
+    stream_response_tx: &mpsc::Sender<Result<SourceTransformResponse, Status>>,
+) -> Result<(), Status> {
+    let handshake_request = stream
+        .message()
+        .await
+        .map_err(|e| Status::internal(format!("Handshake failed: {}", e)))?
+        .ok_or_else(|| Status::internal("Stream closed before handshake"))?;
+
+    if let Some(handshake) = handshake_request.handshake {
+        stream_response_tx
+            .send(Ok(SourceTransformResponse {
+                results: vec![],
+                id: "".to_string(),
+                handshake: Some(handshake),
+            }))
+            .await
+            .map_err(|e| Status::internal(format!("Failed to send handshake response: {}", e)))?;
+        Ok(())
+    } else {
+        Err(Status::invalid_argument("Handshake not present"))
+    }
+}
+
 // shutdown the gRPC server on first error
-async fn manage_gprc_stream(
+async fn manage_grpc_stream(
     request_handler: JoinHandle<()>,
     token: CancellationToken,
     stream_response_tx: mpsc::Sender<Result<SourceTransformResponse, Status>>,
@@ -335,8 +341,8 @@ async fn manage_gprc_stream(
         .expect("Writing to shutdown channel");
 }
 
-// Receives messages from the stream.
-// For each message received from the stream, a new task is spawned to call the transform function and send the response back to the client
+// Receives messages from the stream. For each message received from the stream,
+// a new task is spawned to call the transform function and send the response back to the client
 async fn handle_stream_requests<T>(
     handler: Arc<T>,
     mut stream: Streaming<proto::SourceTransformRequest>,
@@ -673,7 +679,7 @@ mod tests {
             "Not a valid response for handshake request"
         );
 
-        let request = sourcetransform::proto::SourceTransformRequest {
+        let request = proto::SourceTransformRequest {
             request: Some(proto::source_transform_request::Request {
                 id: "1".to_string(),
                 keys: vec!["first".into(), "second".into()],
@@ -772,7 +778,7 @@ mod tests {
 
         let request = proto::SourceTransformRequest {
             request: Some(proto::source_transform_request::Request {
-                id: "1".to_string(),
+                id: "2".to_string(),
                 keys: vec!["first".into(), "second".into()],
                 value: "hello".into(),
                 watermark: Some(prost_types::Timestamp::default()),

From 19ab43ec7fc97b14992d8ec45269d313388ccd8b Mon Sep 17 00:00:00 2001
From: Yashash H L <yashashhl25@gmail.com>
Date: Tue, 8 Oct 2024 13:33:56 +0530
Subject: [PATCH 2/2] format imports, review comments

Signed-off-by: Yashash H L <yashashhl25@gmail.com>
---
 src/batchmap.rs  |  3 ++-
 src/map.rs       | 17 ++++++++++-------
 src/shared.rs    | 11 +++++++----
 src/sideinput.rs |  8 +++++---
 src/sink.rs      | 15 ++++++++-------
 src/source.rs    | 18 ++++++++++--------
 6 files changed, 42 insertions(+), 30 deletions(-)

diff --git a/src/batchmap.rs b/src/batchmap.rs
index 9814a59..4112618 100644
--- a/src/batchmap.rs
+++ b/src/batchmap.rs
@@ -1,9 +1,10 @@
-use chrono::{DateTime, Utc};
 use std::collections::HashMap;
 use std::fs;
 use std::path::PathBuf;
 use std::sync::atomic::{AtomicUsize, Ordering};
 use std::sync::Arc;
+
+use chrono::{DateTime, Utc};
 use tokio::sync::mpsc::channel;
 use tokio::sync::{mpsc, oneshot};
 use tokio_stream::wrappers::ReceiverStream;
diff --git a/src/map.rs b/src/map.rs
index 2ba1fca..a12340e 100644
--- a/src/map.rs
+++ b/src/map.rs
@@ -1,11 +1,9 @@
-use crate::error::{Error, ErrorKind};
-use crate::map::proto::MapResponse;
-use crate::shared::{self, shutdown_signal, ContainerType};
-use chrono::{DateTime, Utc};
 use std::collections::HashMap;
 use std::fs;
 use std::path::PathBuf;
 use std::sync::Arc;
+
+use chrono::{DateTime, Utc};
 use tokio::sync::{mpsc, oneshot};
 use tokio::task::JoinHandle;
 use tokio_stream::wrappers::ReceiverStream;
@@ -13,6 +11,10 @@ use tokio_util::sync::CancellationToken;
 use tonic::{async_trait, Request, Response, Status, Streaming};
 use tracing::{error, info};
 
+use crate::error::{Error, ErrorKind};
+use crate::map::proto::MapResponse;
+use crate::shared::{self, shutdown_signal, ContainerType};
+
 const DEFAULT_CHANNEL_SIZE: usize = 1000;
 const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
 const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/map.sock";
@@ -544,11 +546,8 @@ impl<C> Drop for Server<C> {
 
 #[cfg(test)]
 mod tests {
-    use crate::map;
-    use crate::map::proto::map_client::MapClient;
     use std::{error::Error, time::Duration};
 
-    use crate::map::proto;
     use tempfile::TempDir;
     use tokio::net::UnixStream;
     use tokio::sync::{mpsc, oneshot};
@@ -556,6 +555,10 @@ mod tests {
     use tonic::transport::Uri;
     use tower::service_fn;
 
+    use crate::map;
+    use crate::map::proto;
+    use crate::map::proto::map_client::MapClient;
+
     #[tokio::test]
     async fn map_server() -> Result<(), Box<dyn Error>> {
         struct Cat;
diff --git a/src/shared.rs b/src/shared.rs
index c5380e6..b22f585 100644
--- a/src/shared.rs
+++ b/src/shared.rs
@@ -1,10 +1,11 @@
-use chrono::{DateTime, TimeZone, Timelike, Utc};
-use prost_types::Timestamp;
-use serde::{Deserialize, Serialize};
 use std::fs;
 use std::path::Path;
 use std::sync::LazyLock;
 use std::{collections::HashMap, io};
+
+use chrono::{DateTime, TimeZone, Timelike, Utc};
+use prost_types::Timestamp;
+use serde::{Deserialize, Serialize};
 use tokio::net::UnixListener;
 use tokio::signal;
 use tokio::sync::{mpsc, oneshot};
@@ -178,11 +179,13 @@ pub(crate) async fn shutdown_signal(
 
 #[cfg(test)]
 mod tests {
-    use super::*;
     use std::fs::File;
     use std::io::Read;
+
     use tempfile::NamedTempFile;
 
+    use super::*;
+
     #[test]
     fn test_utc_from_timestamp() {
         let specific_date = Utc.with_ymd_and_hms(2022, 7, 2, 2, 0, 0).unwrap();
diff --git a/src/sideinput.rs b/src/sideinput.rs
index c5e8c0e..93a6bd0 100644
--- a/src/sideinput.rs
+++ b/src/sideinput.rs
@@ -1,13 +1,15 @@
-use crate::error::Error::SideInputError;
-use crate::error::ErrorKind::{InternalError, UserDefinedError};
-use crate::shared::{self, shutdown_signal, ContainerType};
 use std::fs;
 use std::path::PathBuf;
 use std::sync::Arc;
+
 use tokio::sync::{mpsc, oneshot};
 use tokio_util::sync::CancellationToken;
 use tonic::{async_trait, Request, Response, Status};
 
+use crate::error::Error::SideInputError;
+use crate::error::ErrorKind::{InternalError, UserDefinedError};
+use crate::shared::{self, shutdown_signal, ContainerType};
+
 const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
 const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/sideinput.sock";
 const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sideinput-server-info";
diff --git a/src/sink.rs b/src/sink.rs
index 824d0af..329a3dc 100644
--- a/src/sink.rs
+++ b/src/sink.rs
@@ -1,14 +1,9 @@
-use crate::error::Error;
-use crate::error::Error::SinkError;
-use crate::error::ErrorKind::{InternalError, UserDefinedError};
-use crate::shared::{self, ContainerType};
-use crate::sink::sink_pb::SinkResponse;
-
-use chrono::{DateTime, Utc};
 use std::collections::HashMap;
 use std::path::PathBuf;
 use std::sync::Arc;
 use std::{env, fs};
+
+use chrono::{DateTime, Utc};
 use tokio::sync::{mpsc, oneshot};
 use tokio::task::JoinHandle;
 use tokio_stream::wrappers::ReceiverStream;
@@ -16,6 +11,12 @@ use tokio_util::sync::CancellationToken;
 use tonic::{Request, Status, Streaming};
 use tracing::{debug, info};
 
+use crate::error::Error;
+use crate::error::Error::SinkError;
+use crate::error::ErrorKind::{InternalError, UserDefinedError};
+use crate::shared::{self, ContainerType};
+use crate::sink::sink_pb::SinkResponse;
+
 const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
 const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/sink.sock";
 const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sinker-server-info";
diff --git a/src/source.rs b/src/source.rs
index c50894a..6a67fbc 100644
--- a/src/source.rs
+++ b/src/source.rs
@@ -4,11 +4,6 @@ use std::path::PathBuf;
 use std::sync::Arc;
 use std::time::Duration;
 
-use crate::error::Error::SourceError;
-use crate::error::{Error, ErrorKind};
-use crate::shared::{self, prost_timestamp_from_utc, ContainerType};
-use crate::source::proto::{AckRequest, AckResponse, ReadRequest, ReadResponse};
-
 use chrono::{DateTime, Utc};
 use tokio::sync::mpsc::{self, Receiver, Sender};
 use tokio::sync::oneshot;
@@ -18,6 +13,11 @@ use tokio_util::sync::CancellationToken;
 use tonic::{async_trait, Request, Response, Status, Streaming};
 use tracing::{error, info};
 
+use crate::error::Error::SourceError;
+use crate::error::{Error, ErrorKind};
+use crate::shared::{self, prost_timestamp_from_utc, ContainerType};
+use crate::source::proto::{AckRequest, AckResponse, ReadRequest, ReadResponse};
+
 const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
 const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/source.sock";
 const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcer-server-info";
@@ -543,13 +543,12 @@ impl<C> Drop for Server<C> {
 
 #[cfg(test)]
 mod tests {
-    use super::{proto, Message, Offset, SourceReadRequest};
-    use crate::source;
-    use chrono::Utc;
     use std::collections::{HashMap, HashSet};
     use std::error::Error;
     use std::time::Duration;
     use std::vec;
+
+    use chrono::Utc;
     use tempfile::TempDir;
     use tokio::net::UnixStream;
     use tokio::sync::mpsc::Sender;
@@ -560,6 +559,9 @@ mod tests {
     use tower::service_fn;
     use uuid::Uuid;
 
+    use super::{proto, Message, Offset, SourceReadRequest};
+    use crate::source;
+
     // A source that repeats the `num` for the requested count
     struct Repeater {
         num: usize,