From a04f9cd0e07655732e2927f72df51bc336af73de Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Mon, 30 Sep 2024 12:44:34 +0530 Subject: [PATCH 01/10] Merge new changes for rust Signed-off-by: Sreekanth --- rust/Cargo.lock | 2 +- rust/numaflow-core/Cargo.toml | 2 +- .../numaflow-core/proto/sourcetransform.proto | 33 ++- rust/numaflow-core/src/config.rs | 18 +- rust/numaflow-core/src/message.rs | 16 +- .../numaflow-core/src/monovertex/forwarder.rs | 32 +-- .../src/transformer/user_defined.rs | 224 ++++++++++++++---- rust/servesink/Cargo.toml | 2 +- 8 files changed, 232 insertions(+), 97 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 655f30bc4d..b9cfa9d3bd 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1557,7 +1557,7 @@ dependencies = [ [[package]] name = "numaflow" version = "0.1.1" -source = "git+https://github.com/numaproj/numaflow-rs.git?rev=0c1682864a4b906fab52e149cfd7cacc679ce688#0c1682864a4b906fab52e149cfd7cacc679ce688" +source = "git+https://github.com/BulkBeing/numaflow-rs.git?rev=6eb7f3865d42a8ab11ade51622dc4d8feda25b5e#6eb7f3865d42a8ab11ade51622dc4d8feda25b5e" dependencies = [ "chrono", "futures-util", diff --git a/rust/numaflow-core/Cargo.toml b/rust/numaflow-core/Cargo.toml index 85a3bc39b1..962901cb2e 100644 --- a/rust/numaflow-core/Cargo.toml +++ b/rust/numaflow-core/Cargo.toml @@ -38,7 +38,7 @@ log = "0.4.22" [dev-dependencies] tempfile = "3.11.0" -numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "0c1682864a4b906fab52e149cfd7cacc679ce688" } +numaflow = { git = "https://github.com/BulkBeing/numaflow-rs.git", rev = "6eb7f3865d42a8ab11ade51622dc4d8feda25b5e" } [build-dependencies] tonic-build = "0.12.1" diff --git a/rust/numaflow-core/proto/sourcetransform.proto b/rust/numaflow-core/proto/sourcetransform.proto index 18e045c323..9d0a63a9dc 100644 --- a/rust/numaflow-core/proto/sourcetransform.proto +++ b/rust/numaflow-core/proto/sourcetransform.proto @@ -9,21 +9,36 @@ service SourceTransform { // SourceTransformFn applies a function to each request element. // In addition to map function, SourceTransformFn also supports assigning a new event time to response. // SourceTransformFn can be used only at source vertex by source data transformer. - rpc SourceTransformFn(SourceTransformRequest) returns (SourceTransformResponse); + rpc SourceTransformFn(stream SourceTransformRequest) returns (stream SourceTransformResponse); // IsReady is the heartbeat endpoint for gRPC. rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); } +/* + * Handshake message between client and server to indicate the start of transmission. + */ + message Handshake { + // Required field indicating the start of transmission. + bool sot = 1; +} + + /** * SourceTransformerRequest represents a request element. */ message SourceTransformRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - map headers = 5; + message Request { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + map headers = 5; + // This ID is used to uniquely identify a transform request + string id = 6; + } + Request request = 1; + optional Handshake handshake = 2; } /** @@ -37,6 +52,10 @@ message SourceTransformResponse { repeated string tags = 4; } repeated Result results = 1; + // This ID is used to refer the responses to the request it corresponds to. + string id = 2; + // Handshake message between client and server to indicate the start of transmission. + optional Handshake handshake = 3; } /** @@ -44,4 +63,4 @@ message SourceTransformResponse { */ message ReadyResponse { bool ready = 1; -} \ No newline at end of file +} diff --git a/rust/numaflow-core/src/config.rs b/rust/numaflow-core/src/config.rs index 5d245ed397..c3263e999c 100644 --- a/rust/numaflow-core/src/config.rs +++ b/rust/numaflow-core/src/config.rs @@ -3,6 +3,7 @@ use base64::prelude::BASE64_STANDARD; use base64::Engine; use numaflow_models::models::{Backoff, MonoVertex, RetryStrategy}; use std::env; +use std::fmt::Display; use std::sync::OnceLock; const DEFAULT_SOURCE_SOCKET: &str = "/var/run/numaflow/source.sock"; @@ -53,17 +54,14 @@ impl OnFailureStrategy { _ => Some(DEFAULT_SINK_RETRY_ON_FAIL_STRATEGY), } } +} - /// Converts the `OnFailureStrategy` enum variant to a String. - /// This facilitates situations where the enum needs to be displayed or logged as a string. - /// - /// # Returns - /// A string representing the `OnFailureStrategy` enum variant. - fn to_string(&self) -> String { +impl Display for OnFailureStrategy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match *self { - OnFailureStrategy::Retry => "retry".to_string(), - OnFailureStrategy::Fallback => "fallback".to_string(), - OnFailureStrategy::Drop => "drop".to_string(), + OnFailureStrategy::Retry => write!(f, "retry"), + OnFailureStrategy::Fallback => write!(f, "fallback"), + OnFailureStrategy::Drop => write!(f, "drop"), } } } @@ -647,4 +645,4 @@ mod tests { let drop = OnFailureStrategy::Drop; assert_eq!(drop.to_string(), "drop"); } -} +} \ No newline at end of file diff --git a/rust/numaflow-core/src/message.rs b/rust/numaflow-core/src/message.rs index b99a61b31d..d230e994fb 100644 --- a/rust/numaflow-core/src/message.rs +++ b/rust/numaflow-core/src/message.rs @@ -7,7 +7,7 @@ use chrono::{DateTime, Utc}; use crate::error::Error; use crate::monovertex::sink_pb::sink_request::Request; use crate::monovertex::sink_pb::SinkRequest; -use crate::monovertex::source_pb; +use crate::monovertex::{source_pb, sourcetransform_pb}; use crate::monovertex::source_pb::{read_response, AckRequest}; use crate::monovertex::sourcetransform_pb::SourceTransformRequest; use crate::shared::utils::{prost_timestamp_from_utc, utc_from_timestamp}; @@ -58,11 +58,15 @@ impl From for AckRequest { impl From for SourceTransformRequest { fn from(message: Message) -> Self { Self { - keys: message.keys, - value: message.value, - event_time: prost_timestamp_from_utc(message.event_time), - watermark: None, - headers: message.headers, + request: Some(sourcetransform_pb::source_transform_request::Request { + id: message.id, + keys: message.keys, + value: message.value, + event_time: prost_timestamp_from_utc(message.event_time), + watermark: None, + headers: message.headers, + }), + handshake: None, } } } diff --git a/rust/numaflow-core/src/monovertex/forwarder.rs b/rust/numaflow-core/src/monovertex/forwarder.rs index a32aff093b..ab58cfad03 100644 --- a/rust/numaflow-core/src/monovertex/forwarder.rs +++ b/rust/numaflow-core/src/monovertex/forwarder.rs @@ -1,3 +1,10 @@ +use chrono::Utc; +use log::warn; +use std::collections::HashMap; +use tokio::time::sleep; +use tokio_util::sync::CancellationToken; +use tracing::{debug, info}; + use crate::config::{config, OnFailureStrategy}; use crate::error; use crate::error::Error; @@ -8,13 +15,6 @@ use crate::monovertex::sink_pb::Status::{Failure, Fallback, Success}; use crate::sink::user_defined::SinkWriter; use crate::source::user_defined::Source; use crate::transformer::user_defined::SourceTransformer; -use chrono::Utc; -use log::warn; -use std::collections::HashMap; -use tokio::task::JoinSet; -use tokio::time::sleep; -use tokio_util::sync::CancellationToken; -use tracing::{debug, info}; /// Forwarder is responsible for reading messages from the source, applying transformation if /// transformer is present, writing the messages to the sink, and then acknowledging the messages @@ -193,26 +193,14 @@ impl Forwarder { // Applies transformation to the messages if transformer is present // we concurrently apply transformation to all the messages. - async fn apply_transformer(&self, messages: Vec) -> error::Result> { - let Some(transformer_client) = &self.source_transformer else { + async fn apply_transformer(&mut self, messages: Vec) -> error::Result> { + let Some(transformer_client) = &mut self.source_transformer else { // return early if there is no transformer return Ok(messages); }; let start_time = tokio::time::Instant::now(); - let mut jh = JoinSet::new(); - for message in messages { - let mut transformer_client = transformer_client.clone(); - jh.spawn(async move { transformer_client.transform_fn(message).await }); - } - - let mut results = Vec::new(); - while let Some(task) = jh.join_next().await { - let result = task.map_err(|e| Error::TransformerError(format!("{:?}", e)))?; - if let Some(result) = result? { - results.extend(result); - } - } + let results = transformer_client.transform_fn(messages).await?; debug!( "Transformer latency - {}ms", diff --git a/rust/numaflow-core/src/transformer/user_defined.rs b/rust/numaflow-core/src/transformer/user_defined.rs index de7b765b79..71a9d24cd6 100644 --- a/rust/numaflow-core/src/transformer/user_defined.rs +++ b/rust/numaflow-core/src/transformer/user_defined.rs @@ -1,67 +1,178 @@ -use crate::error; -use crate::message::Message; -use crate::monovertex::sourcetransform_pb::source_transform_client::SourceTransformClient; -use crate::monovertex::sourcetransform_pb::SourceTransformRequest; -use crate::shared::utils::utc_from_timestamp; +use std::collections::HashMap; + use tonic::transport::Channel; +use tonic::{Request, Streaming}; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; +use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; +use tracing::warn; +use crate::error::{Result, Error}; +use crate::message::{Message, Offset}; +use crate::monovertex::sourcetransform_pb::{self, SourceTransformRequest, SourceTransformResponse, source_transform_client::SourceTransformClient}; +use crate::shared::utils::utc_from_timestamp; +use crate::config::config; const DROP: &str = "U+005C__DROP__"; /// TransformerClient is a client to interact with the transformer server. -#[derive(Clone)] pub struct SourceTransformer { - client: SourceTransformClient, + read_tx: mpsc::Sender, + resp_stream: Streaming, } impl SourceTransformer { - pub(crate) async fn new(client: SourceTransformClient) -> error::Result { - Ok(Self { client }) - } + pub(crate) async fn new(mut client: SourceTransformClient) -> Result { + let (read_tx, read_rx) = mpsc::channel(config().batch_size as usize); + let read_stream = ReceiverStream::new(read_rx); - pub(crate) async fn transform_fn( - &mut self, - message: Message, - ) -> error::Result>> { - // fields which will not be changed - let offset = message.offset.clone(); - let id = message.id.clone(); - let headers = message.headers.clone(); - - // TODO: is this complex? the reason to do this is, tomorrow when we have the normal - // Pipeline CRD, we can require the Into trait. - let response = self - .client - .source_transform_fn(>::into(message)) + // do a handshake for read with the server before we start sending read requests + let handshake_request = SourceTransformRequest { + request: None, + handshake: Some(sourcetransform_pb::Handshake { sot: true }), + }; + read_tx.send(handshake_request).await.map_err(|e| { + Error::TransformerError(format!("failed to send handshake request: {}", e)) + })?; + + let mut resp_stream = client + .source_transform_fn(Request::new(read_stream)) .await? .into_inner(); - let mut messages = Vec::new(); - for result in response.results { - // if the message is tagged with DROP, we will not forward it. - if result.tags.contains(&DROP.to_string()) { - return Ok(None); + // first response from the server will be the handshake response. We need to check if the + // server has accepted the handshake. + let handshake_response = resp_stream.message().await?.ok_or(Error::TransformerError( + "failed to receive handshake response".to_string(), + ))?; + // handshake cannot to None during the initial phase and it has to set `sot` to true. + if handshake_response.handshake.map_or(true, |h| !h.sot) { + return Err(Error::TransformerError( + "invalid handshake response".to_string(), + )); + } + + Ok(Self { + read_tx, + resp_stream, + }) + } + + pub(crate) async fn transform_fn(&mut self, messages: Vec) -> Result> { + // fields which will not be changed + struct MessageInfo { + offset: Offset, + headers: HashMap, + } + + let mut tracker: HashMap = HashMap::with_capacity(messages.len()); + for message in &messages { + tracker.insert( + message.id.clone(), + MessageInfo { + offset: message.offset.clone(), + headers: message.headers.clone(), + }, + ); + } + + // Cancellation token is used to cancel either sending task (if an error occurs while receiving) or receiving messages (if an error occurs on sending task) + let token = CancellationToken::new(); + + // Send transform requests to the source transformer server + let sender_task: JoinHandle> = tokio::spawn({ + let read_tx = self.read_tx.clone(); + let token = token.clone(); + async move { + for msg in messages { + let result = tokio::select! { + result = read_tx.send(msg.into()) => result, + _ = token.cancelled() => { + warn!("Cancellation token was cancelled while sending source transform requests"); + return Ok(()); + }, + }; + + match result { + Ok(()) => continue, + Err(e) => { + token.cancel(); + return Err(Error::TransformerError(e.to_string())); + } + }; + } + Ok(()) } - let message = Message { - keys: result.keys, - value: result.value, - offset: offset.clone(), - id: id.clone(), - event_time: utc_from_timestamp(result.event_time), - headers: headers.clone(), + }); + + // Receive transformer results + let mut messages = Vec::new(); + while !tracker.is_empty() { + let resp = tokio::select! { + _ = token.cancelled() => { + break; + }, + resp = self.resp_stream.message() => {resp} + }; + + let resp = match resp { + Ok(Some(val)) => val, + Ok(None) => { + // Logging at warning level since we don't expect this to happen + warn!("Source transformer server closed its sending end of the stream. No more messages to receive"); + token.cancel(); + break; + } + Err(e) => { + token.cancel(); + return Err(Error::TransformerError(format!( + "gRPC error while receiving messages from source transformer server: {e:?}" + ))); + } + }; + + let Some((msg_id, msg_info)) = tracker.remove_entry(&resp.id) else { + token.cancel(); + return Err(Error::TransformerError(format!( + "Received message with unknown ID {}", + resp.id + ))); }; - messages.push(message); + + for (i, result) in resp.results.into_iter().enumerate() { + // TODO: Expose metrics + if result.tags.iter().any(|x| x == DROP) { + continue; + } + let message = Message { + id: format!("{}-{}", msg_id, i), + keys: result.keys, + value: result.value, + offset: msg_info.offset.clone(), + event_time: utc_from_timestamp(result.event_time), + headers: msg_info.headers.clone(), + }; + messages.push(message); + } } - Ok(Some(messages)) + sender_task.await.unwrap().map_err(|e| { + Error::TransformerError(format!( + "Sending messages to gRPC transformer failed: {e:?}", + )) + })?; + + Ok(messages) } } #[cfg(test)] mod tests { use std::error::Error; + use std::time::Duration; - use crate::monovertex::sourcetransform_pb::source_transform_client::SourceTransformClient; use crate::shared::utils::create_rpc_channel; + use crate::transformer::user_defined::sourcetransform_pb::source_transform_client::SourceTransformClient; use crate::transformer::user_defined::SourceTransformer; use numaflow::sourcetransform; use tempfile::TempDir; @@ -105,7 +216,7 @@ mod tests { let mut client = SourceTransformer::new(SourceTransformClient::new( create_rpc_channel(sock_file).await?, )) - .await?; + .await?; let message = crate::message::Message { keys: vec!["first".into()], @@ -115,18 +226,29 @@ mod tests { offset: "0".into(), }, event_time: chrono::Utc::now(), - id: "".to_string(), + id: "1".to_string(), headers: Default::default(), }; - let resp = client.transform_fn(message).await?; - assert!(resp.is_some()); - assert_eq!(resp.unwrap().len(), 1); + let resp = tokio::time::timeout( + tokio::time::Duration::from_secs(2), + client.transform_fn(vec![message]), + ) + .await??; + assert_eq!(resp.len(), 1); + + // we need to drop the client, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(client); shutdown_tx .send(()) .expect("failed to send shutdown signal"); - handle.await.expect("failed to join server task"); + tokio::time::sleep(Duration::from_millis(50)).await; + assert!( + handle.is_finished(), + "Expected gRPC server to have shut down" + ); Ok(()) } @@ -169,7 +291,7 @@ mod tests { let mut client = SourceTransformer::new(SourceTransformClient::new( create_rpc_channel(sock_file).await?, )) - .await?; + .await?; let message = crate::message::Message { keys: vec!["second".into()], @@ -183,8 +305,12 @@ mod tests { headers: Default::default(), }; - let resp = client.transform_fn(message).await?; - assert!(resp.is_none()); + let resp = client.transform_fn(vec![message]).await?; + assert!(resp.is_empty()); + + // we need to drop the client, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(client); shutdown_tx .send(()) @@ -192,4 +318,4 @@ mod tests { handle.await.expect("failed to join server task"); Ok(()) } -} +} \ No newline at end of file diff --git a/rust/servesink/Cargo.toml b/rust/servesink/Cargo.toml index a9a768ac6c..76de0e491d 100644 --- a/rust/servesink/Cargo.toml +++ b/rust/servesink/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] tonic = "0.12.0" tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } -numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "0c1682864a4b906fab52e149cfd7cacc679ce688" } +numaflow = { git = "https://github.com/BulkBeing/numaflow-rs.git", rev = "6eb7f3865d42a8ab11ade51622dc4d8feda25b5e" } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } From 8c96ca79f8a26e9c32e44f857d91b32641a7b57d Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Mon, 30 Sep 2024 13:55:27 +0530 Subject: [PATCH 02/10] Merge changes from main Signed-off-by: Sreekanth --- Makefile | 2 +- go.mod | 4 +- go.sum | 8 +- hack/generate-proto.sh | 7 +- pkg/apis/proto/daemon/daemon_grpc.pb.go | 25 +- .../proto/mvtxdaemon/mvtxdaemon_grpc.pb.go | 16 +- .../sourcetransform/v1/sourcetransform.proto | 30 +- pkg/isb/tracker/message_tracker.go | 56 ++ .../tracker/message_tracker_test.go} | 30 +- pkg/sdkclient/sourcetransformer/client.go | 136 ++++- .../sourcetransformer/client_test.go | 171 ++++-- pkg/sdkclient/sourcetransformer/interface.go | 2 +- .../forward/applier/sourcetransformer.go | 8 +- pkg/sources/forward/data_forward.go | 67 +- pkg/sources/forward/data_forward_test.go | 91 ++- pkg/sources/forward/shutdown_test.go | 12 +- pkg/sources/source.go | 2 +- pkg/sources/transformer/grpc_transformer.go | 181 +++--- .../transformer/grpc_transformer_test.go | 575 +++++------------- pkg/udf/forward/forward.go | 2 +- pkg/udf/rpc/grpc_batch_map.go | 33 +- pkg/udf/rpc/tracker.go | 75 --- pkg/webhook/validator/validator.go | 5 +- test/transformer-e2e/transformer_test.go | 27 +- 24 files changed, 751 insertions(+), 814 deletions(-) create mode 100644 pkg/isb/tracker/message_tracker.go rename pkg/{udf/rpc/tracker_test.go => isb/tracker/message_tracker_test.go} (53%) delete mode 100644 pkg/udf/rpc/tracker.go diff --git a/Makefile b/Makefile index a4bc2012bc..11d91c5890 100644 --- a/Makefile +++ b/Makefile @@ -244,7 +244,7 @@ manifests: crds kubectl kustomize config/extensions/webhook > config/validating-webhook-install.yaml $(GOPATH)/bin/golangci-lint: - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b `go env GOPATH`/bin v1.54.1 + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b `go env GOPATH`/bin v1.61.0 .PHONY: lint lint: $(GOPATH)/bin/golangci-lint diff --git a/go.mod b/go.mod index c2d7d6edd5..8cc5299772 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe github.com/nats-io/nats-server/v2 v2.10.20 github.com/nats-io/nats.go v1.37.0 - github.com/numaproj/numaflow-go v0.8.2-0.20240923064822-e16694a878d0 + github.com/numaproj/numaflow-go v0.8.2-0.20240930081452-bd8cc005573a github.com/prometheus/client_golang v1.18.0 github.com/prometheus/client_model v0.5.0 github.com/prometheus/common v0.45.0 @@ -55,7 +55,7 @@ require ( golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117 google.golang.org/grpc v1.66.0 - google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.3.0 + google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.4.0 google.golang.org/protobuf v1.34.2 k8s.io/api v0.29.2 k8s.io/apimachinery v0.29.2 diff --git a/go.sum b/go.sum index b17e994439..bd6f0692f9 100644 --- a/go.sum +++ b/go.sum @@ -485,8 +485,8 @@ github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDm github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/numaproj/numaflow-go v0.8.2-0.20240923064822-e16694a878d0 h1:qPqZfJdPdsz4qymyzMSNICQe/xBnx9P/G3hRbC1DR7k= -github.com/numaproj/numaflow-go v0.8.2-0.20240923064822-e16694a878d0/go.mod h1:g4JZOyUPhjfhv+kR0sX5d8taw/dasgKPXLvQBi39mJ4= +github.com/numaproj/numaflow-go v0.8.2-0.20240930081452-bd8cc005573a h1:xbpsfHFjZFsm99bC6x9/plMDIBIEkdUt4J/EMiEifrg= +github.com/numaproj/numaflow-go v0.8.2-0.20240930081452-bd8cc005573a/go.mod h1:FaCMeV0V9SiLcVf2fwT+GeTJHNaK2gdQsTAIqQ4x7oc= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= @@ -1049,8 +1049,8 @@ google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.66.0 h1:DibZuoBznOxbDQxRINckZcUvnCEvrW9pcWIE2yF9r1c= google.golang.org/grpc v1.66.0/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y= -google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.3.0 h1:rNBFJjBCOgVr9pWD7rs/knKL4FRTKgpZmsRfV214zcA= -google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.3.0/go.mod h1:Dk1tviKTvMCz5tvh7t+fh94dhmQVHuCt2OzJB3CTW9Y= +google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.4.0 h1:9SxA29VM43MF5Z9dQu694wmY5t8E/Gxr7s+RSxiIDmc= +google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.4.0/go.mod h1:yZOK5zhQMiALmuweVdIVoQPa6eIJyXn2B9g5dJDhqX4= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/hack/generate-proto.sh b/hack/generate-proto.sh index bf970ce318..7d9f19cb67 100755 --- a/hack/generate-proto.sh +++ b/hack/generate-proto.sh @@ -22,11 +22,14 @@ install-protobuf() { ARCH=$(uname_arch) echo "OS: $OS ARCH: $ARCH" + if [[ "$ARCH" = "amd64" ]]; then + ARCH="x86_64" + elif [[ "$ARCH" = "arm64" ]]; then + ARCH="aarch_64" + fi BINARY_URL=$PB_REL/download/v${PROTOBUF_VERSION}/protoc-${PROTOBUF_VERSION}-${OS}-${ARCH}.zip if [[ "$OS" = "darwin" ]]; then BINARY_URL=$PB_REL/download/v${PROTOBUF_VERSION}/protoc-${PROTOBUF_VERSION}-osx-universal_binary.zip - elif [[ "$OS" = "linux" ]]; then - BINARY_URL=$PB_REL/download/v${PROTOBUF_VERSION}/protoc-${PROTOBUF_VERSION}-linux-x86_64.zip fi echo "Downloading $BINARY_URL" diff --git a/pkg/apis/proto/daemon/daemon_grpc.pb.go b/pkg/apis/proto/daemon/daemon_grpc.pb.go index 61e15a2a62..6b348d8fdf 100644 --- a/pkg/apis/proto/daemon/daemon_grpc.pb.go +++ b/pkg/apis/proto/daemon/daemon_grpc.pb.go @@ -15,7 +15,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.3.0 +// - protoc-gen-go-grpc v1.4.0 // - protoc v5.27.2 // source: pkg/apis/proto/daemon/daemon.proto @@ -30,8 +30,8 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 +// Requires gRPC-Go v1.62.0 or later. +const _ = grpc.SupportPackageIsVersion8 const ( DaemonService_ListBuffers_FullMethodName = "/daemon.DaemonService/ListBuffers" @@ -44,6 +44,8 @@ const ( // DaemonServiceClient is the client API for DaemonService service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// DaemonService is a grpc service that is used to provide APIs for giving any pipeline information. type DaemonServiceClient interface { ListBuffers(ctx context.Context, in *ListBuffersRequest, opts ...grpc.CallOption) (*ListBuffersResponse, error) GetBuffer(ctx context.Context, in *GetBufferRequest, opts ...grpc.CallOption) (*GetBufferResponse, error) @@ -62,8 +64,9 @@ func NewDaemonServiceClient(cc grpc.ClientConnInterface) DaemonServiceClient { } func (c *daemonServiceClient) ListBuffers(ctx context.Context, in *ListBuffersRequest, opts ...grpc.CallOption) (*ListBuffersResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListBuffersResponse) - err := c.cc.Invoke(ctx, DaemonService_ListBuffers_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_ListBuffers_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -71,8 +74,9 @@ func (c *daemonServiceClient) ListBuffers(ctx context.Context, in *ListBuffersRe } func (c *daemonServiceClient) GetBuffer(ctx context.Context, in *GetBufferRequest, opts ...grpc.CallOption) (*GetBufferResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetBufferResponse) - err := c.cc.Invoke(ctx, DaemonService_GetBuffer_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetBuffer_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -80,8 +84,9 @@ func (c *daemonServiceClient) GetBuffer(ctx context.Context, in *GetBufferReques } func (c *daemonServiceClient) GetVertexMetrics(ctx context.Context, in *GetVertexMetricsRequest, opts ...grpc.CallOption) (*GetVertexMetricsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetVertexMetricsResponse) - err := c.cc.Invoke(ctx, DaemonService_GetVertexMetrics_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetVertexMetrics_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -89,8 +94,9 @@ func (c *daemonServiceClient) GetVertexMetrics(ctx context.Context, in *GetVerte } func (c *daemonServiceClient) GetPipelineWatermarks(ctx context.Context, in *GetPipelineWatermarksRequest, opts ...grpc.CallOption) (*GetPipelineWatermarksResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetPipelineWatermarksResponse) - err := c.cc.Invoke(ctx, DaemonService_GetPipelineWatermarks_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetPipelineWatermarks_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -98,8 +104,9 @@ func (c *daemonServiceClient) GetPipelineWatermarks(ctx context.Context, in *Get } func (c *daemonServiceClient) GetPipelineStatus(ctx context.Context, in *GetPipelineStatusRequest, opts ...grpc.CallOption) (*GetPipelineStatusResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetPipelineStatusResponse) - err := c.cc.Invoke(ctx, DaemonService_GetPipelineStatus_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetPipelineStatus_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -109,6 +116,8 @@ func (c *daemonServiceClient) GetPipelineStatus(ctx context.Context, in *GetPipe // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility +// +// DaemonService is a grpc service that is used to provide APIs for giving any pipeline information. type DaemonServiceServer interface { ListBuffers(context.Context, *ListBuffersRequest) (*ListBuffersResponse, error) GetBuffer(context.Context, *GetBufferRequest) (*GetBufferResponse, error) diff --git a/pkg/apis/proto/mvtxdaemon/mvtxdaemon_grpc.pb.go b/pkg/apis/proto/mvtxdaemon/mvtxdaemon_grpc.pb.go index 33f0b26d6b..76477c3de0 100644 --- a/pkg/apis/proto/mvtxdaemon/mvtxdaemon_grpc.pb.go +++ b/pkg/apis/proto/mvtxdaemon/mvtxdaemon_grpc.pb.go @@ -15,7 +15,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.3.0 +// - protoc-gen-go-grpc v1.4.0 // - protoc v5.27.2 // source: pkg/apis/proto/mvtxdaemon/mvtxdaemon.proto @@ -31,8 +31,8 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 +// Requires gRPC-Go v1.62.0 or later. +const _ = grpc.SupportPackageIsVersion8 const ( MonoVertexDaemonService_GetMonoVertexMetrics_FullMethodName = "/mvtxdaemon.MonoVertexDaemonService/GetMonoVertexMetrics" @@ -42,6 +42,8 @@ const ( // MonoVertexDaemonServiceClient is the client API for MonoVertexDaemonService service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// MonoVertexDaemonService is a grpc service that is used to provide APIs for giving any MonoVertex information. type MonoVertexDaemonServiceClient interface { GetMonoVertexMetrics(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*GetMonoVertexMetricsResponse, error) GetMonoVertexStatus(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*GetMonoVertexStatusResponse, error) @@ -56,8 +58,9 @@ func NewMonoVertexDaemonServiceClient(cc grpc.ClientConnInterface) MonoVertexDae } func (c *monoVertexDaemonServiceClient) GetMonoVertexMetrics(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*GetMonoVertexMetricsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetMonoVertexMetricsResponse) - err := c.cc.Invoke(ctx, MonoVertexDaemonService_GetMonoVertexMetrics_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, MonoVertexDaemonService_GetMonoVertexMetrics_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -65,8 +68,9 @@ func (c *monoVertexDaemonServiceClient) GetMonoVertexMetrics(ctx context.Context } func (c *monoVertexDaemonServiceClient) GetMonoVertexStatus(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*GetMonoVertexStatusResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetMonoVertexStatusResponse) - err := c.cc.Invoke(ctx, MonoVertexDaemonService_GetMonoVertexStatus_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, MonoVertexDaemonService_GetMonoVertexStatus_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -76,6 +80,8 @@ func (c *monoVertexDaemonServiceClient) GetMonoVertexStatus(ctx context.Context, // MonoVertexDaemonServiceServer is the server API for MonoVertexDaemonService service. // All implementations must embed UnimplementedMonoVertexDaemonServiceServer // for forward compatibility +// +// MonoVertexDaemonService is a grpc service that is used to provide APIs for giving any MonoVertex information. type MonoVertexDaemonServiceServer interface { GetMonoVertexMetrics(context.Context, *emptypb.Empty) (*GetMonoVertexMetricsResponse, error) GetMonoVertexStatus(context.Context, *emptypb.Empty) (*GetMonoVertexStatusResponse, error) diff --git a/pkg/apis/proto/sourcetransform/v1/sourcetransform.proto b/pkg/apis/proto/sourcetransform/v1/sourcetransform.proto index b93d82b9a8..740ae1c671 100644 --- a/pkg/apis/proto/sourcetransform/v1/sourcetransform.proto +++ b/pkg/apis/proto/sourcetransform/v1/sourcetransform.proto @@ -28,21 +28,35 @@ service SourceTransform { // SourceTransformFn applies a function to each request element. // In addition to map function, SourceTransformFn also supports assigning a new event time to response. // SourceTransformFn can be used only at source vertex by source data transformer. - rpc SourceTransformFn(SourceTransformRequest) returns (SourceTransformResponse); + rpc SourceTransformFn(stream SourceTransformRequest) returns (stream SourceTransformResponse); // IsReady is the heartbeat endpoint for gRPC. rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); } +/* + * Handshake message between client and server to indicate the start of transmission. + */ + message Handshake { + // Required field indicating the start of transmission. + bool sot = 1; +} + /** * SourceTransformerRequest represents a request element. */ message SourceTransformRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - map headers = 5; + message Request { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + map headers = 5; + // This ID is used to uniquely identify a transform request + string id = 6; + } + Request request = 1; + optional Handshake handshake = 2; } /** @@ -56,6 +70,10 @@ message SourceTransformResponse { repeated string tags = 4; } repeated Result results = 1; + // This ID is used to refer the responses to the request it corresponds to. + string id = 2; + // Handshake message between client and server to indicate the start of transmission. + optional Handshake handshake = 3; } /** diff --git a/pkg/isb/tracker/message_tracker.go b/pkg/isb/tracker/message_tracker.go new file mode 100644 index 0000000000..dfd608e5bf --- /dev/null +++ b/pkg/isb/tracker/message_tracker.go @@ -0,0 +1,56 @@ +package tracker + +import ( + "sync" + + "github.com/numaproj/numaflow/pkg/isb" +) + +// MessageTracker is used to store a key value pair for string and *ReadMessage +// as it can be accessed by concurrent goroutines, we keep all operations +// under a mutex +type MessageTracker struct { + lock sync.RWMutex + m map[string]*isb.ReadMessage +} + +// NewMessageTracker initializes a new instance of a Tracker +func NewMessageTracker(messages []*isb.ReadMessage) *MessageTracker { + m := make(map[string]*isb.ReadMessage, len(messages)) + for _, msg := range messages { + id := msg.ReadOffset.String() + m[id] = msg + } + return &MessageTracker{ + m: m, + lock: sync.RWMutex{}, + } +} + +// Remove will remove the entry for a given id and return the stored value corresponding to this id. +// A `nil` return value indicates that the id doesn't exist in the tracker. +func (t *MessageTracker) Remove(id string) *isb.ReadMessage { + t.lock.Lock() + defer t.lock.Unlock() + item, ok := t.m[id] + if !ok { + return nil + } + delete(t.m, id) + return item +} + +// IsEmpty is a helper function which checks if the Tracker map is empty +// return true if empty +func (t *MessageTracker) IsEmpty() bool { + t.lock.RLock() + defer t.lock.RUnlock() + return len(t.m) == 0 +} + +// Len returns the number of messages currently stored in the tracker +func (t *MessageTracker) Len() int { + t.lock.RLock() + defer t.lock.RUnlock() + return len(t.m) +} diff --git a/pkg/udf/rpc/tracker_test.go b/pkg/isb/tracker/message_tracker_test.go similarity index 53% rename from pkg/udf/rpc/tracker_test.go rename to pkg/isb/tracker/message_tracker_test.go index 21704f4425..3c2ae767d0 100644 --- a/pkg/udf/rpc/tracker_test.go +++ b/pkg/isb/tracker/message_tracker_test.go @@ -1,4 +1,4 @@ -package rpc +package tracker import ( "testing" @@ -6,32 +6,34 @@ import ( "github.com/stretchr/testify/assert" + "github.com/numaproj/numaflow/pkg/isb" "github.com/numaproj/numaflow/pkg/isb/testutils" ) func TestTracker_AddRequest(t *testing.T) { - tr := NewTracker() readMessages := testutils.BuildTestReadMessages(3, time.Unix(1661169600, 0), nil) - for _, msg := range readMessages { - tr.addRequest(&msg) + messages := make([]*isb.ReadMessage, len(readMessages)) + for i, msg := range readMessages { + messages[i] = &msg } + tr := NewMessageTracker(messages) id := readMessages[0].ReadOffset.String() - m, ok := tr.getRequest(id) - assert.True(t, ok) + m := tr.Remove(id) + assert.NotNil(t, m) assert.Equal(t, readMessages[0], *m) } func TestTracker_RemoveRequest(t *testing.T) { - tr := NewTracker() readMessages := testutils.BuildTestReadMessages(3, time.Unix(1661169600, 0), nil) - for _, msg := range readMessages { - tr.addRequest(&msg) + messages := make([]*isb.ReadMessage, len(readMessages)) + for i, msg := range readMessages { + messages[i] = &msg } + tr := NewMessageTracker(messages) id := readMessages[0].ReadOffset.String() - m, ok := tr.getRequest(id) - assert.True(t, ok) + m := tr.Remove(id) + assert.NotNil(t, m) assert.Equal(t, readMessages[0], *m) - tr.removeRequest(id) - _, ok = tr.getRequest(id) - assert.False(t, ok) + m = tr.Remove(id) + assert.Nil(t, m) } diff --git a/pkg/sdkclient/sourcetransformer/client.go b/pkg/sdkclient/sourcetransformer/client.go index d9d47302c0..7f3327ac5f 100644 --- a/pkg/sdkclient/sourcetransformer/client.go +++ b/pkg/sdkclient/sourcetransformer/client.go @@ -18,6 +18,8 @@ package sourcetransformer import ( "context" + "fmt" + "time" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/emptypb" @@ -28,16 +30,18 @@ import ( sdkerr "github.com/numaproj/numaflow/pkg/sdkclient/error" grpcutil "github.com/numaproj/numaflow/pkg/sdkclient/grpc" "github.com/numaproj/numaflow/pkg/sdkclient/serverinfo" + "github.com/numaproj/numaflow/pkg/shared/logging" ) // client contains the grpc connection and the grpc client. type client struct { conn *grpc.ClientConn grpcClt transformpb.SourceTransformClient + stream transformpb.SourceTransform_SourceTransformFnClient } // New creates a new client object. -func New(serverInfo *serverinfo.ServerInfo, inputOptions ...sdkclient.Option) (Client, error) { +func New(ctx context.Context, serverInfo *serverinfo.ServerInfo, inputOptions ...sdkclient.Option) (Client, error) { var opts = sdkclient.DefaultOptions(sdkclient.SourceTransformerAddr) for _, inputOption := range inputOptions { @@ -53,18 +57,81 @@ func New(serverInfo *serverinfo.ServerInfo, inputOptions ...sdkclient.Option) (C c := new(client) c.conn = conn c.grpcClt = transformpb.NewSourceTransformClient(conn) + + var logger = logging.FromContext(ctx) + +waitUntilReady: + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("waiting for transformer gRPC server to be ready: %w", ctx.Err()) + default: + _, err := c.IsReady(ctx, &emptypb.Empty{}) + if err != nil { + logger.Warnf("Transformer server is not ready: %v", err) + time.Sleep(100 * time.Millisecond) + continue waitUntilReady + } + break waitUntilReady + } + } + + c.stream, err = c.grpcClt.SourceTransformFn(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create a gRPC stream for source transform: %w", err) + } + + if err := doHandshake(c.stream); err != nil { + return nil, err + } + return c, nil } +func doHandshake(stream transformpb.SourceTransform_SourceTransformFnClient) error { + // Send handshake request + handshakeReq := &transformpb.SourceTransformRequest{ + Handshake: &transformpb.Handshake{ + Sot: true, + }, + } + if err := stream.Send(handshakeReq); err != nil { + return fmt.Errorf("failed to send handshake request for source tansform: %w", err) + } + + handshakeResp, err := stream.Recv() + if err != nil { + return fmt.Errorf("failed to receive handshake response from source transform stream: %w", err) + } + if resp := handshakeResp.GetHandshake(); resp == nil || !resp.GetSot() { + return fmt.Errorf("invalid handshake response for source transform. Received='%+v'", resp) + } + return nil +} + // NewFromClient creates a new client object from a grpc client. This is used for testing. -func NewFromClient(c transformpb.SourceTransformClient) (Client, error) { +func NewFromClient(ctx context.Context, c transformpb.SourceTransformClient) (Client, error) { + stream, err := c.SourceTransformFn(ctx) + if err != nil { + return nil, err + } + + if err := doHandshake(stream); err != nil { + return nil, err + } + return &client{ grpcClt: c, + stream: stream, }, nil } // CloseConn closes the grpc client connection. -func (c *client) CloseConn(ctx context.Context) error { +func (c *client) CloseConn(_ context.Context) error { + err := c.stream.CloseSend() + if err != nil { + return err + } if c.conn == nil { return nil } @@ -81,11 +148,60 @@ func (c *client) IsReady(ctx context.Context, in *emptypb.Empty) (bool, error) { } // SourceTransformFn SourceTransformerFn applies a function to each request element. -func (c *client) SourceTransformFn(ctx context.Context, request *transformpb.SourceTransformRequest) (*transformpb.SourceTransformResponse, error) { - transformResponse, err := c.grpcClt.SourceTransformFn(ctx, request) - err = sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn", err) - if err != nil { - return nil, err - } - return transformResponse, nil +// Response channel will not be closed. Caller can select on response and error channel to exit on first error. +func (c *client) SourceTransformFn(ctx context.Context, request <-chan *transformpb.SourceTransformRequest) (<-chan *transformpb.SourceTransformResponse, <-chan error) { + clientErrCh := make(chan error) + responseCh := make(chan *transformpb.SourceTransformResponse) + + // This channel is to send the error from the goroutine that receives messages from the stream to the goroutine that sends requests to the server. + // This ensures that we don't need to use clientErrCh in both goroutines. The caller of this function will only be listening for the first error value in clientErrCh. + // If both goroutines were sending error message to this channel (eg. stream failure), one of them will be stuck in sending can not shutdown cleanly. + errCh := make(chan error, 1) + + logger := logging.FromContext(ctx) + + // Receive responses from the stream + go func() { + for { + resp, err := c.stream.Recv() + if err != nil { + // we don't need an EOF check because we only close the stream during shutdown. + errCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn", err) + close(errCh) + return + } + + select { + case <-ctx.Done(): + logger.Warnf("Context cancelled. Stopping retrieving messages from the stream") + return + case responseCh <- resp: + } + } + }() + + go func() { + for { + select { + case <-ctx.Done(): + clientErrCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", ctx.Err()) + return + case err := <-errCh: + clientErrCh <- err + return + case msg, ok := <-request: + if !ok { + // stream is only closed during shutdown + return + } + err := c.stream.Send(msg) + if err != nil { + clientErrCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", err) + return + } + } + } + }() + + return responseCh, clientErrCh } diff --git a/pkg/sdkclient/sourcetransformer/client_test.go b/pkg/sdkclient/sourcetransformer/client_test.go index 27526312fd..619f9533f8 100644 --- a/pkg/sdkclient/sourcetransformer/client_test.go +++ b/pkg/sdkclient/sourcetransformer/client_test.go @@ -18,80 +18,139 @@ package sourcetransformer import ( "context" + "errors" "fmt" - "reflect" + "net" "testing" + "time" - "github.com/golang/mock/gomock" transformpb "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1" - transformermock "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1/transformmock" + "github.com/numaproj/numaflow-go/pkg/sourcetransformer" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/timestamppb" ) func TestClient_IsReady(t *testing.T) { var ctx = context.Background() + svc := &sourcetransformer.Service{ + Transformer: sourcetransformer.SourceTransformFunc(func(ctx context.Context, keys []string, datum sourcetransformer.Datum) sourcetransformer.Messages { + return sourcetransformer.MessagesBuilder() + }), + } + + // Start the gRPC server + conn := newServer(t, func(server *grpc.Server) { + transformpb.RegisterSourceTransformServer(server, svc) + }) + defer conn.Close() + + // Create a client connection to the server + client := transformpb.NewSourceTransformClient(conn) - ctrl := gomock.NewController(t) - defer ctrl.Finish() + testClient, err := NewFromClient(ctx, client) + require.NoError(t, err) - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - mockClient.EXPECT().IsReady(gomock.Any(), gomock.Any()).Return(&transformpb.ReadyResponse{Ready: true}, nil) - mockClient.EXPECT().IsReady(gomock.Any(), gomock.Any()).Return(&transformpb.ReadyResponse{Ready: false}, fmt.Errorf("mock connection refused")) + ready, err := testClient.IsReady(ctx, &emptypb.Empty{}) + require.True(t, ready) + require.NoError(t, err) +} - testClient, err := NewFromClient(mockClient) - assert.NoError(t, err) - reflect.DeepEqual(testClient, &client{ - grpcClt: mockClient, +func newServer(t *testing.T, register func(server *grpc.Server)) *grpc.ClientConn { + lis := bufconn.Listen(100) + t.Cleanup(func() { + _ = lis.Close() }) - ready, err := testClient.IsReady(ctx, &emptypb.Empty{}) - assert.True(t, ready) - assert.NoError(t, err) + server := grpc.NewServer() + t.Cleanup(func() { + server.Stop() + }) - ready, err = testClient.IsReady(ctx, &emptypb.Empty{}) - assert.False(t, ready) - assert.EqualError(t, err, "mock connection refused") -} + register(server) -func TestClient_SourceTransformFn(t *testing.T) { - var ctx = context.Background() + errChan := make(chan error, 1) + go func() { + // t.Fatal should only be called from the goroutine running the test + if err := server.Serve(lis); err != nil { + errChan <- err + } + }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), gomock.Any()).Return(&transformpb.SourceTransformResponse{Results: []*transformpb.SourceTransformResponse_Result{ - { - Keys: []string{"temp-key"}, - Value: []byte("mock result"), - Tags: nil, - }, - }}, nil) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), gomock.Any()).Return(&transformpb.SourceTransformResponse{Results: []*transformpb.SourceTransformResponse_Result{ - { - Keys: []string{"temp-key"}, - Value: []byte("mock result"), - Tags: nil, - }, - }}, fmt.Errorf("mock connection refused")) - - testClient, err := NewFromClient(mockClient) - assert.NoError(t, err) - reflect.DeepEqual(testClient, &client{ - grpcClt: mockClient, + dialer := func(context.Context, string) (net.Conn, error) { + return lis.Dial() + } + + conn, err := grpc.NewClient("passthrough://", grpc.WithContextDialer(dialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + t.Cleanup(func() { + _ = conn.Close() }) + if err != nil { + t.Fatalf("Creating new gRPC client connection: %v", err) + } + + var grpcServerErr error + select { + case grpcServerErr = <-errChan: + case <-time.After(500 * time.Millisecond): + grpcServerErr = errors.New("gRPC server didn't start in 500ms") + } + if err != nil { + t.Fatalf("Failed to start gRPC server: %v", grpcServerErr) + } + + return conn +} - result, err := testClient.SourceTransformFn(ctx, &transformpb.SourceTransformRequest{}) - assert.Equal(t, &transformpb.SourceTransformResponse{Results: []*transformpb.SourceTransformResponse_Result{ - { - Keys: []string{"temp-key"}, - Value: []byte("mock result"), - Tags: nil, - }, - }}, result) - assert.NoError(t, err) - - _, err = testClient.SourceTransformFn(ctx, &transformpb.SourceTransformRequest{}) - assert.EqualError(t, err, "NonRetryable: mock connection refused") +func TestClient_SourceTransformFn(t *testing.T) { + var testTime = time.Date(2021, 8, 15, 14, 30, 45, 100, time.Local) + svc := &sourcetransformer.Service{ + Transformer: sourcetransformer.SourceTransformFunc(func(ctx context.Context, keys []string, datum sourcetransformer.Datum) sourcetransformer.Messages { + msg := datum.Value() + return sourcetransformer.MessagesBuilder().Append(sourcetransformer.NewMessage(msg, testTime).WithKeys([]string{keys[0] + "_test"})) + }), + } + conn := newServer(t, func(server *grpc.Server) { + transformpb.RegisterSourceTransformServer(server, svc) + }) + transformClient := transformpb.NewSourceTransformClient(conn) + var ctx = context.Background() + client, _ := NewFromClient(ctx, transformClient) + + reqChan := make(chan *transformpb.SourceTransformRequest, 1) + go func() { + for i := 0; i < 5; i++ { + reqChan <- &transformpb.SourceTransformRequest{ + Request: &transformpb.SourceTransformRequest_Request{ + Keys: []string{fmt.Sprintf("client_key_%d", i)}, + Value: []byte("test"), + }, + } + } + }() + + respChan, errChan := client.SourceTransformFn(ctx, reqChan) + var results [][]*transformpb.SourceTransformResponse_Result + for i := 0; i < 5; i++ { + var resp *transformpb.SourceTransformResponse + var err error + select { + case resp = <-respChan: + case err = <-errChan: + } + assert.NoError(t, err) + results = append(results, resp.GetResults()) + } + expected := [][]*transformpb.SourceTransformResponse_Result{ + {{Keys: []string{"client_key_0_test"}, Value: []byte("test"), EventTime: timestamppb.New(testTime)}}, + {{Keys: []string{"client_key_1_test"}, Value: []byte("test"), EventTime: timestamppb.New(testTime)}}, + {{Keys: []string{"client_key_2_test"}, Value: []byte("test"), EventTime: timestamppb.New(testTime)}}, + {{Keys: []string{"client_key_3_test"}, Value: []byte("test"), EventTime: timestamppb.New(testTime)}}, + {{Keys: []string{"client_key_4_test"}, Value: []byte("test"), EventTime: timestamppb.New(testTime)}}, + } + assert.ElementsMatch(t, expected, results) } diff --git a/pkg/sdkclient/sourcetransformer/interface.go b/pkg/sdkclient/sourcetransformer/interface.go index 4d8e3d8f71..6006e380c8 100644 --- a/pkg/sdkclient/sourcetransformer/interface.go +++ b/pkg/sdkclient/sourcetransformer/interface.go @@ -27,5 +27,5 @@ import ( type Client interface { CloseConn(ctx context.Context) error IsReady(ctx context.Context, in *emptypb.Empty) (bool, error) - SourceTransformFn(ctx context.Context, request *transformpb.SourceTransformRequest) (*transformpb.SourceTransformResponse, error) + SourceTransformFn(ctx context.Context, request <-chan *transformpb.SourceTransformRequest) (<-chan *transformpb.SourceTransformResponse, <-chan error) } diff --git a/pkg/sources/forward/applier/sourcetransformer.go b/pkg/sources/forward/applier/sourcetransformer.go index 795cd4c5a2..a935d511ea 100644 --- a/pkg/sources/forward/applier/sourcetransformer.go +++ b/pkg/sources/forward/applier/sourcetransformer.go @@ -25,13 +25,13 @@ import ( // SourceTransformApplier applies the source transform on the read message and gives back a new message. Any UserError will be retried here, while // InternalErr can be returned and could be retried by the callee. type SourceTransformApplier interface { - ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) + ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) } // ApplySourceTransformFunc is a function type that implements SourceTransformApplier interface. -type ApplySourceTransformFunc func(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) +type ApplySourceTransformFunc func(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) // ApplyTransform implements SourceTransformApplier interface. -func (f ApplySourceTransformFunc) ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) { - return f(ctx, message) +func (f ApplySourceTransformFunc) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + return f(ctx, messages) } diff --git a/pkg/sources/forward/data_forward.go b/pkg/sources/forward/data_forward.go index 48ff97da3a..9eddafa30d 100644 --- a/pkg/sources/forward/data_forward.go +++ b/pkg/sources/forward/data_forward.go @@ -305,34 +305,14 @@ func (df *DataForward) forwardAChunk(ctx context.Context) { // If a user-defined transformer exists, apply it if df.opts.transformer != nil { - // user-defined transformer concurrent processing request channel - transformerCh := make(chan *isb.ReadWriteMessagePair) - - // create a pool of Transformer Processors - var wg sync.WaitGroup - for i := 0; i < df.opts.transformerConcurrency; i++ { - wg.Add(1) - go func() { - defer wg.Done() - df.concurrentApplyTransformer(ctx, transformerCh) - }() + for _, m := range readMessages { + // assign watermark to the message + m.Watermark = time.Time(processorWM) } concurrentTransformerProcessingStart := time.Now() - for idx, m := range readMessages { + readWriteMessagePairs = df.applyTransformer(ctx, readMessages) - // assign watermark to the message - m.Watermark = time.Time(processorWM) - readWriteMessagePairs[idx].ReadMessage = m - // send transformer processing work to the channel. Thus, the results of the transformer - // application on a read message will be stored as the corresponding writeMessage in readWriteMessagePairs - transformerCh <- &readWriteMessagePairs[idx] - } - // let the go routines know that there is no more work - close(transformerCh) - // wait till the processing is done. this will not be an infinite wait because the transformer processing will exit if - // context.Done() is closed. - wg.Wait() df.opts.logger.Debugw("concurrent applyTransformer completed", zap.Int("concurrency", df.opts.transformerConcurrency), zap.Duration("took", time.Since(concurrentTransformerProcessingStart)), @@ -661,42 +641,12 @@ func (df *DataForward) writeToBuffer(ctx context.Context, toBufferPartition isb. return writeOffsets, nil } -// concurrentApplyTransformer applies the transformer based on the request from the channel -func (df *DataForward) concurrentApplyTransformer(ctx context.Context, readMessagePair <-chan *isb.ReadWriteMessagePair) { - for message := range readMessagePair { - start := time.Now() - metrics.SourceTransformerReadMessagesCount.With(map[string]string{ - metrics.LabelVertex: df.vertexName, - metrics.LabelPipeline: df.pipelineName, - metrics.LabelVertexReplicaIndex: strconv.Itoa(int(df.vertexReplica)), - metrics.LabelPartitionName: df.reader.GetName(), - }).Inc() - - writeMessages, err := df.applyTransformer(ctx, message.ReadMessage) - metrics.SourceTransformerWriteMessagesCount.With(map[string]string{ - metrics.LabelVertex: df.vertexName, - metrics.LabelPipeline: df.pipelineName, - metrics.LabelVertexReplicaIndex: strconv.Itoa(int(df.vertexReplica)), - metrics.LabelPartitionName: df.reader.GetName(), - }).Add(float64(len(writeMessages))) - - message.WriteMessages = append(message.WriteMessages, writeMessages...) - message.Err = err - metrics.SourceTransformerProcessingTime.With(map[string]string{ - metrics.LabelVertex: df.vertexName, - metrics.LabelPipeline: df.pipelineName, - metrics.LabelVertexReplicaIndex: strconv.Itoa(int(df.vertexReplica)), - metrics.LabelPartitionName: df.reader.GetName(), - }).Observe(float64(time.Since(start).Microseconds())) - } -} - // applyTransformer applies the transformer and will block if there is any InternalErr. On the other hand, if this is a UserError // the skip flag is set. The ShutDown flag will only if there is an InternalErr and ForceStop has been invoked. // The UserError retry will be done on the applyTransformer. -func (df *DataForward) applyTransformer(ctx context.Context, readMessage *isb.ReadMessage) ([]*isb.WriteMessage, error) { +func (df *DataForward) applyTransformer(ctx context.Context, messages []*isb.ReadMessage) []isb.ReadWriteMessagePair { for { - writeMessages, err := df.opts.transformer.ApplyTransform(ctx, readMessage) + transformResults, err := df.opts.transformer.ApplyTransform(ctx, messages) if err != nil { df.opts.logger.Errorw("Transformer.Apply error", zap.Error(err)) // TODO: implement retry with backoff etc. @@ -712,12 +662,11 @@ func (df *DataForward) applyTransformer(ctx context.Context, readMessage *isb.Re metrics.LabelVertexType: string(dfv1.VertexTypeSource), metrics.LabelVertexReplicaIndex: strconv.Itoa(int(df.vertexReplica)), }).Inc() - - return nil, err + return []isb.ReadWriteMessagePair{{Err: err}} } continue } - return writeMessages, nil + return transformResults } } diff --git a/pkg/sources/forward/data_forward_test.go b/pkg/sources/forward/data_forward_test.go index 25e41a9fa6..96cb6760e6 100644 --- a/pkg/sources/forward/data_forward_test.go +++ b/pkg/sources/forward/data_forward_test.go @@ -121,8 +121,16 @@ func (f myForwardTest) WhereTo(_ []string, _ []string, s string) ([]forwarder.Ve }}, nil } -func (f myForwardTest) ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) { - return testutils.CopyUDFTestApply(ctx, "test-vertex", message) +func (f myForwardTest) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + out := make([]isb.ReadWriteMessagePair, len(messages)) + for i, msg := range messages { + writeMsg, _ := testutils.CopyUDFTestApply(ctx, "test-vertex", msg) + out[i] = isb.ReadWriteMessagePair{ + ReadMessage: msg, + WriteMessages: writeMsg, + } + } + return out, nil } func TestNewDataForward(t *testing.T) { @@ -856,36 +864,31 @@ func (f *mySourceForwardTestRoundRobin) WhereTo(_ []string, _ []string, s string // such that we can verify message IsLate attribute gets set to true. var testSourceNewEventTime = testSourceWatermark.Add(time.Duration(-1) * time.Minute) -func (f mySourceForwardTest) ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) { - return func(ctx context.Context, readMessage *isb.ReadMessage) ([]*isb.WriteMessage, error) { - _ = ctx - offset := readMessage.ReadOffset - payload := readMessage.Body.Payload - parentPaneInfo := readMessage.MessageInfo - - // apply source data transformer - _ = payload - // copy the payload - result := payload - // assign new event time - parentPaneInfo.EventTime = testSourceNewEventTime - var key []string - - writeMessage := isb.Message{ +func (f mySourceForwardTest) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + results := make([]isb.ReadWriteMessagePair, len(messages)) + for i, message := range messages { + message.MessageInfo.EventTime = testSourceNewEventTime + writeMsg := isb.Message{ Header: isb.Header{ - MessageInfo: parentPaneInfo, + MessageInfo: message.MessageInfo, ID: isb.MessageID{ VertexName: "test-vertex", - Offset: offset.String(), + Offset: message.ReadOffset.String(), }, - Keys: key, + Keys: []string{}, }, Body: isb.Body{ - Payload: result, + Payload: message.Body.Payload, }, } - return []*isb.WriteMessage{{Message: writeMessage}}, nil - }(ctx, message) + results[i] = isb.ReadWriteMessagePair{ + ReadMessage: message, + WriteMessages: []*isb.WriteMessage{{ + Message: writeMsg, + }}, + } + } + return results, nil } // TestSourceWatermarkPublisher is a dummy implementation of isb.SourceWatermarkPublisher interface @@ -1153,8 +1156,16 @@ func (f myForwardDropTest) WhereTo(_ []string, _ []string, s string) ([]forwarde return []forwarder.VertexBuffer{}, nil } -func (f myForwardDropTest) ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) { - return testutils.CopyUDFTestApply(ctx, "test-vertex", message) +func (f myForwardDropTest) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + results := make([]isb.ReadWriteMessagePair, len(messages)) + for i, message := range messages { + writeMsg, _ := testutils.CopyUDFTestApply(ctx, "test-vertex", message) + results[i] = isb.ReadWriteMessagePair{ + ReadMessage: message, + WriteMessages: writeMsg, + } + } + return results, nil } type myForwardToAllTest struct { @@ -1174,8 +1185,16 @@ func (f *myForwardToAllTest) WhereTo(_ []string, _ []string, s string) ([]forwar return output, nil } -func (f *myForwardToAllTest) ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) { - return testutils.CopyUDFTestApply(ctx, "test-vertex", message) +func (f *myForwardToAllTest) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + results := make([]isb.ReadWriteMessagePair, len(messages)) + for i, message := range messages { + writeMsg, _ := testutils.CopyUDFTestApply(ctx, "test-vertex", message) + results[i] = isb.ReadWriteMessagePair{ + ReadMessage: message, + WriteMessages: writeMsg, + } + } + return results, nil } type myForwardInternalErrTest struct { @@ -1188,7 +1207,7 @@ func (f myForwardInternalErrTest) WhereTo(_ []string, _ []string, s string) ([]f }}, nil } -func (f myForwardInternalErrTest) ApplyTransform(_ context.Context, _ *isb.ReadMessage) ([]*isb.WriteMessage, error) { +func (f myForwardInternalErrTest) ApplyTransform(ctx context.Context, _ []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { return nil, &udfapplier.ApplyUDFErr{ UserUDFErr: false, InternalErr: struct { @@ -1209,8 +1228,16 @@ func (f myForwardApplyWhereToErrTest) WhereTo(_ []string, _ []string, s string) }}, fmt.Errorf("whereToStep failed") } -func (f myForwardApplyWhereToErrTest) ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) { - return testutils.CopyUDFTestApply(ctx, "test-vertex", message) +func (f myForwardApplyWhereToErrTest) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + results := make([]isb.ReadWriteMessagePair, len(messages)) + for i, message := range messages { + writeMsg, _ := testutils.CopyUDFTestApply(ctx, "test-vertex", message) + results[i] = isb.ReadWriteMessagePair{ + ReadMessage: message, + WriteMessages: writeMsg, + } + } + return results, nil } type myForwardApplyTransformerErrTest struct { @@ -1223,7 +1250,7 @@ func (f myForwardApplyTransformerErrTest) WhereTo(_ []string, _ []string, s stri }}, nil } -func (f myForwardApplyTransformerErrTest) ApplyTransform(_ context.Context, _ *isb.ReadMessage) ([]*isb.WriteMessage, error) { +func (f myForwardApplyTransformerErrTest) ApplyTransform(_ context.Context, _ []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { return nil, fmt.Errorf("transformer error") } diff --git a/pkg/sources/forward/shutdown_test.go b/pkg/sources/forward/shutdown_test.go index a4ffc5e2e2..34003e729f 100644 --- a/pkg/sources/forward/shutdown_test.go +++ b/pkg/sources/forward/shutdown_test.go @@ -43,8 +43,16 @@ func (s myShutdownTest) WhereTo([]string, []string, string) ([]forwarder.VertexB return []forwarder.VertexBuffer{}, nil } -func (s myShutdownTest) ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) { - return testutils.CopyUDFTestApply(ctx, "", message) +func (f myShutdownTest) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + results := make([]isb.ReadWriteMessagePair, len(messages)) + for i, message := range messages { + writeMsg, _ := testutils.CopyUDFTestApply(ctx, "", message) + results[i] = isb.ReadWriteMessagePair{ + ReadMessage: message, + WriteMessages: writeMsg, + } + } + return results, nil } func (s myShutdownTest) ApplyMapStream(ctx context.Context, message *isb.ReadMessage, writeMessageCh chan<- isb.WriteMessage) error { diff --git a/pkg/sources/source.go b/pkg/sources/source.go index 0b3e23a94b..69bc0c0099 100644 --- a/pkg/sources/source.go +++ b/pkg/sources/source.go @@ -240,7 +240,7 @@ func (sp *SourceProcessor) Start(ctx context.Context) error { return err } - srcTransformerClient, err := sourcetransformer.New(serverInfo, sdkclient.WithMaxMessageSize(maxMessageSize)) + srcTransformerClient, err := sourcetransformer.New(ctx, serverInfo, sdkclient.WithMaxMessageSize(maxMessageSize)) if err != nil { return fmt.Errorf("failed to create transformer gRPC client, %w", err) } diff --git a/pkg/sources/transformer/grpc_transformer.go b/pkg/sources/transformer/grpc_transformer.go index 14b414a348..e578687fb3 100644 --- a/pkg/sources/transformer/grpc_transformer.go +++ b/pkg/sources/transformer/grpc_transformer.go @@ -18,19 +18,18 @@ package transformer import ( "context" + "errors" "fmt" "time" v1 "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1" - "google.golang.org/protobuf/types/known/emptypb" - "google.golang.org/protobuf/types/known/timestamppb" - "k8s.io/apimachinery/pkg/util/wait" - "github.com/numaproj/numaflow/pkg/isb" - sdkerr "github.com/numaproj/numaflow/pkg/sdkclient/error" + "github.com/numaproj/numaflow/pkg/isb/tracker" "github.com/numaproj/numaflow/pkg/sdkclient/sourcetransformer" "github.com/numaproj/numaflow/pkg/shared/logging" "github.com/numaproj/numaflow/pkg/udf/rpc" + "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/timestamppb" ) // GRPCBasedTransformer applies user-defined transformer over gRPC (over Unix Domain Socket) client/server where server is the transformer. @@ -75,68 +74,41 @@ func (u *GRPCBasedTransformer) CloseConn(ctx context.Context) error { return u.client.CloseConn(ctx) } -func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, readMessage *isb.ReadMessage) ([]*isb.WriteMessage, error) { - keys := readMessage.Keys - payload := readMessage.Body.Payload - offset := readMessage.ReadOffset - parentMessageInfo := readMessage.MessageInfo - var req = &v1.SourceTransformRequest{ - Keys: keys, - Value: payload, - EventTime: timestamppb.New(parentMessageInfo.EventTime), - Watermark: timestamppb.New(readMessage.Watermark), - Headers: readMessage.Headers, - } +var errSourceTransformFnEmptyMsgId = errors.New("response from SourceTransformFn doesn't contain a message id") - response, err := u.client.SourceTransformFn(ctx, req) - if err != nil { - udfErr, _ := sdkerr.FromError(err) - switch udfErr.ErrorKind() { - case sdkerr.Retryable: - var success bool - _ = wait.ExponentialBackoffWithContext(ctx, wait.Backoff{ - // retry every "duration * factor + [0, jitter]" interval for 5 times - Duration: 1 * time.Second, - Factor: 1, - Jitter: 0.1, - Steps: 5, - }, func(_ context.Context) (done bool, err error) { - response, err = u.client.SourceTransformFn(ctx, req) - if err != nil { - udfErr, _ = sdkerr.FromError(err) - switch udfErr.ErrorKind() { - case sdkerr.Retryable: - return false, nil - case sdkerr.NonRetryable: - return true, nil - default: - return true, nil - } - } - success = true - return true, nil - }) - if !success { - return nil, &rpc.ApplyUDFErr{ - UserUDFErr: false, - Message: fmt.Sprintf("gRPC client.SourceTransformFn failed, %s", err), - InternalErr: rpc.InternalErr{ - Flag: true, - MainCarDown: false, - }, - } - } - case sdkerr.NonRetryable: - return nil, &rpc.ApplyUDFErr{ - UserUDFErr: false, - Message: fmt.Sprintf("gRPC client.SourceTransformFn failed, %s", err), - InternalErr: rpc.InternalErr{ - Flag: true, - MainCarDown: false, +func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + var transformResults []isb.ReadWriteMessagePair + inputChan := make(chan *v1.SourceTransformRequest) + respChan, errChan := u.client.SourceTransformFn(ctx, inputChan) + + logger := logging.FromContext(ctx) + + msgTracker := tracker.NewMessageTracker(messages) + + go func() { + defer close(inputChan) + for _, msg := range messages { + req := &v1.SourceTransformRequest{ + Request: &v1.SourceTransformRequest_Request{ + Keys: msg.Keys, + Value: msg.Body.Payload, + EventTime: timestamppb.New(msg.MessageInfo.EventTime), + Watermark: timestamppb.New(msg.Watermark), + Headers: msg.Headers, + Id: msg.ReadOffset.String(), }, } - default: - return nil, &rpc.ApplyUDFErr{ + inputChan <- req + } + }() + + messageCount := len(messages) + +loop: + for { + select { + case err := <-errChan: + err = &rpc.ApplyUDFErr{ UserUDFErr: false, Message: fmt.Sprintf("gRPC client.SourceTransformFn failed, %s", err), InternalErr: rpc.InternalErr{ @@ -144,34 +116,63 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, readMessage * MainCarDown: false, }, } - } - } + return nil, err + case resp, ok := <-respChan: + if !ok { + logger.Warn("Response channel from source transform client was closed.") + break loop + } + msgId := resp.GetId() + if msgId == "" { + return nil, errSourceTransformFnEmptyMsgId + } + parentMessage := msgTracker.Remove(msgId) + if parentMessage == nil { + return nil, errors.New("tracker doesn't contain the message ID received from the response") + } + messageCount-- - taggedMessages := make([]*isb.WriteMessage, 0) - for i, result := range response.GetResults() { - keys := result.Keys - if result.EventTime != nil { - // Transformer supports changing event time. - parentMessageInfo.EventTime = result.EventTime.AsTime() - } - taggedMessage := &isb.WriteMessage{ - Message: isb.Message{ - Header: isb.Header{ - MessageInfo: parentMessageInfo, - ID: isb.MessageID{ - VertexName: u.vertexName, - Offset: offset.String(), - Index: int32(i), + var taggedMessages []*isb.WriteMessage + for i, result := range resp.GetResults() { + keys := result.Keys + if result.EventTime != nil { + // Transformer supports changing event time. + parentMessage.MessageInfo.EventTime = result.EventTime.AsTime() + } + taggedMessage := &isb.WriteMessage{ + Message: isb.Message{ + Header: isb.Header{ + MessageInfo: parentMessage.MessageInfo, + ID: isb.MessageID{ + VertexName: u.vertexName, + Offset: parentMessage.ReadOffset.String(), + Index: int32(i), + }, + Keys: keys, + }, + Body: isb.Body{ + Payload: result.Value, + }, }, - Keys: keys, - }, - Body: isb.Body{ - Payload: result.Value, - }, - }, - Tags: result.Tags, + Tags: result.Tags, + } + taggedMessages = append(taggedMessages, taggedMessage) + } + responsePair := isb.ReadWriteMessagePair{ + ReadMessage: parentMessage, + WriteMessages: taggedMessages, + Err: nil, + } + transformResults = append(transformResults, responsePair) + + if messageCount == 0 { + break loop + } } - taggedMessages = append(taggedMessages, taggedMessage) } - return taggedMessages, nil + + if !msgTracker.IsEmpty() { + return nil, fmt.Errorf("transform response for all requests were not received from UDF. Remaining=%d", msgTracker.Len()) + } + return transformResults, nil } diff --git a/pkg/sources/transformer/grpc_transformer_test.go b/pkg/sources/transformer/grpc_transformer_test.go index 959a40bf51..46279bb93a 100644 --- a/pkg/sources/transformer/grpc_transformer_test.go +++ b/pkg/sources/transformer/grpc_transformer_test.go @@ -19,101 +19,60 @@ package transformer import ( "context" "encoding/json" - "fmt" + "errors" + "net" "testing" "time" - "github.com/golang/mock/gomock" - v1 "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1" - transformermock "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1/transformmock" - "github.com/stretchr/testify/assert" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/timestamppb" + "github.com/numaproj/numaflow-go/pkg/sourcetransformer" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + transformpb "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1" "github.com/numaproj/numaflow/pkg/isb" "github.com/numaproj/numaflow/pkg/isb/testutils" - "github.com/numaproj/numaflow/pkg/sdkclient/sourcetransformer" + sourcetransformerSdk "github.com/numaproj/numaflow/pkg/sdkclient/sourcetransformer" "github.com/numaproj/numaflow/pkg/udf/rpc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" ) -func NewMockGRPCBasedTransformer(mockClient *transformermock.MockSourceTransformClient) *GRPCBasedTransformer { - c, _ := sourcetransformer.NewFromClient(mockClient) - return &GRPCBasedTransformer{"test-vertex", c} -} - -func TestGRPCBasedTransformer_WaitUntilReadyWithMockClient(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - mockClient.EXPECT().IsReady(gomock.Any(), gomock.Any()).Return(&v1.ReadyResponse{Ready: true}, nil) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - - u := NewMockGRPCBasedTransformer(mockClient) - err := u.WaitUntilReady(ctx) - assert.NoError(t, err) -} - -type rpcMsg struct { - msg proto.Message -} - -func (r *rpcMsg) Matches(msg interface{}) bool { - m, ok := msg.(proto.Message) - if !ok { - return false +func TestGRPCBasedTransformer_WaitUntilReadyWithServer(t *testing.T) { + svc := &sourcetransformer.Service{ + Transformer: sourcetransformer.SourceTransformFunc(func(ctx context.Context, keys []string, datum sourcetransformer.Datum) sourcetransformer.Messages { + return sourcetransformer.Messages{} + }), } - return proto.Equal(m, r.msg) -} -func (r *rpcMsg) String() string { - return fmt.Sprintf("is %s", r.msg) + conn := newServer(t, func(server *grpc.Server) { + transformpb.RegisterSourceTransformServer(server, svc) + }) + transformClient := transformpb.NewSourceTransformClient(conn) + client, _ := sourcetransformerSdk.NewFromClient(context.Background(), transformClient) + u := NewGRPCBasedTransformer("testVertex", client) + err := u.WaitUntilReady(context.Background()) + assert.NoError(t, err) } -func TestGRPCBasedTransformer_BasicApplyWithMockClient(t *testing.T) { +func TestGRPCBasedTransformer_BasicApplyWithServer(t *testing.T) { t.Run("test success", func(t *testing.T) { - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - req := &v1.SourceTransformRequest{ - Keys: []string{"test_success_key"}, - Value: []byte(`forward_message`), - EventTime: timestamppb.New(time.Unix(1661169600, 0)), - Watermark: timestamppb.New(time.Time{}), + svc := &sourcetransformer.Service{ + Transformer: sourcetransformer.SourceTransformFunc(func(ctx context.Context, keys []string, datum sourcetransformer.Datum) sourcetransformer.Messages { + return sourcetransformer.MessagesBuilder().Append(sourcetransformer.NewMessage(datum.Value(), datum.EventTime()).WithKeys(keys)) + }), } - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(&v1.SourceTransformResponse{ - Results: []*v1.SourceTransformResponse_Result{ - { - Keys: []string{"test_success_key"}, - Value: []byte(`forward_message`), - }, - }, - }, nil) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - u := NewMockGRPCBasedTransformer(mockClient) - got, err := u.ApplyTransform(ctx, &isb.ReadMessage{ + conn := newServer(t, func(server *grpc.Server) { + transformpb.RegisterSourceTransformServer(server, svc) + }) + transformClient := transformpb.NewSourceTransformClient(conn) + ctx := context.Background() + client, err := sourcetransformerSdk.NewFromClient(ctx, transformClient) + require.NoError(t, err, "creating source transformer client") + u := NewGRPCBasedTransformer("testVertex", client) + + got, err := u.ApplyTransform(ctx, []*isb.ReadMessage{{ Message: isb.Message{ Header: isb.Header{ MessageInfo: isb.MessageInfo{ @@ -130,94 +89,33 @@ func TestGRPCBasedTransformer_BasicApplyWithMockClient(t *testing.T) { }, }, ReadOffset: isb.SimpleStringOffset(func() string { return "0" }), - }, + }}, ) assert.NoError(t, err) - assert.Equal(t, req.Keys, got[0].Keys) - assert.Equal(t, req.Value, got[0].Payload) + assert.Equal(t, []string{"test_success_key"}, got[0].WriteMessages[0].Keys) + assert.Equal(t, []byte(`forward_message`), got[0].WriteMessages[0].Payload) }) t.Run("test error", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - req := &v1.SourceTransformRequest{ - Keys: []string{"test_error_key"}, - Value: []byte(`forward_message`), - EventTime: timestamppb.New(time.Unix(1661169660, 0)), - Watermark: timestamppb.New(time.Time{}), + svc := &sourcetransformer.Service{ + Transformer: sourcetransformer.SourceTransformFunc(func(ctx context.Context, keys []string, datum sourcetransformer.Datum) sourcetransformer.Messages { + return sourcetransformer.Messages{} + }), } - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, fmt.Errorf("mock error")) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - u := NewMockGRPCBasedTransformer(mockClient) - _, err := u.ApplyTransform(ctx, &isb.ReadMessage{ - Message: isb.Message{ - Header: isb.Header{ - MessageInfo: isb.MessageInfo{ - EventTime: time.Unix(1661169660, 0), - }, - ID: isb.MessageID{ - VertexName: "test-vertex", - Offset: "0-0", - }, - Keys: []string{"test_error_key"}, - }, - Body: isb.Body{ - Payload: []byte(`forward_message`), - }, - }, - ReadOffset: isb.SimpleStringOffset(func() string { return "0" }), - }, - ) - assert.ErrorIs(t, err, &rpc.ApplyUDFErr{ - UserUDFErr: false, - Message: fmt.Sprintf("%s", err), - InternalErr: rpc.InternalErr{ - Flag: true, - MainCarDown: false, - }, + conn := newServer(t, func(server *grpc.Server) { + transformpb.RegisterSourceTransformServer(server, svc) }) - }) + transformClient := transformpb.NewSourceTransformClient(conn) + ctx, cancel := context.WithCancel(context.Background()) + client, err := sourcetransformerSdk.NewFromClient(ctx, transformClient) + require.NoError(t, err, "creating source transformer client") + u := NewGRPCBasedTransformer("testVertex", client) - t.Run("test error retryable: failed after 5 retries", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() + // This cancelled context is passed to the ApplyTransform function to simulate failure + cancel() - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - req := &v1.SourceTransformRequest{ - Keys: []string{"test_error_key"}, - Value: []byte(`forward_message`), - EventTime: timestamppb.New(time.Unix(1661169660, 0)), - Watermark: timestamppb.New(time.Time{}), - } - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - - u := NewMockGRPCBasedTransformer(mockClient) - _, err := u.ApplyTransform(ctx, &isb.ReadMessage{ + _, err = u.ApplyTransform(ctx, []*isb.ReadMessage{{ Message: isb.Message{ Header: isb.Header{ MessageInfo: isb.MessageInfo{ @@ -234,292 +132,155 @@ func TestGRPCBasedTransformer_BasicApplyWithMockClient(t *testing.T) { }, }, ReadOffset: isb.SimpleStringOffset(func() string { return "0" }), - }, + }}, ) - assert.ErrorIs(t, err, &rpc.ApplyUDFErr{ - UserUDFErr: false, - Message: fmt.Sprintf("%s", err), - InternalErr: rpc.InternalErr{ - Flag: true, - MainCarDown: false, - }, - }) - }) - - t.Run("test error retryable: failed after 1 retry", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - req := &v1.SourceTransformRequest{ - Keys: []string{"test_error_key"}, - Value: []byte(`forward_message`), - EventTime: timestamppb.New(time.Unix(1661169660, 0)), - Watermark: timestamppb.New(time.Time{}), - } - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.InvalidArgument, "mock test err: non retryable").Err()) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - - u := NewMockGRPCBasedTransformer(mockClient) - _, err := u.ApplyTransform(ctx, &isb.ReadMessage{ - Message: isb.Message{ - Header: isb.Header{ - MessageInfo: isb.MessageInfo{ - EventTime: time.Unix(1661169660, 0), - }, - ID: isb.MessageID{ - VertexName: "test-vertex", - Offset: "0-0", - }, - Keys: []string{"test_error_key"}, - }, - Body: isb.Body{ - Payload: []byte(`forward_message`), - }, - }, - ReadOffset: isb.SimpleStringOffset(func() string { return "0" }), - }, - ) - assert.ErrorIs(t, err, &rpc.ApplyUDFErr{ + expectedUDFErr := &rpc.ApplyUDFErr{ UserUDFErr: false, - Message: fmt.Sprintf("%s", err), + Message: "gRPC client.SourceTransformFn failed, NonRetryable: context canceled", InternalErr: rpc.InternalErr{ Flag: true, MainCarDown: false, }, - }) - }) - - t.Run("test error retryable: success after 1 retry", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - req := &v1.SourceTransformRequest{ - Keys: []string{"test_success_key"}, - Value: []byte(`forward_message`), - EventTime: timestamppb.New(time.Unix(1661169720, 0)), - Watermark: timestamppb.New(time.Time{}), - } - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(&v1.SourceTransformResponse{ - Results: []*v1.SourceTransformResponse_Result{ - { - Keys: []string{"test_success_key"}, - Value: []byte(`forward_message`), - }, - }, - }, nil) - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - - u := NewMockGRPCBasedTransformer(mockClient) - got, err := u.ApplyTransform(ctx, &isb.ReadMessage{ - Message: isb.Message{ - Header: isb.Header{ - MessageInfo: isb.MessageInfo{ - EventTime: time.Unix(1661169720, 0), - }, - ID: isb.MessageID{ - VertexName: "test-vertex", - Offset: "0-0", - }, - Keys: []string{"test_success_key"}, - }, - Body: isb.Body{ - Payload: []byte(`forward_message`), - }, - }, - ReadOffset: isb.SimpleStringOffset(func() string { return "0" }), - }, - ) - assert.NoError(t, err) - assert.Equal(t, req.Keys, got[0].Keys) - assert.Equal(t, req.Value, got[0].Payload) - }) - - t.Run("test error non retryable", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - req := &v1.SourceTransformRequest{ - Keys: []string{"test_error_key"}, - Value: []byte(`forward_message`), - EventTime: timestamppb.New(time.Unix(1661169660, 0)), - Watermark: timestamppb.New(time.Time{}), } - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.InvalidArgument, "mock test err: non retryable").Err()) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - - u := NewMockGRPCBasedTransformer(mockClient) - _, err := u.ApplyTransform(ctx, &isb.ReadMessage{ - Message: isb.Message{ - Header: isb.Header{ - MessageInfo: isb.MessageInfo{ - EventTime: time.Unix(1661169660, 0), - }, - ID: isb.MessageID{ - VertexName: "test-vertex", - Offset: "0-0", - }, - Keys: []string{"test_error_key"}, - }, - Body: isb.Body{ - Payload: []byte(`forward_message`), - }, - }, - ReadOffset: isb.SimpleStringOffset(func() string { return "0" }), - }, - ) - assert.ErrorIs(t, err, &rpc.ApplyUDFErr{ - UserUDFErr: false, - Message: fmt.Sprintf("%s", err), - InternalErr: rpc.InternalErr{ - Flag: true, - MainCarDown: false, - }, - }) + var receivedErr *rpc.ApplyUDFErr + assert.ErrorAs(t, err, &receivedErr) + assert.Equal(t, expectedUDFErr, receivedErr) }) } -func TestGRPCBasedTransformer_ApplyWithMockClient_ChangePayload(t *testing.T) { - multiplyBy2 := func(body []byte) interface{} { - var result testutils.PayloadForTest - _ = json.Unmarshal(body, &result) - result.Value = result.Value * 2 - return result - } - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, datum *v1.SourceTransformRequest, opts ...grpc.CallOption) (*v1.SourceTransformResponse, error) { +func TestGRPCBasedTransformer_ApplyWithServer_ChangePayload(t *testing.T) { + svc := &sourcetransformer.Service{ + Transformer: sourcetransformer.SourceTransformFunc(func(ctx context.Context, keys []string, datum sourcetransformer.Datum) sourcetransformer.Messages { var originalValue testutils.PayloadForTest - _ = json.Unmarshal(datum.GetValue(), &originalValue) - doubledValue, _ := json.Marshal(multiplyBy2(datum.GetValue()).(testutils.PayloadForTest)) - var Results []*v1.SourceTransformResponse_Result + _ = json.Unmarshal(datum.Value(), &originalValue) + doubledValue := testutils.PayloadForTest{ + Value: originalValue.Value * 2, + Key: originalValue.Key, + } + doubledValueBytes, _ := json.Marshal(&doubledValue) + + var resultKeys []string if originalValue.Value%2 == 0 { - Results = append(Results, &v1.SourceTransformResponse_Result{ - Keys: []string{"even"}, - Value: doubledValue, - }) + resultKeys = []string{"even"} } else { - Results = append(Results, &v1.SourceTransformResponse_Result{ - Keys: []string{"odd"}, - Value: doubledValue, - }) - } - datumList := &v1.SourceTransformResponse{ - Results: Results, + resultKeys = []string{"odd"} } - return datumList, nil - }, - ).AnyTimes() - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() + return sourcetransformer.MessagesBuilder().Append(sourcetransformer.NewMessage(doubledValueBytes, datum.EventTime()).WithKeys(resultKeys)) + }), + } - u := NewMockGRPCBasedTransformer(mockClient) + conn := newServer(t, func(server *grpc.Server) { + transformpb.RegisterSourceTransformServer(server, svc) + }) + transformClient := transformpb.NewSourceTransformClient(conn) + ctx := context.Background() + client, _ := sourcetransformerSdk.NewFromClient(ctx, transformClient) + u := NewGRPCBasedTransformer("testVertex", client) var count = int64(10) readMessages := testutils.BuildTestReadMessages(count, time.Unix(1661169600, 0), nil) - - var results = make([][]byte, len(readMessages)) - var resultKeys = make([][]string, len(readMessages)) + messages := make([]*isb.ReadMessage, len(readMessages)) for idx, readMessage := range readMessages { - apply, err := u.ApplyTransform(ctx, &readMessage) - assert.NoError(t, err) - results[idx] = apply[0].Payload - resultKeys[idx] = apply[0].Header.Keys + messages[idx] = &readMessage } + apply, err := u.ApplyTransform(context.TODO(), messages) + assert.NoError(t, err) - var expectedResults = make([][]byte, count) - var expectedKeys = make([][]string, count) - for idx, readMessage := range readMessages { + for _, pair := range apply { + resultPayload := pair.WriteMessages[0].Payload + resultKeys := pair.WriteMessages[0].Header.Keys var readMessagePayload testutils.PayloadForTest - _ = json.Unmarshal(readMessage.Payload, &readMessagePayload) + _ = json.Unmarshal(pair.ReadMessage.Payload, &readMessagePayload) + var expectedKeys []string if readMessagePayload.Value%2 == 0 { - expectedKeys[idx] = []string{"even"} + expectedKeys = []string{"even"} } else { - expectedKeys[idx] = []string{"odd"} + expectedKeys = []string{"odd"} } - marshal, _ := json.Marshal(multiplyBy2(readMessage.Payload)) - expectedResults[idx] = marshal - } + assert.Equal(t, expectedKeys, resultKeys) - assert.Equal(t, expectedResults, results) - assert.Equal(t, expectedKeys, resultKeys) + doubledValue := testutils.PayloadForTest{ + Key: readMessagePayload.Key, + Value: readMessagePayload.Value * 2, + } + marshal, _ := json.Marshal(doubledValue) + assert.Equal(t, marshal, resultPayload) + } } -func TestGRPCBasedTransformer_ApplyWithMockClient_ChangeEventTime(t *testing.T) { - testEventTime := time.Date(1992, 2, 8, 0, 0, 0, 100, time.UTC) - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, datum *v1.SourceTransformRequest, opts ...grpc.CallOption) (*v1.SourceTransformResponse, error) { - var Results []*v1.SourceTransformResponse_Result - Results = append(Results, &v1.SourceTransformResponse_Result{ - Keys: []string{"even"}, - Value: datum.Value, - EventTime: timestamppb.New(testEventTime), - }) - datumList := &v1.SourceTransformResponse{ - Results: Results, - } - return datumList, nil - }, - ).AnyTimes() +func newServer(t *testing.T, register func(server *grpc.Server)) *grpc.ClientConn { + lis := bufconn.Listen(100) + t.Cleanup(func() { + _ = lis.Close() + }) + + server := grpc.NewServer() + t.Cleanup(func() { + server.Stop() + }) + + register(server) - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() + errChan := make(chan error, 1) go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") + // t.Fatal should only be called from the goroutine running the test + if err := server.Serve(lis); err != nil { + errChan <- err } }() - u := NewMockGRPCBasedTransformer(mockClient) + dialer := func(context.Context, string) (net.Conn, error) { + return lis.Dial() + } + + conn, err := grpc.NewClient("passthrough://", grpc.WithContextDialer(dialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + t.Cleanup(func() { + _ = conn.Close() + }) + if err != nil { + t.Fatalf("Creating new gRPC client connection: %v", err) + } + + var grpcServerErr error + select { + case grpcServerErr = <-errChan: + case <-time.After(500 * time.Millisecond): + grpcServerErr = errors.New("gRPC server didn't start in 500ms") + } + if err != nil { + t.Fatalf("Failed to start gRPC server: %v", grpcServerErr) + } + + return conn +} + +func TestGRPCBasedTransformer_Apply_ChangeEventTime(t *testing.T) { + testEventTime := time.Date(1992, 2, 8, 0, 0, 0, 100, time.UTC) + svc := &sourcetransformer.Service{ + Transformer: sourcetransformer.SourceTransformFunc(func(ctx context.Context, keys []string, datum sourcetransformer.Datum) sourcetransformer.Messages { + msg := datum.Value() + return sourcetransformer.MessagesBuilder().Append(sourcetransformer.NewMessage(msg, testEventTime).WithKeys([]string{"even"})) + }), + } + conn := newServer(t, func(server *grpc.Server) { + transformpb.RegisterSourceTransformServer(server, svc) + }) + transformClient := transformpb.NewSourceTransformClient(conn) + ctx := context.Background() + client, _ := sourcetransformerSdk.NewFromClient(ctx, transformClient) + u := NewGRPCBasedTransformer("testVertex", client) var count = int64(2) readMessages := testutils.BuildTestReadMessages(count, time.Unix(1661169600, 0), nil) - for _, readMessage := range readMessages { - apply, err := u.ApplyTransform(ctx, &readMessage) - assert.NoError(t, err) - assert.Equal(t, testEventTime, apply[0].EventTime) + messages := make([]*isb.ReadMessage, len(readMessages)) + for idx, readMessage := range readMessages { + messages[idx] = &readMessage + } + apply, err := u.ApplyTransform(context.TODO(), messages) + assert.NoError(t, err) + for _, pair := range apply { + assert.NoError(t, pair.Err) + assert.Equal(t, testEventTime, pair.WriteMessages[0].EventTime) } } diff --git a/pkg/udf/forward/forward.go b/pkg/udf/forward/forward.go index 53efc945da..e768808cc3 100644 --- a/pkg/udf/forward/forward.go +++ b/pkg/udf/forward/forward.go @@ -481,7 +481,7 @@ func (isdf *InterStepDataForward) streamMessage(ctx context.Context, dataMessage if len(dataMessages) > 1 { errMsg := "data message size is not 1 with map UDF streaming" isdf.opts.logger.Errorw(errMsg) - return nil, fmt.Errorf(errMsg) + return nil, errors.New(errMsg) } else if len(dataMessages) == 1 { // send to map UDF only the data messages diff --git a/pkg/udf/rpc/grpc_batch_map.go b/pkg/udf/rpc/grpc_batch_map.go index 6d6c397642..ce65d201fb 100644 --- a/pkg/udf/rpc/grpc_batch_map.go +++ b/pkg/udf/rpc/grpc_batch_map.go @@ -26,26 +26,21 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/numaproj/numaflow/pkg/isb" + "github.com/numaproj/numaflow/pkg/isb/tracker" "github.com/numaproj/numaflow/pkg/sdkclient/batchmapper" "github.com/numaproj/numaflow/pkg/shared/logging" ) // GRPCBasedBatchMap is a map applier that uses gRPC client to invoke the map UDF. It implements the applier.MapApplier interface. type GRPCBasedBatchMap struct { - vertexName string - client batchmapper.Client - requestTracker *tracker + vertexName string + client batchmapper.Client } func NewUDSgRPCBasedBatchMap(vertexName string, client batchmapper.Client) *GRPCBasedBatchMap { return &GRPCBasedBatchMap{ vertexName: vertexName, client: client, - // requestTracker is used to store the read messages in a key, value manner where - // key is the read offset and the reference to read message as the value. - // Once the results are received from the UDF, we map the responses to the corresponding request - // using a lookup on this tracker. - requestTracker: NewTracker(), } } @@ -93,18 +88,17 @@ func (u *GRPCBasedBatchMap) ApplyBatchMap(ctx context.Context, messages []*isb.R // trackerReq is used to store the read messages in a key, value manner where // key is the read offset and the reference to read message as the value. // Once the results are received from the UDF, we map the responses to the corresponding request - // using a lookup on this tracker. - trackerReq := NewTracker() + // using a lookup on this Tracker. + trackerReq := tracker.NewMessageTracker(messages) // Read routine: this goroutine iterates over the input messages and sends each // of the read messages to the grpc client after transforming it to a BatchMapRequest. // Once all messages are sent, it closes the input channel to indicate that all requests have been read. - // On creating a new request, we add it to a tracker map so that the responses on the stream + // On creating a new request, we add it to a Tracker map so that the responses on the stream // can be mapped backed to the given parent request go func() { defer close(inputChan) for _, msg := range messages { - trackerReq.addRequest(msg) inputChan <- u.parseInputRequest(msg) } }() @@ -139,14 +133,14 @@ loop: } // Get the unique request ID for which these responses are meant for. msgId := grpcResp.GetId() - // Fetch the request value for the given ID from the tracker - parentMessage, ok := trackerReq.getRequest(msgId) - if !ok { - // this case is when the given request ID was not present in the tracker. + // Fetch the request value for the given ID from the Tracker + parentMessage := trackerReq.Remove(msgId) + if parentMessage == nil { + // this case is when the given request ID was not present in the Tracker. // This means that either the UDF added an incorrect ID // This cannot be processed further and should result in an error // Can there be another case for this? - logger.Error("Request missing from tracker, ", msgId) + logger.Error("Request missing from message tracker, ", msgId) return nil, fmt.Errorf("incorrect ID found during batch map processing") } // parse the responses received @@ -159,12 +153,11 @@ loop: Err: nil, } udfResults = append(udfResults, responsePair) - trackerReq.removeRequest(msgId) } } - // check if there are elements left in the tracker. This cannot be an acceptable case as we want the + // check if there are elements left in the Tracker. This cannot be an acceptable case as we want the // UDF to send responses for all elements. - if !trackerReq.isEmpty() { + if !trackerReq.IsEmpty() { logger.Error("BatchMap response for all requests not received from UDF") return nil, fmt.Errorf("batchMap response for all requests not received from UDF") } diff --git a/pkg/udf/rpc/tracker.go b/pkg/udf/rpc/tracker.go deleted file mode 100644 index 60b57a7af9..0000000000 --- a/pkg/udf/rpc/tracker.go +++ /dev/null @@ -1,75 +0,0 @@ -package rpc - -import ( - "sync" - - "github.com/numaproj/numaflow/pkg/isb" -) - -// tracker is used to store a key value pair for string and *isb.ReadMessage -// as it can be accessed by concurrent goroutines, we keep all operations -// under a mutex -type tracker struct { - lock sync.RWMutex - m map[string]*isb.ReadMessage -} - -// NewTracker initializes a new instance of a tracker -func NewTracker() *tracker { - return &tracker{ - m: make(map[string]*isb.ReadMessage), - lock: sync.RWMutex{}, - } -} - -// addRequest add a new entry for a given message to the tracker. -// the key is chosen as the read offset of the message -func (t *tracker) addRequest(msg *isb.ReadMessage) { - id := msg.ReadOffset.String() - t.set(id, msg) -} - -// getRequest returns the message corresponding to a given id, along with a bool -// to indicate if it does not exist -func (t *tracker) getRequest(id string) (*isb.ReadMessage, bool) { - return t.get(id) -} - -// removeRequest will remove the entry for a given id -func (t *tracker) removeRequest(id string) { - t.delete(id) -} - -// get is a helper function which fetches the message corresponding to a given id -// it acquires a lock before accessing the map -func (t *tracker) get(key string) (*isb.ReadMessage, bool) { - t.lock.RLock() - defer t.lock.RUnlock() - item, ok := t.m[key] - return item, ok -} - -// set is a helper function which add a key, value pair to the tracker map -// it acquires a lock before accessing the map -func (t *tracker) set(key string, msg *isb.ReadMessage) { - t.lock.Lock() - defer t.lock.Unlock() - t.m[key] = msg -} - -// delete is a helper function which will remove the entry for a given id -// it acquires a lock before accessing the map -func (t *tracker) delete(key string) { - t.lock.Lock() - defer t.lock.Unlock() - delete(t.m, key) -} - -// isEmpty is a helper function which checks if the tracker map is empty -// return true if empty -func (t *tracker) isEmpty() bool { - t.lock.RLock() - defer t.lock.RUnlock() - items := len(t.m) - return items == 0 -} diff --git a/pkg/webhook/validator/validator.go b/pkg/webhook/validator/validator.go index d5f2e86664..6d4e3e46a1 100644 --- a/pkg/webhook/validator/validator.go +++ b/pkg/webhook/validator/validator.go @@ -83,7 +83,10 @@ func GetValidator(ctx context.Context, NumaClient v1alpha1.NumaflowV1alpha1Inter // DeniedResponse constructs a denied AdmissionResponse func DeniedResponse(reason string, args ...interface{}) *admissionv1.AdmissionResponse { - result := apierrors.NewBadRequest(fmt.Sprintf(reason, args...)).Status() + if len(args) > 0 { + reason = fmt.Sprintf(reason, args) + } + result := apierrors.NewBadRequest(reason).Status() return &admissionv1.AdmissionResponse{ Result: &result, Allowed: false, diff --git a/test/transformer-e2e/transformer_test.go b/test/transformer-e2e/transformer_test.go index 55b88f3683..8c07db744c 100644 --- a/test/transformer-e2e/transformer_test.go +++ b/test/transformer-e2e/transformer_test.go @@ -173,23 +173,24 @@ func (s *TransformerSuite) TestSourceTransformer() { } var wg sync.WaitGroup - wg.Add(4) - go func() { - defer wg.Done() - s.testSourceTransformer("python") - }() - go func() { - defer wg.Done() - s.testSourceTransformer("java") - }() + wg.Add(1) + // FIXME: Enable these tests after corresponding SDKs are changed to support bidirectional streaming + //go func() { + // defer wg.Done() + // s.testSourceTransformer("python") + //}() + //go func() { + // defer wg.Done() + // s.testSourceTransformer("java") + //}() go func() { defer wg.Done() s.testSourceTransformer("go") }() - go func() { - defer wg.Done() - s.testSourceTransformer("rust") - }() + //go func() { + // defer wg.Done() + // s.testSourceTransformer("rust") + //}() wg.Wait() } From 9792e583aa07d5ac6489eb0e5cc36f1efacba9dc Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 1 Oct 2024 10:08:30 +0530 Subject: [PATCH 03/10] Update Go and Rust sdks to latest master Signed-off-by: Sreekanth --- go.mod | 2 +- go.sum | 4 ++-- rust/Cargo.lock | 2 +- rust/numaflow-core/Cargo.toml | 2 +- rust/servesink/Cargo.toml | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 8cc5299772..ba62a6f28d 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe github.com/nats-io/nats-server/v2 v2.10.20 github.com/nats-io/nats.go v1.37.0 - github.com/numaproj/numaflow-go v0.8.2-0.20240930081452-bd8cc005573a + github.com/numaproj/numaflow-go v0.8.2-0.20241001031210-60188185d9c0 github.com/prometheus/client_golang v1.18.0 github.com/prometheus/client_model v0.5.0 github.com/prometheus/common v0.45.0 diff --git a/go.sum b/go.sum index bd6f0692f9..9670ccac4b 100644 --- a/go.sum +++ b/go.sum @@ -485,8 +485,8 @@ github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDm github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/numaproj/numaflow-go v0.8.2-0.20240930081452-bd8cc005573a h1:xbpsfHFjZFsm99bC6x9/plMDIBIEkdUt4J/EMiEifrg= -github.com/numaproj/numaflow-go v0.8.2-0.20240930081452-bd8cc005573a/go.mod h1:FaCMeV0V9SiLcVf2fwT+GeTJHNaK2gdQsTAIqQ4x7oc= +github.com/numaproj/numaflow-go v0.8.2-0.20241001031210-60188185d9c0 h1:MN4Q36mPrXqPrv2dNoK3gyV7c1CGwUF3wNJxTZSw1lk= +github.com/numaproj/numaflow-go v0.8.2-0.20241001031210-60188185d9c0/go.mod h1:FaCMeV0V9SiLcVf2fwT+GeTJHNaK2gdQsTAIqQ4x7oc= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= diff --git a/rust/Cargo.lock b/rust/Cargo.lock index b9cfa9d3bd..e2b3045712 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1557,7 +1557,7 @@ dependencies = [ [[package]] name = "numaflow" version = "0.1.1" -source = "git+https://github.com/BulkBeing/numaflow-rs.git?rev=6eb7f3865d42a8ab11ade51622dc4d8feda25b5e#6eb7f3865d42a8ab11ade51622dc4d8feda25b5e" +source = "git+https://github.com/numaproj/numaflow-rs.git?rev=30d8ce1972fd3f0c0b8059fee209516afeef0088#30d8ce1972fd3f0c0b8059fee209516afeef0088" dependencies = [ "chrono", "futures-util", diff --git a/rust/numaflow-core/Cargo.toml b/rust/numaflow-core/Cargo.toml index 962901cb2e..a10a46b9ab 100644 --- a/rust/numaflow-core/Cargo.toml +++ b/rust/numaflow-core/Cargo.toml @@ -38,7 +38,7 @@ log = "0.4.22" [dev-dependencies] tempfile = "3.11.0" -numaflow = { git = "https://github.com/BulkBeing/numaflow-rs.git", rev = "6eb7f3865d42a8ab11ade51622dc4d8feda25b5e" } +numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "30d8ce1972fd3f0c0b8059fee209516afeef0088" } [build-dependencies] tonic-build = "0.12.1" diff --git a/rust/servesink/Cargo.toml b/rust/servesink/Cargo.toml index 76de0e491d..80430c169b 100644 --- a/rust/servesink/Cargo.toml +++ b/rust/servesink/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] tonic = "0.12.0" tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } -numaflow = { git = "https://github.com/BulkBeing/numaflow-rs.git", rev = "6eb7f3865d42a8ab11ade51622dc4d8feda25b5e" } +numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "30d8ce1972fd3f0c0b8059fee209516afeef0088" } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } From b8d0498ba42e9429e536cff93598d994bba7e93a Mon Sep 17 00:00:00 2001 From: Yashash H L Date: Tue, 1 Oct 2024 15:48:16 +0530 Subject: [PATCH 04/10] dummy testing Signed-off-by: Yashash H L --- pkg/sdkclient/sourcetransformer/client.go | 15 ++++++ pkg/shared/expr/eval_bool.go | 27 +++++------ pkg/sources/generator/tickgen.go | 1 - .../event_time/event_time_extractor.go | 32 +++++-------- .../time_extraction_filter.go | 48 +++++++++---------- pkg/sources/transformer/grpc_transformer.go | 20 +++++++- .../extract-event-time-from-payload.yaml | 7 ++- 7 files changed, 88 insertions(+), 62 deletions(-) diff --git a/pkg/sdkclient/sourcetransformer/client.go b/pkg/sdkclient/sourcetransformer/client.go index 7f3327ac5f..1e0710f3eb 100644 --- a/pkg/sdkclient/sourcetransformer/client.go +++ b/pkg/sdkclient/sourcetransformer/client.go @@ -19,6 +19,7 @@ package sourcetransformer import ( "context" "fmt" + "log" "time" "google.golang.org/grpc" @@ -152,6 +153,10 @@ func (c *client) IsReady(ctx context.Context, in *emptypb.Empty) (bool, error) { func (c *client) SourceTransformFn(ctx context.Context, request <-chan *transformpb.SourceTransformRequest) (<-chan *transformpb.SourceTransformResponse, <-chan error) { clientErrCh := make(chan error) responseCh := make(chan *transformpb.SourceTransformResponse) + //defer func() { + // close(responseCh) + // clientErrCh = nil + //}() // This channel is to send the error from the goroutine that receives messages from the stream to the goroutine that sends requests to the server. // This ensures that we don't need to use clientErrCh in both goroutines. The caller of this function will only be listening for the first error value in clientErrCh. @@ -165,18 +170,23 @@ func (c *client) SourceTransformFn(ctx context.Context, request <-chan *transfor for { resp, err := c.stream.Recv() if err != nil { + log.Println("Error in receiving response from the stream") // we don't need an EOF check because we only close the stream during shutdown. errCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn", err) close(errCh) return } + log.Println("Received response from the stream with id - ", resp.GetId()) select { case <-ctx.Done(): + log.Println("Context cancelled. Stopping retrieving messages from the stream") logger.Warnf("Context cancelled. Stopping retrieving messages from the stream") return case responseCh <- resp: + log.Println("Sent response to the channel with id - ", resp.GetId()) } + log.Println("We got a message from the stream") } }() @@ -184,21 +194,26 @@ func (c *client) SourceTransformFn(ctx context.Context, request <-chan *transfor for { select { case <-ctx.Done(): + log.Println("Context cancelled. Stopping sending messages to the stream") clientErrCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", ctx.Err()) return case err := <-errCh: + log.Println("Error in sending request to the stream") clientErrCh <- err return case msg, ok := <-request: if !ok { + log.Println("Request channel closed. Stopping sending messages to the stream") // stream is only closed during shutdown return } + log.Println("Trying to send request to the stream with id - ", msg.GetRequest().GetId()) err := c.stream.Send(msg) if err != nil { clientErrCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", err) return } + log.Println("Sent request to the stream with id - ", msg.GetRequest().GetId()) } } }() diff --git a/pkg/shared/expr/eval_bool.go b/pkg/shared/expr/eval_bool.go index 683fdb9f54..4a8b95ab08 100644 --- a/pkg/shared/expr/eval_bool.go +++ b/pkg/shared/expr/eval_bool.go @@ -22,7 +22,6 @@ import ( "strconv" "github.com/Masterminds/sprig/v3" - "github.com/antonmedv/expr" ) var sprigFuncMap = sprig.GenericFuncMap() @@ -30,19 +29,19 @@ var sprigFuncMap = sprig.GenericFuncMap() const root = "payload" func EvalBool(expression string, msg []byte) (bool, error) { - msgMap := map[string]interface{}{ - root: string(msg), - } - env := getFuncMap(msgMap) - result, err := expr.Eval(expression, env) - if err != nil { - return false, fmt.Errorf("unable to evaluate expression '%s': %s", expression, err) - } - resultBool, ok := result.(bool) - if !ok { - return false, fmt.Errorf("unable to cast expression result '%s' to bool", result) - } - return resultBool, nil + //msgMap := map[string]interface{}{ + // root: string(msg), + //} + //env := getFuncMap(msgMap) + //result, err := expr.Eval(expression, env) + //if err != nil { + // return false, fmt.Errorf("unable to evaluate expression '%s': %s", expression, err) + //} + //resultBool, ok := result.(bool) + //if !ok { + // return false, fmt.Errorf("unable to cast expression result '%s' to bool", result) + //} + return true, nil } func getFuncMap(m map[string]interface{}) map[string]interface{} { diff --git a/pkg/sources/generator/tickgen.go b/pkg/sources/generator/tickgen.go index ff00ba8cba..c0cdb9dcf1 100644 --- a/pkg/sources/generator/tickgen.go +++ b/pkg/sources/generator/tickgen.go @@ -202,7 +202,6 @@ loop: tickgenSourceReadCount.With(map[string]string{metrics.LabelVertex: mg.vertexName, metrics.LabelPipeline: mg.pipelineName}).Inc() msgs = append(msgs, mg.newReadMessage(r.key, r.data, r.offset, r.ts)) case <-timeout: - mg.logger.Infow("Timed out waiting for messages to read.", zap.Duration("waited", mg.readTimeout)) break loop } } diff --git a/pkg/sources/transformer/builtin/event_time/event_time_extractor.go b/pkg/sources/transformer/builtin/event_time/event_time_extractor.go index c844c518db..393baf52e5 100644 --- a/pkg/sources/transformer/builtin/event_time/event_time_extractor.go +++ b/pkg/sources/transformer/builtin/event_time/event_time_extractor.go @@ -21,10 +21,8 @@ import ( "fmt" "time" - "github.com/araddon/dateparse" "github.com/numaproj/numaflow-go/pkg/sourcetransformer" - "github.com/numaproj/numaflow/pkg/shared/expr" "github.com/numaproj/numaflow/pkg/shared/logging" ) @@ -67,21 +65,17 @@ func New(args map[string]string) (sourcetransformer.SourceTransformFunc, error) // apply compiles the payload to extract the new event time. If there is any error during extraction, // we pass on the original input event time. Otherwise, we assign the new event time to the message. func (e eventTimeExtractor) apply(payload []byte, et time.Time, keys []string) (sourcetransformer.Message, error) { - timeStr, err := expr.EvalStr(e.expression, payload) - if err != nil { - return sourcetransformer.NewMessage(payload, et).WithKeys(keys), err - } - - var newEventTime time.Time - time.Local, _ = time.LoadLocation("UTC") - if e.format != "" { - newEventTime, err = time.Parse(e.format, timeStr) - } else { - newEventTime, err = dateparse.ParseStrict(timeStr) - } - if err != nil { - return sourcetransformer.NewMessage(payload, et).WithKeys(keys), err - } else { - return sourcetransformer.NewMessage(payload, newEventTime).WithKeys(keys), nil - } + //timeStr, err := expr.EvalStr(e.expression, payload) + //if err != nil { + // return sourcetransformer.NewMessage(payload, et).WithKeys(keys), err + //} + // + //var newEventTime time.Time + //time.Local, _ = time.LoadLocation("UTC") + //if e.format != "" { + // newEventTime, err = time.Parse(e.format, timeStr) + //} else { + // newEventTime, err = dateparse.ParseStrict(timeStr) + //} + return sourcetransformer.NewMessage(payload, et).WithKeys(keys), nil } diff --git a/pkg/sources/transformer/builtin/time_extraction_filter/time_extraction_filter.go b/pkg/sources/transformer/builtin/time_extraction_filter/time_extraction_filter.go index bb1515cb1e..752c8db563 100644 --- a/pkg/sources/transformer/builtin/time_extraction_filter/time_extraction_filter.go +++ b/pkg/sources/transformer/builtin/time_extraction_filter/time_extraction_filter.go @@ -21,10 +21,8 @@ import ( "fmt" "time" - "github.com/araddon/dateparse" "github.com/numaproj/numaflow-go/pkg/sourcetransformer" - "github.com/numaproj/numaflow/pkg/shared/expr" "github.com/numaproj/numaflow/pkg/shared/logging" ) @@ -69,27 +67,27 @@ func New(args map[string]string) (sourcetransformer.SourceTransformFunc, error) } func (e expressions) apply(et time.Time, payload []byte, keys []string) (sourcetransformer.Message, error) { - result, err := expr.EvalBool(e.filterExpr, payload) - if err != nil { - return sourcetransformer.MessageToDrop(et), err - } - if result { - timeStr, err := expr.EvalStr(e.eventTimeExpr, payload) - if err != nil { - return sourcetransformer.NewMessage(payload, et).WithKeys(keys), err - } - var newEventTime time.Time - time.Local, _ = time.LoadLocation("UTC") - if e.eventTimeFormat != "" { - newEventTime, err = time.Parse(e.eventTimeFormat, timeStr) - } else { - newEventTime, err = dateparse.ParseStrict(timeStr) - } - if err != nil { - return sourcetransformer.NewMessage(payload, et).WithKeys(keys), err - } else { - return sourcetransformer.NewMessage(payload, newEventTime).WithKeys(keys), nil - } - } - return sourcetransformer.MessageToDrop(et), nil + //result, err := expr.EvalBool(e.filterExpr, payload) + //if err != nil { + // return sourcetransformer.MessageToDrop(et), err + //} + //if result { + // timeStr, err := expr.EvalStr(e.eventTimeExpr, payload) + // if err != nil { + // return sourcetransformer.NewMessage(payload, et).WithKeys(keys), err + // } + // var newEventTime time.Time + // time.Local, _ = time.LoadLocation("UTC") + // if e.eventTimeFormat != "" { + // newEventTime, err = time.Parse(e.eventTimeFormat, timeStr) + // } else { + // newEventTime, err = dateparse.ParseStrict(timeStr) + // } + // if err != nil { + // return sourcetransformer.NewMessage(payload, et).WithKeys(keys), err + // } else { + // return sourcetransformer.NewMessage(payload, newEventTime).WithKeys(keys), nil + // } + //} + return sourcetransformer.NewMessage(payload, et).WithKeys(keys), nil } diff --git a/pkg/sources/transformer/grpc_transformer.go b/pkg/sources/transformer/grpc_transformer.go index e578687fb3..e1fa045f4d 100644 --- a/pkg/sources/transformer/grpc_transformer.go +++ b/pkg/sources/transformer/grpc_transformer.go @@ -20,16 +20,18 @@ import ( "context" "errors" "fmt" + "log" "time" v1 "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1" + "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/timestamppb" + "github.com/numaproj/numaflow/pkg/isb" "github.com/numaproj/numaflow/pkg/isb/tracker" "github.com/numaproj/numaflow/pkg/sdkclient/sourcetransformer" "github.com/numaproj/numaflow/pkg/shared/logging" "github.com/numaproj/numaflow/pkg/udf/rpc" - "google.golang.org/protobuf/types/known/emptypb" - "google.golang.org/protobuf/types/known/timestamppb" ) // GRPCBasedTransformer applies user-defined transformer over gRPC (over Unix Domain Socket) client/server where server is the transformer. @@ -80,6 +82,9 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*i var transformResults []isb.ReadWriteMessagePair inputChan := make(chan *v1.SourceTransformRequest) respChan, errChan := u.client.SourceTransformFn(ctx, inputChan) + defer func() { + log.Println("Returned from ApplyTransform") + }() logger := logging.FromContext(ctx) @@ -88,6 +93,7 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*i go func() { defer close(inputChan) for _, msg := range messages { + log.Println("Sending message to source transform client") req := &v1.SourceTransformRequest{ Request: &v1.SourceTransformRequest_Request{ Keys: msg.Keys, @@ -106,8 +112,10 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*i loop: for { + log.Println("Waiting for response from source transform client") select { case err := <-errChan: + log.Println("Error from source transform client") err = &rpc.ApplyUDFErr{ UserUDFErr: false, Message: fmt.Sprintf("gRPC client.SourceTransformFn failed, %s", err), @@ -118,6 +126,7 @@ loop: } return nil, err case resp, ok := <-respChan: + println("Response from source transform client") if !ok { logger.Warn("Response channel from source transform client was closed.") break loop @@ -131,12 +140,14 @@ loop: return nil, errors.New("tracker doesn't contain the message ID received from the response") } messageCount-- + log.Println("Message count: ", messageCount) var taggedMessages []*isb.WriteMessage for i, result := range resp.GetResults() { keys := result.Keys if result.EventTime != nil { // Transformer supports changing event time. + log.Println("Updating event time from ", parentMessage.MessageInfo.EventTime.UnixMilli(), " to ", result.EventTime.AsTime().UnixMilli()) parentMessage.MessageInfo.EventTime = result.EventTime.AsTime() } taggedMessage := &isb.WriteMessage{ @@ -166,13 +177,18 @@ loop: transformResults = append(transformResults, responsePair) if messageCount == 0 { + log.Println("All messages are transformed.") break loop } } + log.Println("Received some response from source transform client") } + log.Println("Checking if all messages are transformed.") if !msgTracker.IsEmpty() { + log.Println("All messages are not transformed yet , pending messages count: ", msgTracker.Len()) return nil, fmt.Errorf("transform response for all requests were not received from UDF. Remaining=%d", msgTracker.Len()) } + log.Println("Exiting ApplyTransform") return transformResults, nil } diff --git a/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml b/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml index 7bee8ef95f..8482fd9838 100644 --- a/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml +++ b/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml @@ -6,7 +6,12 @@ spec: vertices: - name: in source: - http: {} + generator: + # How many messages to generate in the duration. + rpu: 5 + duration: 1s + # Optional, size of each generated message, defaults to 10. + msgSize: 1024 transformer: builtin: name: eventTimeExtractor From b71a6eb32f5e7acec8193594dc5398e9fda40e37 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 1 Oct 2024 22:41:28 +0530 Subject: [PATCH 05/10] working - background goroutine --- pkg/sdkclient/sourcetransformer/client.go | 126 +++++++------ .../sourcetransformer/client_test.go | 21 +-- pkg/sdkclient/sourcetransformer/interface.go | 2 +- pkg/sources/forward/data_forward.go | 1 + pkg/sources/transformer/grpc_transformer.go | 168 +++++++----------- .../extract-event-time-from-payload.yaml | 14 +- 6 files changed, 146 insertions(+), 186 deletions(-) diff --git a/pkg/sdkclient/sourcetransformer/client.go b/pkg/sdkclient/sourcetransformer/client.go index 1e0710f3eb..9ba873ad14 100644 --- a/pkg/sdkclient/sourcetransformer/client.go +++ b/pkg/sdkclient/sourcetransformer/client.go @@ -19,7 +19,7 @@ package sourcetransformer import ( "context" "fmt" - "log" + "golang.org/x/sync/errgroup" "time" "google.golang.org/grpc" @@ -36,9 +36,32 @@ import ( // client contains the grpc connection and the grpc client. type client struct { - conn *grpc.ClientConn - grpcClt transformpb.SourceTransformClient - stream transformpb.SourceTransform_SourceTransformFnClient + conn *grpc.ClientConn + grpcClt transformpb.SourceTransformClient + stream transformpb.SourceTransform_SourceTransformFnClient + requestsCh chan<- *transformpb.SourceTransformRequest + responsesCh <-chan *transformpb.SourceTransformResponse + errCh <-chan error +} + +func sendRequests(stream transformpb.SourceTransform_SourceTransformFnClient, requestsCh <-chan *transformpb.SourceTransformRequest, errCh chan<- error) { + for request := range requestsCh { + if err := stream.Send(request); err != nil { + errCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", err) + return + } + } +} + +func receiveResponses(stream transformpb.SourceTransform_SourceTransformFnClient, responseCh chan<- *transformpb.SourceTransformResponse, errCh chan<- error) { + for { + resp, err := stream.Recv() + if err != nil { + errCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Recv", err) + return + } + responseCh <- resp + } } // New creates a new client object. @@ -86,6 +109,15 @@ waitUntilReady: return nil, err } + errCh := make(chan error, 2) + requestsCh := make(chan *transformpb.SourceTransformRequest) + responsesCh := make(chan *transformpb.SourceTransformResponse) + go sendRequests(c.stream, requestsCh, errCh) + go receiveResponses(c.stream, responsesCh, errCh) + + c.errCh = errCh + c.requestsCh = requestsCh + c.responsesCh = responsesCh return c, nil } @@ -150,73 +182,39 @@ func (c *client) IsReady(ctx context.Context, in *emptypb.Empty) (bool, error) { // SourceTransformFn SourceTransformerFn applies a function to each request element. // Response channel will not be closed. Caller can select on response and error channel to exit on first error. -func (c *client) SourceTransformFn(ctx context.Context, request <-chan *transformpb.SourceTransformRequest) (<-chan *transformpb.SourceTransformResponse, <-chan error) { - clientErrCh := make(chan error) - responseCh := make(chan *transformpb.SourceTransformResponse) - //defer func() { - // close(responseCh) - // clientErrCh = nil - //}() - - // This channel is to send the error from the goroutine that receives messages from the stream to the goroutine that sends requests to the server. - // This ensures that we don't need to use clientErrCh in both goroutines. The caller of this function will only be listening for the first error value in clientErrCh. - // If both goroutines were sending error message to this channel (eg. stream failure), one of them will be stuck in sending can not shutdown cleanly. - errCh := make(chan error, 1) - - logger := logging.FromContext(ctx) - - // Receive responses from the stream - go func() { - for { - resp, err := c.stream.Recv() - if err != nil { - log.Println("Error in receiving response from the stream") - // we don't need an EOF check because we only close the stream during shutdown. - errCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn", err) - close(errCh) - return - } - log.Println("Received response from the stream with id - ", resp.GetId()) +func (c *client) SourceTransformFn(ctx context.Context, requests []*transformpb.SourceTransformRequest) ([]*transformpb.SourceTransformResponse, error) { + //logger := logging.FromContext(ctx) + grp, grpCtx := errgroup.WithContext(ctx) + grp.Go(func() error { + for _, req := range requests { select { - case <-ctx.Done(): - log.Println("Context cancelled. Stopping retrieving messages from the stream") - logger.Warnf("Context cancelled. Stopping retrieving messages from the stream") - return - case responseCh <- resp: - log.Println("Sent response to the channel with id - ", resp.GetId()) + case <-grpCtx.Done(): + return grpCtx.Err() + case c.requestsCh <- req: + continue + case err := <-c.errCh: + return err } - log.Println("We got a message from the stream") } - }() + return nil + }) - go func() { - for { + resp := make([]*transformpb.SourceTransformResponse, 0, len(requests)) + grp.Go(func() error { + for i := 0; i < len(requests); i++ { select { - case <-ctx.Done(): - log.Println("Context cancelled. Stopping sending messages to the stream") - clientErrCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", ctx.Err()) - return - case err := <-errCh: - log.Println("Error in sending request to the stream") - clientErrCh <- err - return - case msg, ok := <-request: - if !ok { - log.Println("Request channel closed. Stopping sending messages to the stream") - // stream is only closed during shutdown - return - } - log.Println("Trying to send request to the stream with id - ", msg.GetRequest().GetId()) - err := c.stream.Send(msg) - if err != nil { - clientErrCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", err) - return - } - log.Println("Sent request to the stream with id - ", msg.GetRequest().GetId()) + case <-grpCtx.Done(): + return grpCtx.Err() + case r := <-c.responsesCh: + resp = append(resp, r) + case err := <-c.errCh: + return err } } - }() + return nil + }) - return responseCh, clientErrCh + err := grp.Wait() + return resp, err } diff --git a/pkg/sdkclient/sourcetransformer/client_test.go b/pkg/sdkclient/sourcetransformer/client_test.go index 619f9533f8..c66abbd6ea 100644 --- a/pkg/sdkclient/sourcetransformer/client_test.go +++ b/pkg/sdkclient/sourcetransformer/client_test.go @@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -26,7 +26,6 @@ import ( transformpb "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1" "github.com/numaproj/numaflow-go/pkg/sourcetransformer" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -121,10 +120,10 @@ func TestClient_SourceTransformFn(t *testing.T) { var ctx = context.Background() client, _ := NewFromClient(ctx, transformClient) - reqChan := make(chan *transformpb.SourceTransformRequest, 1) + requests := make([]*transformpb.SourceTransformRequest, 5) go func() { for i := 0; i < 5; i++ { - reqChan <- &transformpb.SourceTransformRequest{ + requests[i] = &transformpb.SourceTransformRequest{ Request: &transformpb.SourceTransformRequest_Request{ Keys: []string{fmt.Sprintf("client_key_%d", i)}, Value: []byte("test"), @@ -133,16 +132,10 @@ func TestClient_SourceTransformFn(t *testing.T) { } }() - respChan, errChan := client.SourceTransformFn(ctx, reqChan) + responses, err := client.SourceTransformFn(ctx, requests) + require.NoError(t, err) var results [][]*transformpb.SourceTransformResponse_Result - for i := 0; i < 5; i++ { - var resp *transformpb.SourceTransformResponse - var err error - select { - case resp = <-respChan: - case err = <-errChan: - } - assert.NoError(t, err) + for _, resp := range responses { results = append(results, resp.GetResults()) } expected := [][]*transformpb.SourceTransformResponse_Result{ @@ -152,5 +145,5 @@ func TestClient_SourceTransformFn(t *testing.T) { {{Keys: []string{"client_key_3_test"}, Value: []byte("test"), EventTime: timestamppb.New(testTime)}}, {{Keys: []string{"client_key_4_test"}, Value: []byte("test"), EventTime: timestamppb.New(testTime)}}, } - assert.ElementsMatch(t, expected, results) + require.ElementsMatch(t, expected, results) } diff --git a/pkg/sdkclient/sourcetransformer/interface.go b/pkg/sdkclient/sourcetransformer/interface.go index 6006e380c8..883353f3a6 100644 --- a/pkg/sdkclient/sourcetransformer/interface.go +++ b/pkg/sdkclient/sourcetransformer/interface.go @@ -27,5 +27,5 @@ import ( type Client interface { CloseConn(ctx context.Context) error IsReady(ctx context.Context, in *emptypb.Empty) (bool, error) - SourceTransformFn(ctx context.Context, request <-chan *transformpb.SourceTransformRequest) (<-chan *transformpb.SourceTransformResponse, <-chan error) + SourceTransformFn(ctx context.Context, requests []*transformpb.SourceTransformRequest) ([]*transformpb.SourceTransformResponse, error) } diff --git a/pkg/sources/forward/data_forward.go b/pkg/sources/forward/data_forward.go index 9eddafa30d..21558aaccd 100644 --- a/pkg/sources/forward/data_forward.go +++ b/pkg/sources/forward/data_forward.go @@ -571,6 +571,7 @@ func (df *DataForward) writeToBuffer(ctx context.Context, toBufferPartition isb. zap.String("reason", err.Error()), zap.String("partition", toBufferPartition.GetName()), zap.String("vertex", df.vertexName), zap.String("pipeline", df.pipelineName), + zap.String("msg_id", msg.ID.String()), ) } else { needRetry = true diff --git a/pkg/sources/transformer/grpc_transformer.go b/pkg/sources/transformer/grpc_transformer.go index e1fa045f4d..27a7cd34f2 100644 --- a/pkg/sources/transformer/grpc_transformer.go +++ b/pkg/sources/transformer/grpc_transformer.go @@ -28,7 +28,6 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/numaproj/numaflow/pkg/isb" - "github.com/numaproj/numaflow/pkg/isb/tracker" "github.com/numaproj/numaflow/pkg/sdkclient/sourcetransformer" "github.com/numaproj/numaflow/pkg/shared/logging" "github.com/numaproj/numaflow/pkg/udf/rpc" @@ -80,115 +79,78 @@ var errSourceTransformFnEmptyMsgId = errors.New("response from SourceTransformFn func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { var transformResults []isb.ReadWriteMessagePair - inputChan := make(chan *v1.SourceTransformRequest) - respChan, errChan := u.client.SourceTransformFn(ctx, inputChan) - defer func() { - log.Println("Returned from ApplyTransform") - }() - - logger := logging.FromContext(ctx) - - msgTracker := tracker.NewMessageTracker(messages) - - go func() { - defer close(inputChan) - for _, msg := range messages { - log.Println("Sending message to source transform client") - req := &v1.SourceTransformRequest{ - Request: &v1.SourceTransformRequest_Request{ - Keys: msg.Keys, - Value: msg.Body.Payload, - EventTime: timestamppb.New(msg.MessageInfo.EventTime), - Watermark: timestamppb.New(msg.Watermark), - Headers: msg.Headers, - Id: msg.ReadOffset.String(), - }, - } - inputChan <- req + ctx, cancel := context.WithCancel(ctx) + defer cancel() + requests := make([]*v1.SourceTransformRequest, 0, len(messages)) + idToMsgMapping := make(map[string]*isb.ReadMessage) + for _, msg := range messages { + log.Println("Sending message to source transform client") + id := msg.ReadOffset.String() + idToMsgMapping[id] = msg + req := &v1.SourceTransformRequest{ + Request: &v1.SourceTransformRequest_Request{ + Keys: msg.Keys, + Value: msg.Body.Payload, + EventTime: timestamppb.New(msg.MessageInfo.EventTime), + Watermark: timestamppb.New(msg.Watermark), + Headers: msg.Headers, + Id: id, + }, } - }() - - messageCount := len(messages) + requests = append(requests, req) + } + responses, err := u.client.SourceTransformFn(ctx, requests) + + if err != nil { + err = &rpc.ApplyUDFErr{ + UserUDFErr: false, + Message: fmt.Sprintf("gRPC client.SourceTransformFn failed, %s", err), + InternalErr: rpc.InternalErr{ + Flag: true, + MainCarDown: false, + }, + } + return nil, err + } -loop: - for { - log.Println("Waiting for response from source transform client") - select { - case err := <-errChan: - log.Println("Error from source transform client") - err = &rpc.ApplyUDFErr{ - UserUDFErr: false, - Message: fmt.Sprintf("gRPC client.SourceTransformFn failed, %s", err), - InternalErr: rpc.InternalErr{ - Flag: true, - MainCarDown: false, - }, - } - return nil, err - case resp, ok := <-respChan: - println("Response from source transform client") - if !ok { - logger.Warn("Response channel from source transform client was closed.") - break loop - } - msgId := resp.GetId() - if msgId == "" { - return nil, errSourceTransformFnEmptyMsgId - } - parentMessage := msgTracker.Remove(msgId) - if parentMessage == nil { - return nil, errors.New("tracker doesn't contain the message ID received from the response") + var taggedMessages []*isb.WriteMessage + for _, resp := range responses { + parentMessage, ok := idToMsgMapping[resp.GetId()] + if !ok { + panic("tracker doesn't contain the message ID received from the response") + } + for i, result := range resp.GetResults() { + keys := result.Keys + if result.EventTime != nil { + // Transformer supports changing event time. + log.Println("Updating event time from ", parentMessage.MessageInfo.EventTime.UnixMilli(), " to ", result.EventTime.AsTime().UnixMilli()) + parentMessage.MessageInfo.EventTime = result.EventTime.AsTime() } - messageCount-- - log.Println("Message count: ", messageCount) - - var taggedMessages []*isb.WriteMessage - for i, result := range resp.GetResults() { - keys := result.Keys - if result.EventTime != nil { - // Transformer supports changing event time. - log.Println("Updating event time from ", parentMessage.MessageInfo.EventTime.UnixMilli(), " to ", result.EventTime.AsTime().UnixMilli()) - parentMessage.MessageInfo.EventTime = result.EventTime.AsTime() - } - taggedMessage := &isb.WriteMessage{ - Message: isb.Message{ - Header: isb.Header{ - MessageInfo: parentMessage.MessageInfo, - ID: isb.MessageID{ - VertexName: u.vertexName, - Offset: parentMessage.ReadOffset.String(), - Index: int32(i), - }, - Keys: keys, - }, - Body: isb.Body{ - Payload: result.Value, + taggedMessage := &isb.WriteMessage{ + Message: isb.Message{ + Header: isb.Header{ + MessageInfo: parentMessage.MessageInfo, + ID: isb.MessageID{ + VertexName: u.vertexName, + Offset: parentMessage.ReadOffset.String(), + Index: int32(i), }, + Keys: keys, }, - Tags: result.Tags, - } - taggedMessages = append(taggedMessages, taggedMessage) - } - responsePair := isb.ReadWriteMessagePair{ - ReadMessage: parentMessage, - WriteMessages: taggedMessages, - Err: nil, - } - transformResults = append(transformResults, responsePair) - - if messageCount == 0 { - log.Println("All messages are transformed.") - break loop + Body: isb.Body{ + Payload: result.Value, + }, + }, + Tags: result.Tags, } + taggedMessages = append(taggedMessages, taggedMessage) } - log.Println("Received some response from source transform client") - } - - log.Println("Checking if all messages are transformed.") - if !msgTracker.IsEmpty() { - log.Println("All messages are not transformed yet , pending messages count: ", msgTracker.Len()) - return nil, fmt.Errorf("transform response for all requests were not received from UDF. Remaining=%d", msgTracker.Len()) + responsePair := isb.ReadWriteMessagePair{ + ReadMessage: parentMessage, + WriteMessages: taggedMessages, + Err: nil, + } + transformResults = append(transformResults, responsePair) } - log.Println("Exiting ApplyTransform") return transformResults, nil } diff --git a/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml b/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml index 8482fd9838..896499984e 100644 --- a/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml +++ b/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml @@ -5,6 +5,10 @@ metadata: spec: vertices: - name: in + containerTemplate: + env: + - name: NUMAFLOW_PPROF + value: !!str "true" source: generator: # How many messages to generate in the duration. @@ -13,10 +17,12 @@ spec: # Optional, size of each generated message, defaults to 10. msgSize: 1024 transformer: - builtin: - name: eventTimeExtractor - kwargs: - expression: json(payload).item[1].time + container: + image: quay.io/numaio/numaflow-rs/source-transformer-now:stable +# builtin: +# name: eventTimeExtractor +# kwargs: +# expression: json(payload).item[1].time - name: out scale: min: 1 From f80918f5f73532626f9ba950cebc20cb71fd621a Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 1 Oct 2024 23:52:37 +0530 Subject: [PATCH 06/10] Change SourceTransformFn interface Signed-off-by: Sreekanth --- pkg/sdkclient/sourcetransformer/client.go | 103 +++++++----------- pkg/sources/transformer/grpc_transformer.go | 6 +- .../extract-event-time-from-payload.yaml | 6 +- 3 files changed, 49 insertions(+), 66 deletions(-) diff --git a/pkg/sdkclient/sourcetransformer/client.go b/pkg/sdkclient/sourcetransformer/client.go index 9ba873ad14..e11e45f750 100644 --- a/pkg/sdkclient/sourcetransformer/client.go +++ b/pkg/sdkclient/sourcetransformer/client.go @@ -19,7 +19,8 @@ package sourcetransformer import ( "context" "fmt" - "golang.org/x/sync/errgroup" + "log" + "sync" "time" "google.golang.org/grpc" @@ -36,32 +37,9 @@ import ( // client contains the grpc connection and the grpc client. type client struct { - conn *grpc.ClientConn - grpcClt transformpb.SourceTransformClient - stream transformpb.SourceTransform_SourceTransformFnClient - requestsCh chan<- *transformpb.SourceTransformRequest - responsesCh <-chan *transformpb.SourceTransformResponse - errCh <-chan error -} - -func sendRequests(stream transformpb.SourceTransform_SourceTransformFnClient, requestsCh <-chan *transformpb.SourceTransformRequest, errCh chan<- error) { - for request := range requestsCh { - if err := stream.Send(request); err != nil { - errCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", err) - return - } - } -} - -func receiveResponses(stream transformpb.SourceTransform_SourceTransformFnClient, responseCh chan<- *transformpb.SourceTransformResponse, errCh chan<- error) { - for { - resp, err := stream.Recv() - if err != nil { - errCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Recv", err) - return - } - responseCh <- resp - } + conn *grpc.ClientConn + grpcClt transformpb.SourceTransformClient + stream transformpb.SourceTransform_SourceTransformFnClient } // New creates a new client object. @@ -109,15 +87,6 @@ waitUntilReady: return nil, err } - errCh := make(chan error, 2) - requestsCh := make(chan *transformpb.SourceTransformRequest) - responsesCh := make(chan *transformpb.SourceTransformResponse) - go sendRequests(c.stream, requestsCh, errCh) - go receiveResponses(c.stream, responsesCh, errCh) - - c.errCh = errCh - c.requestsCh = requestsCh - c.responsesCh = responsesCh return c, nil } @@ -185,36 +154,48 @@ func (c *client) IsReady(ctx context.Context, in *emptypb.Empty) (bool, error) { func (c *client) SourceTransformFn(ctx context.Context, requests []*transformpb.SourceTransformRequest) ([]*transformpb.SourceTransformResponse, error) { //logger := logging.FromContext(ctx) - grp, grpCtx := errgroup.WithContext(ctx) - grp.Go(func() error { + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer func() { wg.Done() }() for _, req := range requests { select { - case <-grpCtx.Done(): - return grpCtx.Err() - case c.requestsCh <- req: - continue - case err := <-c.errCh: - return err + case <-ctx.Done(): + return + default: + } + log.Println("Sending request:", req.Request.Id) + if err := c.stream.Send(req); err != nil { + cancel(sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", err)) + return } + log.Println("Sent request:", req.Request.Id) } - return nil - }) + }() - resp := make([]*transformpb.SourceTransformResponse, 0, len(requests)) - grp.Go(func() error { - for i := 0; i < len(requests); i++ { - select { - case <-grpCtx.Done(): - return grpCtx.Err() - case r := <-c.responsesCh: - resp = append(resp, r) - case err := <-c.errCh: - return err - } + responses := make([]*transformpb.SourceTransformResponse, 0, len(requests)) + for i := 0; i < len(requests); i++ { + select { + case <-ctx.Done(): + err := context.Cause(ctx) + log.Println("Context cancelled while receiving:", err) + return nil, err + default: } - return nil - }) + log.Println("Receiving response") + resp, err := c.stream.Recv() + log.Println("Received response:", resp.GetId(), err) + if err != nil { + err = sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Recv", err) + cancel(err) + return nil, err + } + responses = append(responses, resp) + } + + wg.Wait() - err := grp.Wait() - return resp, err + return responses, nil } diff --git a/pkg/sources/transformer/grpc_transformer.go b/pkg/sources/transformer/grpc_transformer.go index 27a7cd34f2..fac95c8c10 100644 --- a/pkg/sources/transformer/grpc_transformer.go +++ b/pkg/sources/transformer/grpc_transformer.go @@ -84,8 +84,7 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*i requests := make([]*v1.SourceTransformRequest, 0, len(messages)) idToMsgMapping := make(map[string]*isb.ReadMessage) for _, msg := range messages { - log.Println("Sending message to source transform client") - id := msg.ReadOffset.String() + id := msg.Message.ID.String() idToMsgMapping[id] = msg req := &v1.SourceTransformRequest{ Request: &v1.SourceTransformRequest_Request{ @@ -99,7 +98,10 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*i } requests = append(requests, req) } + + log.Println("Sending message to source transform client") responses, err := u.client.SourceTransformFn(ctx, requests) + log.Println("Received responses from source transform client") if err != nil { err = &rpc.ApplyUDFErr{ diff --git a/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml b/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml index 896499984e..91f863dede 100644 --- a/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml +++ b/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml @@ -16,9 +16,9 @@ spec: duration: 1s # Optional, size of each generated message, defaults to 10. msgSize: 1024 - transformer: - container: - image: quay.io/numaio/numaflow-rs/source-transformer-now:stable +# transformer: +# container: +# image: quay.io/numaio/numaflow-rs/source-transformer-now:stable # builtin: # name: eventTimeExtractor # kwargs: From 0cc303f9c0e937917dc6ed455e68910115349e16 Mon Sep 17 00:00:00 2001 From: Yashash H L Date: Wed, 2 Oct 2024 06:33:23 +0530 Subject: [PATCH 07/10] fix duplicates Signed-off-by: Yashash H L --- pkg/sdkclient/sourcetransformer/client.go | 4 +--- pkg/sources/forward/data_forward.go | 14 ++++++++++++++ pkg/sources/transformer/grpc_transformer.go | 8 ++++---- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/pkg/sdkclient/sourcetransformer/client.go b/pkg/sdkclient/sourcetransformer/client.go index e11e45f750..73eee4897f 100644 --- a/pkg/sdkclient/sourcetransformer/client.go +++ b/pkg/sdkclient/sourcetransformer/client.go @@ -166,7 +166,6 @@ func (c *client) SourceTransformFn(ctx context.Context, requests []*transformpb. return default: } - log.Println("Sending request:", req.Request.Id) if err := c.stream.Send(req); err != nil { cancel(sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", err)) return @@ -184,14 +183,13 @@ func (c *client) SourceTransformFn(ctx context.Context, requests []*transformpb. return nil, err default: } - log.Println("Receiving response") resp, err := c.stream.Recv() - log.Println("Received response:", resp.GetId(), err) if err != nil { err = sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Recv", err) cancel(err) return nil, err } + log.Println("Received response:", resp.GetId()) responses = append(responses, resp) } diff --git a/pkg/sources/forward/data_forward.go b/pkg/sources/forward/data_forward.go index 21558aaccd..bf38e349ae 100644 --- a/pkg/sources/forward/data_forward.go +++ b/pkg/sources/forward/data_forward.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "log" "strconv" "sync" "time" @@ -516,6 +517,13 @@ func (df *DataForward) writeToBuffers( for toVertexName, toVertexMessages := range messageToStep { writeOffsets[toVertexName] = make([][]isb.Offset, len(toVertexMessages)) } + for _, toVertexMessages := range messageToStep { + for idx, messages := range toVertexMessages { + for _, message := range messages { + log.Println(idx, " Writing message to buffer with message id: ", message.ID.String()) + } + } + } for toVertexName, toVertexBuffer := range df.toBuffers { for index, partition := range toVertexBuffer { writeOffsets[toVertexName][index], err = df.writeToBuffer(ctx, partition, messageToStep[toVertexName][index]) @@ -667,6 +675,12 @@ func (df *DataForward) applyTransformer(ctx context.Context, messages []*isb.Rea } continue } + for _, result := range transformResults { + log.Println("Transformed read message with id: ", result.ReadMessage.ID.String()) + for _, message := range result.WriteMessages { + log.Println("Transformed write message with id: ", message.ID.String()) + } + } return transformResults } } diff --git a/pkg/sources/transformer/grpc_transformer.go b/pkg/sources/transformer/grpc_transformer.go index fac95c8c10..41afbddc28 100644 --- a/pkg/sources/transformer/grpc_transformer.go +++ b/pkg/sources/transformer/grpc_transformer.go @@ -84,8 +84,9 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*i requests := make([]*v1.SourceTransformRequest, 0, len(messages)) idToMsgMapping := make(map[string]*isb.ReadMessage) for _, msg := range messages { - id := msg.Message.ID.String() + id := msg.ReadOffset.String() idToMsgMapping[id] = msg + log.Println("Sending message with read offset ID: ", id, " message id: ", msg.ID.String()) req := &v1.SourceTransformRequest{ Request: &v1.SourceTransformRequest_Request{ Keys: msg.Keys, @@ -99,9 +100,7 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*i requests = append(requests, req) } - log.Println("Sending message to source transform client") responses, err := u.client.SourceTransformFn(ctx, requests) - log.Println("Received responses from source transform client") if err != nil { err = &rpc.ApplyUDFErr{ @@ -115,12 +114,12 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*i return nil, err } - var taggedMessages []*isb.WriteMessage for _, resp := range responses { parentMessage, ok := idToMsgMapping[resp.GetId()] if !ok { panic("tracker doesn't contain the message ID received from the response") } + var taggedMessages []*isb.WriteMessage for i, result := range resp.GetResults() { keys := result.Keys if result.EventTime != nil { @@ -145,6 +144,7 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*i }, Tags: result.Tags, } + log.Println("Received message with ID: ", taggedMessage.ID.String()) taggedMessages = append(taggedMessages, taggedMessage) } responsePair := isb.ReadWriteMessagePair{ From 35ad8e6d3830c12a0b3c2fd42104925a45396ef4 Mon Sep 17 00:00:00 2001 From: Yashash H L Date: Wed, 2 Oct 2024 08:03:47 +0530 Subject: [PATCH 08/10] fix issues and refactor Signed-off-by: Yashash H L --- pkg/sdkclient/grpc/grpc_utils.go | 2 - pkg/sdkclient/sourcetransformer/client.go | 62 +++++++++---------- pkg/shared/expr/eval_bool.go | 27 ++++---- pkg/sources/forward/data_forward.go | 15 +---- .../event_time/event_time_extractor.go | 32 ++++++---- .../time_extraction_filter.go | 48 +++++++------- pkg/sources/transformer/grpc_transformer.go | 30 +++++---- .../extract-event-time-from-payload.yaml | 23 ++----- test/transformer-e2e/transformer_test.go | 3 +- 9 files changed, 110 insertions(+), 132 deletions(-) diff --git a/pkg/sdkclient/grpc/grpc_utils.go b/pkg/sdkclient/grpc/grpc_utils.go index 293ba8e8d7..71ae252738 100644 --- a/pkg/sdkclient/grpc/grpc_utils.go +++ b/pkg/sdkclient/grpc/grpc_utils.go @@ -18,7 +18,6 @@ package grpc import ( "fmt" - "log" "strconv" "google.golang.org/grpc" @@ -56,7 +55,6 @@ func ConnectToServer(udsSockAddr string, serverInfo *serverinfo.ServerInfo, maxM ) } else { sockAddr = getUdsSockAddr(udsSockAddr) - log.Println("UDS Client:", sockAddr) conn, err = grpc.NewClient(sockAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMessageSize), grpc.MaxCallSendMsgSize(maxMessageSize))) diff --git a/pkg/sdkclient/sourcetransformer/client.go b/pkg/sdkclient/sourcetransformer/client.go index 73eee4897f..7b1d7f11bb 100644 --- a/pkg/sdkclient/sourcetransformer/client.go +++ b/pkg/sdkclient/sourcetransformer/client.go @@ -19,10 +19,9 @@ package sourcetransformer import ( "context" "fmt" - "log" - "sync" "time" + "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/emptypb" @@ -152,48 +151,45 @@ func (c *client) IsReady(ctx context.Context, in *emptypb.Empty) (bool, error) { // SourceTransformFn SourceTransformerFn applies a function to each request element. // Response channel will not be closed. Caller can select on response and error channel to exit on first error. func (c *client) SourceTransformFn(ctx context.Context, requests []*transformpb.SourceTransformRequest) ([]*transformpb.SourceTransformResponse, error) { - //logger := logging.FromContext(ctx) - - ctx, cancel := context.WithCancelCause(ctx) - defer cancel(nil) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer func() { wg.Done() }() + var eg errgroup.Group + // send n requests + eg.Go(func() error { for _, req := range requests { select { case <-ctx.Done(): - return + return nil default: } if err := c.stream.Send(req); err != nil { - cancel(sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", err)) - return + return sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", err) } - log.Println("Sent request:", req.Request.Id) } - }() + return nil + }) - responses := make([]*transformpb.SourceTransformResponse, 0, len(requests)) - for i := 0; i < len(requests); i++ { - select { - case <-ctx.Done(): - err := context.Cause(ctx) - log.Println("Context cancelled while receiving:", err) - return nil, err - default: - } - resp, err := c.stream.Recv() - if err != nil { - err = sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Recv", err) - cancel(err) - return nil, err + // receive n responses + responses := make([]*transformpb.SourceTransformResponse, len(requests)) + eg.Go(func() error { + for i := 0; i < len(requests); i++ { + select { + case <-ctx.Done(): + return nil + default: + } + resp, err := c.stream.Recv() + if err != nil { + return sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Recv", err) + } + responses[i] = resp } - log.Println("Received response:", resp.GetId()) - responses = append(responses, resp) - } + return nil + }) - wg.Wait() + // wait for the send and receive goroutines to finish + // if any of the goroutines return an error, the error will be caught here + if err := eg.Wait(); err != nil { + return nil, err + } return responses, nil } diff --git a/pkg/shared/expr/eval_bool.go b/pkg/shared/expr/eval_bool.go index 4a8b95ab08..683fdb9f54 100644 --- a/pkg/shared/expr/eval_bool.go +++ b/pkg/shared/expr/eval_bool.go @@ -22,6 +22,7 @@ import ( "strconv" "github.com/Masterminds/sprig/v3" + "github.com/antonmedv/expr" ) var sprigFuncMap = sprig.GenericFuncMap() @@ -29,19 +30,19 @@ var sprigFuncMap = sprig.GenericFuncMap() const root = "payload" func EvalBool(expression string, msg []byte) (bool, error) { - //msgMap := map[string]interface{}{ - // root: string(msg), - //} - //env := getFuncMap(msgMap) - //result, err := expr.Eval(expression, env) - //if err != nil { - // return false, fmt.Errorf("unable to evaluate expression '%s': %s", expression, err) - //} - //resultBool, ok := result.(bool) - //if !ok { - // return false, fmt.Errorf("unable to cast expression result '%s' to bool", result) - //} - return true, nil + msgMap := map[string]interface{}{ + root: string(msg), + } + env := getFuncMap(msgMap) + result, err := expr.Eval(expression, env) + if err != nil { + return false, fmt.Errorf("unable to evaluate expression '%s': %s", expression, err) + } + resultBool, ok := result.(bool) + if !ok { + return false, fmt.Errorf("unable to cast expression result '%s' to bool", result) + } + return resultBool, nil } func getFuncMap(m map[string]interface{}) map[string]interface{} { diff --git a/pkg/sources/forward/data_forward.go b/pkg/sources/forward/data_forward.go index bf38e349ae..913be67939 100644 --- a/pkg/sources/forward/data_forward.go +++ b/pkg/sources/forward/data_forward.go @@ -20,7 +20,6 @@ import ( "context" "errors" "fmt" - "log" "strconv" "sync" "time" @@ -517,13 +516,7 @@ func (df *DataForward) writeToBuffers( for toVertexName, toVertexMessages := range messageToStep { writeOffsets[toVertexName] = make([][]isb.Offset, len(toVertexMessages)) } - for _, toVertexMessages := range messageToStep { - for idx, messages := range toVertexMessages { - for _, message := range messages { - log.Println(idx, " Writing message to buffer with message id: ", message.ID.String()) - } - } - } + for toVertexName, toVertexBuffer := range df.toBuffers { for index, partition := range toVertexBuffer { writeOffsets[toVertexName][index], err = df.writeToBuffer(ctx, partition, messageToStep[toVertexName][index]) @@ -675,12 +668,6 @@ func (df *DataForward) applyTransformer(ctx context.Context, messages []*isb.Rea } continue } - for _, result := range transformResults { - log.Println("Transformed read message with id: ", result.ReadMessage.ID.String()) - for _, message := range result.WriteMessages { - log.Println("Transformed write message with id: ", message.ID.String()) - } - } return transformResults } } diff --git a/pkg/sources/transformer/builtin/event_time/event_time_extractor.go b/pkg/sources/transformer/builtin/event_time/event_time_extractor.go index 393baf52e5..c844c518db 100644 --- a/pkg/sources/transformer/builtin/event_time/event_time_extractor.go +++ b/pkg/sources/transformer/builtin/event_time/event_time_extractor.go @@ -21,8 +21,10 @@ import ( "fmt" "time" + "github.com/araddon/dateparse" "github.com/numaproj/numaflow-go/pkg/sourcetransformer" + "github.com/numaproj/numaflow/pkg/shared/expr" "github.com/numaproj/numaflow/pkg/shared/logging" ) @@ -65,17 +67,21 @@ func New(args map[string]string) (sourcetransformer.SourceTransformFunc, error) // apply compiles the payload to extract the new event time. If there is any error during extraction, // we pass on the original input event time. Otherwise, we assign the new event time to the message. func (e eventTimeExtractor) apply(payload []byte, et time.Time, keys []string) (sourcetransformer.Message, error) { - //timeStr, err := expr.EvalStr(e.expression, payload) - //if err != nil { - // return sourcetransformer.NewMessage(payload, et).WithKeys(keys), err - //} - // - //var newEventTime time.Time - //time.Local, _ = time.LoadLocation("UTC") - //if e.format != "" { - // newEventTime, err = time.Parse(e.format, timeStr) - //} else { - // newEventTime, err = dateparse.ParseStrict(timeStr) - //} - return sourcetransformer.NewMessage(payload, et).WithKeys(keys), nil + timeStr, err := expr.EvalStr(e.expression, payload) + if err != nil { + return sourcetransformer.NewMessage(payload, et).WithKeys(keys), err + } + + var newEventTime time.Time + time.Local, _ = time.LoadLocation("UTC") + if e.format != "" { + newEventTime, err = time.Parse(e.format, timeStr) + } else { + newEventTime, err = dateparse.ParseStrict(timeStr) + } + if err != nil { + return sourcetransformer.NewMessage(payload, et).WithKeys(keys), err + } else { + return sourcetransformer.NewMessage(payload, newEventTime).WithKeys(keys), nil + } } diff --git a/pkg/sources/transformer/builtin/time_extraction_filter/time_extraction_filter.go b/pkg/sources/transformer/builtin/time_extraction_filter/time_extraction_filter.go index 752c8db563..bb1515cb1e 100644 --- a/pkg/sources/transformer/builtin/time_extraction_filter/time_extraction_filter.go +++ b/pkg/sources/transformer/builtin/time_extraction_filter/time_extraction_filter.go @@ -21,8 +21,10 @@ import ( "fmt" "time" + "github.com/araddon/dateparse" "github.com/numaproj/numaflow-go/pkg/sourcetransformer" + "github.com/numaproj/numaflow/pkg/shared/expr" "github.com/numaproj/numaflow/pkg/shared/logging" ) @@ -67,27 +69,27 @@ func New(args map[string]string) (sourcetransformer.SourceTransformFunc, error) } func (e expressions) apply(et time.Time, payload []byte, keys []string) (sourcetransformer.Message, error) { - //result, err := expr.EvalBool(e.filterExpr, payload) - //if err != nil { - // return sourcetransformer.MessageToDrop(et), err - //} - //if result { - // timeStr, err := expr.EvalStr(e.eventTimeExpr, payload) - // if err != nil { - // return sourcetransformer.NewMessage(payload, et).WithKeys(keys), err - // } - // var newEventTime time.Time - // time.Local, _ = time.LoadLocation("UTC") - // if e.eventTimeFormat != "" { - // newEventTime, err = time.Parse(e.eventTimeFormat, timeStr) - // } else { - // newEventTime, err = dateparse.ParseStrict(timeStr) - // } - // if err != nil { - // return sourcetransformer.NewMessage(payload, et).WithKeys(keys), err - // } else { - // return sourcetransformer.NewMessage(payload, newEventTime).WithKeys(keys), nil - // } - //} - return sourcetransformer.NewMessage(payload, et).WithKeys(keys), nil + result, err := expr.EvalBool(e.filterExpr, payload) + if err != nil { + return sourcetransformer.MessageToDrop(et), err + } + if result { + timeStr, err := expr.EvalStr(e.eventTimeExpr, payload) + if err != nil { + return sourcetransformer.NewMessage(payload, et).WithKeys(keys), err + } + var newEventTime time.Time + time.Local, _ = time.LoadLocation("UTC") + if e.eventTimeFormat != "" { + newEventTime, err = time.Parse(e.eventTimeFormat, timeStr) + } else { + newEventTime, err = dateparse.ParseStrict(timeStr) + } + if err != nil { + return sourcetransformer.NewMessage(payload, et).WithKeys(keys), err + } else { + return sourcetransformer.NewMessage(payload, newEventTime).WithKeys(keys), nil + } + } + return sourcetransformer.MessageToDrop(et), nil } diff --git a/pkg/sources/transformer/grpc_transformer.go b/pkg/sources/transformer/grpc_transformer.go index 41afbddc28..1dbe6e0f8e 100644 --- a/pkg/sources/transformer/grpc_transformer.go +++ b/pkg/sources/transformer/grpc_transformer.go @@ -20,7 +20,6 @@ import ( "context" "errors" "fmt" - "log" "time" v1 "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1" @@ -54,7 +53,7 @@ func (u *GRPCBasedTransformer) IsHealthy(ctx context.Context) error { // WaitUntilReady waits until the client is connected. func (u *GRPCBasedTransformer) WaitUntilReady(ctx context.Context) error { - log := logging.FromContext(ctx) + logger := logging.FromContext(ctx) for { select { case <-ctx.Done(): @@ -63,7 +62,7 @@ func (u *GRPCBasedTransformer) WaitUntilReady(ctx context.Context) error { if _, err := u.client.IsReady(ctx, &emptypb.Empty{}); err == nil { return nil } else { - log.Infof("waiting for transformer to be ready: %v", err) + logger.Infof("waiting for transformer to be ready: %v", err) time.Sleep(1 * time.Second) } } @@ -78,15 +77,16 @@ func (u *GRPCBasedTransformer) CloseConn(ctx context.Context) error { var errSourceTransformFnEmptyMsgId = errors.New("response from SourceTransformFn doesn't contain a message id") func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { - var transformResults []isb.ReadWriteMessagePair - ctx, cancel := context.WithCancel(ctx) - defer cancel() - requests := make([]*v1.SourceTransformRequest, 0, len(messages)) + transformResults := make([]isb.ReadWriteMessagePair, len(messages)) + requests := make([]*v1.SourceTransformRequest, len(messages)) idToMsgMapping := make(map[string]*isb.ReadMessage) - for _, msg := range messages { + + for i, msg := range messages { + // we track the id to the message mapping to be able to match the response with the original message. + // we use the original message's event time if the user doesn't change it. Also we use the original message's + // read offset + index as the id for the response. id := msg.ReadOffset.String() idToMsgMapping[id] = msg - log.Println("Sending message with read offset ID: ", id, " message id: ", msg.ID.String()) req := &v1.SourceTransformRequest{ Request: &v1.SourceTransformRequest_Request{ Keys: msg.Keys, @@ -97,7 +97,7 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*i Id: id, }, } - requests = append(requests, req) + requests[i] = req } responses, err := u.client.SourceTransformFn(ctx, requests) @@ -114,17 +114,16 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*i return nil, err } - for _, resp := range responses { + for i, resp := range responses { parentMessage, ok := idToMsgMapping[resp.GetId()] if !ok { panic("tracker doesn't contain the message ID received from the response") } - var taggedMessages []*isb.WriteMessage + taggedMessages := make([]*isb.WriteMessage, len(resp.GetResults())) for i, result := range resp.GetResults() { keys := result.Keys if result.EventTime != nil { // Transformer supports changing event time. - log.Println("Updating event time from ", parentMessage.MessageInfo.EventTime.UnixMilli(), " to ", result.EventTime.AsTime().UnixMilli()) parentMessage.MessageInfo.EventTime = result.EventTime.AsTime() } taggedMessage := &isb.WriteMessage{ @@ -144,15 +143,14 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*i }, Tags: result.Tags, } - log.Println("Received message with ID: ", taggedMessage.ID.String()) - taggedMessages = append(taggedMessages, taggedMessage) + taggedMessages[i] = taggedMessage } responsePair := isb.ReadWriteMessagePair{ ReadMessage: parentMessage, WriteMessages: taggedMessages, Err: nil, } - transformResults = append(transformResults, responsePair) + transformResults[i] = responsePair } return transformResults, nil } diff --git a/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml b/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml index 91f863dede..8066caf9ec 100644 --- a/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml +++ b/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml @@ -5,24 +5,13 @@ metadata: spec: vertices: - name: in - containerTemplate: - env: - - name: NUMAFLOW_PPROF - value: !!str "true" source: - generator: - # How many messages to generate in the duration. - rpu: 5 - duration: 1s - # Optional, size of each generated message, defaults to 10. - msgSize: 1024 -# transformer: -# container: -# image: quay.io/numaio/numaflow-rs/source-transformer-now:stable -# builtin: -# name: eventTimeExtractor -# kwargs: -# expression: json(payload).item[1].time + http: { } + transformer: + builtin: + name: eventTimeExtractor + kwargs: + expression: json(payload).item[1].time - name: out scale: min: 1 diff --git a/test/transformer-e2e/transformer_test.go b/test/transformer-e2e/transformer_test.go index 8c07db744c..e6b727fcb9 100644 --- a/test/transformer-e2e/transformer_test.go +++ b/test/transformer-e2e/transformer_test.go @@ -21,6 +21,7 @@ package e2e import ( "context" "encoding/json" + "errors" "fmt" "os" "strconv" @@ -142,7 +143,7 @@ wmLoop: for { select { case <-ctx.Done(): - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { s.T().Log("test timed out") assert.Fail(s.T(), "timed out") break wmLoop From 0edc2e5677f9b8a05463cdf1860b56214b15cdb5 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Wed, 2 Oct 2024 08:47:53 +0530 Subject: [PATCH 09/10] Return context cancellation error Signed-off-by: Sreekanth --- pkg/sdkclient/sourcetransformer/client.go | 4 ++-- pkg/sources/transformer/grpc_transformer.go | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pkg/sdkclient/sourcetransformer/client.go b/pkg/sdkclient/sourcetransformer/client.go index 7b1d7f11bb..92372ff7a4 100644 --- a/pkg/sdkclient/sourcetransformer/client.go +++ b/pkg/sdkclient/sourcetransformer/client.go @@ -157,7 +157,7 @@ func (c *client) SourceTransformFn(ctx context.Context, requests []*transformpb. for _, req := range requests { select { case <-ctx.Done(): - return nil + return ctx.Err() default: } if err := c.stream.Send(req); err != nil { @@ -173,7 +173,7 @@ func (c *client) SourceTransformFn(ctx context.Context, requests []*transformpb. for i := 0; i < len(requests); i++ { select { case <-ctx.Done(): - return nil + return ctx.Err() default: } resp, err := c.stream.Recv() diff --git a/pkg/sources/transformer/grpc_transformer.go b/pkg/sources/transformer/grpc_transformer.go index 1dbe6e0f8e..459e99f21b 100644 --- a/pkg/sources/transformer/grpc_transformer.go +++ b/pkg/sources/transformer/grpc_transformer.go @@ -18,7 +18,6 @@ package transformer import ( "context" - "errors" "fmt" "time" @@ -74,8 +73,6 @@ func (u *GRPCBasedTransformer) CloseConn(ctx context.Context) error { return u.client.CloseConn(ctx) } -var errSourceTransformFnEmptyMsgId = errors.New("response from SourceTransformFn doesn't contain a message id") - func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { transformResults := make([]isb.ReadWriteMessagePair, len(messages)) requests := make([]*v1.SourceTransformRequest, len(messages)) From 0b9cede25d785d7dbc52ec211539ea4de2a5cb11 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Wed, 2 Oct 2024 08:53:30 +0530 Subject: [PATCH 10/10] Fix unit test Signed-off-by: Sreekanth --- pkg/sources/transformer/grpc_transformer_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/sources/transformer/grpc_transformer_test.go b/pkg/sources/transformer/grpc_transformer_test.go index 46279bb93a..cd8ccbe852 100644 --- a/pkg/sources/transformer/grpc_transformer_test.go +++ b/pkg/sources/transformer/grpc_transformer_test.go @@ -137,7 +137,7 @@ func TestGRPCBasedTransformer_BasicApplyWithServer(t *testing.T) { expectedUDFErr := &rpc.ApplyUDFErr{ UserUDFErr: false, - Message: "gRPC client.SourceTransformFn failed, NonRetryable: context canceled", + Message: "gRPC client.SourceTransformFn failed, context canceled", InternalErr: rpc.InternalErr{ Flag: true, MainCarDown: false,