From a1b52dc4bd8db4733cd1421cc6770d9f9e0201ff Mon Sep 17 00:00:00 2001 From: Sreekanth <prsreekanth920@gmail.com> Date: Thu, 16 Jan 2025 15:50:29 +0530 Subject: [PATCH] Move callback handler to tracker Signed-off-by: Sreekanth <prsreekanth920@gmail.com> --- rust/numaflow-core/src/mapper/map.rs | 33 +- rust/numaflow-core/src/monovertex.rs | 27 +- .../numaflow-core/src/monovertex/forwarder.rs | 20 +- rust/numaflow-core/src/pipeline.rs | 46 +-- .../src/pipeline/forwarder/sink_forwarder.rs | 12 +- .../pipeline/forwarder/source_forwarder.rs | 3 +- .../src/pipeline/isb/jetstream/reader.rs | 6 +- .../src/pipeline/isb/jetstream/writer.rs | 73 +--- rust/numaflow-core/src/sink.rs | 47 +-- rust/numaflow-core/src/source.rs | 6 +- rust/numaflow-core/src/tracker.rs | 353 +++++++++++++++--- rust/numaflow-core/src/transformer.rs | 12 +- rust/serving/src/app/callback.rs | 23 +- rust/serving/src/app/callback/state.rs | 2 +- .../src/app/callback/store/memstore.rs | 2 +- .../src/app/callback/store/redisstore.rs | 2 +- rust/serving/src/app/jetstream_proxy.rs | 9 +- rust/serving/src/app/tracker.rs | 4 +- rust/serving/src/callback.rs | 118 ++---- rust/serving/src/config.rs | 4 +- rust/serving/src/lib.rs | 2 +- rust/serving/src/source.rs | 3 + 22 files changed, 424 insertions(+), 383 deletions(-) diff --git a/rust/numaflow-core/src/mapper/map.rs b/rust/numaflow-core/src/mapper/map.rs index 6a9b40d56e..c9682a48b1 100644 --- a/rust/numaflow-core/src/mapper/map.rs +++ b/rust/numaflow-core/src/mapper/map.rs @@ -344,14 +344,7 @@ impl MapHandle { match receiver.await { Ok(Ok(mut mapped_messages)) => { // update the tracker with the number of messages sent and send the mapped messages - if let Err(e) = tracker_handle - .update( - read_msg.id.offset.clone(), - mapped_messages.len() as u32, - true, - ) - .await - { + if let Err(e) = tracker_handle.update_many(&mapped_messages, true).await { error_tx.send(e).await.expect("failed to send error"); return; } @@ -398,10 +391,7 @@ impl MapHandle { for receiver in receivers { match receiver.await { Ok(Ok(mut mapped_messages)) => { - let offset = mapped_messages.first().unwrap().id.offset.clone(); - tracker_handle - .update(offset.clone(), mapped_messages.len() as u32, true) - .await?; + tracker_handle.update_many(&mapped_messages, true).await?; for mapped_message in mapped_messages.drain(..) { output_tx .send(mapped_message) @@ -455,8 +445,7 @@ impl MapHandle { while let Some(result) = receiver.recv().await { match result { Ok(mapped_message) => { - let offset = mapped_message.id.offset.clone(); - if let Err(e) = tracker_handle.update(offset.clone(), 1, false).await { + if let Err(e) = tracker_handle.update(&mapped_message).await { error_tx.send(e).await.expect("failed to send error"); return; } @@ -475,7 +464,7 @@ impl MapHandle { } } - if let Err(e) = tracker_handle.update(read_msg.id.offset, 0, true).await { + if let Err(e) = tracker_handle.update_eof(read_msg.id.offset).await { error_tx.send(e).await.expect("failed to send error"); } }); @@ -530,7 +519,7 @@ mod tests { // wait for the server to start tokio::time::sleep(Duration::from_millis(100)).await; - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let client = MapClient::new(create_rpc_channel(sock_file).await?); let mapper = MapHandle::new( @@ -623,7 +612,7 @@ mod tests { // wait for the server to start tokio::time::sleep(Duration::from_millis(100)).await; - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let client = MapClient::new(create_rpc_channel(sock_file).await?); let mapper = MapHandle::new( MapMode::Unary, @@ -714,7 +703,7 @@ mod tests { // wait for the server to start tokio::time::sleep(Duration::from_millis(100)).await; - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let client = MapClient::new(create_rpc_channel(sock_file).await?); let mapper = MapHandle::new( MapMode::Unary, @@ -810,7 +799,7 @@ mod tests { // wait for the server to start tokio::time::sleep(Duration::from_millis(100)).await; - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let client = MapClient::new(create_rpc_channel(sock_file).await?); let mapper = MapHandle::new( @@ -923,7 +912,7 @@ mod tests { // wait for the server to start tokio::time::sleep(Duration::from_millis(100)).await; - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let client = MapClient::new(create_rpc_channel(sock_file).await?); let mapper = MapHandle::new( MapMode::Batch, @@ -1035,7 +1024,7 @@ mod tests { // wait for the server to start tokio::time::sleep(Duration::from_millis(100)).await; - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let client = MapClient::new(create_rpc_channel(sock_file).await?); let mapper = MapHandle::new( @@ -1134,7 +1123,7 @@ mod tests { tokio::time::sleep(Duration::from_millis(100)).await; let client = MapClient::new(create_rpc_channel(sock_file).await?); - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let mapper = MapHandle::new( MapMode::Stream, 500, diff --git a/rust/numaflow-core/src/monovertex.rs b/rust/numaflow-core/src/monovertex.rs index 6c9cca7b64..937963feac 100644 --- a/rust/numaflow-core/src/monovertex.rs +++ b/rust/numaflow-core/src/monovertex.rs @@ -26,7 +26,11 @@ pub(crate) async fn start_forwarder( cln_token: CancellationToken, config: &MonovertexConfig, ) -> error::Result<()> { - let tracker_handle = TrackerHandle::new(); + let callback_handler = config + .callback_config + .as_ref() + .map(|cb_cfg| CallbackHandler::new(config.name.clone(), cb_cfg.callback_concurrency)); + let tracker_handle = TrackerHandle::new(callback_handler); let (source, source_grpc_client) = create_components::create_source( config.batch_size, config.read_timeout, @@ -70,23 +74,7 @@ pub(crate) async fn start_forwarder( // FIXME: what to do with the handle shared::metrics::start_metrics_server(config.metrics_config.clone(), metrics_state).await; - let callback_handler = match config.callback_config { - Some(ref cb_cfg) => Some(CallbackHandler::new( - config.name.clone(), - cb_cfg.callback_concurrency, - )), - None => None, - }; - - start( - config.clone(), - source, - sink_writer, - transformer, - cln_token, - callback_handler, - ) - .await?; + start(config.clone(), source, sink_writer, transformer, cln_token).await?; Ok(()) } @@ -97,7 +85,6 @@ async fn start( sink: SinkWriter, transformer: Option<Transformer>, cln_token: CancellationToken, - callback_handler: Option<CallbackHandler>, ) -> error::Result<()> { // start the pending reader to publish pending metrics let pending_reader = shared::metrics::create_pending_reader( @@ -107,7 +94,7 @@ async fn start( .await; let _pending_reader_handle = pending_reader.start(is_mono_vertex()).await; - let mut forwarder_builder = ForwarderBuilder::new(source, sink, cln_token, callback_handler); + let mut forwarder_builder = ForwarderBuilder::new(source, sink, cln_token); // add transformer if exists if let Some(transformer_client) = transformer { diff --git a/rust/numaflow-core/src/monovertex/forwarder.rs b/rust/numaflow-core/src/monovertex/forwarder.rs index ea14a662b5..fb0be9b3eb 100644 --- a/rust/numaflow-core/src/monovertex/forwarder.rs +++ b/rust/numaflow-core/src/monovertex/forwarder.rs @@ -28,7 +28,6 @@ //! [Stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html //! [Actor Pattern]: https://ryhl.io/blog/actors-with-tokio/ -use serving::callback::CallbackHandler; use tokio_util::sync::CancellationToken; use crate::error; @@ -45,7 +44,6 @@ pub(crate) struct Forwarder { transformer: Option<Transformer>, sink_writer: SinkWriter, cln_token: CancellationToken, - callback_handler: Option<CallbackHandler>, } pub(crate) struct ForwarderBuilder { @@ -53,7 +51,6 @@ pub(crate) struct ForwarderBuilder { sink_writer: SinkWriter, cln_token: CancellationToken, transformer: Option<Transformer>, - callback_handler: Option<CallbackHandler>, } impl ForwarderBuilder { @@ -62,14 +59,12 @@ impl ForwarderBuilder { streaming_source: Source, streaming_sink: SinkWriter, cln_token: CancellationToken, - callback_handler: Option<CallbackHandler>, ) -> Self { Self { source: streaming_source, sink_writer: streaming_sink, cln_token, transformer: None, - callback_handler, } } @@ -87,7 +82,6 @@ impl ForwarderBuilder { sink_writer: self.sink_writer, transformer: self.transformer, cln_token: self.cln_token, - callback_handler: self.callback_handler, } } } @@ -108,11 +102,7 @@ impl Forwarder { let sink_writer_handle = self .sink_writer - .streaming_write( - transformed_messages_stream, - self.cln_token.clone(), - self.callback_handler.clone(), - ) + .streaming_write(transformed_messages_stream, self.cln_token.clone()) .await?; match tokio::try_join!( @@ -273,7 +263,7 @@ mod tests { .await .map_err(|e| panic!("failed to create source reader: {:?}", e)) .unwrap(); - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let source = Source::new( 5, SourceType::UserDefinedSource(src_read, src_ack, lag_reader), @@ -317,7 +307,7 @@ mod tests { .unwrap(); // create the forwarder with the source, transformer, and writer - let forwarder = ForwarderBuilder::new(source.clone(), sink_writer, cln_token.clone(), None) + let forwarder = ForwarderBuilder::new(source.clone(), sink_writer, cln_token.clone()) .transformer(transformer) .build(); @@ -372,7 +362,7 @@ mod tests { #[tokio::test] async fn test_flatmap_operation() { - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); // create the source which produces x number of messages let cln_token = CancellationToken::new(); @@ -447,7 +437,7 @@ mod tests { .unwrap(); // create the forwarder with the source, transformer, and writer - let forwarder = ForwarderBuilder::new(source.clone(), sink_writer, cln_token.clone(), None) + let forwarder = ForwarderBuilder::new(source.clone(), sink_writer, cln_token.clone()) .transformer(transformer) .build(); diff --git a/rust/numaflow-core/src/pipeline.rs b/rust/numaflow-core/src/pipeline.rs index b83864dd8f..9de727a0d5 100644 --- a/rust/numaflow-core/src/pipeline.rs +++ b/rust/numaflow-core/src/pipeline.rs @@ -51,23 +51,17 @@ async fn start_source_forwarder( config: PipelineConfig, source_config: SourceVtxConfig, ) -> Result<()> { - let tracker_handle = TrackerHandle::new(); + let callback_handler = config.callback_config.as_ref().map(|cb_cfg| { + CallbackHandler::new(config.vertex_name.clone(), cb_cfg.callback_concurrency) + }); + let tracker_handle = TrackerHandle::new(callback_handler); let js_context = create_js_context(config.js_client_config.clone()).await?; - let callback_handler = match config.callback_config { - Some(ref cb_cfg) => Some(CallbackHandler::new( - config.vertex_name.clone(), - cb_cfg.callback_concurrency, - )), - None => None, - }; - let buffer_writer = create_buffer_writer( &config, js_context.clone(), tracker_handle.clone(), cln_token.clone(), - callback_handler, ) .await; @@ -136,8 +130,12 @@ async fn start_map_forwarder( let mut mapper_grpc_client = None; let mut isb_lag_readers = vec![]; + let callback_handler = config.callback_config.as_ref().map(|cb_cfg| { + CallbackHandler::new(config.vertex_name.clone(), cb_cfg.callback_concurrency) + }); + for stream in reader_config.streams.clone() { - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(callback_handler.clone()); let buffer_reader = create_buffer_reader( stream, @@ -163,20 +161,11 @@ async fn start_map_forwarder( mapper_grpc_client = Some(mapper_rpc_client); } - let callback_handler = match config.callback_config { - Some(ref cb_cfg) => Some(CallbackHandler::new( - config.vertex_name.clone(), - cb_cfg.callback_concurrency, - )), - None => None, - }; - let buffer_writer = create_buffer_writer( &config, js_context.clone(), tracker_handle.clone(), cln_token.clone(), - callback_handler, ) .await; forwarder_components.push((buffer_reader, buffer_writer, mapper)); @@ -237,11 +226,15 @@ async fn start_sink_forwarder( .ok_or_else(|| error::Error::Config("No from vertex config found".to_string()))? .reader_config; + let callback_handler = config.callback_config.as_ref().map(|cb_cfg| { + CallbackHandler::new(config.vertex_name.clone(), cb_cfg.callback_concurrency) + }); + // Create sink writers and buffer readers for each stream let mut sink_writers = vec![]; let mut buffer_readers = vec![]; for stream in reader_config.streams.clone() { - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(callback_handler.clone()); let buffer_reader = create_buffer_reader( stream, @@ -284,14 +277,6 @@ async fn start_sink_forwarder( .await; } - let callback_handler = match config.callback_config { - Some(ref cb_cfg) => Some(CallbackHandler::new( - config.vertex_name.clone(), - cb_cfg.callback_concurrency, - )), - None => None, - }; - // Start a new forwarder for each buffer reader let mut forwarder_tasks = Vec::new(); for (buffer_reader, (sink_writer, _, _)) in buffer_readers.into_iter().zip(sink_writers) { @@ -300,7 +285,6 @@ async fn start_sink_forwarder( buffer_reader, sink_writer, cln_token.clone(), - callback_handler.clone(), ) .await; @@ -326,7 +310,6 @@ async fn create_buffer_writer( js_context: Context, tracker_handle: TrackerHandle, cln_token: CancellationToken, - callback_handler: Option<CallbackHandler>, ) -> JetstreamWriter { JetstreamWriter::new( config.to_vertex_config.clone(), @@ -334,7 +317,6 @@ async fn create_buffer_writer( config.paf_concurrency, tracker_handle, cln_token, - callback_handler, ) } diff --git a/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs b/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs index 4610235b49..1d560e94e5 100644 --- a/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs +++ b/rust/numaflow-core/src/pipeline/forwarder/sink_forwarder.rs @@ -1,4 +1,3 @@ -use serving::callback::CallbackHandler; use tokio_util::sync::CancellationToken; use crate::error::Error; @@ -12,7 +11,6 @@ pub(crate) struct SinkForwarder { jetstream_reader: JetstreamReader, sink_writer: SinkWriter, cln_token: CancellationToken, - callback_handler: Option<CallbackHandler>, } impl SinkForwarder { @@ -20,13 +18,11 @@ impl SinkForwarder { jetstream_reader: JetstreamReader, sink_writer: SinkWriter, cln_token: CancellationToken, - callback_handler: Option<CallbackHandler>, ) -> Self { Self { jetstream_reader, sink_writer, cln_token, - callback_handler, } } @@ -38,15 +34,9 @@ impl SinkForwarder { .streaming_read(reader_cancellation_token.clone()) .await?; - let callback_handler = self.callback_handler.clone(); - let sink_writer_handle = self .sink_writer - .streaming_write( - read_messages_stream, - self.cln_token.clone(), - callback_handler, - ) + .streaming_write(read_messages_stream, self.cln_token.clone()) .await?; // Join the reader and sink writer diff --git a/rust/numaflow-core/src/pipeline/forwarder/source_forwarder.rs b/rust/numaflow-core/src/pipeline/forwarder/source_forwarder.rs index ec717b8c8a..c164dafc8d 100644 --- a/rust/numaflow-core/src/pipeline/forwarder/source_forwarder.rs +++ b/rust/numaflow-core/src/pipeline/forwarder/source_forwarder.rs @@ -208,7 +208,7 @@ mod tests { #[cfg(feature = "nats-tests")] #[tokio::test] async fn test_source_forwarder() { - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); // create the source which produces x number of messages let cln_token = CancellationToken::new(); @@ -292,7 +292,6 @@ mod tests { 100, tracker_handle.clone(), cln_token.clone(), - None, ); // create a transformer diff --git a/rust/numaflow-core/src/pipeline/isb/jetstream/reader.rs b/rust/numaflow-core/src/pipeline/isb/jetstream/reader.rs index ca4c9e831d..97aa5f9864 100644 --- a/rust/numaflow-core/src/pipeline/isb/jetstream/reader.rs +++ b/rust/numaflow-core/src/pipeline/isb/jetstream/reader.rs @@ -166,7 +166,7 @@ impl JetstreamReader { // Insert the message into the tracker and wait for the ack to be sent back. let (ack_tx, ack_rx) = oneshot::channel(); - tracker_handle.insert(message_id.offset.clone(), ack_tx).await?; + tracker_handle.insert(&message, ack_tx).await?; tokio::spawn(Self::start_work_in_progress( jetstream_message, @@ -341,7 +341,7 @@ mod tests { 0, context.clone(), buf_reader_config, - TrackerHandle::new(), + TrackerHandle::new(None), 500, ) .await @@ -402,7 +402,7 @@ mod tests { // Create JetStream context let client = async_nats::connect(js_url).await.unwrap(); let context = jetstream::new(client); - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let stream_name = "test_ack"; // Delete stream if it exists diff --git a/rust/numaflow-core/src/pipeline/isb/jetstream/writer.rs b/rust/numaflow-core/src/pipeline/isb/jetstream/writer.rs index 7baba7d11e..063af6391a 100644 --- a/rust/numaflow-core/src/pipeline/isb/jetstream/writer.rs +++ b/rust/numaflow-core/src/pipeline/isb/jetstream/writer.rs @@ -10,7 +10,6 @@ use async_nats::jetstream::publish::PublishAck; use async_nats::jetstream::stream::RetentionPolicy::Limits; use async_nats::jetstream::Context; use bytes::{Bytes, BytesMut}; -use serving::callback::CallbackHandler; use tokio::sync::Semaphore; use tokio::task::JoinHandle; use tokio::time; @@ -45,7 +44,6 @@ pub(crate) struct JetstreamWriter { cancel_token: CancellationToken, tracker_handle: TrackerHandle, sem: Arc<Semaphore>, - callback_handler: Option<CallbackHandler>, } impl JetstreamWriter { @@ -57,7 +55,6 @@ impl JetstreamWriter { paf_concurrency: usize, tracker_handle: TrackerHandle, cancel_token: CancellationToken, - callback_handler: Option<CallbackHandler>, ) -> Self { let streams = config .iter() @@ -76,7 +73,6 @@ impl JetstreamWriter { cancel_token, tracker_handle, sem: Arc::new(Semaphore::new(paf_concurrency)), - callback_handler, }; // spawn a task for checking whether buffer is_full @@ -245,8 +241,7 @@ impl JetstreamWriter { continue; } - this.resolve_pafs(pafs, message, this.callback_handler.clone()) - .await?; + this.resolve_pafs(pafs, message).await?; processed_msgs_count += 1; if last_logged_at.elapsed().as_secs() >= 1 { @@ -343,7 +338,6 @@ impl JetstreamWriter { &self, pafs: Vec<((String, u16), PublishAckFuture)>, message: Message, - callback_handler: Option<CallbackHandler>, ) -> Result<()> { let start_time = Instant::now(); let permit = Arc::clone(&self.sem) @@ -375,9 +369,6 @@ impl JetstreamWriter { .delete(message.id.offset.clone()) .await .expect("Failed to delete offset from tracker"); - if let Err(e) = do_callback(callback_handler.as_ref(), &message).await { - tracing::error!(?e, "Failed to send callback request"); - } } Err(e) => { error!( @@ -404,11 +395,6 @@ impl JetstreamWriter { stream.clone(), Offset::Int(IntOffset::new(ack.sequence, stream.1)), )); - if let Err(e) = - do_callback(callback_handler.as_ref(), &message).await - { - tracing::error!(?e, "Failed to send callback request"); - } } Err(e) => { error!(?e, "Blocking write failed for stream {}", stream.0); @@ -476,25 +462,6 @@ impl JetstreamWriter { } } -async fn do_callback(callback_handler: Option<&CallbackHandler>, message: &Message) -> Result<()> { - let Some(callback_handler) = callback_handler else { - return Ok(()); - }; - - let metadata = message.metadata.as_ref().ok_or_else(|| { - Error::Source("Message does not contain previous vertex name in the metadata".to_owned()) - })?; - - callback_handler - .callback( - &message.headers, - &message.tags, - metadata.previous_vertex.clone(), - ) - .await - .map_err(|e| Error::Source(format!("Failed to send callback for message: {e:?}"))) -} - #[cfg(test)] mod tests { use std::collections::HashMap; @@ -515,7 +482,7 @@ mod tests { #[cfg(feature = "nats-tests")] #[tokio::test] async fn test_async_write() { - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let cln_token = CancellationToken::new(); let js_url = "localhost:4222"; // Create JetStream context @@ -559,7 +526,6 @@ mod tests { 100, tracker_handle, cln_token.clone(), - None, ); let message = Message { @@ -656,7 +622,7 @@ mod tests { #[cfg(feature = "nats-tests")] #[tokio::test] async fn test_write_with_cancellation() { - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let js_url = "localhost:4222"; // Create JetStream context let client = async_nats::connect(js_url).await.unwrap(); @@ -702,7 +668,6 @@ mod tests { 100, tracker_handle, cancel_token.clone(), - None, ); let mut result_receivers = Vec::new(); @@ -857,7 +822,7 @@ mod tests { #[cfg(feature = "nats-tests")] #[tokio::test] async fn test_check_stream_status() { - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let js_url = "localhost:4222"; // Create JetStream context let client = async_nats::connect(js_url).await.unwrap(); @@ -906,7 +871,6 @@ mod tests { 100, tracker_handle, cancel_token.clone(), - None, ); let mut js_writer = writer.clone(); @@ -956,7 +920,7 @@ mod tests { // Create JetStream context let client = async_nats::connect(js_url).await.unwrap(); let context = jetstream::new(client); - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let stream_name = "test_publish_messages"; // Delete stream if it exists @@ -997,7 +961,6 @@ mod tests { 100, tracker_handle.clone(), cln_token.clone(), - None, ); let (messages_tx, messages_rx) = tokio::sync::mpsc::channel(500); @@ -1019,10 +982,7 @@ mod tests { metadata: None, }; let (ack_tx, ack_rx) = tokio::sync::oneshot::channel(); - tracker_handle - .insert(message.id.offset.clone(), ack_tx) - .await - .unwrap(); + tracker_handle.insert(&message, ack_tx).await.unwrap(); ack_rxs.push(ack_rx); messages_tx.send(message).await.unwrap(); } @@ -1046,7 +1006,7 @@ mod tests { // Create JetStream context let client = async_nats::connect(js_url).await.unwrap(); let context = jetstream::new(client); - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let stream_name = "test_publish_cancellation"; // Delete stream if it exists @@ -1087,7 +1047,6 @@ mod tests { 100, tracker_handle.clone(), cancel_token.clone(), - None, ); let (tx, rx) = tokio::sync::mpsc::channel(500); @@ -1109,10 +1068,7 @@ mod tests { metadata: None, }; let (ack_tx, ack_rx) = tokio::sync::oneshot::channel(); - tracker_handle - .insert(message.id.offset.clone(), ack_tx) - .await - .unwrap(); + tracker_handle.insert(&message, ack_tx).await.unwrap(); ack_rxs.push(ack_rx); tx.send(message).await.unwrap(); } @@ -1137,10 +1093,7 @@ mod tests { metadata: None, }; let (ack_tx, ack_rx) = tokio::sync::oneshot::channel(); - tracker_handle - .insert("offset_101".to_string().into(), ack_tx) - .await - .unwrap(); + tracker_handle.insert(&message, ack_tx).await.unwrap(); ack_rxs.push(ack_rx); tx.send(message).await.unwrap(); drop(tx); @@ -1168,7 +1121,7 @@ mod tests { let js_url = "localhost:4222"; let client = async_nats::connect(js_url).await.unwrap(); let context = jetstream::new(client); - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let cln_token = CancellationToken::new(); let vertex1_streams = vec!["vertex1-0", "vertex1-1"]; @@ -1231,7 +1184,6 @@ mod tests { 100, tracker_handle.clone(), cln_token.clone(), - None, ); let (messages_tx, messages_rx) = tokio::sync::mpsc::channel(500); @@ -1252,10 +1204,7 @@ mod tests { metadata: None, }; let (ack_tx, ack_rx) = tokio::sync::oneshot::channel(); - tracker_handle - .insert(message.id.offset.clone(), ack_tx) - .await - .unwrap(); + tracker_handle.insert(&message, ack_tx).await.unwrap(); ack_rxs.push(ack_rx); messages_tx.send(message).await.unwrap(); } diff --git a/rust/numaflow-core/src/sink.rs b/rust/numaflow-core/src/sink.rs index 027f426082..fdbfbe21c5 100644 --- a/rust/numaflow-core/src/sink.rs +++ b/rust/numaflow-core/src/sink.rs @@ -4,7 +4,6 @@ use std::time::Duration; use numaflow_pb::clients::sink::sink_client::SinkClient; use numaflow_pb::clients::sink::sink_response; use numaflow_pb::clients::sink::Status::{Failure, Fallback, Success}; -use serving::callback::CallbackHandler; use tokio::sync::mpsc::Receiver; use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinHandle; @@ -256,7 +255,6 @@ impl SinkWriter { &self, messages_stream: ReceiverStream<Message>, cancellation_token: CancellationToken, - callback_handler: Option<CallbackHandler>, ) -> Result<JoinHandle<Result<()>>> { let handle: JoinHandle<Result<()>> = tokio::spawn({ let mut this = self.clone(); @@ -311,22 +309,6 @@ impl SinkWriter { } } - if let Some(ref callback_handler) = callback_handler { - for message in batch { - let metadata = message.metadata.ok_or_else(|| { - Error::Source(format!( - "Writing to Sink: message does not contain previous vertex name in the metadata" - )) - })?; - if let Err(e) = callback_handler - .callback(&message.headers, &message.tags, metadata.previous_vertex) - .await - { - tracing::error!(?e, "Failed to send callback for message"); - } - } - }; - // publish sink metrics if is_mono_vertex() { monovertex_metrics() @@ -769,7 +751,7 @@ mod tests { 10, Duration::from_secs(1), SinkClientType::Log, - TrackerHandle::new(), + TrackerHandle::new(None), ) .build() .await @@ -800,7 +782,7 @@ mod tests { #[tokio::test] async fn test_streaming_write() { - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let sink_writer = SinkWriterBuilder::new( 10, Duration::from_millis(100), @@ -833,16 +815,13 @@ mod tests { for msg in messages { let (ack_tx, ack_rx) = oneshot::channel(); ack_rxs.push(ack_rx); - tracker_handle - .insert(msg.id.offset.clone(), ack_tx) - .await - .unwrap(); + tracker_handle.insert(&msg, ack_tx).await.unwrap(); let _ = tx.send(msg).await; } drop(tx); let handle = sink_writer - .streaming_write(ReceiverStream::new(rx), CancellationToken::new(), None) + .streaming_write(ReceiverStream::new(rx), CancellationToken::new()) .await .unwrap(); @@ -856,7 +835,7 @@ mod tests { #[tokio::test] async fn test_streaming_write_error() { - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); // start the server let (_shutdown_tx, shutdown_rx) = oneshot::channel(); let tmp_dir = tempfile::TempDir::new().unwrap(); @@ -912,16 +891,13 @@ mod tests { for msg in messages { let (ack_tx, ack_rx) = oneshot::channel(); ack_rxs.push(ack_rx); - tracker_handle - .insert(msg.id.offset.clone(), ack_tx) - .await - .unwrap(); + tracker_handle.insert(&msg, ack_tx).await.unwrap(); let _ = tx.send(msg).await; } drop(tx); let cln_token = CancellationToken::new(); let handle = sink_writer - .streaming_write(ReceiverStream::new(rx), cln_token.clone(), None) + .streaming_write(ReceiverStream::new(rx), cln_token.clone()) .await .unwrap(); @@ -942,7 +918,7 @@ mod tests { #[tokio::test] async fn test_fallback_write() { - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); // start the server let (_shutdown_tx, shutdown_rx) = oneshot::channel(); @@ -999,17 +975,14 @@ mod tests { let mut ack_rxs = vec![]; for msg in messages { let (ack_tx, ack_rx) = oneshot::channel(); - tracker_handle - .insert(msg.id.offset.clone(), ack_tx) - .await - .unwrap(); + tracker_handle.insert(&msg, ack_tx).await.unwrap(); ack_rxs.push(ack_rx); let _ = tx.send(msg).await; } drop(tx); let cln_token = CancellationToken::new(); let handle = sink_writer - .streaming_write(ReceiverStream::new(rx), cln_token.clone(), None) + .streaming_write(ReceiverStream::new(rx), cln_token.clone()) .await .unwrap(); diff --git a/rust/numaflow-core/src/source.rs b/rust/numaflow-core/src/source.rs index d48c37ab63..54ca9f90cf 100644 --- a/rust/numaflow-core/src/source.rs +++ b/rust/numaflow-core/src/source.rs @@ -321,9 +321,7 @@ impl Source { let offset = message.offset.clone().expect("offset can never be none"); // insert the offset and the ack one shot in the tracker. - tracker_handle - .insert(message.id.offset.clone(), resp_ack_tx) - .await?; + tracker_handle.insert(&message, resp_ack_tx).await?; // store the ack one shot in the batch to invoke ack later. ack_batch.push((offset, resp_ack_rx)); @@ -546,7 +544,7 @@ mod tests { let source = Source::new( 5, SourceType::UserDefinedSource(src_read, src_ack, lag_reader), - TrackerHandle::new(), + TrackerHandle::new(None), true, ); diff --git a/rust/numaflow-core/src/tracker.rs b/rust/numaflow-core/src/tracker.rs index 6a568ea910..875bb6a1af 100644 --- a/rust/numaflow-core/src/tracker.rs +++ b/rust/numaflow-core/src/tracker.rs @@ -11,10 +11,12 @@ use std::collections::HashMap; use bytes::Bytes; +use serving::callback::CallbackHandler; +use serving::{DEFAULT_CALLBACK_URL_HEADER_KEY, DEFAULT_ID_HEADER}; use tokio::sync::{mpsc, oneshot}; use crate::error::Error; -use crate::message::ReadAck; +use crate::message::{Message, ReadAck}; use crate::Result; /// TrackerEntry represents the state of a tracked message. @@ -23,6 +25,7 @@ struct TrackerEntry { ack_send: oneshot::Sender<ReadAck>, count: u32, eof: bool, + callback_info: Option<CallbackInfo>, } /// ActorMessage represents the messages that can be sent to the Tracker actor. @@ -30,12 +33,20 @@ enum ActorMessage { Insert { offset: Bytes, ack_send: oneshot::Sender<ReadAck>, + callback_info: Option<CallbackInfo>, }, Update { offset: Bytes, - count: u32, + responses: Option<Vec<String>>, + }, + UpdateMany { + offset: Bytes, + responses: Vec<Option<Vec<String>>>, eof: bool, }, + UpdateEOF { + offset: Bytes, + }, Delete { offset: Bytes, }, @@ -54,6 +65,61 @@ enum ActorMessage { struct Tracker { entries: HashMap<Bytes, TrackerEntry>, receiver: mpsc::Receiver<ActorMessage>, + callback_handler: Option<CallbackHandler>, +} + +#[derive(Debug)] +struct CallbackInfo { + id: String, + callback_url: String, + from_vertex: String, + responses: Vec<Option<Vec<String>>>, +} + +impl TryFrom<&Message> for CallbackInfo { + type Error = Error; + + fn try_from(message: &Message) -> std::result::Result<Self, Self::Error> { + let callback_url = message + .headers + .get(DEFAULT_CALLBACK_URL_HEADER_KEY) + .ok_or_else(|| { + Error::Source(format!( + "{DEFAULT_CALLBACK_URL_HEADER_KEY} header is not present in the message headers", + )) + })? + .to_owned(); + let uuid = message + .headers + .get(DEFAULT_ID_HEADER) + .ok_or_else(|| { + Error::Source(format!( + "{DEFAULT_ID_HEADER} is not found in message headers", + )) + })? + .to_owned(); + + let from_vertex = message + .metadata + .as_ref() + .ok_or_else(|| Error::Source("Metadata field is empty in the message".into()))? + .previous_vertex + .clone(); + + // FIXME: empty message tags + let mut msg_tags = None; + if let Some(ref tags) = message.tags { + if !tags.is_empty() { + msg_tags = Some(tags.iter().cloned().collect()); + } + }; + Ok(CallbackInfo { + id: uuid, + callback_url, + from_vertex, + responses: vec![msg_tags], + }) + } } impl Drop for Tracker { @@ -70,10 +136,14 @@ impl Drop for Tracker { impl Tracker { /// Creates a new Tracker instance with the given receiver for actor messages. - fn new(receiver: mpsc::Receiver<ActorMessage>) -> Self { + fn new( + receiver: mpsc::Receiver<ActorMessage>, + callback_handler: impl Into<Option<CallbackHandler>>, + ) -> Self { Self { entries: HashMap::new(), receiver, + callback_handler: callback_handler.into(), } } @@ -90,11 +160,22 @@ impl Tracker { ActorMessage::Insert { offset, ack_send: respond_to, + callback_info, + } => { + self.handle_insert(offset, callback_info, respond_to); + } + ActorMessage::Update { offset, responses } => { + self.handle_update(offset, responses); + } + ActorMessage::UpdateMany { + offset, + responses, + eof, } => { - self.handle_insert(offset, respond_to); + self.handle_update_many(offset, responses, eof); } - ActorMessage::Update { offset, count, eof } => { - self.handle_update(offset, count, eof); + ActorMessage::UpdateEOF { offset } => { + self.handle_update_eof(offset); } ActorMessage::Delete { offset } => { self.handle_delete(offset); @@ -114,22 +195,73 @@ impl Tracker { } /// Inserts a new entry into the tracker with the given offset and ack sender. - fn handle_insert(&mut self, offset: Bytes, respond_to: oneshot::Sender<ReadAck>) { + fn handle_insert( + &mut self, + offset: Bytes, + callback_info: Option<CallbackInfo>, + respond_to: oneshot::Sender<ReadAck>, + ) { self.entries.insert( offset, TrackerEntry { ack_send: respond_to, count: 0, eof: true, + callback_info, }, ); } /// Updates an existing entry in the tracker with the number of expected messages and EOF status. - fn handle_update(&mut self, offset: Bytes, count: u32, eof: bool) { + fn handle_update(&mut self, offset: Bytes, responses: Option<Vec<String>>) { + if let Some(entry) = self.entries.get_mut(&offset) { + entry.count += 1; + entry + .callback_info + .as_mut() + .map(|cb| cb.responses.push(responses)); + // if the count is zero, we can send an ack immediately + // this is case where map stream will send eof true after + // receiving all the messages. + if entry.count == 0 { + let entry = self.entries.remove(&offset).unwrap(); + entry + .ack_send + .send(ReadAck::Ack) + .expect("Failed to send ack"); + } + } + } + + fn handle_update_eof(&mut self, offset: Bytes) { + if let Some(entry) = self.entries.get_mut(&offset) { + entry.eof = true; + // if the count is zero, we can send an ack immediately + // this is case where map stream will send eof true after + // receiving all the messages. + if entry.count == 0 { + let entry = self.entries.remove(&offset).unwrap(); + entry + .ack_send + .send(ReadAck::Ack) + .expect("Failed to send ack"); + } + } + } + + fn handle_update_many( + &mut self, + offset: Bytes, + responses: Vec<Option<Vec<String>>>, + eof: bool, + ) { if let Some(entry) = self.entries.get_mut(&offset) { - entry.count += count; + entry.count += responses.len() as u32; entry.eof = eof; + entry + .callback_info + .as_mut() + .map(|cb| cb.responses.extend(responses)); // if the count is zero, we can send an ack immediately // this is case where map stream will send eof true after // receiving all the messages. @@ -187,24 +319,68 @@ impl Tracker { #[derive(Clone)] pub(crate) struct TrackerHandle { sender: mpsc::Sender<ActorMessage>, + enable_callbacks: bool, } impl TrackerHandle { /// Creates a new TrackerHandle instance and spawns the Tracker. - pub(crate) fn new() -> Self { + pub(crate) fn new(callback_handler: Option<CallbackHandler>) -> Self { + let enable_callbacks = callback_handler.is_some(); let (sender, receiver) = mpsc::channel(100); - let tracker = Tracker::new(receiver); + let tracker = Tracker::new(receiver, callback_handler); tokio::spawn(tracker.run()); - Self { sender } + Self { + sender, + enable_callbacks, + } } /// Inserts a new message into the Tracker with the given offset and acknowledgment sender. pub(crate) async fn insert( &self, - offset: Bytes, + message: &Message, ack_send: oneshot::Sender<ReadAck>, ) -> Result<()> { - let message = ActorMessage::Insert { offset, ack_send }; + let offset = message.id.offset.clone(); + let mut callback_info = None; + if self.enable_callbacks { + callback_info = Some(message.try_into()?); + } + let message = ActorMessage::Insert { + offset, + ack_send, + callback_info, + }; + self.sender + .send(message) + .await + .map_err(|e| Error::Tracker(format!("{:?}", e)))?; + Ok(()) + } + + /// Updates an existing message in the Tracker with the given offset, count, and EOF status. + pub(crate) async fn update(&self, message: &Message) -> Result<()> { + let offset = message.id.offset.clone(); + let mut responses: Option<Vec<String>> = None; + if self.enable_callbacks { + // FIXME: empty message tags + if let Some(ref tags) = message.tags { + if !tags.is_empty() { + responses = Some(tags.iter().cloned().collect()); + } + }; + } + let message = ActorMessage::Update { offset, responses }; + self.sender + .send(message) + .await + .map_err(|e| Error::Tracker(format!("{:?}", e)))?; + Ok(()) + } + + /// Updates an existing message in the Tracker with the given offset, count, and EOF status. + pub(crate) async fn update_eof(&self, offset: Bytes) -> Result<()> { + let message = ActorMessage::UpdateEOF { offset }; self.sender .send(message) .await @@ -213,8 +389,29 @@ impl TrackerHandle { } /// Updates an existing message in the Tracker with the given offset, count, and EOF status. - pub(crate) async fn update(&self, offset: Bytes, count: u32, eof: bool) -> Result<()> { - let message = ActorMessage::Update { offset, count, eof }; + pub(crate) async fn update_many(&self, messages: &[Message], eof: bool) -> Result<()> { + if messages.is_empty() { + return Ok(()); + } + let offset = messages.first().unwrap().id.offset.clone(); + let mut responses: Vec<Option<Vec<String>>> = vec![]; + // if self.enable_callbacks { + // FIXME: empty message tags + for message in messages { + let mut response: Option<Vec<String>> = None; + if let Some(ref tags) = message.tags { + if !tags.is_empty() { + response = Some(tags.iter().cloned().collect()); + } + }; + responses.push(response); + } + // } + let message = ActorMessage::UpdateMany { + offset, + responses, + eof, + }; self.sender .send(message) .await @@ -268,27 +465,42 @@ impl TrackerHandle { #[cfg(test)] mod tests { + use std::sync::Arc; + use tokio::sync::oneshot; use tokio::time::{timeout, Duration}; + use crate::message::MessageID; + use super::*; #[tokio::test] async fn test_insert_update_delete() { - let handle = TrackerHandle::new(); + let handle = TrackerHandle::new(None); let (ack_send, ack_recv) = oneshot::channel(); + let offset = Bytes::from_static(b"offset1"); + let message = Message { + keys: Arc::from([]), + tags: None, + value: Bytes::from_static(b"test"), + offset: None, + event_time: Default::default(), + id: MessageID { + vertex_name: "in".into(), + offset: offset.clone(), + index: 1, + }, + headers: HashMap::new(), + metadata: None, + }; + // Insert a new message - handle - .insert("offset1".to_string().into(), ack_send) - .await - .unwrap(); + handle.insert(&message, ack_send).await.unwrap(); // Update the message - handle - .update("offset1".to_string().into(), 1, true) - .await - .unwrap(); + handle.update(&message).await.unwrap(); + handle.update_eof(offset).await.unwrap(); // Delete the message handle.delete("offset1".to_string().into()).await.unwrap(); @@ -302,25 +514,36 @@ mod tests { #[tokio::test] async fn test_update_with_multiple_deletes() { - let handle = TrackerHandle::new(); + let handle = TrackerHandle::new(None); let (ack_send, ack_recv) = oneshot::channel(); + let offset = Bytes::from_static(b"offset1"); + let message = Message { + keys: Arc::from([]), + tags: None, + value: Bytes::from_static(b"test"), + offset: None, + event_time: Default::default(), + id: MessageID { + vertex_name: "in".into(), + offset: offset.clone(), + index: 1, + }, + headers: HashMap::new(), + metadata: None, + }; + // Insert a new message - handle - .insert("offset1".to_string().into(), ack_send) - .await - .unwrap(); + handle.insert(&message, ack_send).await.unwrap(); + let messages: Vec<Message> = std::iter::repeat(message).take(3).collect(); // Update the message with a count of 3 - handle - .update("offset1".to_string().into(), 3, true) - .await - .unwrap(); + handle.update_many(&messages, true).await.unwrap(); // Delete the message three times - handle.delete("offset1".to_string().into()).await.unwrap(); - handle.delete("offset1".to_string().into()).await.unwrap(); - handle.delete("offset1".to_string().into()).await.unwrap(); + handle.delete(offset.clone()).await.unwrap(); + handle.delete(offset.clone()).await.unwrap(); + handle.delete(offset).await.unwrap(); // Verify that the message was deleted and ack was received after the third delete let result = timeout(Duration::from_secs(1), ack_recv).await.unwrap(); @@ -331,17 +554,30 @@ mod tests { #[tokio::test] async fn test_discard() { - let handle = TrackerHandle::new(); + let handle = TrackerHandle::new(None); let (ack_send, ack_recv) = oneshot::channel(); + let offset = Bytes::from_static(b"offset1"); + let message = Message { + keys: Arc::from([]), + tags: None, + value: Bytes::from_static(b"test"), + offset: None, + event_time: Default::default(), + id: MessageID { + vertex_name: "in".into(), + offset: offset.clone(), + index: 1, + }, + headers: HashMap::new(), + metadata: None, + }; + // Insert a new message - handle - .insert("offset1".to_string().into(), ack_send) - .await - .unwrap(); + handle.insert(&message, ack_send).await.unwrap(); // Discard the message - handle.discard("offset1".to_string().into()).await.unwrap(); + handle.discard(offset).await.unwrap(); // Verify that the message was discarded and nak was received let result = timeout(Duration::from_secs(1), ack_recv).await.unwrap(); @@ -352,23 +588,34 @@ mod tests { #[tokio::test] async fn test_discard_after_update_with_higher_count() { - let handle = TrackerHandle::new(); + let handle = TrackerHandle::new(None); let (ack_send, ack_recv) = oneshot::channel(); + let offset = Bytes::from_static(b"offset1"); + let message = Message { + keys: Arc::from([]), + tags: None, + value: Bytes::from_static(b"test"), + offset: None, + event_time: Default::default(), + id: MessageID { + vertex_name: "in".into(), + offset: offset.clone(), + index: 1, + }, + headers: HashMap::new(), + metadata: None, + }; + // Insert a new message - handle - .insert("offset1".to_string().into(), ack_send) - .await - .unwrap(); + handle.insert(&message, ack_send).await.unwrap(); + let messages: Vec<Message> = std::iter::repeat(message).take(3).collect(); // Update the message with a count of 3 - handle - .update("offset1".to_string().into(), 3, false) - .await - .unwrap(); + handle.update_many(&messages, false).await.unwrap(); // Discard the message - handle.discard("offset1".to_string().into()).await.unwrap(); + handle.discard(offset).await.unwrap(); // Verify that the message was discarded and nak was received let result = timeout(Duration::from_secs(1), ack_recv).await.unwrap(); diff --git a/rust/numaflow-core/src/transformer.rs b/rust/numaflow-core/src/transformer.rs index a44e838c0d..21bdc19c5b 100644 --- a/rust/numaflow-core/src/transformer.rs +++ b/rust/numaflow-core/src/transformer.rs @@ -138,11 +138,7 @@ impl Transformer { match receiver.await { Ok(Ok(mut transformed_messages)) => { if let Err(e) = tracker_handle - .update( - read_msg.id.offset.clone(), - transformed_messages.len() as u32, - true, - ) + .update_many(&transformed_messages, true) .await { let _ = error_tx.send(e).await; @@ -278,7 +274,7 @@ mod tests { // wait for the server to start tokio::time::sleep(Duration::from_millis(100)).await; - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let client = SourceTransformClient::new(create_rpc_channel(sock_file).await?); let transformer = Transformer::new(500, 10, client, tracker_handle.clone()).await?; @@ -355,7 +351,7 @@ mod tests { // wait for the server to start tokio::time::sleep(Duration::from_millis(100)).await; - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let client = SourceTransformClient::new(create_rpc_channel(sock_file).await?); let transformer = Transformer::new(500, 10, client, tracker_handle.clone()).await?; @@ -441,7 +437,7 @@ mod tests { // wait for the server to start tokio::time::sleep(Duration::from_millis(100)).await; - let tracker_handle = TrackerHandle::new(); + let tracker_handle = TrackerHandle::new(None); let client = SourceTransformClient::new(create_rpc_channel(sock_file).await?); let transformer = Transformer::new(500, 10, client, tracker_handle.clone()).await?; diff --git a/rust/serving/src/app/callback.rs b/rust/serving/src/app/callback.rs index d5708a4e62..7d8de8815e 100644 --- a/rust/serving/src/app/callback.rs +++ b/rust/serving/src/app/callback.rs @@ -1,9 +1,9 @@ use axum::{body::Bytes, extract::State, http::HeaderMap, routing, Json, Router}; -use serde::{Deserialize, Serialize}; use tracing::error; use self::store::Store; use crate::app::response::ApiError; +use crate::callback::Callback; /// in-memory state store including connection tracking pub(crate) mod state; @@ -12,26 +12,6 @@ use state::State as CallbackState; /// store for storing the state pub(crate) mod store; -/// As message passes through each component (map, transformer, sink, etc.). it emits a beacon via callback -/// to inform that message has been processed by this component. -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct Callback { - pub(crate) id: String, - pub(crate) vertex: String, - pub(crate) cb_time: u64, - pub(crate) from_vertex: String, - /// Due to flat-map operation, we can have 0 or more responses. - pub(crate) responses: Vec<Response>, -} - -/// It contains details about the `To` vertex via tags (conditional forwarding). -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct Response { - /// If tags is None, the message is forwarded to all vertices, if len(Vec) == 0, it means that - /// the message has been dropped. - pub(crate) tags: Option<Vec<String>>, -} - #[derive(Clone)] struct CallbackAppState<T: Clone> { tid_header: String, @@ -106,6 +86,7 @@ mod tests { use crate::app::callback::state::State as CallbackState; use crate::app::callback::store::memstore::InMemoryStore; use crate::app::tracker::MessageGraph; + use crate::callback::Response; use crate::pipeline::PipelineDCG; const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; diff --git a/rust/serving/src/app/callback/state.rs b/rust/serving/src/app/callback/state.rs index 419a40f60f..2354899aa9 100644 --- a/rust/serving/src/app/callback/state.rs +++ b/rust/serving/src/app/callback/state.rs @@ -228,7 +228,7 @@ where mod tests { use super::*; use crate::app::callback::store::memstore::InMemoryStore; - use crate::app::callback::Response; + use crate::callback::Response; use crate::pipeline::PipelineDCG; use axum::body::Bytes; diff --git a/rust/serving/src/app/callback/store/memstore.rs b/rust/serving/src/app/callback/store/memstore.rs index 9e470a7c42..0e7135b8e2 100644 --- a/rust/serving/src/app/callback/store/memstore.rs +++ b/rust/serving/src/app/callback/store/memstore.rs @@ -98,7 +98,7 @@ mod tests { use super::*; use crate::app::callback::store::{PayloadToSave, Store}; - use crate::app::callback::{Callback, Response}; + use crate::callback::{Callback, Response}; #[tokio::test] async fn test_save_and_retrieve_callbacks() { diff --git a/rust/serving/src/app/callback/store/redisstore.rs b/rust/serving/src/app/callback/store/redisstore.rs index d4fe501cdb..490b7be940 100644 --- a/rust/serving/src/app/callback/store/redisstore.rs +++ b/rust/serving/src/app/callback/store/redisstore.rs @@ -203,7 +203,7 @@ impl super::Store for RedisConnection { mod tests { use super::*; use crate::app::callback::store::LocalStore; - use crate::app::callback::Response; + use crate::callback::Response; use axum::body::Bytes; use redis::AsyncCommands; diff --git a/rust/serving/src/app/jetstream_proxy.rs b/rust/serving/src/app/jetstream_proxy.rs index d4302a3937..ad4db52ca6 100644 --- a/rust/serving/src/app/jetstream_proxy.rs +++ b/rust/serving/src/app/jetstream_proxy.rs @@ -260,12 +260,11 @@ mod tests { use tower::ServiceExt; use super::*; - use crate::app::callback; use crate::app::callback::state::State as CallbackState; use crate::app::callback::store::memstore::InMemoryStore; use crate::app::callback::store::PayloadToSave; - use crate::app::callback::Callback; use crate::app::tracker::MessageGraph; + use crate::callback::{Callback, Response}; use crate::config::DEFAULT_ID_HEADER; use crate::pipeline::PipelineDCG; use crate::{Error, Settings}; @@ -361,21 +360,21 @@ mod tests { vertex: "in".to_string(), cb_time: 12345, from_vertex: "in".to_string(), - responses: vec![callback::Response { tags: None }], + responses: vec![Response { tags: None }], }, Callback { id: id.to_string(), vertex: "cat".to_string(), cb_time: 12345, from_vertex: "in".to_string(), - responses: vec![callback::Response { tags: None }], + responses: vec![Response { tags: None }], }, Callback { id: id.to_string(), vertex: "out".to_string(), cb_time: 12345, from_vertex: "cat".to_string(), - responses: vec![callback::Response { tags: None }], + responses: vec![Response { tags: None }], }, ] } diff --git a/rust/serving/src/app/tracker.rs b/rust/serving/src/app/tracker.rs index 5f3b24db7b..e1b23d35d1 100644 --- a/rust/serving/src/app/tracker.rs +++ b/rust/serving/src/app/tracker.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use serde::{Deserialize, Serialize}; -use crate::app::callback::Callback; +use crate::callback::Callback; use crate::pipeline::{Edge, OperatorType, PipelineDCG}; use crate::Error; @@ -250,7 +250,7 @@ impl MessageGraph { #[cfg(test)] mod tests { use super::*; - use crate::app::callback::Response; + use crate::callback::Response; use crate::pipeline::{Conditions, Tag, Vertex}; #[test] diff --git a/rust/serving/src/callback.rs b/rust/serving/src/callback.rs index 75d5522b8e..4c6db654c0 100644 --- a/rust/serving/src/callback.rs +++ b/rust/serving/src/callback.rs @@ -1,29 +1,32 @@ use std::{ - collections::HashMap, sync::Arc, time::{Duration, SystemTime, UNIX_EPOCH}, }; use reqwest::Client; +use serde::{Deserialize, Serialize}; use tokio::sync::Semaphore; -use crate::config::DEFAULT_CALLBACK_URL_HEADER_KEY; use crate::config::DEFAULT_ID_HEADER; -use crate::Error; -/// The data to be sent in the POST request -#[derive(serde::Serialize)] -struct CallbackPayload { - /// Unique identifier of the message - id: String, - /// Name of the vertex - vertex: String, - /// Time when the callback was made - cb_time: u64, - /// Name of the vertex from which the message was sent - from_vertex: String, - /// List of tags associated with the message - tags: Option<Vec<String>>, +/// As message passes through each component (map, transformer, sink, etc.). it emits a beacon via callback +/// to inform that message has been processed by this component. +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Callback { + pub(crate) id: String, + pub(crate) vertex: String, + pub(crate) cb_time: u64, + pub(crate) from_vertex: String, + /// Due to flat-map operation, we can have 0 or more responses. + pub(crate) responses: Vec<Response>, +} + +/// It contains details about the `To` vertex via tags (conditional forwarding). +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Response { + /// If tags is None, the message is forwarded to all vertices, if len(Vec) == 0, it means that + /// the message has been dropped. + pub(crate) tags: Option<Vec<String>>, } #[derive(Clone)] @@ -52,43 +55,26 @@ impl CallbackHandler { pub async fn callback( &self, - message_headers: &HashMap<String, String>, - message_tags: &Option<Arc<[String]>>, + id: String, + callback_url: String, previous_vertex: String, + responses: Vec<Option<Vec<String>>>, ) -> crate::Result<()> { - let callback_url = message_headers - .get(DEFAULT_CALLBACK_URL_HEADER_KEY) - .ok_or_else(|| { - Error::Source(format!( - "{DEFAULT_CALLBACK_URL_HEADER_KEY} header is not present in the message headers", - )) - })? - .to_owned(); - let uuid = message_headers - .get(DEFAULT_ID_HEADER) - .ok_or_else(|| { - Error::Source(format!( - "{DEFAULT_ID_HEADER} is not found in message headers", - )) - })? - .to_owned(); let cb_time = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("System time is older than Unix epoch time") .as_millis() as u64; - let mut msg_tags = None; - if let Some(tags) = message_tags { - if !tags.is_empty() { - msg_tags = Some(tags.iter().cloned().collect()); - } - }; + let responses = responses + .into_iter() + .map(|tags| Response { tags }) + .collect(); - let callback_payload = CallbackPayload { + let callback_payload = Callback { vertex: self.vertex_name.clone(), - id: uuid.clone(), + id: id.clone(), cb_time, - tags: msg_tags, + responses, from_vertex: previous_vertex, }; @@ -104,7 +90,7 @@ impl CallbackHandler { for i in 1..=TOTAL_ATTEMPTS { let resp = client .post(&callback_url) - .header(DEFAULT_ID_HEADER, uuid.clone()) + .header(DEFAULT_ID_HEADER, id.clone()) .json(&[&callback_payload]) .send() .await; @@ -175,17 +161,17 @@ mod tests { use crate::app::callback::store::memstore::InMemoryStore; use crate::app::start_main_server; use crate::app::tracker::MessageGraph; - use crate::callback::{CallbackHandler, DEFAULT_CALLBACK_URL_HEADER_KEY, DEFAULT_ID_HEADER}; + use crate::callback::CallbackHandler; use crate::config::generate_certs; use crate::pipeline::PipelineDCG; use crate::{AppState, Settings}; use axum_server::tls_rustls::RustlsConfig; - use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::sync::mpsc; type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>; + #[tokio::test] async fn test_callback() -> Result<()> { // Set up the CryptoProvider (controls core cryptography used by rustls) for the process @@ -241,22 +227,16 @@ mod tests { assert!(server_ready, "Server is not ready"); let callback_handler = CallbackHandler::new("test".into(), 10); - let message_headers: HashMap<String, String> = [ - ( - DEFAULT_CALLBACK_URL_HEADER_KEY, - "https://localhost:3003/v1/process/callback", - ), - (DEFAULT_ID_HEADER, ID_VALUE), - ] - .into_iter() - .map(|(k, v)| (k.into(), v.into())) - .collect(); - let tags = Arc::from(vec!["tag1".to_owned()]); // On the server, this fails with SubGraphInvalidInput("Invalid callback: 1234, vertex: in") // We get 200 OK response from the server, since we already registered this request ID in the store. callback_handler - .callback(&message_headers, &Some(tags), "in".into()) + .callback( + ID_VALUE.into(), + "https://localhost:3003/v1/process/callback".into(), + "in".into(), + vec![], + ) .await?; let mut data = None; for _ in 0..10 { @@ -273,26 +253,4 @@ mod tests { server_handle.abort(); Ok(()) } - - #[tokio::test] - async fn test_callback_missing_headers() -> Result<()> { - let callback_handler = CallbackHandler::new("test".into(), 10); - let message_headers: HashMap<String, String> = HashMap::new(); - let result = callback_handler - .callback(&message_headers, &None, "in".into()) - .await; - assert!(result.is_err()); - - let mut message_headers: HashMap<String, String> = HashMap::new(); - message_headers.insert( - DEFAULT_CALLBACK_URL_HEADER_KEY.into(), - "https://localhost:3003/v1/process/callback".into(), - ); - let result = callback_handler - .callback(&message_headers, &None, "in".into()) - .await; - assert!(result.is_err()); - - Ok(()) - } } diff --git a/rust/serving/src/config.rs b/rust/serving/src/config.rs index 116b7c2579..d5e93ab6f2 100644 --- a/rust/serving/src/config.rs +++ b/rust/serving/src/config.rs @@ -18,8 +18,8 @@ const ENV_NUMAFLOW_SERVING_APP_PORT: &str = "NUMAFLOW_SERVING_APP_LISTEN_PORT"; const ENV_NUMAFLOW_SERVING_AUTH_TOKEN: &str = "NUMAFLOW_SERVING_AUTH_TOKEN"; const ENV_MIN_PIPELINE_SPEC: &str = "NUMAFLOW_SERVING_MIN_PIPELINE_SPEC"; -pub(crate) const DEFAULT_ID_HEADER: &str = "X-Numaflow-Id"; -pub(crate) const DEFAULT_CALLBACK_URL_HEADER_KEY: &str = "X-Numaflow-Callback-Url"; +pub const DEFAULT_ID_HEADER: &str = "X-Numaflow-Id"; +pub const DEFAULT_CALLBACK_URL_HEADER_KEY: &str = "X-Numaflow-Callback-Url"; pub fn generate_certs() -> std::result::Result<(Certificate, KeyPair), String> { let CertifiedKey { cert, key_pair } = generate_simple_self_signed(vec!["localhost".into()]) diff --git a/rust/serving/src/lib.rs b/rust/serving/src/lib.rs index 1292fe64e2..f49d0c9164 100644 --- a/rust/serving/src/lib.rs +++ b/rust/serving/src/lib.rs @@ -15,7 +15,7 @@ use crate::metrics::start_https_metrics_server; mod app; mod config; -pub use config::Settings; +pub use {config::Settings, config::DEFAULT_CALLBACK_URL_HEADER_KEY, config::DEFAULT_ID_HEADER}; mod consts; mod error; diff --git a/rust/serving/src/source.rs b/rust/serving/src/source.rs index 25d36c9ebb..4c1d7ae35c 100644 --- a/rust/serving/src/source.rs +++ b/rust/serving/src/source.rs @@ -247,6 +247,9 @@ mod tests { type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>; #[tokio::test] async fn test_serving_source() -> Result<()> { + // Setup the CryptoProvider (controls core cryptography used by rustls) for the process + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); + let settings = Arc::new(Settings::default()); let serving_source = ServingSource::new(Arc::clone(&settings), 10, Duration::from_millis(1), 0).await?;