From 5af45242d20674e79a68ae09f92345ac4cc3184d Mon Sep 17 00:00:00 2001 From: Florian Lemaitre Date: Thu, 14 Nov 2024 18:20:02 +0100 Subject: [PATCH] [WIP] Server tests --- packages/rust/armonik/Cargo.lock | 1 + packages/rust/armonik/Cargo.toml | 13 +- packages/rust/armonik/examples/client.rs | 20 - packages/rust/armonik/examples/server.rs | 128 ------- .../objects/agent/create_results_metadata.rs | 9 +- packages/rust/armonik/src/server/events.rs | 4 +- packages/rust/armonik/src/server/mod.rs | 32 +- packages/rust/armonik/src/server/results.rs | 8 +- packages/rust/armonik/src/server/submitter.rs | 8 +- packages/rust/armonik/tests/agent.rs | 351 ++++++++++++++++++ packages/rust/armonik/tests/applications.rs | 58 +++ packages/rust/armonik/tests/auth.rs | 43 +++ packages/rust/armonik/tests/common/mod.rs | 19 + packages/rust/armonik/tests/events.rs | 106 ++++++ packages/rust/armonik/tests/sessions.rs | 347 +++++++++++++++++ 15 files changed, 971 insertions(+), 176 deletions(-) delete mode 100644 packages/rust/armonik/examples/client.rs delete mode 100644 packages/rust/armonik/examples/server.rs create mode 100644 packages/rust/armonik/tests/agent.rs create mode 100644 packages/rust/armonik/tests/applications.rs create mode 100644 packages/rust/armonik/tests/auth.rs create mode 100644 packages/rust/armonik/tests/common/mod.rs create mode 100644 packages/rust/armonik/tests/events.rs create mode 100644 packages/rust/armonik/tests/sessions.rs diff --git a/packages/rust/armonik/Cargo.lock b/packages/rust/armonik/Cargo.lock index 6313ab959..8a0517650 100644 --- a/packages/rust/armonik/Cargo.lock +++ b/packages/rust/armonik/Cargo.lock @@ -36,6 +36,7 @@ checksum = "74f37166d7d48a0284b99dd824694c26119c700b53bf0d1540cdb147dbdaaf13" name = "armonik" version = "3.21.0-beta-0" dependencies = [ + "async-stream", "eyre", "futures", "http-body-util", diff --git a/packages/rust/armonik/Cargo.toml b/packages/rust/armonik/Cargo.toml index 60a65bc0e..d25a1c92b 100644 --- a/packages/rust/armonik/Cargo.toml +++ b/packages/rust/armonik/Cargo.toml @@ -48,6 +48,7 @@ hyper-util = { version = "0.1", features = ["client", "http1"] } http-body-util = "0.1" serde_json = "1.0" serial_test = "3.1" +async-stream = "0.3" tokio = { version = "1.41", features = [ "rt-multi-thread", "macros", @@ -58,10 +59,10 @@ tokio = { version = "1.41", features = [ [build-dependencies] tonic-build = "0.12" -[[example]] -name = "client" -required-features = ["client"] +[[test]] +name = "sessions" +required-features = ["client", "server"] -[[example]] -name = "server" -required-features = ["server"] +[[test]] +name = "agent" +required-features = ["client", "server"] diff --git a/packages/rust/armonik/examples/client.rs b/packages/rust/armonik/examples/client.rs deleted file mode 100644 index 0e4c27d43..000000000 --- a/packages/rust/armonik/examples/client.rs +++ /dev/null @@ -1,20 +0,0 @@ -use tracing_subscriber::{prelude::*, EnvFilter}; - -#[tokio::main] -async fn main() -> Result<(), eyre::Report> { - tracing_subscriber::registry() - .with(tracing_subscriber::fmt::layer()) - .with(EnvFilter::from_default_env()) - .init(); - let client = armonik::Client::new().await?; - - let session = tokio::time::timeout( - tokio::time::Duration::from_secs(1), - client.sessions().create([""], Default::default()), - ) - .await??; - - println!("Created session {session} using partition"); - - Ok(()) -} diff --git a/packages/rust/armonik/examples/server.rs b/packages/rust/armonik/examples/server.rs deleted file mode 100644 index 3465b4286..000000000 --- a/packages/rust/armonik/examples/server.rs +++ /dev/null @@ -1,128 +0,0 @@ -use std::sync::Arc; - -use tokio_util::sync::CancellationToken; -use tracing_subscriber::{prelude::*, EnvFilter}; - -use armonik::server::SessionsServiceExt; -use armonik::sessions; - -pub struct Server; - -impl armonik::server::SessionsService for Server { - /// Get a sessions list using pagination, filters and sorting. - async fn list( - self: Arc, - _request: sessions::list::Request, - _cancellation_token: CancellationToken, - ) -> std::result::Result { - todo!() - } - - /// Get a session by its id. - async fn get( - self: Arc, - _request: sessions::get::Request, - _cancellation_token: CancellationToken, - ) -> std::result::Result { - todo!() - } - - /// Cancel a session by its id. - async fn cancel( - self: Arc, - _request: sessions::cancel::Request, - _cancellation_token: CancellationToken, - ) -> std::result::Result { - todo!() - } - - /// Create a session - async fn create( - self: Arc, - _request: sessions::create::Request, - cancellation_token: CancellationToken, - ) -> std::result::Result { - tracing::info!("create called"); - if let Some(()) = cancellation_token - .run_until_cancelled(tokio::time::sleep(tokio::time::Duration::from_secs(2))) - .await - { - tracing::info!("create returned"); - Ok(sessions::create::Response { - session_id: String::from("abc"), - }) - } else { - tracing::info!("client cancelled RPC"); - tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; - tracing::info!("future still running"); - Err(tonic::Status::aborted("client cancelled RPC")) - } - } - - /// Pause a session by its id. - async fn pause( - self: Arc, - _request: sessions::pause::Request, - _cancellation_token: CancellationToken, - ) -> std::result::Result { - todo!() - } - - /// Resume a paused session by its id. - async fn resume( - self: Arc, - _request: sessions::resume::Request, - _cancellation_token: CancellationToken, - ) -> std::result::Result { - todo!() - } - - /// Close a session by its id. - async fn close( - self: Arc, - _request: sessions::close::Request, - _cancellation_token: CancellationToken, - ) -> std::result::Result { - todo!() - } - - /// Purge a session by its id. Removes Results data. - async fn purge( - self: Arc, - _request: sessions::purge::Request, - _cancellation_token: CancellationToken, - ) -> std::result::Result { - todo!() - } - - /// Delete a session by its id. Removes metadata from Results, Sessions and Tasks associated to the session. - async fn delete( - self: Arc, - _request: sessions::delete::Request, - _cancellation_token: CancellationToken, - ) -> std::result::Result { - todo!() - } - - /// Stops clients and/or workers from submitting new tasks in the given session. - async fn stop_submission( - self: Arc, - _request: sessions::stop_submission::Request, - _cancellation_token: CancellationToken, - ) -> std::result::Result { - todo!() - } -} - -#[tokio::main] -pub async fn main() -> Result<(), eyre::Report> { - tracing_subscriber::registry() - .with(tracing_subscriber::fmt::layer()) - .with(EnvFilter::from_default_env()) - .init(); - tonic::transport::Server::builder() - .add_service(Server.sessions_server()) - .serve("127.0.0.1:3456".parse()?) - .await?; - Ok(()) -} diff --git a/packages/rust/armonik/src/objects/agent/create_results_metadata.rs b/packages/rust/armonik/src/objects/agent/create_results_metadata.rs index 0e7724056..524e346f8 100644 --- a/packages/rust/armonik/src/objects/agent/create_results_metadata.rs +++ b/packages/rust/armonik/src/objects/agent/create_results_metadata.rs @@ -62,7 +62,14 @@ impl From for v3::agent::CreateResultsMetaDataResponse { fn from(value: Response) -> Self { Self { communication_token: value.communication_token, - results: value.results.into_values().map(Into::into).collect(), + results: value + .results + .into_iter() + .map(|(k, v)| { + debug_assert_eq!(k, v.name); + v.into() + }) + .collect(), } } } diff --git a/packages/rust/armonik/src/server/events.rs b/packages/rust/armonik/src/server/events.rs index d49ef9c6c..26d93a096 100644 --- a/packages/rust/armonik/src/server/events.rs +++ b/packages/rust/armonik/src/server/events.rs @@ -32,9 +32,7 @@ impl EventsServiceExt for T { #[crate::reexports::async_trait] impl v3::events::events_server::Events for T { - type GetEventsStream = crate::reexports::tokio_stream::wrappers::ReceiverStream< - Result, - >; + type GetEventsStream = crate::server::ServerStream; async fn get_events( self: Arc, request: tonic::Request, diff --git a/packages/rust/armonik/src/server/mod.rs b/packages/rust/armonik/src/server/mod.rs index 2edbb3cfa..00a8e9442 100644 --- a/packages/rust/armonik/src/server/mod.rs +++ b/packages/rust/armonik/src/server/mod.rs @@ -5,10 +5,10 @@ mod events; mod partitions; mod results; mod sessions; +mod submitter; mod tasks; mod versions; mod worker; -mod submitter; pub use agent::{AgentService, AgentServiceExt}; pub use applications::{ApplicationsService, ApplicationsServiceExt}; @@ -17,10 +17,10 @@ pub use events::{EventsService, EventsServiceExt}; pub use partitions::{PartitionsService, PartitionsServiceExt}; pub use results::{ResultsService, ResultsServiceExt}; pub use sessions::{SessionsService, SessionsServiceExt}; +pub use submitter::{SubmitterService, SubmitterServiceExt}; pub use tasks::{TasksService, TasksServiceExt}; pub use versions::{VersionsService, VersionsServiceExt}; pub use worker::{WorkerService, WorkerServiceExt}; -pub use submitter::{SubmitterService, SubmitterServiceExt}; macro_rules! define_trait_methods { (trait $name:ident {$($(#[$attr:meta])* fn $service:ident::$method:ident ;)* $(--- $($body:tt)*)?}) => { @@ -56,7 +56,7 @@ macro_rules! impl_trait_methods { (unary ($self:ident, $request:ident) { $inner:path }) => { { let ct = tokio_util::sync::CancellationToken::new(); - let _cancel_guard = ct.clone().drop_guard(); + let _drop_guard = ct.clone().drop_guard(); let fut = tokio::spawn(async move { $inner($self, $request.into_inner().into(), ct).await}); match fut.await { Ok(Ok(res)) => Ok(tonic::Response::new(res.into())), @@ -68,7 +68,7 @@ macro_rules! impl_trait_methods { (stream client ($self:ident, $request:ident) { $inner:path }) => { { let ct = tokio_util::sync::CancellationToken::new(); - let _cancel_guard = ct.clone().drop_guard(); + let _drop_guard = ct.clone().drop_guard(); let fut = tokio::spawn(async move { $inner( $self, @@ -86,7 +86,7 @@ macro_rules! impl_trait_methods { (stream server ($self:ident, $request:ident) { $inner:path }) => { { let ct = tokio_util::sync::CancellationToken::new(); - let _cancel_guard = ct.clone().drop_guard(); + let drop_guard = ct.clone().drop_guard(); let fut = tokio::spawn(async move { $inner($self, $request.into_inner().into(), ct).await }); match fut.await { Ok(Ok(stream)) => { @@ -100,7 +100,10 @@ macro_rules! impl_trait_methods { }); Ok(tonic::Response::new( - crate::reexports::tokio_stream::wrappers::ReceiverStream::new(rx), + crate::server::ServerStream{ + receiver: rx, + drop_guard, + }, )) } Ok(Err(err)) => Err(err), @@ -110,5 +113,22 @@ macro_rules! impl_trait_methods { }; } +pub struct ServerStream { + receiver: tokio::sync::mpsc::Receiver>, + #[allow(unused)] + drop_guard: tokio_util::sync::DropGuard, +} + +impl crate::reexports::tokio_stream::Stream for ServerStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.receiver.poll_recv(cx) + } +} + use define_trait_methods; use impl_trait_methods; diff --git a/packages/rust/armonik/src/server/results.rs b/packages/rust/armonik/src/server/results.rs index 798670acb..521c20116 100644 --- a/packages/rust/armonik/src/server/results.rs +++ b/packages/rust/armonik/src/server/results.rs @@ -87,9 +87,7 @@ super::impl_trait_methods! { crate::server::impl_trait_methods!(stream client (self, request) {ResultsService::upload}) } - type DownloadResultDataStream = crate::reexports::tokio_stream::wrappers::ReceiverStream< - Result, - >; + type DownloadResultDataStream = crate::server::ServerStream; async fn download_result_data( self: std::sync::Arc, request: tonic::Request, @@ -100,9 +98,7 @@ super::impl_trait_methods! { super::impl_trait_methods!(stream server (self, request) {ResultsService::download}) } - type WatchResultsStream = crate::reexports::tokio_stream::wrappers::ReceiverStream< - Result, - >; + type WatchResultsStream = crate::server::ServerStream; async fn watch_results( self: std::sync::Arc, _request: tonic::Request>, diff --git a/packages/rust/armonik/src/server/submitter.rs b/packages/rust/armonik/src/server/submitter.rs index fdfdbc58d..b2e17502e 100644 --- a/packages/rust/armonik/src/server/submitter.rs +++ b/packages/rust/armonik/src/server/submitter.rs @@ -89,9 +89,7 @@ super::impl_trait_methods! { } - type TryGetResultStreamStream = crate::reexports::tokio_stream::wrappers::ReceiverStream< - Result, - >; + type TryGetResultStreamStream = crate::server::ServerStream; async fn try_get_result_stream( self: std::sync::Arc, request: tonic::Request, @@ -103,9 +101,7 @@ super::impl_trait_methods! { } - type WatchResultsStream = crate::reexports::tokio_stream::wrappers::ReceiverStream< - Result, - >; + type WatchResultsStream = crate::server::ServerStream; async fn watch_results( self: std::sync::Arc, _request: tonic::Request>, diff --git a/packages/rust/armonik/tests/agent.rs b/packages/rust/armonik/tests/agent.rs new file mode 100644 index 000000000..a6a844cf0 --- /dev/null +++ b/packages/rust/armonik/tests/agent.rs @@ -0,0 +1,351 @@ +use std::sync::Arc; + +use armonik::{agent, reexports::tokio_stream::StreamExt, server::AgentServiceExt}; + +mod common; + +#[derive(Debug, Clone, Default)] +struct Service { + failure: Option, + wait: Option, +} + +impl armonik::server::AgentService for Service { + async fn create_results_metadata( + self: Arc, + request: agent::create_results_metadata::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(agent::create_results_metadata::Response { + communication_token: request.communication_token, + results: request + .names + .into_iter() + .map(|name| { + ( + name.clone(), + agent::ResultMetaData { + session_id: String::from("rpc-create-results-metadata-output"), + name, + ..Default::default() + }, + ) + }) + .collect(), + }) + }, + ) + .await + } + + async fn create_results( + self: Arc, + request: agent::create_results::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(agent::create_results::Response { + communication_token: request.communication_token, + results: request + .results + .into_iter() + .map(|(name, _)| { + eprintln!("NAME: {name}"); + ( + name.clone(), + agent::ResultMetaData { + name, + session_id: String::from("rpc-create-results-output"), + ..Default::default() + }, + ) + }) + .collect(), + }) + }, + ) + .await + } + + async fn notify_result_data( + self: Arc, + request: agent::notify_result_data::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(agent::notify_result_data::Response { + result_ids: vec![ + request.communication_token, + String::from("rpc-notify-result-data-output"), + ], + }) + }, + ) + .await + } + + async fn submit_tasks( + self: Arc, + request: agent::submit_tasks::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(agent::submit_tasks::Response { + communication_token: request.communication_token, + items: vec![agent::submit_tasks::ResponseItem { + task_id: String::from("rpc-submit-tasks-output"), + ..Default::default() + }], + }) + }, + ) + .await + } + + async fn get_resource_data( + self: Arc, + _request: agent::get_resource_data::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(agent::get_resource_data::Response { + result_id: String::from("rpc-get-resource-data-output"), + }) + }, + ) + .await + } + + async fn get_common_data( + self: Arc, + _request: agent::get_common_data::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(agent::get_common_data::Response { + result_id: String::from("rpc-get-common-data-output"), + }) + }, + ) + .await + } + + async fn get_direct_data( + self: Arc, + _request: agent::get_direct_data::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(agent::get_direct_data::Response { + result_id: String::from("rpc-get-direct-data-output"), + }) + }, + ) + .await + } + + async fn create_tasks( + self: Arc, + request: impl tonic::codegen::tokio_stream::Stream< + Item = Result, + > + Send + + 'static, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> Result { + let mut request = std::pin::pin!(request); + let mut token = None; + loop { + match request.next().await { + Some(Ok(agent::create_tasks::Request::InitTaskRequest { + communication_token, + .. + })) => { + token = Some(communication_token); + } + Some(Ok(_)) => {} + Some(Err(err)) => return Err(err), + None => break, + } + } + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(agent::create_tasks::Response::Status { + communication_token: token.unwrap_or_default(), + statuses: vec![agent::create_tasks::Status::TaskInfo { + task_id: String::from("rpc-create-tasks-output"), + expected_output_keys: vec![], + data_dependencies: vec![], + payload_id: String::new(), + }], + }) + }, + ) + .await + } +} + +#[tokio::test] +async fn create_results_metadata() { + let mut client = armonik::Client::with_channel(Service::default().agent_server()).agent(); + + let response = client + .create_results_metadata("rpc-create-results-metadata-input", "", ["result-id"]) + .await + .unwrap(); + + assert_eq!( + response["result-id"].session_id, + "rpc-create-results-metadata-output" + ); +} + +#[tokio::test] +async fn create_results() { + let mut client = armonik::Client::with_channel(Service::default().agent_server()).agent(); + + let response = client + .create_results("rpc-create-results-input", "", [("result-id", b"")]) + .await + .unwrap(); + + assert_eq!( + response["result-id"].session_id, + "rpc-create-results-output" + ); +} + +#[tokio::test] +async fn notify_result_data() { + let mut client = armonik::Client::with_channel(Service::default().agent_server()).agent(); + + let response = client + .notify_result_data("rpc-notify-result-data-input", "", [""]) + .await + .unwrap(); + + assert_eq!(response[0], "rpc-notify-result-data-input"); + assert_eq!(response[1], "rpc-notify-result-data-output"); +} + +#[tokio::test] +async fn submit_tasks() { + let mut client: armonik::client::AgentClient< + armonik::api::v3::agent::agent_server::AgentServer, + > = armonik::Client::with_channel(Service::default().agent_server()).agent(); + + let response = client + .submit_tasks("rpc-submit-tasks-input", "", None, []) + .await + .unwrap(); + + assert_eq!(response[0].task_id, "rpc-submit-tasks-output"); +} + +#[tokio::test] +async fn get_resource_data() { + let mut client: armonik::client::AgentClient< + armonik::api::v3::agent::agent_server::AgentServer, + > = armonik::Client::with_channel(Service::default().agent_server()).agent(); + + let response = client + .call(agent::get_resource_data::Request { + communication_token: String::from("rpc-get-resource-data-input"), + result_id: String::from(""), + }) + .await + .unwrap(); + + assert_eq!(response.result_id, "rpc-get-resource-data-output"); +} + +#[tokio::test] +async fn get_common_data() { + let mut client: armonik::client::AgentClient< + armonik::api::v3::agent::agent_server::AgentServer, + > = armonik::Client::with_channel(Service::default().agent_server()).agent(); + + let response = client + .call(agent::get_common_data::Request { + communication_token: String::from("rpc-get-common-data-input"), + result_id: String::from(""), + }) + .await + .unwrap(); + + assert_eq!(response.result_id, "rpc-get-common-data-output"); +} + +#[tokio::test] +async fn get_direct_data() { + let mut client: armonik::client::AgentClient< + armonik::api::v3::agent::agent_server::AgentServer, + > = armonik::Client::with_channel(Service::default().agent_server()).agent(); + + let response = client + .call(agent::get_direct_data::Request { + communication_token: String::from("rpc-get-direct-data-input"), + result_id: String::from(""), + }) + .await + .unwrap(); + + assert_eq!(response.result_id, "rpc-get-direct-data-output"); +} + +#[tokio::test] +async fn create_tasks() { + let mut client: armonik::client::AgentClient< + armonik::api::v3::agent::agent_server::AgentServer, + > = armonik::Client::with_channel(Service::default().agent_server()).agent(); + + let response = client + .create_tasks(futures::stream::iter([ + agent::create_tasks::Request::InitRequest { + communication_token: String::from("rpc-create-tasks-input"), + request: agent::create_tasks::InitRequest { task_options: None }, + }, + ])) + .await + .unwrap(); + + match &response[0] { + agent::create_tasks::Status::TaskInfo { task_id, .. } => { + assert_eq!(task_id, "rpc-create-tasks-output"); + } + agent::create_tasks::Status::Error(err) => { + panic!("Expected TaskInfo, but got Error({err})") + } + } +} diff --git a/packages/rust/armonik/tests/applications.rs b/packages/rust/armonik/tests/applications.rs new file mode 100644 index 000000000..c772351b1 --- /dev/null +++ b/packages/rust/armonik/tests/applications.rs @@ -0,0 +1,58 @@ +use std::sync::Arc; + +use armonik::{applications, server::ApplicationsServiceExt}; + +mod common; + +#[derive(Debug, Clone, Default)] +struct Service { + failure: Option, + wait: Option, +} + +impl armonik::server::ApplicationsService for Service { + async fn list( + self: Arc, + request: applications::list::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(applications::list::Response { + applications: vec![applications::Raw { + name: String::from("rpc-list-output"), + ..Default::default() + }], + page: request.page, + page_size: request.page_size, + total: 1337, + }) + }, + ) + .await + } +} + +#[tokio::test] +async fn list() { + let mut client = + armonik::Client::with_channel(Service::default().applications_server()).applications(); + + let response = client + .list( + armonik::applications::filter::Or::default(), + armonik::applications::Sort::default(), + 3, + 12, + ) + .await + .unwrap(); + + assert_eq!(response.page, 3); + assert_eq!(response.page_size, 12); + assert_eq!(response.total, 1337); + assert_eq!(response.applications[0].name, "rpc-list-output"); +} diff --git a/packages/rust/armonik/tests/auth.rs b/packages/rust/armonik/tests/auth.rs new file mode 100644 index 000000000..8340f72ce --- /dev/null +++ b/packages/rust/armonik/tests/auth.rs @@ -0,0 +1,43 @@ +use std::sync::Arc; + +use armonik::{auth, server::AuthServiceExt}; + +mod common; + +#[derive(Debug, Clone, Default)] +struct Service { + failure: Option, + wait: Option, +} + +impl armonik::server::AuthService for Service { + async fn current_user( + self: Arc, + _request: auth::current_user::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(auth::current_user::Response { + user: auth::User { + username: String::from("rpc-current-user-output"), + ..Default::default() + }, + }) + }, + ) + .await + } +} + +#[tokio::test] +async fn current_user() { + let mut client = armonik::Client::with_channel(Service::default().auth_server()).auth(); + + let response = client.current_user().await.unwrap(); + + assert_eq!(response.username, "rpc-current-user-output"); +} diff --git a/packages/rust/armonik/tests/common/mod.rs b/packages/rust/armonik/tests/common/mod.rs new file mode 100644 index 000000000..d1cf592c1 --- /dev/null +++ b/packages/rust/armonik/tests/common/mod.rs @@ -0,0 +1,19 @@ +#[allow(unused)] +pub(crate) async fn unary_rpc_impl( + duration: Option, + failure: Option, + cancellation_token: tokio_util::sync::CancellationToken, + response: impl FnOnce() -> Result, +) -> Result { + if let Some(duration) = duration { + cancellation_token + .run_until_cancelled(tokio::time::sleep(duration)) + .await + .ok_or(tonic::Status::cancelled("Request has been cancelled"))?; + } + if let Some(failure) = failure { + Err(failure) + } else { + response() + } +} diff --git a/packages/rust/armonik/tests/events.rs b/packages/rust/armonik/tests/events.rs new file mode 100644 index 000000000..c5d40cf67 --- /dev/null +++ b/packages/rust/armonik/tests/events.rs @@ -0,0 +1,106 @@ +use std::sync::Arc; + +use armonik::{events, reexports::tokio_stream::StreamExt, server::EventsServiceExt}; + +mod common; + +struct Service { + failure: Option, + wait: Option, + dropped: tokio_util::sync::CancellationToken, +} + +impl armonik::server::EventsService for Service { + async fn subscribe( + self: Arc, + request: events::subscribe::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> Result< + impl tonic::codegen::tokio_stream::Stream< + Item = Result, + > + Send, + tonic::Status, + > { + let end_ct = self.dropped.clone(); + Ok(async_stream::try_stream! { + let _drop_guard = end_ct.drop_guard(); + loop { + if let Some(duration) = self.wait.clone() { + cancellation_token + .run_until_cancelled(tokio::time::sleep(duration)) + .await + .ok_or(tonic::Status::cancelled("Request has been cancelled"))?; + } else if cancellation_token.is_cancelled() { + eprintln!("cancelled"); + Err(tonic::Status::cancelled("Request has been cancelled"))?; + } + + if let Some(failure) = self.failure.clone() { + eprintln!("failure: {failure:?}"); + Err(failure)? + } + + eprintln!("event"); + yield events::subscribe::Response{ + session_id: request.session_id.clone(), + update: events::Update::NewResult(events::NewResult { + result_id: String::from("rpc-subscribe-output"), + ..Default::default() + }), + }; + } + }) + } +} + +#[tokio::test] +async fn subscribe() { + let cancellation_token = tokio_util::sync::CancellationToken::new(); + let mut client = armonik::Client::with_channel( + Service { + failure: None, + wait: None, + dropped: cancellation_token.clone(), + } + .events_server(), + ) + .events(); + + let mut response = client + .subscribe( + "rpc-subscribe-input", + armonik::tasks::filter::Or::default(), + armonik::results::filter::Or::default(), + [events::EventsEnum::Unspecified], + ) + .await + .unwrap(); + + let event = response.next().await.unwrap().unwrap(); + + assert_eq!(event.session_id, "rpc-subscribe-input"); + match event.update { + events::Update::NewResult(new_result) => { + assert_eq!(new_result.result_id, "rpc-subscribe-output") + } + event => panic!("expected a NewResult, but got {event:?}"), + } + + match response.next().await { + Some(Ok(event)) => eprintln!("Got event: {event:?}"), + Some(Err(err)) => eprintln!("Got error: {err:?}"), + None => { + eprintln!("Got end of stream"); + } + } + + std::mem::drop(response); + + if cancellation_token + .run_until_cancelled(tokio::time::sleep(tokio::time::Duration::from_millis(10))) + .await + .is_some() + { + panic!("Expected a cancel, but got a timeout"); + } +} diff --git a/packages/rust/armonik/tests/sessions.rs b/packages/rust/armonik/tests/sessions.rs new file mode 100644 index 000000000..8c510586a --- /dev/null +++ b/packages/rust/armonik/tests/sessions.rs @@ -0,0 +1,347 @@ +use std::sync::Arc; + +use armonik::{server::SessionsServiceExt, sessions}; + +mod common; + +#[derive(Debug, Clone, Default)] +struct Service { + failure: Option, + wait: Option, +} + +impl armonik::server::SessionsService for Service { + async fn list( + self: Arc, + request: sessions::list::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(sessions::list::Response { + sessions: vec![sessions::Raw { + session_id: String::from("rpc-list-output"), + ..Default::default() + }], + page: request.page, + page_size: request.page_size, + total: 1337, + }) + }, + ) + .await + } + + async fn get( + self: Arc, + request: sessions::get::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(sessions::get::Response { + session: sessions::Raw { + session_id: request.session_id, + partition_ids: vec![String::from("rpc-get-output")], + ..Default::default() + }, + }) + }, + ) + .await + } + + async fn cancel( + self: Arc, + request: sessions::cancel::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(sessions::cancel::Response { + session: sessions::Raw { + session_id: request.session_id, + partition_ids: vec![String::from("rpc-cancel-output")], + ..Default::default() + }, + }) + }, + ) + .await + } + + async fn create( + self: Arc, + _request: sessions::create::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(sessions::create::Response { + session_id: String::from("rpc-create-output"), + }) + }, + ) + .await + } + + async fn pause( + self: Arc, + request: sessions::pause::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(sessions::pause::Response { + session: sessions::Raw { + session_id: request.session_id, + partition_ids: vec![String::from("rpc-pause-output")], + ..Default::default() + }, + }) + }, + ) + .await + } + + async fn resume( + self: Arc, + request: sessions::resume::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(sessions::resume::Response { + session: sessions::Raw { + session_id: request.session_id, + partition_ids: vec![String::from("rpc-resume-output")], + ..Default::default() + }, + }) + }, + ) + .await + } + + async fn close( + self: Arc, + request: sessions::close::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(sessions::close::Response { + session: sessions::Raw { + session_id: request.session_id, + partition_ids: vec![String::from("rpc-close-output")], + ..Default::default() + }, + }) + }, + ) + .await + } + + async fn purge( + self: Arc, + request: sessions::purge::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(sessions::purge::Response { + session: sessions::Raw { + session_id: request.session_id, + partition_ids: vec![String::from("rpc-purge-output")], + ..Default::default() + }, + }) + }, + ) + .await + } + + async fn delete( + self: Arc, + request: sessions::delete::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(sessions::delete::Response { + session: sessions::Raw { + session_id: request.session_id, + partition_ids: vec![String::from("rpc-delete-output")], + ..Default::default() + }, + }) + }, + ) + .await + } + + async fn stop_submission( + self: Arc, + request: sessions::stop_submission::Request, + cancellation_token: tokio_util::sync::CancellationToken, + ) -> std::result::Result { + common::unary_rpc_impl( + self.wait.clone(), + self.failure.clone(), + cancellation_token, + || { + Ok(sessions::stop_submission::Response { + session: sessions::Raw { + session_id: request.session_id, + partition_ids: vec![String::from("rpc-stop-output")], + ..Default::default() + }, + }) + }, + ) + .await + } +} + +#[tokio::test] +async fn list() { + let mut client = armonik::Client::with_channel(Service::default().sessions_server()).sessions(); + + let response = client + .list( + armonik::sessions::filter::Or::default(), + armonik::sessions::Sort::default(), + true, + 3, + 12, + ) + .await + .unwrap(); + + assert_eq!(response.page, 3); + assert_eq!(response.page_size, 12); + assert_eq!(response.total, 1337); + assert_eq!(response.sessions[0].session_id, "rpc-list-output"); +} + +#[tokio::test] +async fn get() { + let mut client = armonik::Client::with_channel(Service::default().sessions_server()).sessions(); + + let response = client.get("rpc-get-input").await.unwrap(); + + assert_eq!(response.session_id, "rpc-get-input"); + assert_eq!(response.partition_ids[0], "rpc-get-output"); +} + +#[tokio::test] +async fn cancel() { + let mut client = armonik::Client::with_channel(Service::default().sessions_server()).sessions(); + + let response = client.cancel("rpc-cancel-input").await.unwrap(); + + assert_eq!(response.session_id, "rpc-cancel-input"); + assert_eq!(response.partition_ids[0], "rpc-cancel-output"); +} + +#[tokio::test] +async fn create() { + let mut client = armonik::Client::with_channel(Service::default().sessions_server()).sessions(); + + let response = client + .create(vec![String::from("rpc-create-input")], Default::default()) + .await + .unwrap(); + + assert_eq!(response, "rpc-create-output"); +} + +#[tokio::test] +async fn pause() { + let mut client = armonik::Client::with_channel(Service::default().sessions_server()).sessions(); + + let response = client.pause("rpc-pause-input").await.unwrap(); + + assert_eq!(response.session_id, "rpc-pause-input"); + assert_eq!(response.partition_ids[0], "rpc-pause-output"); +} + +#[tokio::test] +async fn resume() { + let mut client = armonik::Client::with_channel(Service::default().sessions_server()).sessions(); + + let response = client.resume("rpc-resume-input").await.unwrap(); + + assert_eq!(response.session_id, "rpc-resume-input"); + assert_eq!(response.partition_ids[0], "rpc-resume-output"); +} + +#[tokio::test] +async fn close() { + let mut client = armonik::Client::with_channel(Service::default().sessions_server()).sessions(); + + let response = client.close("rpc-close-input").await.unwrap(); + + assert_eq!(response.session_id, "rpc-close-input"); + assert_eq!(response.partition_ids[0], "rpc-close-output"); +} + +#[tokio::test] +async fn purge() { + let mut client = armonik::Client::with_channel(Service::default().sessions_server()).sessions(); + + let response = client.purge("rpc-purge-input").await.unwrap(); + + assert_eq!(response.session_id, "rpc-purge-input"); + assert_eq!(response.partition_ids[0], "rpc-purge-output"); +} + +#[tokio::test] +async fn delete() { + let mut client = armonik::Client::with_channel(Service::default().sessions_server()).sessions(); + + let response = client.delete("rpc-delete-input").await.unwrap(); + + assert_eq!(response.session_id, "rpc-delete-input"); + assert_eq!(response.partition_ids[0], "rpc-delete-output"); +} + +#[tokio::test] +async fn stop_submission() { + let mut client = armonik::Client::with_channel(Service::default().sessions_server()).sessions(); + + let response = client + .stop_submission("rpc-stop-input", true, true) + .await + .unwrap(); + + assert_eq!(response.session_id, "rpc-stop-input"); + assert_eq!(response.partition_ids[0], "rpc-stop-output"); +}