diff --git a/go.mod b/go.mod index 84ad2e952f..94ae56ddce 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe github.com/nats-io/nats-server/v2 v2.10.17 github.com/nats-io/nats.go v1.36.0 - github.com/numaproj/numaflow-go v0.8.2-0.20240917052911-ee2f3086d64e + github.com/numaproj/numaflow-go v0.8.2-0.20240918054944-0fd13d430793 github.com/prometheus/client_golang v1.18.0 github.com/prometheus/client_model v0.5.0 github.com/prometheus/common v0.45.0 diff --git a/go.sum b/go.sum index 236809e8c4..9fa9b85a66 100644 --- a/go.sum +++ b/go.sum @@ -485,8 +485,8 @@ github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDm github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/numaproj/numaflow-go v0.8.2-0.20240917052911-ee2f3086d64e h1:F3iujbel8y5X20bVMY0Am6XDyL5eDOC/6kxyI8uxfpg= -github.com/numaproj/numaflow-go v0.8.2-0.20240917052911-ee2f3086d64e/go.mod h1:g4JZOyUPhjfhv+kR0sX5d8taw/dasgKPXLvQBi39mJ4= +github.com/numaproj/numaflow-go v0.8.2-0.20240918054944-0fd13d430793 h1:kUQw1LsUvmTjqFfcia6DZOxy8qCQwvdY0TpOnR8w3Xg= +github.com/numaproj/numaflow-go v0.8.2-0.20240918054944-0fd13d430793/go.mod h1:g4JZOyUPhjfhv+kR0sX5d8taw/dasgKPXLvQBi39mJ4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= diff --git a/pkg/sdkclient/source/client.go b/pkg/sdkclient/source/client.go index 15be4c9317..305271d95e 100644 --- a/pkg/sdkclient/source/client.go +++ b/pkg/sdkclient/source/client.go @@ -86,23 +86,42 @@ waitUntilReady: return nil, fmt.Errorf("failed to create ack stream: %v", err) } - // Send handshake request - handshakeRequest := &sourcepb.ReadRequest{ + // Send handshake request for read stream + readHandshakeRequest := &sourcepb.ReadRequest{ Handshake: &sourcepb.Handshake{ Sot: true, }, } - if err := c.readStream.Send(handshakeRequest); err != nil { - return nil, fmt.Errorf("failed to send handshake request: %v", err) + if err := c.readStream.Send(readHandshakeRequest); err != nil { + return nil, fmt.Errorf("failed to send read handshake request: %v", err) } - // Wait for handshake response - handshakeResponse, err := c.readStream.Recv() + // Wait for handshake response for read stream + readHandshakeResponse, err := c.readStream.Recv() if err != nil { - return nil, fmt.Errorf("failed to receive handshake response: %v", err) + return nil, fmt.Errorf("failed to receive read handshake response: %v", err) } - if handshakeResponse.GetHandshake() == nil || !handshakeResponse.GetHandshake().GetSot() { - return nil, fmt.Errorf("invalid handshake response") + if readHandshakeResponse.GetHandshake() == nil || !readHandshakeResponse.GetHandshake().GetSot() { + return nil, fmt.Errorf("invalid read handshake response") + } + + // Send handshake request for ack stream + ackHandshakeRequest := &sourcepb.AckRequest{ + Handshake: &sourcepb.Handshake{ + Sot: true, + }, + } + if err := c.ackStream.Send(ackHandshakeRequest); err != nil { + return nil, fmt.Errorf("failed to send ack handshake request: %v", err) + } + + // Wait for handshake response for ack stream + ackHandshakeResponse, err := c.ackStream.Recv() + if err != nil { + return nil, fmt.Errorf("failed to receive ack handshake response: %v", err) + } + if ackHandshakeResponse.GetHandshake() == nil || !ackHandshakeResponse.GetHandshake().GetSot() { + return nil, fmt.Errorf("invalid ack handshake response") } return c, nil @@ -172,11 +191,19 @@ func (c *client) ReadFn(_ context.Context, req *sourcepb.ReadRequest, datumCh ch // AckFn acknowledges the data from the source. func (c *client) AckFn(_ context.Context, req *sourcepb.AckRequest) (*sourcepb.AckResponse, error) { + // Send the ack request err := c.ackStream.Send(req) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to send ack request: %v", err) } - return &sourcepb.AckResponse{}, nil + + // Wait for the ack response + resp, err := c.ackStream.Recv() + if err != nil { + return nil, fmt.Errorf("failed to receive ack response: %v", err) + } + + return resp, nil } // PendingFn returns the number of pending data from the source. diff --git a/pkg/sdkclient/source/client_test.go b/pkg/sdkclient/source/client_test.go index f1ae7d80f4..818c3c3430 100644 --- a/pkg/sdkclient/source/client_test.go +++ b/pkg/sdkclient/source/client_test.go @@ -150,22 +150,47 @@ func TestAckFn(t *testing.T) { defer ctrl.Finish() mockClient := sourcemock.NewMockSourceClient(ctrl) - mockStream := sourcemock.NewMockSource_AckFnClient(ctrl) + + // Handshake request and response + mockStream.EXPECT().Send(&sourcepb.AckRequest{ + Handshake: &sourcepb.Handshake{ + Sot: true, + }, + }).Return(nil) + mockStream.EXPECT().Recv().Return(&sourcepb.AckResponse{ + Handshake: &sourcepb.Handshake{ + Sot: true, + }, + }, nil) + + // Ack request and response mockStream.EXPECT().Send(gomock.Any()).Return(nil) - mockStream.EXPECT().Send(gomock.Any()).Return(fmt.Errorf("mock connection refused")) + mockStream.EXPECT().Recv().Return(&sourcepb.AckResponse{}, nil) testClient := client{ grpcClt: mockClient, ackStream: mockStream, } + // Perform handshake + ackHandshakeRequest := &sourcepb.AckRequest{ + Handshake: &sourcepb.Handshake{ + Sot: true, + }, + } + err := testClient.ackStream.Send(ackHandshakeRequest) + assert.NoError(t, err) + + ackHandshakeResponse, err := testClient.ackStream.Recv() + assert.NoError(t, err) + assert.NotNil(t, ackHandshakeResponse.GetHandshake()) + assert.True(t, ackHandshakeResponse.GetHandshake().GetSot()) + + // Test AckFn ack, err := testClient.AckFn(ctx, &sourcepb.AckRequest{}) assert.NoError(t, err) assert.Equal(t, &sourcepb.AckResponse{}, ack) - - _, err = testClient.AckFn(ctx, &sourcepb.AckRequest{}) - assert.EqualError(t, err, "mock connection refused") } func TestPendingFn(t *testing.T) { diff --git a/pkg/sources/udsource/grpc_udsource_test.go b/pkg/sources/udsource/grpc_udsource_test.go index 459f9ae39f..e0a0ab4ca5 100644 --- a/pkg/sources/udsource/grpc_udsource_test.go +++ b/pkg/sources/udsource/grpc_udsource_test.go @@ -288,6 +288,7 @@ func Test_gRPCBasedUDSource_ApplyAckWithMockClient(t *testing.T) { mockAckClient.EXPECT().Send(req1).Return(nil).Times(1) mockAckClient.EXPECT().Send(req2).Return(nil).Times(1) + mockAckClient.EXPECT().Recv().Return(&sourcepb.AckResponse{}, nil).Times(2) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -340,6 +341,6 @@ func Test_gRPCBasedUDSource_ApplyAckWithMockClient(t *testing.T) { NewUserDefinedSourceOffset(offset1), NewUserDefinedSourceOffset(offset2), }) - assert.ErrorIs(t, err, status.New(codes.DeadlineExceeded, "mock test err").Err()) + assert.Equal(t, err.Error(), fmt.Sprintf("failed to send ack request: %s", status.New(codes.DeadlineExceeded, "mock test err").Err())) }) } diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 10e16e60ee..624d5f14a8 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -43,9 +43,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.88" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e1496f8fb1fbf272686b8d37f523dab3e4a7443300055e74cdaa449f3114356" +checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" [[package]] name = "arc-swap" @@ -351,18 +351,18 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.1" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" dependencies = [ "serde", ] [[package]] name = "cc" -version = "1.1.18" +version = "1.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476" +checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0" dependencies = [ "jobserver", "libc", @@ -1098,7 +1098,7 @@ dependencies = [ "tokio", "tokio-rustls 0.26.0", "tower-service", - "webpki-roots 0.26.3", + "webpki-roots 0.26.5", ] [[package]] @@ -1136,9 +1136,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.60" +version = "0.1.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -1478,7 +1478,7 @@ dependencies = [ "hyper-util", "kube", "log", - "numaflow 0.1.1 (git+https://github.com/numaproj/numaflow-rs.git?branch=handshake)", + "numaflow 0.1.1", "numaflow-models", "once_cell", "parking_lot", @@ -1605,29 +1605,7 @@ dependencies = [ [[package]] name = "numaflow" version = "0.1.1" -source = "git+https://github.com/numaproj/numaflow-rs.git?branch=handshake#f3061b039c877e9828c50fcd0424391727e4920d" -dependencies = [ - "chrono", - "futures-util", - "hyper-util", - "prost", - "prost-types", - "serde", - "serde_json", - "thiserror", - "tokio", - "tokio-stream", - "tokio-util", - "tonic", - "tonic-build", - "tracing", - "uuid", -] - -[[package]] -name = "numaflow" -version = "0.1.1" -source = "git+https://github.com/numaproj/numaflow-rs.git?branch=source-streaming#dcbb26834153b84853d9757e25395d92a1314d4a" +source = "git+https://github.com/numaproj/numaflow-rs.git?branch=handshake#baecc88456f317b08bc869f82596e2b746cf798b" dependencies = [ "chrono", "futures-util", @@ -2027,15 +2005,15 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.4" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bffec3605b73c6f1754535084a85229fa8a30f86014e6c81aeec4abb68b0285" +checksum = "4fe68c2e9e1a1234e218683dbdf9f9dfcb094113c5ac2b938dfcb9bab4c4140b" dependencies = [ "libc", "once_cell", "socket2", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2248,7 +2226,7 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots 0.26.3", + "webpki-roots 0.26.5", "windows-registry", ] @@ -2616,7 +2594,7 @@ dependencies = [ name = "servesink" version = "0.1.0" dependencies = [ - "numaflow 0.1.1 (git+https://github.com/numaproj/numaflow-rs.git?branch=source-streaming)", + "numaflow 0.1.1", "reqwest 0.12.7", "tokio", "tonic", @@ -3029,9 +3007,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.20" +version = "0.22.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "583c44c02ad26b0c3f3066fe629275e50627026c51ac2e595cca4c230ce1ce1d" +checksum = "3b072cee73c449a636ffd6f32bd8de3a9f7119139aff882f44943ce2986dc5cf" dependencies = [ "indexmap 2.5.0", "serde", @@ -3260,9 +3238,9 @@ checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "unicode-normalization" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" dependencies = [ "tinyvec", ] @@ -3430,9 +3408,9 @@ checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" [[package]] name = "webpki-roots" -version = "0.26.3" +version = "0.26.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd7c23921eeb1713a4e851530e9b9756e4fb0e89978582942612524cf09f01cd" +checksum = "0bd24728e5af82c6c4ec1b66ac4844bdf8156257fccda846ec58b42cd0cdbe6a" dependencies = [ "rustls-pki-types", ] diff --git a/rust/monovertex/proto/source.proto b/rust/monovertex/proto/source.proto index 31a762ac59..93c5f19278 100644 --- a/rust/monovertex/proto/source.proto +++ b/rust/monovertex/proto/source.proto @@ -17,7 +17,7 @@ service Source { // The caller (numa) expects the AckFn to be successful, and it does not expect any errors. // If there are some irrecoverable errors when the callee (UDSource) is processing the AckFn request, // then it is best to crash because there are no other retry mechanisms possible. - rpc AckFn(stream AckRequest) returns (AckResponse); + rpc AckFn(stream AckRequest) returns (stream AckResponse); // PendingFn returns the number of pending records at the user defined source. rpc PendingFn(google.protobuf.Empty) returns (PendingResponse); @@ -112,6 +112,7 @@ message AckRequest { } // Required field holding the request. The list will be ordered and will have the same order as the original Read response. Request request = 1; + optional Handshake handshake = 2; } /* @@ -131,6 +132,7 @@ message AckResponse { } // Required field holding the result. Result result = 1; + optional Handshake handshake = 2; } /* diff --git a/rust/monovertex/src/forwarder.rs b/rust/monovertex/src/forwarder.rs index 0b41cd3569..d60644b338 100644 --- a/rust/monovertex/src/forwarder.rs +++ b/rust/monovertex/src/forwarder.rs @@ -7,7 +7,7 @@ use crate::metrics; use crate::metrics::forward_metrics; use crate::sink::SinkWriter; use crate::sink_pb::Status::{Failure, Fallback, Success}; -use crate::source::SourceReader; +use crate::source::{SourceAcker, SourceReader}; use crate::transformer::SourceTransformer; use chrono::Utc; use tokio::task::JoinSet; @@ -21,6 +21,7 @@ use tracing::{debug, info}; /// back to the source. pub(crate) struct Forwarder { source_reader: SourceReader, + source_acker: SourceAcker, sink_writer: SinkWriter, source_transformer: Option, fb_sink_writer: Option, @@ -31,6 +32,7 @@ pub(crate) struct Forwarder { /// ForwarderBuilder is used to build a Forwarder instance with optional fields. pub(crate) struct ForwarderBuilder { source_reader: SourceReader, + source_acker: SourceAcker, sink_writer: SinkWriter, cln_token: CancellationToken, source_transformer: Option, @@ -40,13 +42,15 @@ pub(crate) struct ForwarderBuilder { impl ForwarderBuilder { /// Create a new builder with mandatory fields pub(crate) fn new( - source_client: SourceReader, - sink_client: SinkWriter, + source_reader: SourceReader, + source_acker: SourceAcker, + sink_writer: SinkWriter, cln_token: CancellationToken, ) -> Self { Self { - source_reader: source_client, - sink_writer: sink_client, + source_reader, + source_acker, + sink_writer, cln_token, source_transformer: None, fb_sink_writer: None, @@ -71,6 +75,7 @@ impl ForwarderBuilder { let common_labels = metrics::forward_metrics_labels().clone(); Forwarder { source_reader: self.source_reader, + source_acker: self.source_acker, sink_writer: self.sink_writer, source_transformer: self.source_transformer, fb_sink_writer: self.fb_sink_writer, @@ -525,7 +530,7 @@ impl Forwarder { let n = offsets.len(); let start_time = tokio::time::Instant::now(); - self.source_reader.ack(offsets).await?; + self.source_acker.ack(offsets).await?; debug!("Ack latency - {}ms", start_time.elapsed().as_millis()); @@ -557,7 +562,7 @@ mod tests { use crate::shared::create_rpc_channel; use crate::sink::SinkWriter; use crate::sink_pb::sink_client::SinkClient; - use crate::source::SourceReader; + use crate::source::{SourceAcker, SourceReader}; use crate::source_pb::source_client::SourceClient; use crate::sourcetransform_pb::source_transform_client::SourceTransformClient; use crate::transformer::SourceTransformer; @@ -649,10 +654,7 @@ mod tests { #[tonic::async_trait] impl sink::Sinker for InMemorySink { - async fn sink( - &self, - mut input: tokio::sync::mpsc::Receiver, - ) -> Vec { + async fn sink(&self, mut input: mpsc::Receiver) -> Vec { let mut responses: Vec = Vec::new(); while let Some(datum) = input.recv().await { let response = match std::str::from_utf8(&datum.value) { @@ -742,13 +744,19 @@ mod tests { let cln_token = CancellationToken::new(); - let source_client = SourceReader::new(SourceClient::new( + let source_reader = SourceReader::new(SourceClient::new( + create_rpc_channel(source_sock_file.clone()).await.unwrap(), + )) + .await + .expect("failed to connect to source server"); + + let source_acker = SourceAcker::new(SourceClient::new( create_rpc_channel(source_sock_file).await.unwrap(), )) .await .expect("failed to connect to source server"); - let sink_client = SinkWriter::new(SinkClient::new( + let sink_writer = SinkWriter::new(SinkClient::new( create_rpc_channel(sink_sock_file).await.unwrap(), )) .await @@ -760,9 +768,10 @@ mod tests { .await .expect("failed to connect to transformer server"); - let mut forwarder = ForwarderBuilder::new(source_client, sink_client, cln_token.clone()) - .source_transformer(transformer_client) - .build(); + let mut forwarder = + ForwarderBuilder::new(source_reader, source_acker, sink_writer, cln_token.clone()) + .source_transformer(transformer_client) + .build(); // Assert the received message in a different task let assert_handle = tokio::spawn(async move { @@ -864,20 +873,27 @@ mod tests { let cln_token = CancellationToken::new(); - let source_client = SourceReader::new(SourceClient::new( + let source_reader = SourceReader::new(SourceClient::new( + create_rpc_channel(source_sock_file.clone()).await.unwrap(), + )) + .await + .expect("failed to connect to source server"); + + let source_acker = SourceAcker::new(SourceClient::new( create_rpc_channel(source_sock_file).await.unwrap(), )) .await .expect("failed to connect to source server"); - let sink_client = SinkWriter::new(SinkClient::new( + let sink_writer = SinkWriter::new(SinkClient::new( create_rpc_channel(sink_sock_file).await.unwrap(), )) .await .expect("failed to connect to sink server"); let mut forwarder = - ForwarderBuilder::new(source_client, sink_client, cln_token.clone()).build(); + ForwarderBuilder::new(source_reader, source_acker, sink_writer, cln_token.clone()) + .build(); let cancel_handle = tokio::spawn(async move { tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; @@ -910,10 +926,7 @@ mod tests { #[tonic::async_trait] impl sink::Sinker for FallbackSender { - async fn sink( - &self, - mut input: tokio::sync::mpsc::Receiver, - ) -> Vec { + async fn sink(&self, mut input: mpsc::Receiver) -> Vec { let mut responses = vec![]; while let Some(datum) = input.recv().await { responses.append(&mut vec![sink::Response::fallback(datum.id)]); @@ -924,7 +937,7 @@ mod tests { #[tokio::test] async fn test_fb_sink() { - let (sink_tx, mut sink_rx) = tokio::sync::mpsc::channel(10); + let (sink_tx, mut sink_rx) = mpsc::channel(10); // Start the source server let (source_shutdown_tx, source_shutdown_rx) = tokio::sync::oneshot::channel(); @@ -982,27 +995,34 @@ mod tests { let cln_token = CancellationToken::new(); - let source_client = SourceReader::new(SourceClient::new( + let source_reader = SourceReader::new(SourceClient::new( + create_rpc_channel(source_sock_file.clone()).await.unwrap(), + )) + .await + .expect("failed to connect to source server"); + + let source_acker = SourceAcker::new(SourceClient::new( create_rpc_channel(source_sock_file).await.unwrap(), )) .await .expect("failed to connect to source server"); - let sink_client = SinkWriter::new(SinkClient::new( + let sink_writer = SinkWriter::new(SinkClient::new( create_rpc_channel(sink_sock_file).await.unwrap(), )) .await .expect("failed to connect to sink server"); - let fb_sink_client = SinkWriter::new(SinkClient::new( + let fb_sink_writer = SinkWriter::new(SinkClient::new( create_rpc_channel(fb_sink_sock_file).await.unwrap(), )) .await .expect("failed to connect to fb sink server"); - let mut forwarder = ForwarderBuilder::new(source_client, sink_client, cln_token.clone()) - .fallback_sink_writer(fb_sink_client) - .build(); + let mut forwarder = + ForwarderBuilder::new(source_reader, source_acker, sink_writer, cln_token.clone()) + .fallback_sink_writer(fb_sink_writer) + .build(); let assert_handle = tokio::spawn(async move { let received_message = sink_rx.recv().await.unwrap(); diff --git a/rust/monovertex/src/lib.rs b/rust/monovertex/src/lib.rs index f7b267cbff..d3d612c2db 100644 --- a/rust/monovertex/src/lib.rs +++ b/rust/monovertex/src/lib.rs @@ -10,7 +10,7 @@ use crate::metrics::MetricsState; use crate::shared::create_rpc_channel; use crate::sink::{SinkWriter, FB_SINK_SOCKET, SINK_SOCKET}; use crate::sink_pb::sink_client::SinkClient; -use crate::source::{SourceReader, SOURCE_SOCKET}; +use crate::source::{SourceAcker, SourceReader, SOURCE_SOCKET}; use crate::source_pb::source_client::SourceClient; use crate::sourcetransform_pb::source_transform_client::SourceTransformClient; use crate::transformer::{SourceTransformer, TRANSFORMER_SOCKET}; @@ -159,9 +159,11 @@ async fn start_forwarder(cln_token: CancellationToken) -> Result<()> { // build the forwarder let source_reader = SourceReader::new(source_grpc_client.clone()).await?; + let source_acker = SourceAcker::new(source_grpc_client.clone()).await?; let sink_writer = SinkWriter::new(sink_grpc_client.clone()).await?; - let mut forwarder_builder = ForwarderBuilder::new(source_reader, sink_writer, cln_token); + let mut forwarder_builder = + ForwarderBuilder::new(source_reader, source_acker, sink_writer, cln_token); // add transformer if exists if let Some(transformer_grpc_client) = transformer_grpc_client { diff --git a/rust/monovertex/src/source.rs b/rust/monovertex/src/source.rs index 580271e465..eaafb3ae15 100644 --- a/rust/monovertex/src/source.rs +++ b/rust/monovertex/src/source.rs @@ -1,3 +1,4 @@ +use crate::config::config; use crate::error::Error::SourceError; use crate::error::Result; use crate::message::{Message, Offset}; @@ -8,47 +9,27 @@ use crate::source_pb::{ }; use base64::prelude::BASE64_STANDARD; use base64::Engine; -use std::thread::sleep; use tokio::sync::mpsc; -use tokio::task::JoinHandle; use tokio_stream::wrappers::ReceiverStream; use tonic::transport::Channel; use tonic::{Request, Streaming}; -use tracing::{info, warn}; pub(crate) const SOURCE_SOCKET: &str = "/var/run/numaflow/source.sock"; pub(crate) const SOURCE_SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcer-server-info"; -/// SourceReader reads messages from a source and acks them. +/// SourceReader reads messages from a source. #[derive(Debug)] pub(crate) struct SourceReader { read_tx: mpsc::Sender, resp_stream: Streaming, - ack_tx: mpsc::Sender, - ack_handle: JoinHandle<()>, -} - -impl Drop for SourceReader { - fn drop(&mut self) { - // wait for the ack task to flush all the acks. - // FIXME: hacky way to wait for the ack task to finish. We should have a better way to handle this. - sleep(std::time::Duration::from_secs(30)); - // in a happy path scenario, the ack task would have already been finished. - if !self.ack_handle.is_finished() { - warn!("aborting ack task"); - self.ack_handle.abort(); - } - } } impl SourceReader { pub(crate) async fn new(mut client: SourceClient) -> Result { - let (read_tx, read_rx) = mpsc::channel(500); - let (ack_tx, ack_rx) = mpsc::channel(500); - + let (read_tx, read_rx) = mpsc::channel(config().batch_size as usize); let read_stream = ReceiverStream::new(read_rx); - // do a handshake with the server before we start sending read requests + // do a handshake for read with the server before we start sending read requests let handshake_request = ReadRequest { request: None, handshake: Some(source_pb::Handshake { sot: true }), @@ -72,21 +53,9 @@ impl SourceReader { return Err(SourceError("invalid handshake response".to_string())); } - // spawn a task to handle acks. - let mut ack_client = client.clone(); - let ack_handle = tokio::spawn(async move { - let ack_response = ack_client - .ack_fn(Request::new(ReceiverStream::new(ack_rx))) - .await - .expect("ack should not have failed"); - info!("Closing ack stream {:?}", ack_response); - }); - Ok(Self { read_tx, resp_stream, - ack_tx, - ack_handle, }) } @@ -123,6 +92,46 @@ impl SourceReader { } Ok(messages) } +} + +/// SourceAcker acks the messages from a source. +#[derive(Debug)] +pub(crate) struct SourceAcker { + ack_tx: mpsc::Sender, + ack_resp_stream: Streaming, +} + +impl SourceAcker { + pub(crate) async fn new(mut client: SourceClient) -> Result { + let (ack_tx, ack_rx) = mpsc::channel(config().batch_size as usize); + let ack_stream = ReceiverStream::new(ack_rx); + + // do a handshake for ack with the server before we start sending ack requests + let ack_handshake_request = AckRequest { + request: None, + handshake: Some(source_pb::Handshake { sot: true }), + }; + ack_tx + .send(ack_handshake_request) + .await + .map_err(|e| SourceError(format!("failed to send ack handshake request: {}", e)))?; + + let mut ack_resp_stream = client.ack_fn(Request::new(ack_stream)).await?.into_inner(); + + // first response from the server will be the handshake response. We need to check if the + // server has accepted the handshake. + let ack_handshake_response = ack_resp_stream.message().await?.ok_or(SourceError( + "failed to receive ack handshake response".to_string(), + ))?; + if ack_handshake_response.handshake.map_or(true, |h| !h.sot) { + return Err(SourceError("invalid ack handshake response".to_string())); + } + + Ok(Self { + ack_tx, + ack_resp_stream, + }) + } pub(crate) async fn ack(&mut self, offsets: Vec) -> Result { for offset in offsets { @@ -135,14 +144,22 @@ impl SourceReader { partition_id: offset.partition_id, }), }), + handshake: None, }; self.ack_tx .send(request) .await .map_err(|e| SourceError(e.to_string()))?; + + // wait for the ack response for each ack request + self.ack_resp_stream + .message() + .await? + .ok_or(SourceError("failed to receive ack response".to_string()))?; } Ok(AckResponse { result: Some(ack_response::Result { success: Some(()) }), + handshake: None, }) } } @@ -152,7 +169,7 @@ mod tests { use std::collections::HashSet; use crate::shared::create_rpc_channel; - use crate::source::SourceReader; + use crate::source::{SourceAcker, SourceReader}; use crate::source_pb::source_client::SourceClient; use chrono::Utc; use numaflow::source; @@ -237,16 +254,24 @@ mod tests { // TODO: flaky tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - let mut source_client = SourceReader::new(SourceClient::new( + let mut source_reader = SourceReader::new(SourceClient::new( + create_rpc_channel(sock_file.clone()).await.unwrap(), + )) + .await + .map_err(|e| panic!("failed to create source reader: {:?}", e)) + .unwrap(); + + let mut source_acker = SourceAcker::new(SourceClient::new( create_rpc_channel(sock_file).await.unwrap(), )) .await + .map_err(|e| panic!("failed to create source acker: {:?}", e)) .unwrap(); - let messages = source_client.read(5, 1000).await.unwrap(); + let messages = source_reader.read(5, 1000).await.unwrap(); assert_eq!(messages.len(), 5); - let response = source_client + let response = source_acker .ack(messages.iter().map(|m| m.offset.clone()).collect()) .await .unwrap(); @@ -254,7 +279,8 @@ mod tests { // we need to drop the client, because if there are any in-flight requests // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 - drop(source_client); + drop(source_reader); + drop(source_acker); shutdown_tx .send(()) .expect("failed to send shutdown signal"); diff --git a/rust/servesink/Cargo.toml b/rust/servesink/Cargo.toml index 6d79dc2b7e..90a7c44696 100644 --- a/rust/servesink/Cargo.toml +++ b/rust/servesink/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] tonic = "0.12.0" tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } -numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", branch = "source-streaming" } +numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", branch = "handshake" } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } diff --git a/test/udsource-e2e/testdata/simple-source-go.yaml b/test/udsource-e2e/testdata/simple-source-go.yaml index 84515c7007..5d0943bd03 100644 --- a/test/udsource-e2e/testdata/simple-source-go.yaml +++ b/test/udsource-e2e/testdata/simple-source-go.yaml @@ -11,7 +11,6 @@ spec: # A simple user-defined source for e2e testing # See https://github.com/numaproj/numaflow-go/tree/main/pkg/sourcer/examples/simple_source image: quay.io/yhl25/numaflow-go/source-simple-source:stable - imagePullPolicy: Always limits: readBatchSize: 500 scale: diff --git a/test/udsource-e2e/testdata/simple-source-java.yaml b/test/udsource-e2e/testdata/simple-source-java.yaml index c53ecf9b03..b85745ebf9 100644 --- a/test/udsource-e2e/testdata/simple-source-java.yaml +++ b/test/udsource-e2e/testdata/simple-source-java.yaml @@ -13,7 +13,6 @@ spec: # A simple user-defined source for e2e testing # See https://github.com/numaproj/numaflow-java/tree/main/examples/src/main/java/io/numaproj/numaflow/examples/source/simple image: quay.io/yhl25/numaflow-java/source-simple-source:stable - imagePullPolicy: IfNotPresent limits: readBatchSize: 500 - name: out diff --git a/test/udsource-e2e/testdata/simple-source-python.yaml b/test/udsource-e2e/testdata/simple-source-python.yaml index a64960e9fe..9862b63bb6 100644 --- a/test/udsource-e2e/testdata/simple-source-python.yaml +++ b/test/udsource-e2e/testdata/simple-source-python.yaml @@ -13,7 +13,6 @@ spec: # A simple user-defined source for e2e testing # See https://github.com/numaproj/numaflow-python/tree/main/examples/source/simple_source image: quay.io/numaio/numaflow-python/simple-source:stable - imagePullPolicy: Always limits: readBatchSize: 500 - name: out diff --git a/test/udsource-e2e/testdata/simple-source-rs.yaml b/test/udsource-e2e/testdata/simple-source-rs.yaml index dabf34df2f..0cff657496 100644 --- a/test/udsource-e2e/testdata/simple-source-rs.yaml +++ b/test/udsource-e2e/testdata/simple-source-rs.yaml @@ -10,8 +10,7 @@ spec: container: # A simple user-defined source for e2e testing # See https://github.com/numaproj/numaflow-go/tree/main/pkg/sourcer/examples/simple_source - image: quay.io/numaio/numaflow-rs/simple-source:stable - imagePullPolicy: Always + image: quay.io/yhl25/numaflow-rs/simple-source:stable limits: readBatchSize: 500 scale: