Skip to content

Commit

Permalink
test update
Browse files Browse the repository at this point in the history
Signed-off-by: Sidhant Kohli <[email protected]>
  • Loading branch information
Sidhant Kohli committed Jul 15, 2024
1 parent 7e16c60 commit df658c0
Showing 1 changed file with 113 additions and 40 deletions.
153 changes: 113 additions & 40 deletions src/batchmap.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use chrono::{DateTime, Utc};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;

use chrono::{DateTime, Utc};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc::channel;
use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::ReceiverStream;
Expand Down Expand Up @@ -261,9 +262,13 @@ where
let (grpc_response_tx, grpc_response_rx) =
channel::<Result<proto::BatchMapResponse, Status>>(1);

let shutdown_tx = self._shutdown_tx.clone();

// call the user's batch map handle
let batch_map_handle = self.handler.batchmap(rx);
let counter_orig = Arc::new(AtomicUsize::new(0));

let counter = counter_orig.clone();
// write to the user-defined channel
tokio::spawn(async move {
while let Some(next_message) = stream
Expand All @@ -277,64 +282,57 @@ where
tx.send(datum)
.await
.expect("send be successfully received!");
counter.fetch_add(1, Ordering::Relaxed);
}
});

// wait for the sink handle to respond
let responses = batch_map_handle.await;

// TODO(): add the check for length of responses

println!(
"Received responses from the batch map handle {}",
responses.len()
);

let counter2 = counter_orig.clone();
tokio::spawn(async move {
// check if the number of responses is equal to the number of messages received
let num_responses = counter2.load(Ordering::Relaxed);
println!("Number of responses: {}", num_responses);
if num_responses != responses.len() {
grpc_response_tx
.send(Err(Status::internal(
"number of responses does not \
match the number of messages received",
)))
.await
.expect("send to grpc response channel failed");

// Send a shutdown signal to the grpc server.
shutdown_tx.send(()).await.expect("shutdown_tx send failed");
}
// forward the responses
for response in responses {
let mut results: Vec<proto::batch_map_response::Result> = Vec::new();
// convert the response to proto format
for message in response.message {
let resp = proto::batch_map_response::Result {
keys: message.keys.unwrap_or_default(),
value: message.value,
tags: message.tags.unwrap_or_default(),
};
// append the response to vector
results.push(resp)
}
let grpc_resp = crate::batchmap::proto::BatchMapResponse {
results,
id: response.id,
};
// send the response to the grpc client
println!("Sending response to client {}", grpc_resp.id);

let send_result = grpc_response_tx.send(Ok(grpc_resp)).await;

if let Err(e) = send_result {
println!("Sending response to client {}", response.id);
let send_result = grpc_response_tx
.send(Ok(proto::BatchMapResponse {
results: response.message.into_iter().map(|m| m.into()).collect(),
id: response.id,
}))
.await;
// if the send fails, return an error status on the streaming endpoint
if send_result.is_err() {
grpc_response_tx
.send(Err(Status::internal(e.to_string())))
.send(Err(Status::internal(
send_result.err().unwrap().to_string(),
)))
.await
.expect("send to grpc response channel failed");
return;
}
println!("Sent response from grpc");
}
});

// // forward the responses
// for response in responses {
// println!("Sending response to client {}", response.id);
// let send_result = grpc_response_tx.send(Ok(proto::BatchMapResponse {
// results: response.message.into_iter().map(|m| m.into()).collect(),
// id: response.id,
// })).await;
// // if the send fails, return an error status on the streaming endpoint
// if send_result.is_err() {
// return Err(Status::internal("Failed to send response to client"));
// }
// }

// Return the receiver stream to the client
Ok(Response::new(ReceiverStream::new(grpc_response_rx)))
}
Expand Down Expand Up @@ -544,4 +542,79 @@ mod tests {
assert!(task.is_finished(), "gRPC server is still running");
Ok(())
}

#[tokio::test]
async fn error_length() -> Result<(), Box<dyn Error>> {
struct Logger;
#[tonic::async_trait]
impl batchmap::BatchMapper for Logger {
async fn batchmap(&self, mut input: Receiver<Datum>) -> Vec<BatchResponse> {
let mut responses: Vec<BatchResponse> = Vec::new();
while let Some(datum) = input.recv().await {}
responses
}
}

let tmp_dir = TempDir::new()?;
let sock_file = tmp_dir.path().join("batchmap.sock");
let server_info_file = tmp_dir.path().join("batchmapper-server-info");

let mut server = batchmap::Server::new(Logger)
.with_server_info_file(&server_info_file)
.with_socket_file(&sock_file)
.with_max_message_size(10240);

assert_eq!(server.max_message_size(), 10240);
assert_eq!(server.server_info_file(), server_info_file);
assert_eq!(server.socket_file(), sock_file);

let (shutdown_tx, shutdown_rx) = oneshot::channel();
let task = tokio::spawn(async move { server.start_with_shutdown(shutdown_rx).await });

tokio::time::sleep(Duration::from_millis(50)).await;

// https://github.com/hyperium/tonic/blob/master/examples/src/uds/client.rs
let channel = tonic::transport::Endpoint::try_from("http://[::]:50051")?
.connect_with_connector(service_fn(move |_: Uri| {
// https://rust-lang.github.io/async-book/03_async_await/01_chapter.html#async-lifetimes
let sock_file = sock_file.clone();
async move {
Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new(
UnixStream::connect(sock_file).await?,
))
}
}))
.await?;

let mut client = BatchMapClient::new(channel);
let request = batchmap::proto::BatchMapRequest {
keys: vec!["first".into(), "second".into()],
value: "hello".into(),
watermark: Some(prost_types::Timestamp::default()),
event_time: Some(prost_types::Timestamp::default()),
id: "1".to_string(),
headers: Default::default(),
};

let resp = client
.batch_map_fn(tokio_stream::iter(vec![request]))
.await?;
let mut r = resp.into_inner();

let mut error_flag = false;

if let Err(e) = r.message().await {
assert_eq!(e.code(), tonic::Code::Internal);
assert!(e.message().contains(
"number of responses does not \
match the number of messages received"
));
error_flag = true;
}
// Check if the error flag is set
assert!(error_flag, "Expected error from server");
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(task.is_finished(), "gRPC server is still running");
Ok(())
}
}

0 comments on commit df658c0

Please sign in to comment.