From 19f975ad3fffb7c8804ae2f969cc0bd2c04b3d1d Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Thu, 29 Feb 2024 20:42:09 +0530 Subject: [PATCH] feat: simplified sourcer server API changes (#32) Signed-off-by: Sreekanth --- examples/simple-source/.dockerignore | 1 + examples/simple-source/Cargo.toml | 2 +- examples/simple-source/Dockerfile | 4 +- examples/simple-source/src/main.rs | 27 +- src/map.rs | 2 +- src/source.rs | 383 +++++++++++++++++++++++---- 6 files changed, 346 insertions(+), 73 deletions(-) create mode 100644 examples/simple-source/.dockerignore diff --git a/examples/simple-source/.dockerignore b/examples/simple-source/.dockerignore new file mode 100644 index 0000000..9f97022 --- /dev/null +++ b/examples/simple-source/.dockerignore @@ -0,0 +1 @@ +target/ \ No newline at end of file diff --git a/examples/simple-source/Cargo.toml b/examples/simple-source/Cargo.toml index 248b87b..7a4bcf4 100644 --- a/examples/simple-source/Cargo.toml +++ b/examples/simple-source/Cargo.toml @@ -8,7 +8,7 @@ name = "server" path = "src/main.rs" [dependencies] -tonic = "0.9" +tonic = "0.10.2" tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", branch = "main" } chrono = "0.4.30" \ No newline at end of file diff --git a/examples/simple-source/Dockerfile b/examples/simple-source/Dockerfile index 5bc8583..9ff65d8 100644 --- a/examples/simple-source/Dockerfile +++ b/examples/simple-source/Dockerfile @@ -1,4 +1,4 @@ -FROM rust:1.70 as build +FROM rust:1.76-bookworm as build RUN apt-get update RUN apt-get install protobuf-compiler -y @@ -16,7 +16,7 @@ COPY ./Cargo.lock ./Cargo.lock RUN cargo build --release # our final base -FROM rust +FROM debian:bookworm # copy the build artifact from the build stage COPY --from=build /examples/target/release/server . diff --git a/examples/simple-source/src/main.rs b/examples/simple-source/src/main.rs index ccaadb3..e4ed64f 100644 --- a/examples/simple-source/src/main.rs +++ b/examples/simple-source/src/main.rs @@ -1,16 +1,14 @@ ///! An example for simple User Defined Source. It generates a continuous increasing sequence of offsets and some data for each call to [`numaflow::source::sourcer::read`]. #[tokio::main] -async fn main() -> Result<(), Box> { +async fn main() -> Result<(), Box> { let source_handle = simple_source::SimpleSource::new(); - numaflow::source::start_uds_server(source_handle).await?; - - Ok(()) + numaflow::source::Server::new(source_handle).start().await } pub(crate) mod simple_source { use std::{ - collections::HashMap, + collections::HashSet, sync::atomic::{AtomicUsize, Ordering}, sync::RwLock, }; @@ -24,14 +22,14 @@ pub(crate) mod simple_source { /// does not provide a mutable reference as explained in [`numaflow::source::Sourcer`] pub(crate) struct SimpleSource { read_idx: AtomicUsize, - yet_to_ack: RwLock>, + yet_to_ack: RwLock>, } impl SimpleSource { pub fn new() -> Self { Self { read_idx: AtomicUsize::new(0), - yet_to_ack: RwLock::new(HashMap::new()), + yet_to_ack: RwLock::new(HashSet::new()), } } } @@ -39,6 +37,9 @@ pub(crate) mod simple_source { #[async_trait] impl Sourcer for SimpleSource { async fn read(&self, source_request: SourceReadRequest, transmitter: Sender) { + if !self.yet_to_ack.read().unwrap().is_empty() { + return; + } let start = Instant::now(); for i in 1..=source_request.count { @@ -58,7 +59,7 @@ pub(crate) mod simple_source { value: format!("{i} at {offset}").into_bytes(), offset: Offset { offset: offset.to_be_bytes().to_vec(), - partition_id: "0".to_string(), + partition_id: 0, }, event_time: chrono::offset::Utc::now(), keys: vec![], @@ -67,10 +68,8 @@ pub(crate) mod simple_source { .unwrap(); // add the entry to hashmap to mark the offset as pending to-be-acked - match self.yet_to_ack.write() { - Ok(mut guard) => guard.insert(offset as u32, true), - Err(_) => panic!("lock has been poisoned!"), - }; + let mut yet_to_ack = self.yet_to_ack.write().expect("lock has been poisoned"); + yet_to_ack.insert(offset as u32); } } @@ -87,5 +86,9 @@ pub(crate) mod simple_source { // pending for simple source is zero since we are not reading from any external source 0 } + + async fn partitions(&self) -> Option> { + Some(vec![1]) + } } } diff --git a/src/map.rs b/src/map.rs index 3effc99..0ace85a 100644 --- a/src/map.rs +++ b/src/map.rs @@ -154,7 +154,7 @@ impl Server { self.sock_addr.as_path() } - /// Set the maximum size of an encoded and decoded gRPC message. The value of `message_size` is in bytes. Default value is 4MB. + /// Set the maximum size of an encoded and decoded gRPC message. The value of `message_size` is in bytes. Default value is 64MB. pub fn with_max_message_size(mut self, message_size: usize) -> Self { self.max_message_size = message_size; self diff --git a/src/source.rs b/src/source.rs index 639b7a3..1e05698 100644 --- a/src/source.rs +++ b/src/source.rs @@ -1,22 +1,17 @@ #![warn(missing_docs)] +use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; use crate::shared::{self, prost_timestamp_from_utc}; -use crate::source::sourcer::source_server::{Source, SourceServer}; -use crate::source::sourcer::{ - AckRequest, AckResponse, PendingResponse, ReadRequest, ReadResponse, ReadyResponse, -}; use chrono::{DateTime, Utc}; use tokio::sync::mpsc::{self, Sender}; +use tokio::sync::oneshot; use tokio_stream::wrappers::ReceiverStream; -use tonic::transport::Server; use tonic::{async_trait, Request, Response, Status}; -use self::sourcer::{partitions_response, PartitionsResponse}; - -mod sourcer { +mod proto { tonic::include_proto!("source.v1"); } @@ -71,30 +66,30 @@ pub struct Offset { } #[async_trait] -impl Source for SourceService +impl proto::source_server::Source for SourceService where T: Sourcer + Send + Sync + 'static, { - type ReadFnStream = ReceiverStream>; + type ReadFnStream = ReceiverStream>; async fn read_fn( &self, - request: Request, + request: Request, ) -> Result, Status> { let sr = request.into_inner().request.unwrap(); // tx,rx pair for sending data over to user-defined source let (stx, mut srx) = mpsc::channel::(1); // tx,rx pair for gRPC response - let (tx, rx) = mpsc::channel::>(1); + let (tx, rx) = mpsc::channel::>(1); // start the ud-source rx asynchronously and start populating the gRPC response so it can be streamed to the gRPC client (numaflow). tokio::spawn(async move { while let Some(resp) = srx.recv().await { - tx.send(Ok(ReadResponse { - result: Some(sourcer::read_response::Result { + tx.send(Ok(proto::ReadResponse { + result: Some(proto::read_response::Result { payload: resp.value, - offset: Some(sourcer::Offset { + offset: Some(proto::Offset { offset: resp.offset.offset, partition_id: resp.offset.partition_id, }), @@ -103,7 +98,7 @@ where }), })) .await - .unwrap(); + .expect("receiver dropped"); } }); @@ -125,13 +120,22 @@ where Ok(Response::new(ReceiverStream::new(rx))) } - async fn ack_fn(&self, request: Request) -> Result, Status> { - let ar: AckRequest = request.into_inner(); + async fn ack_fn( + &self, + request: Request, + ) -> Result, Status> { + let ar: proto::AckRequest = request.into_inner(); + + let success_response = Response::new(proto::AckResponse { + result: Some(proto::ack_response::Result { success: Some(()) }), + }); + + let Some(request) = ar.request else { + return Ok(success_response); + }; // invoke the user-defined source's ack handler - let offsets = ar - .request - .unwrap() + let offsets = request .offsets .into_iter() .map(|so| Offset { @@ -142,17 +146,15 @@ where self.handler.ack(offsets).await; - Ok(Response::new(AckResponse { - result: Some(sourcer::ack_response::Result { success: Some(()) }), - })) + Ok(success_response) } - async fn pending_fn(&self, _: Request<()>) -> Result, Status> { + async fn pending_fn(&self, _: Request<()>) -> Result, Status> { // invoke the user-defined source's pending handler let pending = self.handler.pending().await; - Ok(Response::new(PendingResponse { - result: Some(sourcer::pending_response::Result { + Ok(Response::new(proto::PendingResponse { + result: Some(proto::pending_response::Result { count: pending as i64, }), })) @@ -161,7 +163,7 @@ where async fn partitions_fn( &self, _request: Request<()>, - ) -> Result, Status> { + ) -> Result, Status> { let partitions = match self.handler.partitions().await { Some(v) => v, None => vec![std::env::var("NUMAFLOW_REPLICA") @@ -169,47 +171,314 @@ where .parse::() .unwrap_or_default()], }; - Ok(Response::new(PartitionsResponse { - result: Some(partitions_response::Result { partitions }), + Ok(Response::new(proto::PartitionsResponse { + result: Some(proto::partitions_response::Result { partitions }), })) } - async fn is_ready(&self, _: Request<()>) -> Result, Status> { - Ok(Response::new(ReadyResponse { ready: true })) + async fn is_ready(&self, _: Request<()>) -> Result, Status> { + Ok(Response::new(proto::ReadyResponse { ready: true })) } } /// Message is the response from the user's [`Sourcer::read`] pub struct Message { - /// Value is the value passed to the next vertex. + /// The value passed to the next vertex. pub value: Vec, - /// Offset is the offset of the message. When the message is acked, the offset is passed to the user's [`Sourcer::ack`]. + /// Offset of the message. When the message is acked, the offset is passed to the user's [`Sourcer::ack`]. pub offset: Offset, - /// EventTime is the time at which the message was generated. + /// The time at which the message was generated. pub event_time: DateTime, - /// Keys are the keys of the message. + /// Keys of the message. pub keys: Vec, } -/// Starts a gRPC server over an UDS (unix-domain-socket) endpoint. -pub async fn start_uds_server(m: T) -> Result<(), Box> -where - T: Sourcer + Send + Sync + 'static, -{ - let server_info_file = if std::env::var_os("NUMAFLOW_POD").is_some() { - "/var/run/numaflow/server-info" - } else { - "/tmp/numaflow.server-info" - }; - let socket_file = "/var/run/numaflow/source.sock"; - let listener = shared::create_listener_stream(socket_file, server_info_file)?; - let source_service = SourceService { - handler: Arc::new(m), - }; - - Server::builder() - .add_service(SourceServer::new(source_service)) - .serve_with_incoming(listener) - .await - .map_err(Into::into) +/// gRPC server for starting a [`Sourcer`] service +#[derive(Debug)] +pub struct Server { + sock_addr: PathBuf, + max_message_size: usize, + server_info_file: PathBuf, + svc: Option, +} + +impl Server { + /// Creates a new gRPC `Server` instance + pub fn new(source_svc: T) -> Self { + let server_info_file = if std::env::var_os("NUMAFLOW_POD").is_some() { + "/var/run/numaflow/server-info" + } else { + "/tmp/numaflow.server-info" + }; + Server { + sock_addr: "/var/run/numaflow/source.sock".into(), + max_message_size: 64 * 1024 * 1024, + server_info_file: server_info_file.into(), + svc: Some(source_svc), + } + } + + /// Set the unix domain socket file path used by the gRPC server to listen for incoming connections. + /// Default value is `/var/run/numaflow/source.sock` + pub fn with_socket_file(mut self, file: impl Into) -> Self { + self.sock_addr = file.into(); + self + } + + /// Get the unix domain socket file path where gRPC server listens for incoming connections. Default value is `/var/run/numaflow/source.sock` + pub fn socket_file(&self) -> &std::path::Path { + self.sock_addr.as_path() + } + + /// Set the maximum size of an encoded and decoded gRPC message. The value of `message_size` is in bytes. Default value is 4MB. + pub fn with_max_message_size(mut self, message_size: usize) -> Self { + self.max_message_size = message_size; + self + } + + /// Get the maximum size of an encoded and decoded gRPC message in bytes. Default value is 64MB. + pub fn max_message_size(&self) -> usize { + self.max_message_size + } + + /// Change the file in which numflow server information is stored on start up to the new value. Default value is `/tmp/numaflow.server-info` + pub fn with_server_info_file(mut self, file: impl Into) -> Self { + self.server_info_file = file.into(); + self + } + + /// Get the path to the file where numaflow server info is stored. Default value is `/tmp/numaflow.server-info` + pub fn server_info_file(&self) -> &std::path::Path { + self.server_info_file.as_path() + } + + /// Starts the gRPC server. When message is received on the `shutdown` channel, graceful shutdown of the gRPC server will be initiated. + pub async fn start_with_shutdown( + &mut self, + shutdown: oneshot::Receiver<()>, + ) -> Result<(), Box> + where + T: Sourcer + Send + Sync + 'static, + { + let listener = shared::create_listener_stream(&self.sock_addr, &self.server_info_file)?; + let handler = self.svc.take().unwrap(); + let source_service = SourceService { + handler: Arc::new(handler), + }; + + let source_svc = proto::source_server::SourceServer::new(source_service) + .max_decoding_message_size(self.max_message_size) + .max_decoding_message_size(self.max_message_size); + + let shutdown = async { + shutdown + .await + .expect("Receiving message from shutdown channel"); + }; + tonic::transport::Server::builder() + .add_service(source_svc) + .serve_with_incoming_shutdown(listener, shutdown) + .await + .map_err(Into::into) + } + + /// Starts the gRPC server. Automatically registers singal handlers for SIGINT and SIGTERM and initiates graceful shutdown of gRPC server when either one of the singal arrives. + pub async fn start(&mut self) -> Result<(), Box> + where + T: Sourcer + Send + Sync + 'static, + { + let (tx, rx) = oneshot::channel::<()>(); + tokio::spawn(shared::wait_for_signal(tx)); + self.start_with_shutdown(rx).await + } +} + +#[cfg(test)] +mod tests { + use super::proto; + use chrono::Utc; + use std::collections::HashSet; + use std::vec; + use std::{error::Error, time::Duration}; + use tokio_stream::StreamExt; + use tower::service_fn; + + use crate::source::{self, Message, Offset, SourceReadRequest}; + use tempfile::TempDir; + use tokio::sync::mpsc::Sender; + use tokio::sync::oneshot; + use tonic::transport::Uri; + + // A source that repeats the `num` for the requested count + struct Repeater { + num: usize, + yet_to_ack: std::sync::RwLock>, + } + + impl Repeater { + fn new(num: usize) -> Self { + Self { + num, + yet_to_ack: std::sync::RwLock::new(HashSet::new()), + } + } + } + + #[tonic::async_trait] + impl source::Sourcer for Repeater { + async fn read(&self, request: SourceReadRequest, transmitter: Sender) { + let event_time = Utc::now(); + let mut message_offsets = Vec::with_capacity(request.count); + for i in 0..request.count { + // we assume timestamp in nanoseconds would be unique on each read operation from our source + let offset = format!("{}-{}", event_time.timestamp_nanos_opt().unwrap(), i); + transmitter + .send(Message { + value: self.num.to_le_bytes().to_vec(), + event_time, + offset: Offset { + offset: offset.clone().into_bytes(), + partition_id: 0, + }, + keys: vec![], + }) + .await + .unwrap(); + message_offsets.push(offset) + } + self.yet_to_ack.write().unwrap().extend(message_offsets) + } + + async fn ack(&self, offsets: Vec) { + for offset in offsets { + self.yet_to_ack + .write() + .unwrap() + .remove(&String::from_utf8(offset.offset).unwrap()); + } + } + + async fn pending(&self) -> usize { + // The pending function should return the number of pending messages that can be read from the source. + // However, for this source the pending messages will always be 0. + // For testing purposes, we return the number of messages that are not yet acknowledged as pending. + self.yet_to_ack.read().unwrap().len() + } + + async fn partitions(&self) -> Option> { + Some(vec![2]) + } + } + + #[tokio::test] + async fn source_server() -> Result<(), Box> { + let tmp_dir = TempDir::new()?; + let sock_file = tmp_dir.path().join("source.sock"); + let server_info_file = tmp_dir.path().join("server_info"); + + let mut server = source::Server::new(Repeater::new(8)) + .with_server_info_file(&server_info_file) + .with_socket_file(&sock_file) + .with_max_message_size(10240); + + assert_eq!(server.max_message_size(), 10240); + assert_eq!(server.server_info_file(), server_info_file); + assert_eq!(server.socket_file(), sock_file); + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + // https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs + let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")? + .connect_with_connector(service_fn(move |_: Uri| { + // Connect to a Uds socket + let sock_file = sock_file.clone(); + tokio::net::UnixStream::connect(sock_file) + })) + .await?; + + let mut client = proto::source_client::SourceClient::new(channel); + let request = tonic::Request::new(proto::ReadRequest { + request: Some(proto::read_request::Request { + num_records: 5, + timeout_in_ms: 500, + }), + }); + + let resp = client.read_fn(request).await?; + let resp = resp.into_inner(); + let result: Vec = resp + .map(|item| item.unwrap().result.unwrap()) + .collect() + .await; + let response_values: Vec = result + .iter() + .map(|item| { + usize::from_le_bytes( + item.payload + .clone() + .try_into() + .expect("expected Vec length to be 8"), + ) + }) + .collect(); + assert_eq!(response_values, vec![8, 8, 8, 8, 8]); + + let pending_before_ack = client + .pending_fn(tonic::Request::new(())) + .await + .unwrap() + .into_inner(); + assert_eq!( + pending_before_ack.result.unwrap().count, + 5, + "Expected pending messages to be 5 before ACK" + ); + + let offsets_to_ack: Vec = result + .iter() + .map(|item| item.clone().offset.unwrap()) + .collect(); + let ack_request = tonic::Request::new(proto::AckRequest { + request: Some(proto::ack_request::Request { + offsets: offsets_to_ack, + }), + }); + let resp = client.ack_fn(ack_request).await.unwrap().into_inner(); + assert!( + resp.result.unwrap().success.is_some(), + "Expected acknowledgement request to be successful" + ); + + let pending_before_ack = client + .pending_fn(tonic::Request::new(())) + .await + .unwrap() + .into_inner(); + assert_eq!( + pending_before_ack.result.unwrap().count, + 0, + "Expected pending messages to be 0 after ACK" + ); + + let partitions = client + .partitions_fn(tonic::Request::new(())) + .await + .unwrap() + .into_inner(); + assert_eq!( + partitions.result.unwrap().partitions, + vec![2], + "Expected number of partitions to be 2" + ); + + shutdown_tx + .send(()) + .expect("Sending shutdown signal to gRPC server"); + tokio::time::sleep(Duration::from_millis(50)).await; + assert!(task.is_finished(), "gRPC server is still running"); + Ok(()) + } }