Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Bidirectional streaming for source transformer #91

Merged
merged 10 commits into from
Oct 1, 2024
31 changes: 25 additions & 6 deletions proto/sourcetransform.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,38 @@ service SourceTransform {
// SourceTransformFn applies a function to each request element.
// In addition to map function, SourceTransformFn also supports assigning a new event time to response.
// SourceTransformFn can be used only at source vertex by source data transformer.
rpc SourceTransformFn(SourceTransformRequest) returns (SourceTransformResponse);
rpc SourceTransformFn(stream SourceTransformRequest) returns (stream SourceTransformResponse);

// IsReady is the heartbeat endpoint for gRPC.
rpc IsReady(google.protobuf.Empty) returns (ReadyResponse);
}

/*
* Handshake message between client and server to indicate the start of transmission.
*/
message Handshake {
// Required field indicating the start of transmission.
bool sot = 1;
}

/**
* SourceTransformerRequest represents a request element.
*/
message SourceTransformRequest {
repeated string keys = 1;
bytes value = 2;
google.protobuf.Timestamp event_time = 3;
google.protobuf.Timestamp watermark = 4;
map<string, string> headers = 5;
message Request {
repeated string keys = 1;
bytes value = 2;
google.protobuf.Timestamp event_time = 3;
google.protobuf.Timestamp watermark = 4;
map<string, string> headers = 5;
// This ID is used to uniquely identify a transform request
string id = 6;
}
Request request = 1;
optional Handshake handshake = 2;
}


/**
* SourceTransformerResponse represents a response element.
*/
Expand All @@ -37,6 +52,10 @@ message SourceTransformResponse {
repeated string tags = 4;
}
repeated Result results = 1;
// This ID is used to refer the responses to the request it corresponds to.
string id = 2;
// Handshake message between client and server to indicate the start of transmission.
optional Handshake handshake = 3;
}

