diff --git a/examples/simple-source/src/main.rs b/examples/simple-source/src/main.rs index 8e78970..1fe06e9 100644 --- a/examples/simple-source/src/main.rs +++ b/examples/simple-source/src/main.rs @@ -61,9 +61,11 @@ pub(crate) mod simple_source { self.yet_to_ack.write().unwrap().extend(message_offsets) } - async fn ack(&self, offset: Offset) { - let x = &String::from_utf8(offset.offset).unwrap(); - self.yet_to_ack.write().unwrap().remove(x); + async fn ack(&self, offset: Vec) { + for offset in offset { + let x = &String::from_utf8(offset.offset).unwrap(); + self.yet_to_ack.write().unwrap().remove(x); + } } async fn pending(&self) -> usize { diff --git a/proto/sink.proto b/proto/sink.proto index db68e69..861e60e 100644 --- a/proto/sink.proto +++ b/proto/sink.proto @@ -77,7 +77,7 @@ message SinkResponse { // err_msg is the error message, set it if success is set to false. string err_msg = 3; } - Result result = 1; + repeated Result results = 1; optional Handshake handshake = 2; optional TransmissionStatus status = 3; } \ No newline at end of file diff --git a/proto/source.proto b/proto/source.proto index 8878ac6..d78f4d0 100644 --- a/proto/source.proto +++ b/proto/source.proto @@ -111,7 +111,7 @@ message ReadResponse { message AckRequest { message Request { // Required field holding the offset to be acked - Offset offset = 1; + repeated Offset offsets = 1; } // Required field holding the request. The list will be ordered and will have the same order as the original Read response. Request request = 1; diff --git a/src/servers/sink.v1.rs b/src/servers/sink.v1.rs index 5fb32cb..5155ea7 100644 --- a/src/servers/sink.v1.rs +++ b/src/servers/sink.v1.rs @@ -61,8 +61,8 @@ pub struct TransmissionStatus { /// SinkResponse is the individual response of each message written to the sink. #[derive(Clone, PartialEq, ::prost::Message)] pub struct SinkResponse { - #[prost(message, optional, tag = "1")] - pub result: ::core::option::Option, + #[prost(message, repeated, tag = "1")] + pub results: ::prost::alloc::vec::Vec, #[prost(message, optional, tag = "2")] pub handshake: ::core::option::Option, #[prost(message, optional, tag = "3")] diff --git a/src/servers/source.v1.rs b/src/servers/source.v1.rs index 68210fc..95a500f 100644 --- a/src/servers/source.v1.rs +++ b/src/servers/source.v1.rs @@ -179,8 +179,8 @@ pub mod ack_request { #[derive(Clone, PartialEq, ::prost::Message)] pub struct Request { /// Required field holding the offset to be acked - #[prost(message, optional, tag = "1")] - pub offset: ::core::option::Option, + #[prost(message, repeated, tag = "1")] + pub offsets: ::prost::alloc::vec::Vec, } } /// diff --git a/src/sink.rs b/src/sink.rs index ef07cfa..5b6a751 100644 --- a/src/sink.rs +++ b/src/sink.rs @@ -232,6 +232,7 @@ where // loop until the global stream has been shutdown. let mut global_stream_ended = false; while !global_stream_ended { + let start = std::time::Instant::now(); // for every batch, we need to read from the stream. The end-of-batch is // encoded in the request. global_stream_ended = Self::process_sink_batch( @@ -240,6 +241,7 @@ where grpc_resp_tx.clone(), ) .await?; + println!("Time taken for batch: {:?}", start.elapsed().as_micros()); } Ok(()) } @@ -259,20 +261,19 @@ where // spawn the UDF let sinker_handle = tokio::spawn(async move { let responses = sink_handle.sink(rx).await; - for response in responses { - resp_tx - .send(Ok(SinkResponse { - result: Some(response.into()), - handshake: None, - status: None, - })) - .await - .expect("Sending response to channel"); - } + resp_tx + .send(Ok(SinkResponse { + results: responses.into_iter().map(|r| r.into()).collect(), + handshake: None, + status: None, + })) + .await + .expect("Sending response to channel"); + // send an EOT message to the client to indicate the end of transmission for this batch resp_tx .send(Ok(SinkResponse { - result: None, + results: vec![], handshake: None, status: Some(sink_pb::TransmissionStatus { eot: true }), })) @@ -385,7 +386,7 @@ where if let Some(handshake) = handshake_request.handshake { resp_tx .send(Ok(SinkResponse { - result: None, + results: vec![], handshake: Some(handshake), status: None, })) @@ -641,32 +642,31 @@ mod tests { let mut resp_stream = resp.into_inner(); // handshake response let resp = resp_stream.message().await.unwrap().unwrap(); - assert!(resp.result.is_none()); assert!(resp.handshake.is_some()); let resp = resp_stream.message().await.unwrap().unwrap(); - assert!(resp.result.is_some()); - let msg = &resp.result.unwrap(); + assert!(!resp.results.is_empty()); + let msg = &resp.results.get(0).unwrap(); assert_eq!(msg.err_msg, ""); assert_eq!(msg.id, "1"); // eot for first request let resp = resp_stream.message().await.unwrap().unwrap(); - assert!(resp.result.is_none()); + assert!(resp.results.is_empty()); assert!(resp.handshake.is_none()); let msg = &resp.status.unwrap(); assert!(msg.eot); let resp = resp_stream.message().await.unwrap().unwrap(); - assert!(resp.result.is_some()); + assert!(!resp.results.is_empty()); assert!(resp.handshake.is_none()); - let msg = &resp.result.unwrap(); + let msg = &resp.results.get(0).unwrap(); assert_eq!(msg.err_msg, ""); assert_eq!(msg.id, "2"); // eot for second request let resp = resp_stream.message().await.unwrap().unwrap(); - assert!(resp.result.is_none()); + assert!(resp.results.is_empty()); assert!(resp.handshake.is_none()); let msg = &resp.status.unwrap(); assert!(msg.eot); @@ -773,7 +773,7 @@ mod tests { // handshake response let resp = resp_stream.message().await.unwrap().unwrap(); - assert!(resp.result.is_none()); + assert!(resp.results.is_empty()); assert!(resp.handshake.is_some()); let err_resp = resp_stream.message().await; diff --git a/src/source.rs b/src/source.rs index 79adf63..190f9f2 100644 --- a/src/source.rs +++ b/src/source.rs @@ -52,7 +52,7 @@ pub trait Sourcer { /// Reads the messages from the source and sends them to the transmitter. async fn read(&self, request: SourceReadRequest, transmitter: Sender); /// Acknowledges the message that has been processed by the user-defined source. - async fn ack(&self, offset: Offset); + async fn ack(&self, offset: Vec); /// Returns the number of messages that are yet to be processed by the user-defined source. async fn pending(&self) -> usize; /// Returns the partitions associated with the source. This will be used by the platform to determine @@ -275,14 +275,16 @@ where let request = ack_request.request .ok_or_else(|| SourceError(ErrorKind::InternalError("Invalid request, request can't be empty".to_string())))?; - let offset = request.offset - .ok_or_else(|| SourceError(ErrorKind::InternalError("Invalid request, offset can't be empty".to_string())))?; - - handler_fn - .ack(Offset { - offset: offset.offset, + let offsets = request.offsets + .iter() + .map(|offset| Offset { + offset: offset.offset.clone(), partition_id: offset.partition_id, }) + .collect(); + + handler_fn + .ack(offsets) .await; // the return of handler_fn implicitly means that the ack is successful; hence @@ -602,11 +604,13 @@ mod tests { self.yet_to_ack.write().unwrap().extend(message_offsets) } - async fn ack(&self, offset: Offset) { - self.yet_to_ack - .write() - .unwrap() - .remove(&String::from_utf8(offset.offset).unwrap()); + async fn ack(&self, offset: Vec) { + for offset in offset { + self.yet_to_ack + .write() + .unwrap() + .remove(&String::from_utf8(offset.offset).unwrap()); + } } async fn pending(&self) -> usize { @@ -705,10 +709,10 @@ mod tests { for resp in response_values.iter() { let ack_request = proto::AckRequest { request: Some(proto::ack_request::Request { - offset: Some(proto::Offset { + offsets: vec![proto::Offset { offset: resp.offset.clone().unwrap().offset, partition_id: resp.offset.clone().unwrap().partition_id, - }), + }], }), handshake: None, };