diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..2e7ee8d --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +# Description: Makefile for Rust projects + +# perform a cargo fmt on all directories containing a Cargo.toml file +.PHONY: lint +# find all directories containing Cargo.toml files +DIRS := $(shell find . -type f -name Cargo.toml -exec dirname {} \; | sort -u) +$(info Included directories: $(DIRS)) +lint: + @for dir in $(DIRS); do \ + echo "Formatting code in $$dir"; \ + cargo fmt --all --manifest-path "$$dir/Cargo.toml"; \ + done + +# run cargo test on the repository root +.PHONY: test +test: + cargo test --workspace diff --git a/build.rs b/build.rs index 0df8df8..45943b8 100644 --- a/build.rs +++ b/build.rs @@ -9,6 +9,7 @@ fn main() { "proto/reduce.proto", "proto/sink.proto", "proto/sideinput.proto", + "proto/batchmap.proto", ], &["proto"], ) diff --git a/examples/batchmap-cat/.dockerignore b/examples/batchmap-cat/.dockerignore new file mode 100644 index 0000000..9f97022 --- /dev/null +++ b/examples/batchmap-cat/.dockerignore @@ -0,0 +1 @@ +target/ \ No newline at end of file diff --git a/examples/batchmap-cat/Cargo.toml b/examples/batchmap-cat/Cargo.toml new file mode 100644 index 0000000..781f868 --- /dev/null +++ b/examples/batchmap-cat/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "batchmap-cat" +version = "0.1.0" +edition = "2021" + + +[[bin]] +name = "server" +path = "src/main.rs" + +[dependencies] +tonic = "0.12.0" +tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } +numaflow = { path = "../../" } diff --git a/examples/batchmap-cat/Dockerfile b/examples/batchmap-cat/Dockerfile new file mode 100644 index 0000000..a496d1d --- /dev/null +++ b/examples/batchmap-cat/Dockerfile @@ -0,0 +1,20 @@ +FROM rust:1.75-bookworm AS build + +RUN apt-get update +RUN apt-get install protobuf-compiler -y + +WORKDIR /numaflow-rs +COPY ./ ./ +WORKDIR /numaflow-rs/examples/batchmap-cat + +# build for release +RUN cargo build --release + +# our final base +FROM debian:bookworm AS map-cat + +# copy the build artifact from the build stage +COPY --from=build /numaflow-rs/examples/batchmap-cat/target/release/server . + +# set the startup command to run your binary +CMD ["./server"] diff --git a/examples/batchmap-cat/Makefile b/examples/batchmap-cat/Makefile new file mode 100644 index 0000000..8f9436a --- /dev/null +++ b/examples/batchmap-cat/Makefile @@ -0,0 +1,20 @@ +TAG ?= stable +PUSH ?= false +IMAGE_REGISTRY = quay.io/numaio/numaflow-rs/batchmap-cat:${TAG} +DOCKER_FILE_PATH = examples/batchmap-cat/Dockerfile + +.PHONY: update +update: + cargo check + cargo update + +.PHONY: image +image: update + cd ../../ && docker build \ + -f ${DOCKER_FILE_PATH} \ + -t ${IMAGE_REGISTRY} . + @if [ "$(PUSH)" = "true" ]; then docker push ${IMAGE_REGISTRY}; fi + +.PHONY: clean +clean: + -rm -rf target diff --git a/examples/batchmap-cat/manifests/simple-batchmap-cat.yaml b/examples/batchmap-cat/manifests/simple-batchmap-cat.yaml new file mode 100644 index 0000000..60fc23c --- /dev/null +++ b/examples/batchmap-cat/manifests/simple-batchmap-cat.yaml @@ -0,0 +1,31 @@ +apiVersion: numaflow.numaproj.io/v1alpha1 +kind: Pipeline +metadata: + name: rust-batchmap-cat +spec: + vertices: + - name: in + source: + # A self data generating source + generator: + rpu: 300 + duration: 1s + keyCount: 5 + value: 5 + - name: cat + scale: + min: 1 + udf: + container: + image: quay.io/numaio/numaflow-rs/batchmap-cat:stable + - name: out + sink: + # A simple log printing sink + log: { } + edges: + - from: in + to: cat + - from: cat + to: out + + diff --git a/examples/batchmap-cat/src/main.rs b/examples/batchmap-cat/src/main.rs new file mode 100644 index 0000000..9bf4c73 --- /dev/null +++ b/examples/batchmap-cat/src/main.rs @@ -0,0 +1,26 @@ +use numaflow::batchmap; +use numaflow::batchmap::{BatchResponse, Datum, Message}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + batchmap::Server::new(Cat).start().await +} + +struct Cat; + +#[tonic::async_trait] +impl batchmap::BatchMapper for Cat { + async fn batchmap(&self, mut input: tokio::sync::mpsc::Receiver) -> Vec { + let mut responses: Vec = Vec::new(); + while let Some(datum) = input.recv().await { + let mut response = BatchResponse::from_id(datum.id); + response.append(Message { + keys: Some(datum.keys), + value: datum.value, + tags: None, + }); + responses.push(response); + } + responses + } +} diff --git a/examples/reduce-counter/src/main.rs b/examples/reduce-counter/src/main.rs index 77fa50b..83146ab 100644 --- a/examples/reduce-counter/src/main.rs +++ b/examples/reduce-counter/src/main.rs @@ -9,7 +9,7 @@ async fn main() -> Result<(), Box> { mod counter { use numaflow::reduce::{Message, ReduceRequest}; - use numaflow::reduce::{Reducer, Metadata}; + use numaflow::reduce::{Metadata, Reducer}; use tokio::sync::mpsc::Receiver; use tonic::async_trait; @@ -44,8 +44,10 @@ mod counter { while input.recv().await.is_some() { counter += 1; } - let message = Message::new(counter.to_string().into_bytes()).tags(vec![]).keys(keys.clone()); + let message = Message::new(counter.to_string().into_bytes()) + .tags(vec![]) + .keys(keys.clone()); vec![message] } } -} \ No newline at end of file +} diff --git a/examples/sideinput/src/main.rs b/examples/sideinput/src/main.rs index 448a152..39af6dd 100644 --- a/examples/sideinput/src/main.rs +++ b/examples/sideinput/src/main.rs @@ -1,7 +1,6 @@ +use numaflow::sideinput::{self, SideInputer}; use std::sync::Mutex; use std::time::{SystemTime, UNIX_EPOCH}; -use numaflow::sideinput::{self, SideInputer}; - use tonic::async_trait; @@ -37,5 +36,7 @@ impl SideInputer for SideInputHandler { #[tokio::main] async fn main() -> Result<(), Box> { - sideinput::Server::new(SideInputHandler::new()).start().await + sideinput::Server::new(SideInputHandler::new()) + .start() + .await } diff --git a/examples/sideinput/udf/src/main.rs b/examples/sideinput/udf/src/main.rs index c998b76..f8d3bb5 100644 --- a/examples/sideinput/udf/src/main.rs +++ b/examples/sideinput/udf/src/main.rs @@ -1,7 +1,7 @@ use std::path::Path; use notify::{RecursiveMode, Result, Watcher}; -use numaflow::map::{Mapper, MapRequest, Message, Server}; +use numaflow::map::{MapRequest, Mapper, Message, Server}; use tokio::spawn; use tonic::async_trait; diff --git a/proto/batchmap.proto b/proto/batchmap.proto new file mode 100644 index 0000000..1670bdc --- /dev/null +++ b/proto/batchmap.proto @@ -0,0 +1,50 @@ +syntax = "proto3"; + +import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; + +package batchmap.v1; + +service BatchMap { + // IsReady is the heartbeat endpoint for gRPC. + rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); + + // BatchMapFn is a bi-directional streaming rpc which applies a + // Map function on each BatchMapRequest element of the stream and then returns streams + // back MapResponse elements. + rpc BatchMapFn(stream BatchMapRequest) returns (stream BatchMapResponse); +} + +/** + * BatchMapRequest represents a request element. + */ +message BatchMapRequest { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + map headers = 5; + // This ID is used uniquely identify a map request + string id = 6; +} + +/** + * BatchMapResponse represents a response element. + */ +message BatchMapResponse { + message Result { + repeated string keys = 1; + bytes value = 2; + repeated string tags = 3; + } + repeated Result results = 1; + // This ID is used to refer the responses to the request it corresponds to. + string id = 2; +} + +/** + * ReadyResponse is the health check result. + */ +message ReadyResponse { + bool ready = 1; +} \ No newline at end of file diff --git a/src/batchmap.rs b/src/batchmap.rs new file mode 100644 index 0000000..d82f22a --- /dev/null +++ b/src/batchmap.rs @@ -0,0 +1,780 @@ +use chrono::{DateTime, Utc}; +use std::collections::HashMap; +use std::fs; +use std::path::PathBuf; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::mpsc::channel; +use tokio::sync::{mpsc, oneshot}; +use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; +use tonic::{Request, Response, Status, Streaming}; + +use crate::batchmap::proto::batch_map_server::BatchMap; +use crate::error::Error; +use crate::error::Error::BatchMapError; +use crate::error::ErrorKind::{InternalError, UserDefinedError}; +use crate::shared; +use crate::shared::shutdown_signal; + +const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; +const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/batchmap.sock"; +const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/mapper-server-info"; +const DROP: &str = "U+005C__DROP__"; +/// Numaflow Batch Map Proto definitions. +pub mod proto { + tonic::include_proto!("batchmap.v1"); +} + +struct BatchMapService { + handler: Arc, + _shutdown_tx: mpsc::Sender<()>, + cancellation_token: CancellationToken, +} + +/// BatchMapper trait for implementing batch mode user defined function. +/// +/// Types implementing this trait can be passed as user defined batch map handle. +#[tonic::async_trait] +pub trait BatchMapper { + /// The batch map handle is given a stream of [`Datum`] and the result is + /// Vec of [`BatchResponse`]. + /// Here it's important to note that the size of the vec for the responses + /// should be equal to the number of elements in the input stream. + /// + /// # Example + /// + /// A simple batch map. + /// + /// ```no_run + /// use numaflow::batchmap::{self, BatchResponse, Datum, Message}; + /// use std::error::Error; + /// + /// struct FlatMap; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// batchmap::Server::new(FlatMap).start().await + /// } + /// + /// #[tonic::async_trait] + /// impl batchmap::BatchMapper for FlatMap { + /// + /// async fn batchmap(&self, mut input: tokio::sync::mpsc::Receiver) -> Vec { + /// let mut responses: Vec = Vec::new(); + /// while let Some(datum) = input.recv().await { + /// let mut response = BatchResponse::from_id(datum.id); + /// response.append(Message { + /// keys: Option::from(datum.keys), + /// value: datum.value, + /// tags: None, + /// }); + /// responses.push(response); + /// } + /// responses + /// } + /// } + /// ``` + async fn batchmap(&self, input: mpsc::Receiver) -> Vec; +} + +/// Incoming request into the handler of [`BatchMapper`]. +pub struct Datum { + /// Set of keys in the (key, value) terminology of map/reduce paradigm. + pub keys: Vec, + /// The value in the (key, value) terminology of map/reduce paradigm. + pub value: Vec, + /// [watermark](https://numaflow.numaproj.io/core-concepts/watermarks/) represented by time is a guarantee that we will not see an element older than this time. + pub watermark: DateTime, + /// Time of the element as seen at source or aligned after a reduce operation. + pub event_time: DateTime, + /// ID is the unique id of the message to be sent to the Batch Map. + pub id: String, + /// Headers for the message. + pub headers: HashMap, +} + +impl From for Datum { + fn from(sr: proto::BatchMapRequest) -> Self { + Self { + keys: sr.keys, + value: sr.value, + watermark: shared::utc_from_timestamp(sr.watermark), + event_time: shared::utc_from_timestamp(sr.event_time), + id: sr.id, + headers: sr.headers, + } + } +} +/// Message is the response struct from the [`Mapper::map`] . +#[derive(Debug, PartialEq)] +pub struct Message { + /// Keys are a collection of strings which will be passed on to the next vertex as is. It can + /// be an empty collection. + pub keys: Option>, + /// Value is the value passed to the next vertex. + pub value: Vec, + /// Tags are used for [conditional forwarding](https://numaflow.numaproj.io/user-guide/reference/conditional-forwarding/). + pub tags: Option>, +} + +/// Represents a message that can be modified and forwarded. +impl crate::batchmap::Message { + /// Creates a new message with the specified value. + /// + /// This constructor initializes the message with no keys, tags, or specific event time. + /// + /// # Arguments + /// + /// * `value` - A vector of bytes representing the message's payload. + /// + /// # Examples + /// + /// ``` + /// use numaflow::batchmap::Message; + /// let message = Message::new(vec![1, 2, 3, 4]); + /// ``` + pub fn new(value: Vec) -> Self { + Self { + value, + keys: None, + tags: None, + } + } + /// Marks the message to be dropped by creating a new `Message` with an empty value and a special "DROP" tag. + /// + /// # Examples + /// + /// ``` + /// use numaflow::batchmap::Message; + /// let dropped_message = Message::message_to_drop(); + /// ``` + pub fn message_to_drop() -> crate::batchmap::Message { + crate::batchmap::Message { + keys: None, + value: vec![], + tags: Some(vec![crate::batchmap::DROP.to_string()]), + } + } + + /// Sets or replaces the keys associated with this message. + /// + /// # Arguments + /// + /// * `keys` - A vector of strings representing the keys. + /// + /// # Examples + /// + /// ``` + /// use numaflow::batchmap::Message; + /// let message = Message::new(vec![1, 2, 3]).keys(vec!["key1".to_string(), "key2".to_string()]); + /// ``` + pub fn keys(mut self, keys: Vec) -> Self { + self.keys = Some(keys); + self + } + + /// Sets or replaces the tags associated with this message. + /// + /// # Arguments + /// + /// * `tags` - A vector of strings representing the tags. + /// + /// # Examples + /// + /// ``` + /// use numaflow::batchmap::Message; + /// let message = Message::new(vec![1, 2, 3]).tags(vec!["tag1".to_string(), "tag2".to_string()]); + /// ``` + pub fn tags(mut self, tags: Vec) -> Self { + self.tags = Some(tags); + self + } + + /// Replaces the value of the message. + /// + /// # Arguments + /// + /// * `value` - A new vector of bytes that replaces the current message value. + /// + /// # Examples + /// + /// ``` + /// use numaflow::batchmap::Message; + /// let message = Message::new(vec![1, 2, 3]).value(vec![4, 5, 6]); + /// ``` + pub fn value(mut self, value: Vec) -> Self { + self.value = value; + self + } +} +/// The result of the call to [`BatchMapper::batchmap`] method. +pub struct BatchResponse { + /// id is the unique ID of the message. + pub id: String, + // message is the response from the batch map. + pub message: Vec, +} + +impl BatchResponse { + /// Creates a new `BatchResponse` for a given id and empty message. + pub fn from_id(id: String) -> Self { + Self { + id, + message: Vec::new(), + } + } + + /// append a message to the response. + pub fn append(&mut self, message: Message) { + self.message.push(message); + } +} + +impl From for proto::batch_map_response::Result { + fn from(value: Message) -> Self { + proto::batch_map_response::Result { + keys: value.keys.unwrap_or_default(), + value: value.value, + tags: value.tags.unwrap_or_default(), + } + } +} + +#[tonic::async_trait] +impl BatchMap for BatchMapService +where + T: BatchMapper + Send + Sync + 'static, +{ + async fn is_ready( + &self, + _: Request<()>, + ) -> Result, Status> { + Ok(tonic::Response::new(proto::ReadyResponse { ready: true })) + } + + type BatchMapFnStream = ReceiverStream>; + + async fn batch_map_fn( + &self, + request: Request>, + ) -> Result, Status> { + let mut stream = request.into_inner(); + + // Create a channel to send the messages to the user defined function. + let (tx, rx) = mpsc::channel::(1); + + // Create a channel to send the response back to the grpc client. + let (grpc_response_tx, grpc_response_rx) = + channel::>(1); + + // clone the shutdown_tx to be used in the writer spawn + let shutdown_tx = self._shutdown_tx.clone(); + + // clone the cancellation token to be used in the writer spawn + let writer_cln_token = self.cancellation_token.clone(); + + // counter to keep track of the number of messages received + let total_messages_recvd = Arc::new(AtomicUsize::new(0)); + + // clone the counter to be used in the request spawn + let counter = Arc::clone(&total_messages_recvd); + + // clone the shutdown_tx to be used in the request spawn + let read_shutdown_tx = shutdown_tx.clone(); + // read the messages from the grpc client and send it to the user defined function + let read_handler = tokio::spawn(async move { + loop { + match stream.message().await { + Ok(Some(message)) => { + let datum = Datum::from(message); + if let Err(e) = tx.send(datum).await { + tracing::error!("Failed to send message: {}", e); + break; + } + counter.fetch_add(1, Ordering::Relaxed); + } + // If there's an error or the stream ends, break the loop to stop the task. + // and send a shutdown signal to the grpc server. + Ok(None) => break, + Err(e) => { + tracing::error!("Error reading message: {}", e); + read_shutdown_tx + .send(()) + .await + .expect("shutdown_tx send failed"); + break; + } + } + } + }); + + // Create a channel for receiving the response from the user defined function. + let (response_tx, mut response_rx) = channel::>(1); + + let handler = Arc::clone(&self.handler); + + let udf_response_tx = response_tx.clone(); + // spawn a task to invoke the user defined function + let udf_task_handle = tokio::spawn(async move { + // wait for the messages to be received + let messages = handler.batchmap(rx).await; + + let counter = total_messages_recvd.load(Ordering::Relaxed); + // check if the number of responses matches the number of messages received + // if not send an error back to the grpc client. + if counter != messages.len() { + let _ = udf_response_tx + .send(Err(BatchMapError(InternalError( + "number of responses does not match the number of messages received" + .to_string(), + )))) + .await; + return; + } + + // send the response back to the grpc client + for response in messages { + let send_result = udf_response_tx + .send(Ok(proto::BatchMapResponse { + results: response.message.into_iter().map(|m| m.into()).collect(), + id: response.id, + })) + .await; + // if there's an error sending the response back, send an error back to the grpc client. + if let Err(e) = send_result { + let _ = udf_response_tx + .send(Err(BatchMapError(InternalError(format!( + "Failed to send response back: {}", + e + ))))) + .await; + return; + } + } + }); + + // Spawn a task to handle the error from the user defined function + let error_handle = tokio::spawn(async move { + // if there was an error while executing the user defined function spawn, + // send an error back to the grpc client. + if let Err(e) = udf_task_handle.await { + let _ = response_tx + .send(Err(BatchMapError(UserDefinedError(format!(" {}", e))))) + .await; + } + }); + + // Spawn a task to write the response to the grpc client, we also need to check if the cancel token is set + // in that case we need to stop the task. + tokio::spawn(async move { + // wait for the batch map handle to respond + loop { + tokio::select! { + response = response_rx.recv() => { + match response { + Some(Ok(response)) => { + grpc_response_tx + .send(Ok(response)) + .await + .expect("send to grpc response channel failed"); + }, + Some(Err(error)) => { + tracing::error!("Error from UDF: {:?}", error); + grpc_response_tx + .send(Err(Status::internal(error.to_string()))) + .await + .expect("send to grpc response channel failed"); + // Send a shutdown signal to the grpc server. + shutdown_tx.send(()).await.expect("shutdown_tx send failed"); + } + None => { + // TODO: What should be for None? Is this reachable + break; + } + } + } + // If the cancellation token is set, stop the task. + _ = writer_cln_token.cancelled() => { + tracing::info!("token cancelled!, shutting down"); + // Send an abort signal to the task executor to abort all the tasks. + error_handle.abort(); + read_handler.abort(); + break; + } + } + } + }); + + // Return the receiver stream to the client + Ok(Response::new(ReceiverStream::new(grpc_response_rx))) + } +} + +/// gRPC server to start a batch map service +#[derive(Debug)] +pub struct Server { + sock_addr: PathBuf, + max_message_size: usize, + server_info_file: PathBuf, + svc: Option, +} +impl crate::batchmap::Server { + pub fn new(batch_map_svc: T) -> Self { + crate::batchmap::Server { + sock_addr: DEFAULT_SOCK_ADDR.into(), + max_message_size: DEFAULT_MAX_MESSAGE_SIZE, + server_info_file: DEFAULT_SERVER_INFO_FILE.into(), + svc: Some(batch_map_svc), + } + } + + /// Set the unix domain socket file path used by the gRPC server to listen for incoming connections. + /// Default value is `/var/run/numaflow/batchmap.sock` + pub fn with_socket_file(mut self, file: impl Into) -> Self { + self.sock_addr = file.into(); + self + } + + /// Get the unix domain socket file path where gRPC server listens for incoming connections. Default value is `/var/run/numaflow/batchmap.sock` + pub fn socket_file(&self) -> &std::path::Path { + self.sock_addr.as_path() + } + + /// Set the maximum size of an encoded and decoded gRPC message. The value of `message_size` is in bytes. Default value is 64MB. + pub fn with_max_message_size(mut self, message_size: usize) -> Self { + self.max_message_size = message_size; + self + } + + /// Get the maximum size of an encoded and decoded gRPC message in bytes. Default value is 64MB. + pub fn max_message_size(&self) -> usize { + self.max_message_size + } + + /// Change the file in which numaflow server information is stored on start up to the new value. Default value is `/var/run/numaflow/batchmapper-server-info` + pub fn with_server_info_file(mut self, file: impl Into) -> Self { + self.server_info_file = file.into(); + self + } + + /// Get the path to the file where numaflow server info is stored. Default value is `/var/run/numaflow/mapper-server-info` + pub fn server_info_file(&self) -> &std::path::Path { + self.server_info_file.as_path() + } + + /// Starts the gRPC server. When message is received on the `shutdown` channel, graceful shutdown of the gRPC server will be initiated. + pub async fn start_with_shutdown( + &mut self, + shutdown_rx: oneshot::Receiver<()>, + ) -> Result<(), Box> + where + T: BatchMapper + Send + Sync + 'static, + { + let mut info = shared::default_info_file(); + // update the info json metadata field, and add the map mode + info["metadata"][shared::MAP_MODE_KEY] = + serde_json::Value::String(shared::BATCH_MAP.to_string()); + let listener = + shared::create_listener_stream(&self.sock_addr, &self.server_info_file, info)?; + let handler = self.svc.take().unwrap(); + + let cln_token = CancellationToken::new(); + + // Create a channel to send shutdown signal to the server to do graceful shutdown in case of non retryable errors. + let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); + let map_svc = crate::batchmap::BatchMapService { + handler: Arc::new(handler), + _shutdown_tx: internal_shutdown_tx, + cancellation_token: cln_token.clone(), + }; + + let map_svc = proto::batch_map_server::BatchMapServer::new(map_svc) + .max_encoding_message_size(self.max_message_size) + .max_decoding_message_size(self.max_message_size); + + let shutdown = shutdown_signal(internal_shutdown_rx, Some(shutdown_rx)); + + tonic::transport::Server::builder() + .add_service(map_svc) + .serve_with_incoming_shutdown(listener, shutdown) + .await?; + + Ok(()) + } + + /// Starts the gRPC server. Automatically registers signal handlers for SIGINT and SIGTERM and initiates graceful shutdown of gRPC server when either one of the signal arrives. + pub async fn start(&mut self) -> Result<(), Box> + where + T: BatchMapper + Send + Sync + 'static, + { + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + self.start_with_shutdown(shutdown_rx).await + } +} + +impl Drop for Server { + // Cleanup the socket file when the server is dropped so that when the server is restarted, it can bind to the + // same address. UnixListener doesn't implement Drop trait, so we have to manually remove the socket file. + fn drop(&mut self) { + let _ = fs::remove_file(&self.sock_addr); + } +} + +#[cfg(test)] +mod tests { + use std::{error::Error, time::Duration}; + + use tempfile::TempDir; + use tokio::net::UnixStream; + use tokio::sync::mpsc::Receiver; + use tokio::sync::oneshot; + use tonic::transport::Uri; + use tower::service_fn; + + use crate::batchmap; + use crate::batchmap::proto::batch_map_client::BatchMapClient; + use crate::batchmap::{BatchResponse, Datum, Message}; + + #[tokio::test] + async fn batch_map_server() -> Result<(), Box> { + struct Logger; + #[tonic::async_trait] + impl batchmap::BatchMapper for Logger { + async fn batchmap(&self, mut input: Receiver) -> Vec { + let mut responses: Vec = Vec::new(); + while let Some(datum) = input.recv().await { + let mut response = BatchResponse::from_id(datum.id); + response.append(Message { + keys: Option::from(datum.keys), + value: datum.value, + tags: None, + }); + responses.push(response); + } + responses + } + } + + let tmp_dir = TempDir::new()?; + let sock_file = tmp_dir.path().join("batchmap.sock"); + let server_info_file = tmp_dir.path().join("batchmapper-server-info"); + + let mut server = batchmap::Server::new(Logger) + .with_server_info_file(&server_info_file) + .with_socket_file(&sock_file) + .with_max_message_size(10240); + + assert_eq!(server.max_message_size(), 10240); + assert_eq!(server.server_info_file(), server_info_file); + assert_eq!(server.socket_file(), sock_file); + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs + let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")? + .connect_with_connector(service_fn(move |_: Uri| { + // https://rust-lang.github.io/async-book/03_async_await/01_chapter.html#async-lifetimes + let sock_file = sock_file.clone(); + async move { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + UnixStream::connect(sock_file).await?, + )) + } + })) + .await?; + + let mut client = BatchMapClient::new(channel); + let request = batchmap::proto::BatchMapRequest { + keys: vec!["first".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + id: "1".to_string(), + headers: Default::default(), + }; + + let request2 = batchmap::proto::BatchMapRequest { + keys: vec!["second".into()], + value: "hello2".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + id: "2".to_string(), + headers: Default::default(), + }; + + let resp = client + .batch_map_fn(tokio_stream::iter(vec![request, request2])) + .await?; + let mut r = resp.into_inner(); + let mut responses = Vec::new(); + + while let Some(response) = r.message().await? { + responses.push(response); + } + + assert_eq!(responses.len(), 2, "Expected two message from server"); + assert_eq!(&responses[0].id, "1"); + assert_eq!(&responses[1].id, "2"); + + shutdown_tx + .send(()) + .expect("Sending shutdown signal to gRPC server"); + tokio::time::sleep(Duration::from_millis(50)).await; + assert!(task.is_finished(), "gRPC server is still running"); + Ok(()) + } + + #[tokio::test] + async fn error_length() -> Result<(), Box> { + struct Logger; + #[tonic::async_trait] + impl batchmap::BatchMapper for Logger { + async fn batchmap(&self, mut input: Receiver) -> Vec { + let responses: Vec = Vec::new(); + while let Some(_datum) = input.recv().await {} + responses + } + } + + let tmp_dir = TempDir::new()?; + let sock_file = tmp_dir.path().join("batchmap.sock"); + let server_info_file = tmp_dir.path().join("batchmapper-server-info"); + + let mut server = batchmap::Server::new(Logger) + .with_server_info_file(&server_info_file) + .with_socket_file(&sock_file) + .with_max_message_size(10240); + + assert_eq!(server.max_message_size(), 10240); + assert_eq!(server.server_info_file(), server_info_file); + assert_eq!(server.socket_file(), sock_file); + + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs + let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")? + .connect_with_connector(service_fn(move |_: Uri| { + // https://rust-lang.github.io/async-book/03_async_await/01_chapter.html#async-lifetimes + let sock_file = sock_file.clone(); + async move { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + UnixStream::connect(sock_file).await?, + )) + } + })) + .await?; + + let mut client = BatchMapClient::new(channel); + let request = batchmap::proto::BatchMapRequest { + keys: vec!["first".into(), "second".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + id: "1".to_string(), + headers: Default::default(), + }; + + let resp = client + .batch_map_fn(tokio_stream::iter(vec![request])) + .await?; + let mut r = resp.into_inner(); + + let Err(server_err) = r.message().await else { + return Err("Expected error from server".into()); + }; + + assert_eq!(server_err.code(), tonic::Code::Internal); + assert!(server_err.message().contains( + "number of responses does not \ + match the number of messages received" + )); + + tokio::time::sleep(Duration::from_millis(50)).await; + assert!(task.is_finished(), "gRPC server is still running"); + Ok(()) + } + #[tokio::test] + async fn batchmap_panic() -> Result<(), Box> { + struct PanicBatch; + #[tonic::async_trait] + impl batchmap::BatchMapper for PanicBatch { + async fn batchmap(&self, _input: Receiver) -> Vec { + panic!("Should not cross 5"); + } + } + + let tmp_dir = TempDir::new()?; + let sock_file = tmp_dir.path().join("batchmap.sock"); + let server_info_file = tmp_dir.path().join("mapper-server-info"); + + let mut server = batchmap::Server::new(PanicBatch) + .with_server_info_file(&server_info_file) + .with_socket_file(&sock_file) + .with_max_message_size(10240); + + assert_eq!(server.max_message_size(), 10240); + assert_eq!(server.server_info_file(), server_info_file); + assert_eq!(server.socket_file(), sock_file); + + let (_shutdown_tx, shutdown_rx) = oneshot::channel(); + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs + let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")? + .connect_with_connector(service_fn(move |_: Uri| { + // https://rust-lang.github.io/async-book/03_async_await/01_chapter.html#async-lifetimes + let sock_file = sock_file.clone(); + async move { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + UnixStream::connect(sock_file).await?, + )) + } + })) + .await?; + + let mut client = BatchMapClient::new(channel); + let mut requests = Vec::new(); + + for i in 0..10 { + let request = batchmap::proto::BatchMapRequest { + keys: vec!["first".into(), "second".into()], + value: format!("hello {}", i).into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + id: i.to_string(), + headers: Default::default(), + }; + requests.push(request); + } + + let resp = client.batch_map_fn(tokio_stream::iter(requests)).await?; + let mut response_stream = resp.into_inner(); + + if let Err(e) = response_stream.message().await { + println!("Error: {:?}", e); + assert_eq!(e.code(), tonic::Code::Internal); + assert!(e.message().contains("User Defined Error")) + } else { + return Err("Expected error from server".into()); + }; + + // server should shut down gracefully because there was a panic in the handler. + for _ in 0..10 { + tokio::time::sleep(Duration::from_millis(10)).await; + if task.is_finished() { + break; + } + } + assert!(task.is_finished(), "gRPC server is still running"); + Ok(()) + } +} diff --git a/src/error.rs b/src/error.rs index d517799..e33102f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -23,6 +23,9 @@ pub enum Error { #[error("Source Error - {0}")] SourceError(ErrorKind), + #[error("BatchMap Error - {0}")] + BatchMapError(ErrorKind), + #[error("Source Transformer Error: {0}")] SourceTransformerError(ErrorKind), diff --git a/src/lib.rs b/src/lib.rs index c7d7320..0ad867d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,6 +32,9 @@ pub mod sink; /// building [side input](https://numaflow.numaproj.io/user-guide/reference/side-inputs/) pub mod sideinput; +/// batchmap is for writing the [batch map mode](https://numaflow.numaproj.io/user-guide/user-defined-functions/map/batchmap/) handlers. +pub mod batchmap; + // Error handling on Numaflow SDKs! // // Any non-recoverable error will cause the process to shutdown with a non-zero exit status. All errors are non-recoverable. diff --git a/src/map.rs b/src/map.rs index 56c5345..5e43c56 100644 --- a/src/map.rs +++ b/src/map.rs @@ -308,7 +308,12 @@ impl Server { where T: Mapper + Send + Sync + 'static, { - let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; + let mut info = shared::default_info_file(); + // update the info json metadata field, and add the map mode key value pair + info["metadata"][shared::MAP_MODE_KEY] = + serde_json::Value::String(shared::UNARY_MAP.to_string()); + let listener = + shared::create_listener_stream(&self.sock_addr, &self.server_info_file, info)?; let handler = self.svc.take().unwrap(); let cln_token = CancellationToken::new(); diff --git a/src/reduce.rs b/src/reduce.rs index 554c8bd..7e51c51 100644 --- a/src/reduce.rs +++ b/src/reduce.rs @@ -817,7 +817,11 @@ impl Server { where C: ReducerCreator + Send + Sync + 'static, { - let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; + let listener = shared::create_listener_stream( + &self.sock_addr, + &self.server_info_file, + shared::default_info_file(), + )?; let creator = self.creator.take().unwrap(); let (internal_shutdown_tx, internal_shutdown_rx) = channel(1); let cln_token = CancellationToken::new(); diff --git a/src/shared.rs b/src/shared.rs index 6a4123c..c127ae5 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -10,23 +10,39 @@ use tokio::sync::{mpsc, oneshot}; use tokio_stream::wrappers::UnixListenerStream; use tracing::info; +// default_info_file is a function to get a default server info json +// file content. This is used to write the server info file. +// This function is used in the write_info_file function. +// This function is not exposed to the user. +pub fn default_info_file() -> serde_json::Value { + let metadata: HashMap = HashMap::new(); + serde_json::json!({ + "protocol": "uds", + "language": "rust", + "version": "0.0.1", + "metadata": metadata, + }) +} +pub(crate) const MAP_MODE_KEY: &str = "MAP_MODE"; +pub(crate) const UNARY_MAP: &str = "unary-map"; +pub(crate) const STREAM_MAP: &str = "stream-map"; +pub(crate) const BATCH_MAP: &str = "batch-map"; + // #[tracing::instrument(skip(path), fields(path = ?path.as_ref()))] #[tracing::instrument(fields(path = ? path.as_ref()))] -fn write_info_file(path: impl AsRef) -> io::Result<()> { +fn write_info_file(path: impl AsRef, mut server_info: serde_json::Value) -> io::Result<()> { let parent = path.as_ref().parent().unwrap(); fs::create_dir_all(parent)?; // TODO: make port-number and CPU meta-data configurable, e.g., ("CPU_LIMIT", "1") - let metadata: HashMap = HashMap::new(); - let info = serde_json::json!({ - "protocol": "uds", - "language": "rust", - "version": "0.0.1", - "metadata": metadata, - }); + + // if server_info object is not provided, use the default one + if server_info.is_null() { + server_info = default_info_file(); + } // Convert to a string of JSON and print it out - let content = format!("{}U+005C__END__", info); + let content = format!("{}U+005C__END__", server_info); info!(content, "Writing to file"); fs::write(path, content) } @@ -34,8 +50,10 @@ fn write_info_file(path: impl AsRef) -> io::Result<()> { pub(crate) fn create_listener_stream( socket_file: impl AsRef, server_info_file: impl AsRef, + server_info: serde_json::Value, ) -> Result> { - write_info_file(server_info_file).map_err(|e| format!("writing info file: {e:?}"))?; + write_info_file(server_info_file, server_info) + .map_err(|e| format!("writing info file: {e:?}"))?; let uds_stream = UnixListener::bind(socket_file)?; Ok(UnixListenerStream::new(uds_stream)) @@ -149,8 +167,14 @@ mod tests { // Create a temporary file let temp_file = NamedTempFile::new()?; + // Get a default server info file content + // let server_info = default_info_file(); + let mut info = default_info_file(); + // update the info json metadata field, and add the map mode key value pair + info["metadata"][MAP_MODE_KEY] = serde_json::Value::String(BATCH_MAP.to_string()); + // Call write_info_file with the path of the temporary file - write_info_file(temp_file.path())?; + write_info_file(temp_file.path(), info)?; // Open the file and read its contents let mut file = File::open(temp_file.path())?; @@ -161,7 +185,7 @@ mod tests { assert!(contents.contains(r#""protocol":"uds""#)); assert!(contents.contains(r#""language":"rust""#)); assert!(contents.contains(r#""version":"0.0.1""#)); - assert!(contents.contains(r#""metadata":{}"#)); + assert!(contents.contains(r#""metadata":{"MAP_MODE":"batch-map"}"#)); Ok(()) } diff --git a/src/sideinput.rs b/src/sideinput.rs index 6256a46..3c82706 100644 --- a/src/sideinput.rs +++ b/src/sideinput.rs @@ -194,7 +194,11 @@ impl Server { where T: SideInputer + Send + Sync + 'static, { - let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; + let listener = shared::create_listener_stream( + &self.sock_addr, + &self.server_info_file, + shared::default_info_file(), + )?; let handler = self.svc.take().unwrap(); let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); let cln_token = CancellationToken::new(); diff --git a/src/sink.rs b/src/sink.rs index 618bbee..2e5eaea 100644 --- a/src/sink.rs +++ b/src/sink.rs @@ -323,7 +323,11 @@ impl Server { where T: Sinker + Send + Sync + 'static, { - let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; + let listener = shared::create_listener_stream( + &self.sock_addr, + &self.server_info_file, + shared::default_info_file(), + )?; let handler = self.svc.take().unwrap(); let cln_token = CancellationToken::new(); let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); diff --git a/src/source.rs b/src/source.rs index abd8d41..fa54be0 100644 --- a/src/source.rs +++ b/src/source.rs @@ -264,7 +264,11 @@ impl Server { where T: Sourcer + Send + Sync + 'static, { - let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; + let listener = shared::create_listener_stream( + &self.sock_addr, + &self.server_info_file, + shared::default_info_file(), + )?; let handler = self.svc.take().unwrap(); let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); let cln_token = CancellationToken::new(); diff --git a/src/sourcetransform.rs b/src/sourcetransform.rs index b95fe72..1372586 100644 --- a/src/sourcetransform.rs +++ b/src/sourcetransform.rs @@ -334,7 +334,11 @@ impl Server { where T: SourceTransformer + Send + Sync + 'static, { - let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; + let listener = shared::create_listener_stream( + &self.sock_addr, + &self.server_info_file, + shared::default_info_file(), + )?; let handler = self.svc.take().unwrap(); let (internal_shutdown_tx, internal_shutdown_rx) = mpsc::channel(1); let cln_token = CancellationToken::new();