/**
Expand Down
258 changes: 192 additions & 66 deletions src/sourcetransform.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
use crate::error::Error::SourceTransformerError;
use crate::error::ErrorKind::UserDefinedError;
use crate::shared::{self, prost_timestamp_from_utc};
use crate::error::Error::{self, SourceTransformerError};
use crate::error::ErrorKind;
BulkBeing marked this conversation as resolved.
Show resolved Hide resolved
use crate::shared::{self, prost_timestamp_from_utc, utc_from_timestamp};
use chrono::{DateTime, Utc};
use proto::SourceTransformResponse;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::CancellationToken;
use tonic::{async_trait, Request, Response, Status};
use tonic::{async_trait, Request, Response, Status, Streaming};
use tracing::{error, info};

const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/sourcetransform.sock";
const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcetransformer-server-info";
const DEFAULT_CHANNEL_SIZE: usize = 1000;

const DROP: &str = "U+005C__DROP__";

Expand Down Expand Up @@ -218,54 +223,116 @@ impl From<Message> for proto::source_transform_response::Result {
}
}

impl From<proto::SourceTransformRequest> for SourceTransformRequest {
fn from(value: proto::SourceTransformRequest) -> Self {
Self {
keys: value.keys,
value: value.value,
watermark: shared::utc_from_timestamp(value.watermark),
eventtime: shared::utc_from_timestamp(value.event_time),
headers: value.headers,
}
}
}

#[async_trait]
impl<T> proto::source_transform_server::SourceTransform for SourceTransformerService<T>
where
T: SourceTransformer + Send + Sync + 'static,
{
type SourceTransformFnStream = ReceiverStream<Result<SourceTransformResponse, Status>>;

async fn source_transform_fn(
&self,
request: Request<proto::SourceTransformRequest>,
) -> Result<Response<proto::SourceTransformResponse>, Status> {
let request = request.into_inner();
request: Request<Streaming<proto::SourceTransformRequest>>,
) -> Result<Response<Self::SourceTransformFnStream>, Status> {
let mut stream = request.into_inner();
let handler = Arc::clone(&self.handler);
let handle = tokio::spawn(async move { handler.transform(request.into()).await });
let shutdown_tx = self.shutdown_tx.clone();
let cancellation_token = self.cancellation_token.clone();

// Wait for the handler to finish processing the request. If the server is shutting down(token will be cancelled),
// then return an error.
tokio::select! {
result = handle => {
match result {
Ok(messages) => Ok(Response::new(proto::SourceTransformResponse {
results: messages.into_iter().map(|msg| msg.into()).collect(),
})),
Err(e) => {
tracing::error!("Error in source transform handler: {:?}", e);
// Send a shutdown signal to the server to do a graceful shutdown because there was
// a panic in the handler.
shutdown_tx.send(()).await.expect("Sending shutdown signal to gRPC server");
Err(Status::internal(SourceTransformerError(UserDefinedError(e.to_string())).to_string()))

let (tx, rx) =
mpsc::channel::<Result<SourceTransformResponse, Status>>(DEFAULT_CHANNEL_SIZE);

// do the handshake first to let the client know that we are ready to receive read requests.
let handshake_request = stream
BulkBeing marked this conversation as resolved.
Show resolved Hide resolved
.message()
.await
.map_err(|e| Status::internal(format!("handshake failed {}", e)))?
.ok_or_else(|| Status::internal("stream closed before handshake"))?;

if let Some(handshake) = handshake_request.handshake {
tx.send(Ok(SourceTransformResponse {
results: vec![],
id: "".to_string(),
handshake: Some(handshake),
}))
.await
.map_err(|e| Status::internal(format!("failed to send handshake response {}", e)))?;
} else {
return Err(Status::invalid_argument("Handshake not present"));
}

let handle: JoinHandle<Result<(), Error>> = tokio::spawn({
let shutdown_tx = self.shutdown_tx.clone();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please document your though

let cancellation_token = self.cancellation_token.clone();
let tx = tx.clone();
async move {
loop {
tokio::select! {
transform_request = stream.message() => {
yhl25 marked this conversation as resolved.
Show resolved Hide resolved
let transform_request = transform_request.map_err(|e| SourceTransformerError(ErrorKind::InternalError(e.to_string())))?
.ok_or_else(||SourceTransformerError(ErrorKind::InternalError("Stream closed".to_string())))?;

let Some(request) = transform_request.request else {
return Err(SourceTransformerError(ErrorKind::InternalError("Transform request can not be none".to_string())));
};

let message_id = request.id.clone();
let handler_input = SourceTransformRequest{
keys: request.keys,
value: request.value,
watermark: utc_from_timestamp(request.watermark),
eventtime: utc_from_timestamp(request.event_time),
headers: request.headers
};

BulkBeing marked this conversation as resolved.
Show resolved Hide resolved
let handler = handler.clone();
// let messages = handler.transform(handler_input).await;
let udf_tranform_task = tokio::spawn(async move { handler.transform(handler_input).await });
BulkBeing marked this conversation as resolved.
Show resolved Hide resolved
let messages = tokio::select! {
result = udf_tranform_task => {
match result {
Ok(messages) => messages,
Err(e) => {
tracing::error!("Failed to run transform function: {e:?}");
// Send a shutdown signal to the server to do a graceful shutdown because there was
// a panic in the handler.
shutdown_tx.send(()).await.expect("Sending shutdown signal to gRPC server");
return Err(SourceTransformerError(ErrorKind::UserDefinedError("panic in transform UDF".to_string())));
}
}
}
};
tx.send(Ok(SourceTransformResponse{
results: messages.into_iter().map(|msg| msg.into()).collect(),
id: message_id,
handshake: None,
})).await.expect("sending messages to the client over gRPC channel");

}
_ = cancellation_token.cancelled() => {
info!("Cancellation token is cancelled, shutting down");
break;
}
}
}
},
_ = cancellation_token.cancelled() => {
Err(Status::internal(SourceTransformerError(UserDefinedError("Server is shutting down".to_string())).to_string()))
},
}
Ok(())
}
});

let shutdown_tx = self.shutdown_tx.clone();
tokio::spawn(async move {
let Err(e) = handle.await else {
return;
};
error!("Shutting down gRPC channel: {e:?}");
tx.send(Err(Status::internal(e.to_string())))
.await
.expect("Sending error message to gRPC response channel");
shutdown_tx
.send(())
.await
.expect("Writing to shutdown channel");
});

Ok(Response::new(ReceiverStream::new(rx)))
}

async fn is_ready(&self, _: Request<()>) -> Result<Response<proto::ReadyResponse>, Status> {
Expand Down Expand Up @@ -390,11 +457,13 @@ mod tests {

use tempfile::TempDir;
use tokio::net::UnixStream;
use tokio::sync::oneshot;
use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::ReceiverStream;
use tonic::transport::Uri;
use tower::service_fn;

use crate::sourcetransform;
use crate::sourcetransform::proto;
use crate::sourcetransform::proto::source_transform_client::SourceTransformClient;

#[tokio::test]
Expand Down Expand Up @@ -447,21 +516,59 @@ mod tests {
.await?;

let mut client = SourceTransformClient::new(channel);
let request = tonic::Request::new(sourcetransform::proto::SourceTransformRequest {
keys: vec!["first".into(), "second".into()],
value: "hello".into(),
watermark: Some(prost_types::Timestamp::default()),
event_time: Some(prost_types::Timestamp::default()),
headers: Default::default(),
});

let resp = client.source_transform_fn(request).await?;
let resp = resp.into_inner();
let (tx, rx) = mpsc::channel(2);

let handshake_request = proto::SourceTransformRequest {
request: None,
handshake: Some(proto::Handshake { sot: true }),
};
tx.send(handshake_request).await.unwrap();

let mut stream = tokio::time::timeout(
Duration::from_secs(2),
client.source_transform_fn(ReceiverStream::new(rx)),
)
.await
.map_err(|_| "timeout while getting stream for source_transform_fn")??
.into_inner();

let handshake_resp = stream.message().await?.unwrap();
assert!(
handshake_resp.results.is_empty(),
"The handshake response should not contain any messages"
);
assert!(
handshake_resp.id.is_empty(),
"The message id of the handshake response should be empty"
);
assert!(
handshake_resp.handshake.is_some(),
"Not a valid response for handshake request"
);

let request = sourcetransform::proto::SourceTransformRequest {
request: Some(proto::source_transform_request::Request {
id: "1".to_string(),
keys: vec!["first".into(), "second".into()],
value: "hello".into(),
watermark: Some(prost_types::Timestamp::default()),
event_time: Some(prost_types::Timestamp::default()),
headers: Default::default(),
}),
handshake: None,
};

tx.send(request).await.unwrap();

let resp = stream.message().await?.unwrap();
assert_eq!(resp.results.len(), 1, "Expected single message from server");
let msg = &resp.results[0];
assert_eq!(msg.keys.first(), Some(&"first".to_owned()));
assert_eq!(msg.value, "hello".as_bytes());

drop(tx);

shutdown_tx
.send(())
.expect("Sending shutdown signal to gRPC server");
Expand Down Expand Up @@ -515,21 +622,40 @@ mod tests {
.await?;

let mut client = SourceTransformClient::new(channel);
let request = tonic::Request::new(sourcetransform::proto::SourceTransformRequest {
keys: vec!["first".into(), "second".into()],
value: "hello".into(),
watermark: Some(prost_types::Timestamp::default()),
event_time: Some(prost_types::Timestamp::default()),
headers: Default::default(),
});

let resp = client.source_transform_fn(request).await;
assert!(resp.is_err(), "Expected error from server");

if let Err(e) = resp {
assert_eq!(e.code(), tonic::Code::Internal);
assert!(e.message().contains("User Defined Error"));
}
let (tx, rx) = mpsc::channel(2);
let handshake_request = proto::SourceTransformRequest {
request: None,
handshake: Some(proto::Handshake { sot: true }),
};
tx.send(handshake_request).await.unwrap();

let mut stream = tokio::time::timeout(
Duration::from_secs(2),
client.source_transform_fn(ReceiverStream::new(rx)),
)
.await
.map_err(|_| "timeout while getting stream for source_transform_fn")??
.into_inner();

let handshake_resp = stream.message().await?.unwrap();
assert!(
handshake_resp.handshake.is_some(),
"Not a valid response for handshake request"
);

let request = proto::SourceTransformRequest {
request: Some(proto::source_transform_request::Request {
id: "1".to_string(),
keys: vec!["first".into(), "second".into()],
value: "hello".into(),
watermark: Some(prost_types::Timestamp::default()),
event_time: Some(prost_types::Timestamp::default()),
headers: Default::default(),
}),
handshake: None,
};
tx.send(request).await.unwrap();

// server should shut down gracefully because there was a panic in the handler.
for _ in 0..10 {
Expand Down
Loading