Skip to content

Commit

Permalink
Move callback handler to tracker
Browse files Browse the repository at this point in the history
Signed-off-by: Sreekanth <[email protected]>
  • Loading branch information
BulkBeing committed Jan 16, 2025
1 parent abf5b89 commit a1b52dc
Show file tree
Hide file tree
Showing 22 changed files with 424 additions and 383 deletions.
33 changes: 11 additions & 22 deletions rust/numaflow-core/src/mapper/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,7 @@ impl MapHandle {
match receiver.await {
Ok(Ok(mut mapped_messages)) => {
// update the tracker with the number of messages sent and send the mapped messages
if let Err(e) = tracker_handle
.update(
read_msg.id.offset.clone(),
mapped_messages.len() as u32,
true,
)
.await
{
if let Err(e) = tracker_handle.update_many(&mapped_messages, true).await {
error_tx.send(e).await.expect("failed to send error");
return;
}
Expand Down Expand Up @@ -398,10 +391,7 @@ impl MapHandle {
for receiver in receivers {
match receiver.await {
Ok(Ok(mut mapped_messages)) => {
let offset = mapped_messages.first().unwrap().id.offset.clone();
tracker_handle
.update(offset.clone(), mapped_messages.len() as u32, true)
.await?;
tracker_handle.update_many(&mapped_messages, true).await?;
for mapped_message in mapped_messages.drain(..) {
output_tx
.send(mapped_message)
Expand Down Expand Up @@ -455,8 +445,7 @@ impl MapHandle {
while let Some(result) = receiver.recv().await {
match result {
Ok(mapped_message) => {
let offset = mapped_message.id.offset.clone();
if let Err(e) = tracker_handle.update(offset.clone(), 1, false).await {
if let Err(e) = tracker_handle.update(&mapped_message).await {
error_tx.send(e).await.expect("failed to send error");
return;
}
Expand All @@ -475,7 +464,7 @@ impl MapHandle {
}
}

if let Err(e) = tracker_handle.update(read_msg.id.offset, 0, true).await {
if let Err(e) = tracker_handle.update_eof(read_msg.id.offset).await {
error_tx.send(e).await.expect("failed to send error");
}
});
Expand Down Expand Up @@ -530,7 +519,7 @@ mod tests {

// wait for the server to start
tokio::time::sleep(Duration::from_millis(100)).await;
let tracker_handle = TrackerHandle::new();
let tracker_handle = TrackerHandle::new(None);

let client = MapClient::new(create_rpc_channel(sock_file).await?);
let mapper = MapHandle::new(
Expand Down Expand Up @@ -623,7 +612,7 @@ mod tests {
// wait for the server to start
tokio::time::sleep(Duration::from_millis(100)).await;

let tracker_handle = TrackerHandle::new();
let tracker_handle = TrackerHandle::new(None);
let client = MapClient::new(create_rpc_channel(sock_file).await?);
let mapper = MapHandle::new(
MapMode::Unary,
Expand Down Expand Up @@ -714,7 +703,7 @@ mod tests {
// wait for the server to start
tokio::time::sleep(Duration::from_millis(100)).await;

let tracker_handle = TrackerHandle::new();
let tracker_handle = TrackerHandle::new(None);
let client = MapClient::new(create_rpc_channel(sock_file).await?);
let mapper = MapHandle::new(
MapMode::Unary,
Expand Down Expand Up @@ -810,7 +799,7 @@ mod tests {

// wait for the server to start
tokio::time::sleep(Duration::from_millis(100)).await;
let tracker_handle = TrackerHandle::new();
let tracker_handle = TrackerHandle::new(None);

let client = MapClient::new(create_rpc_channel(sock_file).await?);
let mapper = MapHandle::new(
Expand Down Expand Up @@ -923,7 +912,7 @@ mod tests {
// wait for the server to start
tokio::time::sleep(Duration::from_millis(100)).await;

let tracker_handle = TrackerHandle::new();
let tracker_handle = TrackerHandle::new(None);
let client = MapClient::new(create_rpc_channel(sock_file).await?);
let mapper = MapHandle::new(
MapMode::Batch,
Expand Down Expand Up @@ -1035,7 +1024,7 @@ mod tests {

// wait for the server to start
tokio::time::sleep(Duration::from_millis(100)).await;
let tracker_handle = TrackerHandle::new();
let tracker_handle = TrackerHandle::new(None);

let client = MapClient::new(create_rpc_channel(sock_file).await?);
let mapper = MapHandle::new(
Expand Down Expand Up @@ -1134,7 +1123,7 @@ mod tests {
tokio::time::sleep(Duration::from_millis(100)).await;

let client = MapClient::new(create_rpc_channel(sock_file).await?);
let tracker_handle = TrackerHandle::new();
let tracker_handle = TrackerHandle::new(None);
let mapper = MapHandle::new(
MapMode::Stream,
500,
Expand Down
27 changes: 7 additions & 20 deletions rust/numaflow-core/src/monovertex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ pub(crate) async fn start_forwarder(
cln_token: CancellationToken,
config: &MonovertexConfig,
) -> error::Result<()> {
let tracker_handle = TrackerHandle::new();
let callback_handler = config
.callback_config
.as_ref()
.map(|cb_cfg| CallbackHandler::new(config.name.clone(), cb_cfg.callback_concurrency));
let tracker_handle = TrackerHandle::new(callback_handler);
let (source, source_grpc_client) = create_components::create_source(
config.batch_size,
config.read_timeout,
Expand Down Expand Up @@ -70,23 +74,7 @@ pub(crate) async fn start_forwarder(
// FIXME: what to do with the handle
shared::metrics::start_metrics_server(config.metrics_config.clone(), metrics_state).await;

let callback_handler = match config.callback_config {
Some(ref cb_cfg) => Some(CallbackHandler::new(
config.name.clone(),
cb_cfg.callback_concurrency,
)),
None => None,
};

start(
config.clone(),
source,
sink_writer,
transformer,
cln_token,
callback_handler,
)
.await?;
start(config.clone(), source, sink_writer, transformer, cln_token).await?;

Ok(())
}
Expand All @@ -97,7 +85,6 @@ async fn start(
sink: SinkWriter,
transformer: Option<Transformer>,
cln_token: CancellationToken,
callback_handler: Option<CallbackHandler>,
) -> error::Result<()> {
// start the pending reader to publish pending metrics
let pending_reader = shared::metrics::create_pending_reader(
Expand All @@ -107,7 +94,7 @@ async fn start(
.await;
let _pending_reader_handle = pending_reader.start(is_mono_vertex()).await;

let mut forwarder_builder = ForwarderBuilder::new(source, sink, cln_token, callback_handler);
let mut forwarder_builder = ForwarderBuilder::new(source, sink, cln_token);

// add transformer if exists
if let Some(transformer_client) = transformer {
Expand Down
20 changes: 5 additions & 15 deletions rust/numaflow-core/src/monovertex/forwarder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
//! [Stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html
//! [Actor Pattern]: https://ryhl.io/blog/actors-with-tokio/
use serving::callback::CallbackHandler;
use tokio_util::sync::CancellationToken;

use crate::error;
Expand All @@ -45,15 +44,13 @@ pub(crate) struct Forwarder {
transformer: Option<Transformer>,
sink_writer: SinkWriter,
cln_token: CancellationToken,
callback_handler: Option<CallbackHandler>,
}

pub(crate) struct ForwarderBuilder {
source: Source,
sink_writer: SinkWriter,
cln_token: CancellationToken,
transformer: Option<Transformer>,
callback_handler: Option<CallbackHandler>,
}

impl ForwarderBuilder {
Expand All @@ -62,14 +59,12 @@ impl ForwarderBuilder {
streaming_source: Source,
streaming_sink: SinkWriter,
cln_token: CancellationToken,
callback_handler: Option<CallbackHandler>,
) -> Self {
Self {
source: streaming_source,
sink_writer: streaming_sink,
cln_token,
transformer: None,
callback_handler,
}
}

Expand All @@ -87,7 +82,6 @@ impl ForwarderBuilder {
sink_writer: self.sink_writer,
transformer: self.transformer,
cln_token: self.cln_token,
callback_handler: self.callback_handler,
}
}
}
Expand All @@ -108,11 +102,7 @@ impl Forwarder {

let sink_writer_handle = self
.sink_writer
.streaming_write(
transformed_messages_stream,
self.cln_token.clone(),
self.callback_handler.clone(),
)
.streaming_write(transformed_messages_stream, self.cln_token.clone())
.await?;

match tokio::try_join!(
Expand Down Expand Up @@ -273,7 +263,7 @@ mod tests {
.await
.map_err(|e| panic!("failed to create source reader: {:?}", e))
.unwrap();
let tracker_handle = TrackerHandle::new();
let tracker_handle = TrackerHandle::new(None);
let source = Source::new(
5,
SourceType::UserDefinedSource(src_read, src_ack, lag_reader),
Expand Down Expand Up @@ -317,7 +307,7 @@ mod tests {
.unwrap();

// create the forwarder with the source, transformer, and writer
let forwarder = ForwarderBuilder::new(source.clone(), sink_writer, cln_token.clone(), None)
let forwarder = ForwarderBuilder::new(source.clone(), sink_writer, cln_token.clone())
.transformer(transformer)
.build();

Expand Down Expand Up @@ -372,7 +362,7 @@ mod tests {

#[tokio::test]
async fn test_flatmap_operation() {
let tracker_handle = TrackerHandle::new();
let tracker_handle = TrackerHandle::new(None);
// create the source which produces x number of messages
let cln_token = CancellationToken::new();

Expand Down Expand Up @@ -447,7 +437,7 @@ mod tests {
.unwrap();

// create the forwarder with the source, transformer, and writer
let forwarder = ForwarderBuilder::new(source.clone(), sink_writer, cln_token.clone(), None)
let forwarder = ForwarderBuilder::new(source.clone(), sink_writer, cln_token.clone())
.transformer(transformer)
.build();

Expand Down
46 changes: 14 additions & 32 deletions rust/numaflow-core/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,17 @@ async fn start_source_forwarder(
config: PipelineConfig,
source_config: SourceVtxConfig,
) -> Result<()> {
let tracker_handle = TrackerHandle::new();
let callback_handler = config.callback_config.as_ref().map(|cb_cfg| {
CallbackHandler::new(config.vertex_name.clone(), cb_cfg.callback_concurrency)
});
let tracker_handle = TrackerHandle::new(callback_handler);
let js_context = create_js_context(config.js_client_config.clone()).await?;

let callback_handler = match config.callback_config {
Some(ref cb_cfg) => Some(CallbackHandler::new(
config.vertex_name.clone(),
cb_cfg.callback_concurrency,
)),
None => None,
};

let buffer_writer = create_buffer_writer(
&config,
js_context.clone(),
tracker_handle.clone(),
cln_token.clone(),
callback_handler,
)
.await;

Expand Down Expand Up @@ -136,8 +130,12 @@ async fn start_map_forwarder(
let mut mapper_grpc_client = None;
let mut isb_lag_readers = vec![];

let callback_handler = config.callback_config.as_ref().map(|cb_cfg| {
CallbackHandler::new(config.vertex_name.clone(), cb_cfg.callback_concurrency)
});

for stream in reader_config.streams.clone() {
let tracker_handle = TrackerHandle::new();
let tracker_handle = TrackerHandle::new(callback_handler.clone());

let buffer_reader = create_buffer_reader(
stream,
Expand All @@ -163,20 +161,11 @@ async fn start_map_forwarder(
mapper_grpc_client = Some(mapper_rpc_client);
}

let callback_handler = match config.callback_config {
Some(ref cb_cfg) => Some(CallbackHandler::new(
config.vertex_name.clone(),
cb_cfg.callback_concurrency,
)),
None => None,
};

let buffer_writer = create_buffer_writer(
&config,
js_context.clone(),
tracker_handle.clone(),
cln_token.clone(),
callback_handler,
)
.await;
forwarder_components.push((buffer_reader, buffer_writer, mapper));
Expand Down Expand Up @@ -237,11 +226,15 @@ async fn start_sink_forwarder(
.ok_or_else(|| error::Error::Config("No from vertex config found".to_string()))?
.reader_config;

let callback_handler = config.callback_config.as_ref().map(|cb_cfg| {
CallbackHandler::new(config.vertex_name.clone(), cb_cfg.callback_concurrency)
});

// Create sink writers and buffer readers for each stream
let mut sink_writers = vec![];
let mut buffer_readers = vec![];
for stream in reader_config.streams.clone() {
let tracker_handle = TrackerHandle::new();
let tracker_handle = TrackerHandle::new(callback_handler.clone());

let buffer_reader = create_buffer_reader(
stream,
Expand Down Expand Up @@ -284,14 +277,6 @@ async fn start_sink_forwarder(
.await;
}

let callback_handler = match config.callback_config {
Some(ref cb_cfg) => Some(CallbackHandler::new(
config.vertex_name.clone(),
cb_cfg.callback_concurrency,
)),
None => None,
};

// Start a new forwarder for each buffer reader
let mut forwarder_tasks = Vec::new();
for (buffer_reader, (sink_writer, _, _)) in buffer_readers.into_iter().zip(sink_writers) {
Expand All @@ -300,7 +285,6 @@ async fn start_sink_forwarder(
buffer_reader,
sink_writer,
cln_token.clone(),
callback_handler.clone(),
)
.await;

Expand All @@ -326,15 +310,13 @@ async fn create_buffer_writer(
js_context: Context,
tracker_handle: TrackerHandle,
cln_token: CancellationToken,
callback_handler: Option<CallbackHandler>,
) -> JetstreamWriter {
JetstreamWriter::new(
config.to_vertex_config.clone(),
js_context,
config.paf_concurrency,
tracker_handle,
cln_token,
callback_handler,
)
}

Expand Down
Loading

0 comments on commit a1b52dc

Please sign in to comment.