From f5486092529a2b5661a0c3939b040c5d3a849bcb Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 24 Dec 2024 15:50:52 +0530 Subject: [PATCH] Remove duplicate code, make serving a built-in source Signed-off-by: Sreekanth --- rust/Cargo.lock | 61 +- rust/Cargo.toml | 2 - rust/extns/numaflow-serving/Cargo.toml | 27 - rust/extns/numaflow-serving/src/app.rs | 276 ------ .../numaflow-serving/src/app/callback.rs | 219 ---- .../src/app/callback/state.rs | 378 ------- .../src/app/callback/store.rs | 35 - .../src/app/callback/store/memstore.rs | 217 ---- .../src/app/callback/store/redisstore.rs | 198 ---- .../numaflow-serving/src/app/response.rs | 60 -- .../extns/numaflow-serving/src/app/tracker.rs | 933 ------------------ rust/extns/numaflow-serving/src/config.rs | 235 ----- rust/extns/numaflow-serving/src/errors.rs | 50 - rust/extns/numaflow-serving/src/lib.rs | 12 - rust/extns/numaflow-serving/src/pipeline.rs | 154 --- rust/extns/numaflow-serving/src/source.rs | 194 ---- rust/numaflow-core/Cargo.toml | 1 - rust/numaflow-core/src/config/components.rs | 4 +- rust/numaflow-core/src/metrics.rs | 2 - .../src/shared/create_components.rs | 2 +- rust/numaflow-core/src/source.rs | 2 +- rust/numaflow-core/src/source/serving.rs | 10 +- rust/serving/Cargo.toml | 1 - rust/serving/src/app.rs | 132 +-- rust/serving/src/app/jetstream_proxy.rs | 104 +- rust/serving/src/config.rs | 2 - 26 files changed, 84 insertions(+), 3227 deletions(-) delete mode 100644 rust/extns/numaflow-serving/Cargo.toml delete mode 100644 rust/extns/numaflow-serving/src/app.rs delete mode 100644 rust/extns/numaflow-serving/src/app/callback.rs delete mode 100644 rust/extns/numaflow-serving/src/app/callback/state.rs delete mode 100644 rust/extns/numaflow-serving/src/app/callback/store.rs delete mode 100644 rust/extns/numaflow-serving/src/app/callback/store/memstore.rs delete mode 100644 rust/extns/numaflow-serving/src/app/callback/store/redisstore.rs delete mode 100644 rust/extns/numaflow-serving/src/app/response.rs delete mode 100644 rust/extns/numaflow-serving/src/app/tracker.rs delete mode 100644 rust/extns/numaflow-serving/src/config.rs delete mode 100644 rust/extns/numaflow-serving/src/errors.rs delete mode 100644 rust/extns/numaflow-serving/src/lib.rs delete mode 100644 rust/extns/numaflow-serving/src/pipeline.rs delete mode 100644 rust/extns/numaflow-serving/src/source.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 21a5f97246..8c9e480319 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -53,39 +53,6 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" -[[package]] -name = "async-nats" -version = "0.35.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab8df97cb8fc4a884af29ab383e9292ea0939cfcdd7d2a17179086dc6c427e7f" -dependencies = [ - "base64 0.22.1", - "bytes", - "futures", - "memchr", - "nkeys", - "nuid", - "once_cell", - "portable-atomic", - "rand", - "regex", - "ring", - "rustls-native-certs 0.7.3", - "rustls-pemfile 2.2.0", - "rustls-webpki 0.102.8", - "serde", - "serde_json", - "serde_nanos", - "serde_repr", - "thiserror 1.0.69", - "time", - "tokio", - "tokio-rustls 0.26.0", - "tracing", - "tryhard", - "url", -] - [[package]] name = "async-nats" version = "0.38.0" @@ -1748,7 +1715,7 @@ dependencies = [ name = "numaflow-core" version = "0.1.0" dependencies = [ - "async-nats 0.38.0", + "async-nats", "axum", "axum-server", "backoff", @@ -1762,7 +1729,6 @@ dependencies = [ "numaflow-models", "numaflow-pb", "numaflow-pulsar", - "numaflow-serving", "parking_lot", "pep440_rs", "pin-project", @@ -1828,30 +1794,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "numaflow-serving" -version = "0.1.0" -dependencies = [ - "axum", - "axum-server", - "backoff", - "base64 0.22.1", - "bytes", - "chrono", - "numaflow-models", - "rcgen", - "redis", - "serde", - "serde_json", - "thiserror 2.0.8", - "tokio", - "tower 0.4.13", - "tower-http", - "tracing", - "trait-variant", - "uuid", -] - [[package]] name = "object" version = "0.36.5" @@ -2885,7 +2827,6 @@ dependencies = [ name = "serving" version = "0.1.0" dependencies = [ - "async-nats 0.35.1", "axum", "axum-macros", "axum-server", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index b986f29316..db7deddb61 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -9,7 +9,6 @@ members = [ "numaflow-pb", "extns/numaflow-pulsar", "numaflow", - "extns/numaflow-serving", ] [workspace.lints.rust] @@ -60,7 +59,6 @@ numaflow-models = { path = "numaflow-models" } backoff = { path = "backoff" } numaflow-pb = { path = "numaflow-pb" } numaflow-pulsar = {path = "extns/numaflow-pulsar"} -numaflow-serving = {path = "extns/numaflow-serving"} tokio = "1.41.1" bytes = "1.7.1" tracing = "0.1.40" diff --git a/rust/extns/numaflow-serving/Cargo.toml b/rust/extns/numaflow-serving/Cargo.toml deleted file mode 100644 index 9a2abf9959..0000000000 --- a/rust/extns/numaflow-serving/Cargo.toml +++ /dev/null @@ -1,27 +0,0 @@ -[package] -name = "numaflow-serving" -version = "0.1.0" -edition = "2021" - -[lints] -workspace = true - -[dependencies] -axum.workspace = true -axum-server.workspace = true -tokio.workspace = true -bytes.workspace = true -tracing.workspace = true -serde = { workspace = true } -numaflow-models.workspace = true -backoff.workspace = true -rcgen = "0.13.1" -tower = "0.4.13" -tower-http = { version = "0.5.2", features = ["trace", "timeout"] } -uuid = { version = "1.10.0", features = ["v4"] } -thiserror = "2.0.8" -base64 = "0.22.1" -serde_json = "1.0.120" -trait-variant = "0.1.2" -redis = { version = "0.26.0", features = ["tokio-comp", "aio", "connection-manager"] } -chrono = { version = "0.4", features = ["serde"] } diff --git a/rust/extns/numaflow-serving/src/app.rs b/rust/extns/numaflow-serving/src/app.rs deleted file mode 100644 index 1d491bdace..0000000000 --- a/rust/extns/numaflow-serving/src/app.rs +++ /dev/null @@ -1,276 +0,0 @@ -use std::collections::HashMap; -use std::net::SocketAddr; -use std::sync::Arc; -use std::time::Duration; - -use axum::body::Body; -use axum::extract::{MatchedPath, Request, State}; -use axum::http::{HeaderMap, StatusCode}; -use axum::middleware::{self, Next}; -use axum::response::{IntoResponse, Response}; -use axum::routing::{get, post}; -use axum::{Json, Router}; -use axum_server::tls_rustls::RustlsConfig; -use axum_server::Handle; -use bytes::Bytes; -use rcgen::{generate_simple_self_signed, Certificate, CertifiedKey, KeyPair}; -use tokio::sync::{mpsc, oneshot}; -use tower::ServiceBuilder; -use tower_http::timeout::TimeoutLayer; -use tower_http::trace::{DefaultOnResponse, TraceLayer}; -use uuid::Uuid; - -use crate::{Error, Message, MessageWrapper, Settings}; -/// -/// manage callbacks -pub(crate) mod callback; - -mod response; - -use crate::app::callback::state::State as CallbackState; - -pub(crate) mod tracker; - -use self::callback::store::Store; -use self::response::{ApiError, ServeResponse}; - -fn generate_certs() -> crate::Result<(Certificate, KeyPair)> { - let CertifiedKey { cert, key_pair } = generate_simple_self_signed(vec!["localhost".into()]) - .map_err(|e| Error::InitError(format!("Failed to generate cert {:?}", e)))?; - Ok((cert, key_pair)) -} - -pub(crate) async fn serve(app: AppState) -> crate::Result<()> -where - T: Clone + Send + Sync + Store + 'static, -{ - let (cert, key) = generate_certs()?; - - let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) - .await - .map_err(|e| Error::InitError(format!("Failed to create tls config {:?}", e)))?; - - // TODO: Move all env variables into one place. Some env variables are loaded when Settings is initialized - - tracing::info!(config = ?app.settings, "Starting server with config and pipeline spec"); - // Start the main server, which serves the application. - tokio::spawn(start_main_server(app, tls_config)); - - Ok(()) -} - -#[derive(Clone)] -pub struct AppState { - pub message: mpsc::Sender, - pub settings: Arc, - pub callback_state: CallbackState, -} - -const PUBLISH_ENDPOINTS: [&str; 3] = [ - "/v1/process/sync", - "/v1/process/sync_serve", - "/v1/process/async", -]; -// auth middleware to do token based authentication for all user facing routes -// if auth is enabled. -async fn auth_middleware( - State(api_auth_token): State>, - request: axum::extract::Request, - next: Next, -) -> Response { - let path = request.uri().path(); - - // we only need to check for the presence of the auth token in the request headers for the publish endpoints - if !PUBLISH_ENDPOINTS.contains(&path) { - return next.run(request).await; - } - - match api_auth_token { - Some(token) => { - // Check for the presence of the auth token in the request headers - let auth_token = match request.headers().get("Authorization") { - Some(token) => token, - None => { - return Response::builder() - .status(401) - .body(Body::empty()) - .expect("failed to build response") - } - }; - if auth_token.to_str().expect("auth token should be a string") - != format!("Bearer {}", token) - { - Response::builder() - .status(401) - .body(Body::empty()) - .expect("failed to build response") - } else { - next.run(request).await - } - } - None => { - // If the auth token is not set, we don't need to check for the presence of the auth token in the request headers - next.run(request).await - } - } -} - -async fn start_main_server(app: AppState, tls_config: RustlsConfig) -> crate::Result<()> -where - T: Clone + Send + Sync + Store + 'static, -{ - let app_addr: SocketAddr = format!("0.0.0.0:{}", &app.settings.app_listen_port) - .parse() - .map_err(|e| Error::InitError(format!("{e:?}")))?; - - let tid_header = app.settings.tid_header.clone(); - let layers = ServiceBuilder::new() - // Add tracing to all requests - .layer( - TraceLayer::new_for_http() - .make_span_with(move |req: &Request| { - let tid = req - .headers() - .get(&tid_header) - .and_then(|v| v.to_str().ok()) - .map(|v| v.to_string()) - .unwrap_or_else(|| Uuid::new_v4().to_string()); - - let matched_path = req - .extensions() - .get::() - .map(MatchedPath::as_str); - - tracing::info_span!("request", tid, method=?req.method(), matched_path) - }) - .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)), - ) - .layer( - // Graceful shutdown will wait for outstanding requests to complete. Add a timeout so - // requests don't hang forever. - TimeoutLayer::new(Duration::from_secs(app.settings.drain_timeout_secs)), - ) - // Add auth middleware to all user facing routes - .layer(middleware::from_fn_with_state( - app.settings.api_auth_token.clone(), - auth_middleware, - )); - - let handle = Handle::new(); - - let router = setup_app(app).await.layer(layers); - - tracing::info!(?app_addr, "Starting application server"); - - axum_server::bind_rustls(app_addr, tls_config) - .handle(handle) - .serve(router.into_make_service()) - .await - .map_err(|e| Error::InitError(format!("Starting web server for metrics: {}", e)))?; - - Ok(()) -} - -async fn setup_app(state: AppState) -> Router { - let router = Router::new() - .route("/health", get(health_check)) - .route("/livez", get(livez)) // Liveliness check - .route("/readyz", get(readyz)) - .with_state(state.clone()); - - let router = router.nest("/v1/process", routes(state.clone())); - router -} - -async fn health_check() -> impl IntoResponse { - "ok" -} - -async fn livez() -> impl IntoResponse { - StatusCode::NO_CONTENT -} - -async fn readyz( - State(app): State>, -) -> impl IntoResponse { - if app.callback_state.clone().ready().await { - StatusCode::NO_CONTENT - } else { - StatusCode::INTERNAL_SERVER_ERROR - } -} - -// extracts the ID from the headers, if not found, generates a new UUID -fn extract_id_from_headers(tid_header: &str, headers: &HeaderMap) -> String { - headers.get(tid_header).map_or_else( - || Uuid::new_v4().to_string(), - |v| String::from_utf8_lossy(v.as_bytes()).to_string(), - ) -} - -fn routes(app_state: AppState) -> Router { - let jetstream_proxy = jetstream_proxy(app_state); - jetstream_proxy -} - -const CALLBACK_URL_KEY: &str = "X-Numaflow-Callback-Url"; -const NUMAFLOW_RESP_ARRAY_LEN: &str = "Numaflow-Array-Len"; -const NUMAFLOW_RESP_ARRAY_IDX_LEN: &str = "Numaflow-Array-Index-Len"; - -struct ProxyState { - tid_header: String, - callback: CallbackState, - callback_url: String, - messages: mpsc::Sender, -} - -pub(crate) fn jetstream_proxy( - state: AppState, -) -> Router { - let proxy_state = Arc::new(ProxyState { - tid_header: state.settings.tid_header.clone(), - callback: state.callback_state.clone(), - messages: state.message.clone(), - callback_url: format!( - "https://{}:{}/v1/process/callback", - state.settings.host_ip, state.settings.app_listen_port - ), - }); - - let router = Router::new() - .route("/async", post(async_publish)) - .with_state(proxy_state); - router -} - -async fn async_publish( - State(proxy_state): State>>, - headers: HeaderMap, - body: Bytes, -) -> Result, ApiError> { - let id = extract_id_from_headers(&proxy_state.tid_header, &headers); - let mut msg_headers: HashMap = HashMap::new(); - for (k, v) in headers.iter() { - msg_headers.insert( - k.to_string(), - String::from_utf8_lossy(v.as_bytes()).to_string(), - ); - } - let (tx, rx) = oneshot::channel(); - let message = MessageWrapper { - confirm_save: tx, - message: Message { - value: body, - id: id.clone(), - headers: msg_headers, - }, - }; - proxy_state.messages.send(message).await.unwrap(); - rx.await.unwrap(); - - Ok(Json(ServeResponse::new( - "Successfully published message".to_string(), - id, - StatusCode::OK, - ))) -} diff --git a/rust/extns/numaflow-serving/src/app/callback.rs b/rust/extns/numaflow-serving/src/app/callback.rs deleted file mode 100644 index b4d43868ee..0000000000 --- a/rust/extns/numaflow-serving/src/app/callback.rs +++ /dev/null @@ -1,219 +0,0 @@ -use axum::{body::Bytes, extract::State, http::HeaderMap, routing, Json, Router}; -use serde::{Deserialize, Serialize}; -use tracing::error; - -use self::store::Store; -use crate::app::response::ApiError; - -/// in-memory state store including connection tracking -pub(crate) mod state; -use state::State as CallbackState; - -/// store for storing the state -pub(crate) mod store; - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub(crate) struct CallbackRequest { - pub(crate) id: String, - pub(crate) vertex: String, - pub(crate) cb_time: u64, - pub(crate) from_vertex: String, - pub(crate) tags: Option>, -} - -#[derive(Clone)] -struct CallbackAppState { - tid_header: String, - callback_state: CallbackState, -} - -pub fn callback_handler( - tid_header: String, - callback_state: CallbackState, -) -> Router { - let app_state = CallbackAppState { - tid_header, - callback_state, - }; - Router::new() - .route("/callback", routing::post(callback)) - .route("/callback_save", routing::post(callback_save)) - .with_state(app_state) -} - -async fn callback_save( - State(app_state): State>, - headers: HeaderMap, - body: Bytes, -) -> Result<(), ApiError> { - let id = headers - .get(&app_state.tid_header) - .map(|id| String::from_utf8_lossy(id.as_bytes()).to_string()) - .ok_or_else(|| ApiError::BadRequest("Missing id header".to_string()))?; - - app_state - .callback_state - .clone() - .save_response(id, body) - .await - .map_err(|e| { - error!(error=?e, "Saving body from callback save request"); - ApiError::InternalServerError( - "Failed to save body from callback save request".to_string(), - ) - })?; - - Ok(()) -} - -async fn callback( - State(app_state): State>, - Json(payload): Json>, -) -> Result<(), ApiError> { - app_state - .callback_state - .clone() - .insert_callback_requests(payload) - .await - .map_err(|e| { - error!(error=?e, "Inserting callback requests"); - ApiError::InternalServerError("Failed to insert callback requests".to_string()) - })?; - - Ok(()) -} - -#[cfg(test)] -mod tests { - use axum::body::Body; - use axum::extract::Request; - use axum::http::header::CONTENT_TYPE; - use axum::http::StatusCode; - use tower::ServiceExt; - - use super::*; - use crate::app::callback::state::State as CallbackState; - use crate::app::callback::store::memstore::InMemoryStore; - use crate::app::tracker::MessageGraph; - use crate::pipeline::PipelineDCG; - - const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; - - #[tokio::test] - async fn test_callback_failure() { - let store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let state = CallbackState::new(msg_graph, store).await.unwrap(); - let app = callback_handler("ID".to_owned(), state); - - let payload = vec![CallbackRequest { - id: "test_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: None, - }]; - - let res = Request::builder() - .method("POST") - .uri("/callback") - .header(CONTENT_TYPE, "application/json") - .body(Body::from(serde_json::to_vec(&payload).unwrap())) - .unwrap(); - - let resp = app.oneshot(res).await.unwrap(); - // id is not registered, so it should return 500 - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); - } - - #[tokio::test] - async fn test_callback_success() { - let store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let mut state = CallbackState::new(msg_graph, store).await.unwrap(); - - let x = state.register("test_id".to_string()); - // spawn a task which will be awaited later - let handle = tokio::spawn(async move { - let _ = x.await.unwrap(); - }); - - let app = callback_handler("ID".to_owned(), state); - - let payload = vec![ - CallbackRequest { - id: "test_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: None, - }, - CallbackRequest { - id: "test_id".to_string(), - vertex: "cat".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: None, - }, - CallbackRequest { - id: "test_id".to_string(), - vertex: "out".to_string(), - cb_time: 12345, - from_vertex: "cat".to_string(), - tags: None, - }, - ]; - - let res = Request::builder() - .method("POST") - .uri("/callback") - .header(CONTENT_TYPE, "application/json") - .body(Body::from(serde_json::to_vec(&payload).unwrap())) - .unwrap(); - - let resp = app.oneshot(res).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - - handle.await.unwrap(); - } - - #[tokio::test] - async fn test_callback_save() { - let store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let state = CallbackState::new(msg_graph, store).await.unwrap(); - let app = callback_handler("ID".to_owned(), state); - - let res = Request::builder() - .method("POST") - .uri("/callback_save") - .header(CONTENT_TYPE, "application/json") - .header("id", "test_id") - .body(Body::from("test_body")) - .unwrap(); - - let resp = app.oneshot(res).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - } - - #[tokio::test] - async fn test_without_id_header() { - let store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let state = CallbackState::new(msg_graph, store).await.unwrap(); - let app = callback_handler("ID".to_owned(), state); - - let res = Request::builder() - .method("POST") - .uri("/callback_save") - .body(Body::from("test_body")) - .unwrap(); - - let resp = app.oneshot(res).await.unwrap(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - } -} diff --git a/rust/extns/numaflow-serving/src/app/callback/state.rs b/rust/extns/numaflow-serving/src/app/callback/state.rs deleted file mode 100644 index 293478ead2..0000000000 --- a/rust/extns/numaflow-serving/src/app/callback/state.rs +++ /dev/null @@ -1,378 +0,0 @@ -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, -}; - -use tokio::sync::oneshot; - -use super::store::Store; -use crate::app::callback::{store::PayloadToSave, CallbackRequest}; -use crate::app::tracker::MessageGraph; -use crate::Error; - -struct RequestState { - // Channel to notify when all callbacks for a message is received - tx: oneshot::Sender>, - // CallbackRequest is immutable, while vtx_visited can grow. - vtx_visited: Vec>, -} - -#[derive(Clone)] -pub(crate) struct State { - // hashmap of vertex infos keyed by ID - // it also contains tx to trigger to response to the syncHTTP call - callbacks: Arc>>, - // generator to generate subgraph - msg_graph_generator: Arc, - // conn is to be used while reading and writing to redis. - store: T, -} - -impl State -where - T: Store, -{ - /// Create a new State to track connections and callback data - pub(crate) async fn new(msg_graph: MessageGraph, store: T) -> crate::Result { - Ok(Self { - callbacks: Arc::new(Mutex::new(HashMap::new())), - msg_graph_generator: Arc::new(msg_graph), - store, - }) - } - - /// register a new connection - /// The oneshot receiver will be notified when all callbacks for this connection is received from the numaflow pipeline - pub(crate) fn register(&mut self, id: String) -> oneshot::Receiver> { - // TODO: add an entry in Redis to note that the entry has been registered - - let (tx, rx) = oneshot::channel(); - let mut guard = self.callbacks.lock().expect("Getting lock on State"); - guard.insert( - id.clone(), - RequestState { - tx, - vtx_visited: Vec::new(), - }, - ); - rx - } - - /// Retrieves the output of the numaflow pipeline - pub(crate) async fn retrieve_saved(&mut self, id: &str) -> Result>, Error> { - self.store.retrieve_datum(id).await - } - - pub(crate) async fn save_response( - &mut self, - id: String, - body: axum::body::Bytes, - ) -> crate::Result<()> { - // we have to differentiate between the saved responses and the callback requests - // saved responses are stored in "id_SAVED", callback requests are stored in "id" - self.store - .save(vec![PayloadToSave::DatumFromPipeline { - key: id, - value: body, - }]) - .await - } - - /// insert_callback_requests is used to insert the callback requests. - pub(crate) async fn insert_callback_requests( - &mut self, - cb_requests: Vec, - ) -> Result<(), Error> { - /* - TODO: should we consider batching the requests and then processing them? - that way algorithm can be invoked only once for a batch of requests - instead of invoking it for each request. - */ - let cb_requests: Vec> = - cb_requests.into_iter().map(Arc::new).collect(); - let redis_payloads: Vec = cb_requests - .iter() - .cloned() - .map(|cbr| PayloadToSave::Callback { - key: cbr.id.clone(), - value: Arc::clone(&cbr), - }) - .collect(); - - self.store.save(redis_payloads).await?; - - for cbr in cb_requests { - let id = cbr.id.clone(); - { - let mut guard = self.callbacks.lock().expect("Getting lock on State"); - guard - .get_mut(&cbr.id) - .ok_or(Error::IDNotFound( - "Connection for the received callback is not present in the in-memory store", - ))? - .vtx_visited - .push(cbr); - } - - // check if the sub graph can be generated - match self.get_subgraph_from_memory(&id) { - Ok(_) => { - // if the sub graph is generated, then we can send the response - self.deregister(&id).await? - } - Err(e) => { - match e { - Error::SubGraphNotFound(_) => { - // if the sub graph is not generated, then we can continue - continue; - } - _ => { - // if there is an error, deregister with the error - self.deregister(&id).await? - } - } - } - } - } - Ok(()) - } - - /// Get the subgraph for the given ID from in-memory. - fn get_subgraph_from_memory(&self, id: &str) -> Result { - let callbacks = self.get_callbacks_from_memory(id).ok_or(Error::IDNotFound( - "Connection for the received callback is not present in the in-memory store", - ))?; - - self.get_subgraph(id.to_string(), callbacks) - } - - /// Get the subgraph for the given ID from persistent store. This is used querying for the status from the service endpoint even after the - /// request has been completed. - pub(crate) async fn retrieve_subgraph_from_storage( - &mut self, - id: &str, - ) -> Result { - // If the id is not found in the in-memory store, fetch from Redis - let callbacks: Vec> = - match self.retrieve_callbacks_from_storage(id).await { - Ok(callbacks) => callbacks, - Err(e) => { - return Err(e); - } - }; - // check if the sub graph can be generated - self.get_subgraph(id.to_string(), callbacks) - } - - // Generate subgraph from the given callbacks - fn get_subgraph( - &self, - id: String, - callbacks: Vec>, - ) -> Result { - match self - .msg_graph_generator - .generate_subgraph_from_callbacks(id, callbacks) - { - Ok(Some(sub_graph)) => Ok(sub_graph), - Ok(None) => Err(Error::SubGraphNotFound( - "Subgraph could not be generated for the given ID", - )), - Err(e) => Err(e), - } - } - - /// deregister is called to trigger response and delete all the data persisted for that ID - pub(crate) async fn deregister(&mut self, id: &str) -> Result<(), Error> { - let state = { - let mut guard = self.callbacks.lock().expect("Getting lock on State"); - // we do not require the data stored in HashMap anymore - guard.remove(id) - }; - - let Some(state) = state else { - return Err(Error::IDNotFound( - "Connection for the received callback is not present in the in-memory store", - )); - }; - - state - .tx - .send(Ok(id.to_string())) - .map_err(|_| Error::Other("Application bug - Receiver is already dropped".to_string())) - } - - // Get the Callback value for the given ID - // TODO: Generate json serialized data here itself to avoid cloning. - fn get_callbacks_from_memory(&self, id: &str) -> Option>> { - let guard = self.callbacks.lock().expect("Getting lock on State"); - guard.get(id).map(|state| state.vtx_visited.clone()) - } - - // Get the Callback value for the given ID from persistent store - async fn retrieve_callbacks_from_storage( - &mut self, - id: &str, - ) -> Result>, Error> { - // If the id is not found in the in-memory store, fetch from Redis - let callbacks: Vec> = match self.store.retrieve_callbacks(id).await { - Ok(response) => response.into_iter().collect(), - Err(e) => { - return Err(e); - } - }; - Ok(callbacks) - } - - // Check if the store is ready - pub(crate) async fn ready(&mut self) -> bool { - self.store.ready().await - } -} - -#[cfg(test)] -mod tests { - use axum::body::Bytes; - - use super::*; - use crate::app::callback::store::memstore::InMemoryStore; - use crate::pipeline::PipelineDCG; - - const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; - - #[tokio::test] - async fn test_state() { - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let store = InMemoryStore::new(); - let mut state = State::new(msg_graph, store).await.unwrap(); - - // Test register - let id = "test_id".to_string(); - let rx = state.register(id.clone()); - - let xid = id.clone(); - - // spawn a task to listen on the receiver, once we have received all the callbacks for the message - // we will get a response from the receiver with the message id - let handle = tokio::spawn(async move { - let result = rx.await.unwrap(); - // Tests deregister, and fetching the subgraph from the memory - assert_eq!(result.unwrap(), xid); - }); - - // Test save_response - let body = Bytes::from("Test Message"); - state.save_response(id.clone(), body).await.unwrap(); - - // Test retrieve_saved - let saved = state.retrieve_saved(&id).await.unwrap(); - assert_eq!(saved, vec!["Test Message".as_bytes()]); - - // Test insert_callback_requests - let cbs = vec![ - CallbackRequest { - id: id.clone(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: None, - }, - CallbackRequest { - id: id.clone(), - vertex: "planner".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: Some(vec!["tiger".to_owned(), "asciiart".to_owned()]), - }, - CallbackRequest { - id: id.clone(), - vertex: "tiger".to_string(), - cb_time: 12345, - from_vertex: "planner".to_string(), - tags: None, - }, - CallbackRequest { - id: id.clone(), - vertex: "asciiart".to_string(), - cb_time: 12345, - from_vertex: "planner".to_string(), - tags: None, - }, - CallbackRequest { - id: id.clone(), - vertex: "serve-sink".to_string(), - cb_time: 12345, - from_vertex: "tiger".to_string(), - tags: None, - }, - CallbackRequest { - id: id.clone(), - vertex: "serve-sink".to_string(), - cb_time: 12345, - from_vertex: "asciiart".to_string(), - tags: None, - }, - ]; - state.insert_callback_requests(cbs).await.unwrap(); - - let sub_graph = state.retrieve_subgraph_from_storage(&id).await; - assert!(sub_graph.is_ok()); - - handle.await.unwrap(); - } - - #[tokio::test] - async fn test_retrieve_saved_no_entry() { - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let store = InMemoryStore::new(); - let mut state = State::new(msg_graph, store).await.unwrap(); - - let id = "nonexistent_id".to_string(); - - // Try to retrieve saved data for an ID that doesn't exist - let result = state.retrieve_saved(&id).await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_insert_callback_requests_invalid_id() { - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let store = InMemoryStore::new(); - let mut state = State::new(msg_graph, store).await.unwrap(); - - let cbs = vec![CallbackRequest { - id: "nonexistent_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: None, - }]; - - // Try to insert callback requests for an ID that hasn't been registered - let result = state.insert_callback_requests(cbs).await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_retrieve_subgraph_from_storage_no_entry() { - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let store = InMemoryStore::new(); - let mut state = State::new(msg_graph, store).await.unwrap(); - - let id = "nonexistent_id".to_string(); - - // Try to retrieve a subgraph for an ID that doesn't exist - let result = state.retrieve_subgraph_from_storage(&id).await; - - // Check that an error is returned - assert!(result.is_err()); - } -} diff --git a/rust/extns/numaflow-serving/src/app/callback/store.rs b/rust/extns/numaflow-serving/src/app/callback/store.rs deleted file mode 100644 index af5f3c4368..0000000000 --- a/rust/extns/numaflow-serving/src/app/callback/store.rs +++ /dev/null @@ -1,35 +0,0 @@ -use std::sync::Arc; - -use crate::app::callback::CallbackRequest; - -// in-memory store -pub(crate) mod memstore; -// redis as the store -pub(crate) mod redisstore; - -pub(crate) enum PayloadToSave { - /// Callback as sent by Numaflow to track the progression - Callback { - key: String, - value: Arc, - }, - /// Data sent by the Numaflow pipeline which is to be delivered as the response - DatumFromPipeline { - key: String, - value: axum::body::Bytes, - }, -} - -/// Store trait to store the callback information. -#[trait_variant::make(Store: Send)] -#[allow(dead_code)] -pub(crate) trait LocalStore { - async fn save(&mut self, messages: Vec) -> crate::Result<()>; - /// retrieve the callback payloads - async fn retrieve_callbacks( - &mut self, - id: &str, - ) -> Result>, crate::Error>; - async fn retrieve_datum(&mut self, id: &str) -> Result>, crate::Error>; - async fn ready(&mut self) -> bool; -} diff --git a/rust/extns/numaflow-serving/src/app/callback/store/memstore.rs b/rust/extns/numaflow-serving/src/app/callback/store/memstore.rs deleted file mode 100644 index 59355c76c6..0000000000 --- a/rust/extns/numaflow-serving/src/app/callback/store/memstore.rs +++ /dev/null @@ -1,217 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use super::PayloadToSave; -use crate::app::callback::CallbackRequest; -use crate::config::SAVED; -use crate::Error; - -/// `InMemoryStore` is an in-memory implementation of the `Store` trait. -/// It uses a `HashMap` to store data in memory. -#[derive(Clone)] -pub(crate) struct InMemoryStore { - /// The data field is a `HashMap` where the key is a `String` and the value is a `Vec>`. - /// It is wrapped in an `Arc>` to allow shared ownership and thread safety. - data: Arc>>>>, -} - -impl InMemoryStore { - /// Creates a new `InMemoryStore` with an empty `HashMap`. - #[allow(dead_code)] - pub(crate) fn new() -> Self { - Self { - data: Arc::new(std::sync::Mutex::new(HashMap::new())), - } - } -} - -impl super::Store for InMemoryStore { - /// Saves a vector of `PayloadToSave` into the `HashMap`. - /// Each `PayloadToSave` is serialized into bytes and stored in the `HashMap` under its key. - async fn save(&mut self, messages: Vec) -> crate::Result<()> { - let mut data = self.data.lock().unwrap(); - for msg in messages { - match msg { - PayloadToSave::Callback { key, value } => { - if key.is_empty() { - return Err(Error::StoreWrite("Key cannot be empty".to_string())); - } - let bytes = serde_json::to_vec(&*value) - .map_err(|e| Error::StoreWrite(format!("Serializing to bytes - {}", e)))?; - data.entry(key).or_default().push(bytes); - } - PayloadToSave::DatumFromPipeline { key, value } => { - if key.is_empty() { - return Err(Error::StoreWrite("Key cannot be empty".to_string())); - } - data.entry(format!("{}_{}", key, SAVED)) - .or_default() - .push(value.into()); - } - } - } - Ok(()) - } - - /// Retrieves callbacks for a given id from the `HashMap`. - /// Each callback is deserialized from bytes into a `CallbackRequest`. - async fn retrieve_callbacks(&mut self, id: &str) -> Result>, Error> { - let data = self.data.lock().unwrap(); - match data.get(id) { - Some(result) => { - let messages: Result, _> = result - .iter() - .map(|msg| { - let cbr: CallbackRequest = serde_json::from_slice(msg).map_err(|_| { - Error::StoreRead( - "Failed to parse CallbackRequest from bytes".to_string(), - ) - })?; - Ok(Arc::new(cbr)) - }) - .collect(); - messages - } - None => Err(Error::StoreRead(format!("No entry found for id: {}", id))), - } - } - - /// Retrieves data for a given id from the `HashMap`. - /// Each piece of data is deserialized from bytes into a `String`. - async fn retrieve_datum(&mut self, id: &str) -> Result>, Error> { - let id = format!("{}_{}", id, SAVED); - let data = self.data.lock().unwrap(); - match data.get(&id) { - Some(result) => Ok(result.to_vec()), - None => Err(Error::StoreRead(format!("No entry found for id: {}", id))), - } - } - - async fn ready(&mut self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use super::*; - use crate::app::callback::store::{PayloadToSave, Store}; - use crate::app::callback::CallbackRequest; - - #[tokio::test] - async fn test_save_and_retrieve_callbacks() { - let mut store = InMemoryStore::new(); - let key = "test_key".to_string(); - let value = Arc::new(CallbackRequest { - id: "test_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: None, - }); - - // Save a callback - store - .save(vec![PayloadToSave::Callback { - key: key.clone(), - value: Arc::clone(&value), - }]) - .await - .unwrap(); - - // Retrieve the callback - let retrieved = store.retrieve_callbacks(&key).await.unwrap(); - - // Check that the retrieved callback is the same as the one we saved - assert_eq!(retrieved.len(), 1); - assert_eq!(retrieved[0].id, "test_id".to_string()) - } - - #[tokio::test] - async fn test_save_and_retrieve_datum() { - let mut store = InMemoryStore::new(); - let key = "test_key".to_string(); - let value = "test_value".to_string(); - - // Save a datum - store - .save(vec![PayloadToSave::DatumFromPipeline { - key: key.clone(), - value: value.clone().into(), - }]) - .await - .unwrap(); - - // Retrieve the datum - let retrieved = store.retrieve_datum(&key).await.unwrap(); - - // Check that the retrieved datum is the same as the one we saved - assert_eq!(retrieved.len(), 1); - assert_eq!(retrieved[0], value.as_bytes()); - } - - #[tokio::test] - async fn test_retrieve_callbacks_no_entry() { - let mut store = InMemoryStore::new(); - let key = "nonexistent_key".to_string(); - - // Try to retrieve a callback for a key that doesn't exist - let result = store.retrieve_callbacks(&key).await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_retrieve_datum_no_entry() { - let mut store = InMemoryStore::new(); - let key = "nonexistent_key".to_string(); - - // Try to retrieve a datum for a key that doesn't exist - let result = store.retrieve_datum(&key).await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_save_invalid_callback() { - let mut store = InMemoryStore::new(); - let value = Arc::new(CallbackRequest { - id: "test_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: None, - }); - - // Try to save a callback with an invalid key - let result = store - .save(vec![PayloadToSave::Callback { - key: "".to_string(), - value: Arc::clone(&value), - }]) - .await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_save_invalid_datum() { - let mut store = InMemoryStore::new(); - - // Try to save a datum with an invalid key - let result = store - .save(vec![PayloadToSave::DatumFromPipeline { - key: "".to_string(), - value: "test_value".into(), - }]) - .await; - - // Check that an error is returned - assert!(result.is_err()); - } -} diff --git a/rust/extns/numaflow-serving/src/app/callback/store/redisstore.rs b/rust/extns/numaflow-serving/src/app/callback/store/redisstore.rs deleted file mode 100644 index deae8b42cf..0000000000 --- a/rust/extns/numaflow-serving/src/app/callback/store/redisstore.rs +++ /dev/null @@ -1,198 +0,0 @@ -use std::sync::Arc; - -use backoff::retry::Retry; -use backoff::strategy::fixed; -use redis::aio::ConnectionManager; -use redis::RedisError; -use tokio::sync::Semaphore; - -use super::PayloadToSave; -use crate::app::callback::CallbackRequest; -use crate::config::RedisConfig; -use crate::config::SAVED; -use crate::Error; - -const LPUSH: &str = "LPUSH"; -const LRANGE: &str = "LRANGE"; -const EXPIRE: &str = "EXPIRE"; - -// Handle to the Redis actor. -#[derive(Clone)] -pub(crate) struct RedisConnection { - conn_manager: ConnectionManager, - config: RedisConfig, -} - -impl RedisConnection { - /// Creates a new RedisConnection with concurrent operations on Redis set by max_tasks. - pub(crate) async fn new(config: RedisConfig) -> crate::Result { - let client = redis::Client::open(config.addr.as_str()) - .map_err(|e| Error::Connection(format!("Creating Redis client: {e:?}")))?; - let conn = client - .get_connection_manager() - .await - .map_err(|e| Error::Connection(format!("Connecting to Redis server: {e:?}")))?; - Ok(Self { - conn_manager: conn, - config, - }) - } - - async fn execute_redis_cmd( - conn_manager: &mut ConnectionManager, - ttl_secs: Option, - key: &str, - val: &Vec, - ) -> Result<(), RedisError> { - let mut pipe = redis::pipe(); - pipe.cmd(LPUSH).arg(key).arg(val); - - // if the ttl is configured, add the EXPIRE command to the pipeline - if let Some(ttl) = ttl_secs { - pipe.cmd(EXPIRE).arg(key).arg(ttl); - } - - // Execute the pipeline - pipe.query_async(conn_manager).await.map(|_: ()| ()) - } - - // write to Redis with retries - async fn write_to_redis(&self, key: &str, value: &Vec) -> crate::Result<()> { - let interval = fixed::Interval::from_millis(self.config.retries_duration_millis.into()) - .take(self.config.retries); - - Retry::retry( - interval, - || async { - // https://hackmd.io/@compiler-errors/async-closures - Self::execute_redis_cmd( - &mut self.conn_manager.clone(), - self.config.ttl_secs, - key, - value, - ) - .await - }, - |e: &RedisError| !e.is_unrecoverable_error(), - ) - .await - .map_err(|err| Error::StoreWrite(format!("Saving to redis: {}", err).to_string())) - } -} - -async fn handle_write_requests( - redis_conn: RedisConnection, - msg: PayloadToSave, -) -> crate::Result<()> { - match msg { - PayloadToSave::Callback { key, value } => { - // Convert the CallbackRequest to a byte array - let value = serde_json::to_vec(&*value) - .map_err(|e| Error::StoreWrite(format!("Serializing payload - {}", e)))?; - - redis_conn.write_to_redis(&key, &value).await - } - - // Write the byte array to Redis - PayloadToSave::DatumFromPipeline { key, value } => { - // we have to differentiate between the saved responses and the callback requests - // saved responses are stored in "id_SAVED", callback requests are stored in "id" - let key = format!("{}_{}", key, SAVED); - let value: Vec = value.into(); - - redis_conn.write_to_redis(&key, &value).await - } - } -} - -// It is possible to move the methods defined here to be methods on the Redis actor and communicate through channels. -// With that, all public APIs defined on RedisConnection can be on &self (immutable). -impl super::Store for RedisConnection { - // Attempt to save all payloads. Returns error if we fail to save at least one message. - async fn save(&mut self, messages: Vec) -> crate::Result<()> { - let mut tasks = vec![]; - // This is put in place not to overload Redis and also way some kind of - // flow control. - let sem = Arc::new(Semaphore::new(self.config.max_tasks)); - for msg in messages { - let permit = Arc::clone(&sem).acquire_owned().await; - let redis_conn = self.clone(); - let task = tokio::spawn(async move { - let _permit = permit; - handle_write_requests(redis_conn, msg).await - }); - tasks.push(task); - } - for task in tasks { - task.await.unwrap()?; - } - Ok(()) - } - - async fn retrieve_callbacks(&mut self, id: &str) -> Result>, Error> { - let result: Result>, RedisError> = redis::cmd(LRANGE) - .arg(id) - .arg(0) - .arg(-1) - .query_async(&mut self.conn_manager) - .await; - - match result { - Ok(result) => { - if result.is_empty() { - return Err(Error::StoreRead(format!("No entry found for id: {}", id))); - } - - let messages: Result, _> = result - .into_iter() - .map(|msg| { - let cbr: CallbackRequest = serde_json::from_slice(&msg).map_err(|e| { - Error::StoreRead(format!("Parsing payload from bytes - {}", e)) - })?; - Ok(Arc::new(cbr)) - }) - .collect(); - - messages - } - Err(e) => Err(Error::StoreRead(format!( - "Failed to read from redis: {:?}", - e - ))), - } - } - - async fn retrieve_datum(&mut self, id: &str) -> Result>, Error> { - // saved responses are stored in "id_SAVED" - let key = format!("{}_{}", id, SAVED); - let result: Result>, RedisError> = redis::cmd(LRANGE) - .arg(key) - .arg(0) - .arg(-1) - .query_async(&mut self.conn_manager) - .await; - - match result { - Ok(result) => { - if result.is_empty() { - return Err(Error::StoreRead(format!("No entry found for id: {}", id))); - } - - Ok(result) - } - Err(e) => Err(Error::StoreRead(format!( - "Failed to read from redis: {:?}", - e - ))), - } - } - - // Check if the Redis connection is healthy - async fn ready(&mut self) -> bool { - let mut conn = self.conn_manager.clone(); - match redis::cmd("PING").query_async::(&mut conn).await { - Ok(response) => response == "PONG", - Err(_) => false, - } - } -} diff --git a/rust/extns/numaflow-serving/src/app/response.rs b/rust/extns/numaflow-serving/src/app/response.rs deleted file mode 100644 index 40064a1f78..0000000000 --- a/rust/extns/numaflow-serving/src/app/response.rs +++ /dev/null @@ -1,60 +0,0 @@ -use axum::http::StatusCode; -use axum::response::{IntoResponse, Response}; -use axum::Json; -use chrono::{DateTime, Utc}; -use serde::Serialize; - -// Response sent by the serve handler sync/async to the client(user). -#[derive(Serialize)] -pub(crate) struct ServeResponse { - pub(crate) message: String, - pub(crate) id: String, - pub(crate) code: u16, - pub(crate) timestamp: DateTime, -} - -impl ServeResponse { - pub(crate) fn new(message: String, id: String, status: StatusCode) -> Self { - Self { - code: status.as_u16(), - message, - id, - timestamp: Utc::now(), - } - } -} - -// Error response sent by all the handlers to the client(user). -#[derive(Debug, Serialize)] -pub enum ApiError { - BadRequest(String), - InternalServerError(String), - BadGateway(String), -} - -impl IntoResponse for ApiError { - fn into_response(self) -> Response { - #[derive(Serialize)] - struct ErrorBody { - message: String, - code: u16, - timestamp: DateTime, - } - - let (status, message) = match self { - ApiError::BadRequest(message) => (StatusCode::BAD_REQUEST, message), - ApiError::InternalServerError(message) => (StatusCode::INTERNAL_SERVER_ERROR, message), - ApiError::BadGateway(message) => (StatusCode::BAD_GATEWAY, message), - }; - - ( - status, - Json(ErrorBody { - code: status.as_u16(), - message, - timestamp: Utc::now(), - }), - ) - .into_response() - } -} diff --git a/rust/extns/numaflow-serving/src/app/tracker.rs b/rust/extns/numaflow-serving/src/app/tracker.rs deleted file mode 100644 index 33137f45db..0000000000 --- a/rust/extns/numaflow-serving/src/app/tracker.rs +++ /dev/null @@ -1,933 +0,0 @@ -use std::collections::HashMap; -use std::string::ToString; -use std::sync::Arc; - -use serde::{Deserialize, Serialize}; - -use crate::app::callback::CallbackRequest; -use crate::pipeline::{Edge, OperatorType, PipelineDCG}; -use crate::Error; - -fn compare_slice(operator: &OperatorType, a: &[String], b: &[String]) -> bool { - match operator { - OperatorType::And => a.iter().all(|val| b.contains(val)), - OperatorType::Or => a.iter().any(|val| b.contains(val)), - OperatorType::Not => !a.iter().any(|val| b.contains(val)), - } -} - -type Graph = HashMap>; - -#[derive(Serialize, Deserialize, Debug)] -struct Subgraph { - id: String, - blocks: Vec, -} - -const DROP: &str = "U+005C__DROP__"; - -/// MessageGraph is a struct that generates the graph from the source vertex to the downstream vertices -/// for a message using the given callbacks. -pub(crate) struct MessageGraph { - dag: Graph, -} - -/// Block is a struct that contains the information about the block in the subgraph. -#[derive(Clone, Debug, Deserialize, Serialize)] -pub(crate) struct Block { - from: String, - to: String, - cb_time: u64, -} - -// CallbackRequestWrapper is a struct that contains the information about the callback request and -// whether it has been visited or not. It is used to keep track of the visited callbacks. -#[derive(Debug)] -struct CallbackRequestWrapper { - callback_request: Arc, - visited: bool, -} - -impl MessageGraph { - /// This function generates a sub graph from a list of callbacks. - /// It first creates a HashMap to map each vertex to its corresponding callbacks. - /// Then it finds the source vertex by checking if the vertex and from_vertex fields are the same. - /// Finally, it calls the `generate_subgraph` function to generate the subgraph from the source vertex. - pub(crate) fn generate_subgraph_from_callbacks( - &self, - id: String, - callbacks: Vec>, - ) -> Result, Error> { - // Create a HashMap to map each vertex to its corresponding callbacks - let mut callback_map: HashMap> = HashMap::new(); - let mut source_vertex = None; - - // Find the source vertex, source vertex is the vertex that has the same vertex and from_vertex - for callback in callbacks { - // Check if the vertex is present in the graph - if !self.dag.contains_key(&callback.from_vertex) { - return Err(Error::SubGraphInvalidInput(format!( - "Invalid callback: {}, vertex: {}", - callback.id, callback.from_vertex - ))); - } - - if callback.vertex == callback.from_vertex { - source_vertex = Some(callback.vertex.clone()); - } - callback_map - .entry(callback.vertex.clone()) - .or_default() - .push(CallbackRequestWrapper { - callback_request: Arc::clone(&callback), - visited: false, - }); - } - - // If there is no source vertex, return None - let source_vertex = match source_vertex { - Some(vertex) => vertex, - None => return Ok(None), - }; - - // Create a new subgraph. - let mut subgraph: Subgraph = Subgraph { - id, - blocks: Vec::new(), - }; - // Call the `generate_subgraph` function to generate the subgraph from the source vertex - let result = self.generate_subgraph( - source_vertex.clone(), - source_vertex, - &mut callback_map, - &mut subgraph, - ); - - // If the subgraph is generated successfully, serialize it into a JSON string and return it. - // Otherwise, return None - if result { - match serde_json::to_string(&subgraph) { - Ok(json) => Ok(Some(json)), - Err(e) => Err(Error::SubGraphGenerator(e.to_string())), - } - } else { - Ok(None) - } - } - - // generate_subgraph function is a recursive function that generates the sub graph from the source vertex for - // the given list of callbacks. The function returns true if the subgraph is generated successfully(if we are - // able to find a subgraph for the message using the given callbacks), it - // updates the subgraph with the path from the source vertex to the downstream vertices. - fn generate_subgraph( - &self, - current: String, - from: String, - callback_map: &mut HashMap>, - subgraph: &mut Subgraph, - ) -> bool { - let mut current_callback: Option> = None; - - // we need to borrow the callback_map as mutable to update the visited flag of the callback - // so that next time when we visit the same callback, we can skip it. Because there can be cases - // where there will be multiple callbacks for the same vertex. - // If there are no callbacks for the current vertex, we should not continue - // because it means we have not received all the callbacks for the given message. - let Some(callbacks) = callback_map.get_mut(¤t) else { - return false; - }; - - // iterate over the callbacks for the current vertex and find the one that has not been visited, - // and it is coming from the same vertex as the current vertex - for callback in callbacks { - if callback.callback_request.from_vertex == from && !callback.visited { - callback.visited = true; - current_callback = Some(Arc::clone(&callback.callback_request)); - break; - } - } - - // If there is no callback which is not visited and its parent is the "from" vertex, then we should - // return false because we have not received all the callbacks for the given message. - let Some(current_callback) = current_callback else { - return false; - }; - - // add the current block to the subgraph - subgraph.blocks.push(Block { - from: current_callback.from_vertex.clone(), - to: current_callback.vertex.clone(), - cb_time: current_callback.cb_time, - }); - - // if the current vertex has a DROP tag, then we should not proceed further - // and return true - if current_callback - .tags - .as_ref() - .map_or(false, |tags| tags.contains(&DROP.to_string())) - { - return true; - } - - // recursively invoke the downstream vertices of the current vertex, if any - if let Some(edges) = self.dag.get(¤t) { - for edge in edges { - // check if the edge should proceed based on the conditions - // if there are no conditions, we should proceed with the edge - // if there are conditions, we should check the tags of the current callback - // with the tags of the edge and the operator of the tags to decide if we should - // proceed with the edge - let should_proceed = edge - .conditions - .as_ref() - // If the edge has conditions, get the tags - .and_then(|conditions| conditions.tags.as_ref()) - // If there are no conditions or tags, default to true (i.e., proceed with the edge) - // If there are tags, compare the tags with the current callback's tags and the operator - // to decide if we should proceed with the edge. - .map_or(true, |tags| { - current_callback - .tags - .as_ref() - // If the current callback has no tags we should not proceed with the edge for "and" and "or" operators - // because we expect the current callback to have tags specified in the edge. - // If the current callback has no tags we should proceed with the edge for "not" operator. - // because we don't expect the current callback to have tags specified in the edge. - .map_or( - tags.operator.as_ref() == Some(&OperatorType::Not), - |callback_tags| { - tags.operator.as_ref().map_or(false, |operator| { - // If there is no operator, default to false (i.e., do not proceed with the edge) - // If there is an operator, compare the current callback's tags with the edge's tags - compare_slice(operator, callback_tags, &tags.values) - }) - }, - ) - }); - - // if the conditions are not met, then proceed to the next edge - if !should_proceed { - continue; - } - - // proceed to the downstream vertex - // if any of the downstream vertex returns false, then we should return false. - if !self.generate_subgraph(edge.to.clone(), current.clone(), callback_map, subgraph) - { - return false; - } - } - } - // if there are no downstream vertices, or all the downstream vertices returned true, - // we can return true - true - } - - // from_env reads the pipeline stored in the environment variable and creates a MessageGraph from it. - pub(crate) fn from_pipeline(pipeline_spec: &PipelineDCG) -> Result { - let mut dag = Graph::with_capacity(pipeline_spec.edges.len()); - for edge in &pipeline_spec.edges { - dag.entry(edge.from.clone()).or_default().push(edge.clone()); - } - - Ok(MessageGraph { dag }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::pipeline::{Conditions, Tag, Vertex}; - - #[test] - fn test_no_subgraph() { - let mut dag: Graph = HashMap::new(); - dag.insert( - "a".to_string(), - vec![ - Edge { - from: "a".to_string(), - to: "b".to_string(), - conditions: None, - }, - Edge { - from: "a".to_string(), - to: "c".to_string(), - conditions: None, - }, - ], - ); - let message_graph = MessageGraph { dag }; - - let mut callback_map: HashMap> = HashMap::new(); - callback_map.insert( - "a".to_string(), - vec![CallbackRequestWrapper { - callback_request: Arc::new(CallbackRequest { - id: "uuid1".to_string(), - vertex: "a".to_string(), - cb_time: 1, - from_vertex: "a".to_string(), - tags: None, - }), - visited: false, - }], - ); - - let mut subgraph: Subgraph = Subgraph { - id: "uuid1".to_string(), - blocks: Vec::new(), - }; - let result = message_graph.generate_subgraph( - "a".to_string(), - "a".to_string(), - &mut callback_map, - &mut subgraph, - ); - - assert!(!result); - } - - #[test] - fn test_generate_subgraph() { - let mut dag: Graph = HashMap::new(); - dag.insert( - "a".to_string(), - vec![ - Edge { - from: "a".to_string(), - to: "b".to_string(), - conditions: None, - }, - Edge { - from: "a".to_string(), - to: "c".to_string(), - conditions: None, - }, - ], - ); - let message_graph = MessageGraph { dag }; - - let mut callback_map: HashMap> = HashMap::new(); - callback_map.insert( - "a".to_string(), - vec![CallbackRequestWrapper { - callback_request: Arc::new(CallbackRequest { - id: "uuid1".to_string(), - vertex: "a".to_string(), - cb_time: 1, - from_vertex: "a".to_string(), - tags: None, - }), - visited: false, - }], - ); - - callback_map.insert( - "b".to_string(), - vec![CallbackRequestWrapper { - callback_request: Arc::new(CallbackRequest { - id: "uuid1".to_string(), - vertex: "b".to_string(), - cb_time: 1, - from_vertex: "a".to_string(), - tags: None, - }), - visited: false, - }], - ); - - callback_map.insert( - "c".to_string(), - vec![CallbackRequestWrapper { - callback_request: Arc::new(CallbackRequest { - id: "uuid1".to_string(), - vertex: "c".to_string(), - cb_time: 1, - from_vertex: "a".to_string(), - tags: None, - }), - visited: false, - }], - ); - - let mut subgraph: Subgraph = Subgraph { - id: "uuid1".to_string(), - blocks: Vec::new(), - }; - let result = message_graph.generate_subgraph( - "a".to_string(), - "a".to_string(), - &mut callback_map, - &mut subgraph, - ); - - assert!(result); - } - - #[test] - fn test_generate_subgraph_complex() { - let pipeline = PipelineDCG { - vertices: vec![ - Vertex { - name: "a".to_string(), - }, - Vertex { - name: "b".to_string(), - }, - Vertex { - name: "c".to_string(), - }, - Vertex { - name: "d".to_string(), - }, - Vertex { - name: "e".to_string(), - }, - Vertex { - name: "f".to_string(), - }, - Vertex { - name: "g".to_string(), - }, - Vertex { - name: "h".to_string(), - }, - Vertex { - name: "i".to_string(), - }, - ], - edges: vec![ - Edge { - from: "a".to_string(), - to: "b".to_string(), - conditions: None, - }, - Edge { - from: "a".to_string(), - to: "c".to_string(), - conditions: None, - }, - Edge { - from: "b".to_string(), - to: "d".to_string(), - conditions: None, - }, - Edge { - from: "c".to_string(), - to: "e".to_string(), - conditions: None, - }, - Edge { - from: "d".to_string(), - to: "f".to_string(), - conditions: None, - }, - Edge { - from: "e".to_string(), - to: "f".to_string(), - conditions: None, - }, - Edge { - from: "f".to_string(), - to: "g".to_string(), - conditions: None, - }, - Edge { - from: "g".to_string(), - to: "h".to_string(), - conditions: Some(Conditions { - tags: Some(Tag { - operator: Some(OperatorType::And), - values: vec!["even".to_string()], - }), - }), - }, - Edge { - from: "g".to_string(), - to: "i".to_string(), - conditions: Some(Conditions { - tags: Some(Tag { - operator: Some(OperatorType::Or), - values: vec!["odd".to_string()], - }), - }), - }, - ], - }; - - let message_graph = MessageGraph::from_pipeline(&pipeline).unwrap(); - let source_vertex = "a".to_string(); - - let raw_callback = r#"[ - { - "id": "xxxx", - "vertex": "a", - "from_vertex": "a", - "cb_time": 123456789 - }, - { - "id": "xxxx", - "vertex": "b", - "from_vertex": "a", - "cb_time": 123456867 - }, - { - "id": "xxxx", - "vertex": "c", - "from_vertex": "a", - "cb_time": 123456819 - }, - { - "id": "xxxx", - "vertex": "d", - "from_vertex": "b", - "cb_time": 123456840 - }, - { - "id": "xxxx", - "vertex": "e", - "from_vertex": "c", - "cb_time": 123456843 - }, - { - "id": "xxxx", - "vertex": "f", - "from_vertex": "d", - "cb_time": 123456854 - }, - { - "id": "xxxx", - "vertex": "f", - "from_vertex": "e", - "cb_time": 123456886 - }, - { - "id": "xxxx", - "vertex": "g", - "from_vertex": "f", - "tags": ["even"], - "cb_time": 123456885 - }, - { - "id": "xxxx", - "vertex": "g", - "from_vertex": "f", - "tags": ["even"], - "cb_time": 123456888 - }, - { - "id": "xxxx", - "vertex": "h", - "from_vertex": "g", - "cb_time": 123456889 - }, - { - "id": "xxxx", - "vertex": "h", - "from_vertex": "g", - "cb_time": 123456890 - } - ]"#; - - let callbacks: Vec = serde_json::from_str(raw_callback).unwrap(); - let mut callback_map: HashMap> = HashMap::new(); - - for callback in callbacks { - callback_map - .entry(callback.vertex.clone()) - .or_default() - .push(CallbackRequestWrapper { - callback_request: Arc::new(callback), - visited: false, - }); - } - - let mut subgraph: Subgraph = Subgraph { - id: "xxxx".to_string(), - blocks: Vec::new(), - }; - let result = message_graph.generate_subgraph( - source_vertex.clone(), - source_vertex, - &mut callback_map, - &mut subgraph, - ); - - assert!(result); - } - - #[test] - fn test_simple_dropped_message() { - let pipeline = PipelineDCG { - vertices: vec![ - Vertex { - name: "a".to_string(), - }, - Vertex { - name: "b".to_string(), - }, - Vertex { - name: "c".to_string(), - }, - ], - edges: vec![ - Edge { - from: "a".to_string(), - to: "b".to_string(), - conditions: None, - }, - Edge { - from: "b".to_string(), - to: "c".to_string(), - conditions: None, - }, - ], - }; - - let message_graph = MessageGraph::from_pipeline(&pipeline).unwrap(); - let source_vertex = "a".to_string(); - - let raw_callback = r#" - [ - { - "id": "xxxx", - "vertex": "a", - "from_vertex": "a", - "cb_time": 123456789 - }, - { - "id": "xxxx", - "vertex": "b", - "from_vertex": "a", - "cb_time": 123456867, - "tags": ["U+005C__DROP__"] - } - ]"#; - - let callbacks: Vec = serde_json::from_str(raw_callback).unwrap(); - let mut callback_map: HashMap> = HashMap::new(); - - for callback in callbacks { - callback_map - .entry(callback.vertex.clone()) - .or_default() - .push(CallbackRequestWrapper { - callback_request: Arc::new(callback), - visited: false, - }); - } - - let mut subgraph: Subgraph = Subgraph { - id: "xxxx".to_string(), - blocks: Vec::new(), - }; - let result = message_graph.generate_subgraph( - source_vertex.clone(), - source_vertex, - &mut callback_map, - &mut subgraph, - ); - - assert!(result); - } - - #[test] - fn test_complex_dropped_message() { - let pipeline = PipelineDCG { - vertices: vec![ - Vertex { - name: "a".to_string(), - }, - Vertex { - name: "b".to_string(), - }, - Vertex { - name: "c".to_string(), - }, - Vertex { - name: "d".to_string(), - }, - Vertex { - name: "e".to_string(), - }, - Vertex { - name: "f".to_string(), - }, - Vertex { - name: "g".to_string(), - }, - Vertex { - name: "h".to_string(), - }, - Vertex { - name: "i".to_string(), - }, - ], - edges: vec![ - Edge { - from: "a".to_string(), - to: "b".to_string(), - conditions: None, - }, - Edge { - from: "a".to_string(), - to: "c".to_string(), - conditions: None, - }, - Edge { - from: "b".to_string(), - to: "d".to_string(), - conditions: None, - }, - Edge { - from: "c".to_string(), - to: "e".to_string(), - conditions: None, - }, - Edge { - from: "d".to_string(), - to: "f".to_string(), - conditions: None, - }, - Edge { - from: "e".to_string(), - to: "f".to_string(), - conditions: None, - }, - Edge { - from: "f".to_string(), - to: "g".to_string(), - conditions: None, - }, - Edge { - from: "g".to_string(), - to: "h".to_string(), - conditions: Some(Conditions { - tags: Some(Tag { - operator: Some(OperatorType::And), - values: vec!["even".to_string()], - }), - }), - }, - Edge { - from: "g".to_string(), - to: "i".to_string(), - conditions: Some(Conditions { - tags: Some(Tag { - operator: Some(OperatorType::Or), - values: vec!["odd".to_string()], - }), - }), - }, - ], - }; - - let message_graph = MessageGraph::from_pipeline(&pipeline).unwrap(); - let source_vertex = "a".to_string(); - - let raw_callback = r#" - [ - { - "id": "xxxx", - "vertex": "a", - "from_vertex": "a", - "cb_time": 123456789 - }, - { - "id": "xxxx", - "vertex": "b", - "from_vertex": "a", - "cb_time": 123456867 - }, - { - "id": "xxxx", - "vertex": "c", - "from_vertex": "a", - "cb_time": 123456819, - "tags": ["U+005C__DROP__"] - }, - { - "id": "xxxx", - "vertex": "d", - "from_vertex": "b", - "cb_time": 123456840 - }, - { - "id": "xxxx", - "vertex": "f", - "from_vertex": "d", - "cb_time": 123456854 - }, - { - "id": "xxxx", - "vertex": "g", - "from_vertex": "f", - "tags": ["even"], - "cb_time": 123456885 - }, - { - "id": "xxxx", - "vertex": "h", - "from_vertex": "g", - "cb_time": 123456889 - } - ]"#; - - let callbacks: Vec = serde_json::from_str(raw_callback).unwrap(); - let mut callback_map: HashMap> = HashMap::new(); - - for callback in callbacks { - callback_map - .entry(callback.vertex.clone()) - .or_default() - .push(CallbackRequestWrapper { - callback_request: Arc::new(callback), - visited: false, - }); - } - - let mut subgraph: Subgraph = Subgraph { - id: "xxxx".to_string(), - blocks: Vec::new(), - }; - let result = message_graph.generate_subgraph( - source_vertex.clone(), - source_vertex, - &mut callback_map, - &mut subgraph, - ); - - assert!(result); - } - - #[test] - fn test_simple_cycle_pipeline() { - let pipeline = PipelineDCG { - vertices: vec![ - Vertex { - name: "a".to_string(), - }, - Vertex { - name: "b".to_string(), - }, - Vertex { - name: "c".to_string(), - }, - ], - edges: vec![ - Edge { - from: "a".to_string(), - to: "b".to_string(), - conditions: None, - }, - Edge { - from: "b".to_string(), - to: "a".to_string(), - conditions: Some(Conditions { - tags: Some(Tag { - operator: Some(OperatorType::Not), - values: vec!["failed".to_string()], - }), - }), - }, - ], - }; - - let message_graph = MessageGraph::from_pipeline(&pipeline).unwrap(); - let source_vertex = "a".to_string(); - - let raw_callback = r#" - [ - { - "id": "xxxx", - "vertex": "a", - "from_vertex": "a", - "cb_time": 123456789 - }, - { - "id": "xxxx", - "vertex": "b", - "from_vertex": "a", - "cb_time": 123456867, - "tags": ["failed"] - }, - { - "id": "xxxx", - "vertex": "a", - "from_vertex": "b", - "cb_time": 123456819 - }, - { - "id": "xxxx", - "vertex": "b", - "from_vertex": "a", - "cb_time": 123456819 - }, - { - "id": "xxxx", - "vertex": "c", - "from_vertex": "b", - "cb_time": 123456819 - } - ]"#; - - let callbacks: Vec = serde_json::from_str(raw_callback).unwrap(); - let mut callback_map: HashMap> = HashMap::new(); - - for callback in callbacks { - callback_map - .entry(callback.vertex.clone()) - .or_default() - .push(CallbackRequestWrapper { - callback_request: Arc::new(callback), - visited: false, - }); - } - - let mut subgraph: Subgraph = Subgraph { - id: "xxxx".to_string(), - blocks: Vec::new(), - }; - let result = message_graph.generate_subgraph( - source_vertex.clone(), - source_vertex, - &mut callback_map, - &mut subgraph, - ); - - assert!(result); - } - - #[test] - fn test_generate_subgraph_from_callbacks_with_invalid_vertex() { - // Create a simple graph - let mut dag: Graph = HashMap::new(); - dag.insert( - "a".to_string(), - vec![Edge { - from: "a".to_string(), - to: "b".to_string(), - conditions: None, - }], - ); - let message_graph = MessageGraph { dag }; - - // Create a callback with an invalid vertex - let callbacks = vec![Arc::new(CallbackRequest { - id: "test".to_string(), - vertex: "invalid_vertex".to_string(), - from_vertex: "invalid_vertex".to_string(), - cb_time: 1, - tags: None, - })]; - - // Call the function with the invalid callback - let result = message_graph.generate_subgraph_from_callbacks("test".to_string(), callbacks); - - // Check that the function returned an error - assert!(result.is_err()); - assert!(matches!(result, Err(Error::SubGraphInvalidInput(_)))); - } -} diff --git a/rust/extns/numaflow-serving/src/config.rs b/rust/extns/numaflow-serving/src/config.rs deleted file mode 100644 index d306b47d70..0000000000 --- a/rust/extns/numaflow-serving/src/config.rs +++ /dev/null @@ -1,235 +0,0 @@ -use std::collections::HashMap; - -use base64::prelude::BASE64_STANDARD; -use base64::Engine; -use serde::{Deserialize, Serialize}; - -use crate::pipeline::PipelineDCG; -use crate::Error; - -const ENV_NUMAFLOW_SERVING_SOURCE_OBJECT: &str = "NUMAFLOW_SERVING_SOURCE_OBJECT"; -const ENV_NUMAFLOW_SERVING_STORE_TTL: &str = "NUMAFLOW_SERVING_STORE_TTL"; -const ENV_NUMAFLOW_SERVING_HOST_IP: &str = "NUMAFLOW_SERVING_HOST_IP"; -const ENV_NUMAFLOW_SERVING_APP_PORT: &str = "NUMAFLOW_SERVING_APP_LISTEN_PORT"; -const ENV_NUMAFLOW_SERVING_AUTH_TOKEN: &str = "NUMAFLOW_SERVING_AUTH_TOKEN"; -const ENV_MIN_PIPELINE_SPEC: &str = "NUMAFLOW_SERVING_MIN_PIPELINE_SPEC"; - -pub(crate) const SAVED: &str = "SAVED"; - -#[derive(Debug, Deserialize, Clone, PartialEq)] -pub struct RedisConfig { - pub addr: String, - pub max_tasks: usize, - pub retries: usize, - pub retries_duration_millis: u16, - pub ttl_secs: Option, -} - -impl Default for RedisConfig { - fn default() -> Self { - Self { - addr: "redis://127.0.0.1:6379".to_owned(), - max_tasks: 50, - retries: 5, - retries_duration_millis: 100, - // TODO: we might need an option type here. Zero value of u32 can be used instead of None - ttl_secs: Some(1), - } - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct Settings { - pub tid_header: String, - pub app_listen_port: u16, - pub metrics_server_listen_port: u16, - pub upstream_addr: String, - pub drain_timeout_secs: u64, - /// The IP address of the numaserve pod. This will be used to construct the value for X-Numaflow-Callback-Url header - pub host_ip: String, - pub api_auth_token: Option, - pub redis: RedisConfig, - pub pipeline_spec: PipelineDCG, -} - -impl Default for Settings { - fn default() -> Self { - Self { - tid_header: "ID".to_owned(), - app_listen_port: 3000, - metrics_server_listen_port: 3001, - upstream_addr: "localhost:8888".to_owned(), - drain_timeout_secs: 10, - host_ip: "127.0.0.1".to_owned(), - api_auth_token: None, - redis: RedisConfig::default(), - pipeline_spec: Default::default(), - } - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Serving { - #[serde(rename = "msgIDHeaderKey")] - pub msg_id_header_key: Option, - #[serde(rename = "store")] - pub callback_storage: CallbackStorageConfig, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct CallbackStorageConfig { - pub url: String, -} - -/// This implementation is to load settings from env variables -impl TryFrom> for Settings { - type Error = Error; - fn try_from(env_vars: HashMap) -> std::result::Result { - let host_ip = env_vars - .get(ENV_NUMAFLOW_SERVING_HOST_IP) - .ok_or_else(|| { - Error::ParseConfig(format!( - "Environment variable {ENV_NUMAFLOW_SERVING_HOST_IP} is not set" - )) - })? - .to_owned(); - - let pipeline_spec: PipelineDCG = env_vars - .get(ENV_MIN_PIPELINE_SPEC) - .ok_or_else(|| { - Error::ParseConfig(format!( - "Pipeline spec is not set using environment variable {ENV_MIN_PIPELINE_SPEC}" - )) - })? - .parse() - .map_err(|e| { - Error::ParseConfig(format!( - "Parsing pipeline spec: {}: error={e:?}", - env_vars.get(ENV_MIN_PIPELINE_SPEC).unwrap() - )) - })?; - - let mut settings = Settings { - host_ip, - pipeline_spec, - ..Default::default() - }; - - if let Some(api_auth_token) = env_vars.get(ENV_NUMAFLOW_SERVING_AUTH_TOKEN) { - settings.api_auth_token = Some(api_auth_token.to_owned()); - } - - if let Some(app_port) = env_vars.get(ENV_NUMAFLOW_SERVING_APP_PORT) { - settings.app_listen_port = app_port.parse().map_err(|e| { - Error::ParseConfig(format!( - "Parsing {ENV_NUMAFLOW_SERVING_APP_PORT}(set to '{app_port}'): {e:?}" - )) - })?; - } - - // Update redis.ttl_secs from environment variable - if let Some(ttl_secs) = env_vars.get(ENV_NUMAFLOW_SERVING_STORE_TTL) { - let ttl_secs: u32 = ttl_secs.parse().map_err(|e| { - Error::ParseConfig(format!("parsing {ENV_NUMAFLOW_SERVING_STORE_TTL}: {e:?}")) - })?; - settings.redis.ttl_secs = Some(ttl_secs); - } - - let Some(source_spec_encoded) = env_vars.get(ENV_NUMAFLOW_SERVING_SOURCE_OBJECT) else { - return Ok(settings); - }; - - let source_spec_decoded = BASE64_STANDARD - .decode(source_spec_encoded.as_bytes()) - .map_err(|e| Error::ParseConfig(format!("decoding NUMAFLOW_SERVING_SOURCE: {e:?}")))?; - - let source_spec = serde_json::from_slice::(&source_spec_decoded) - .map_err(|e| Error::ParseConfig(format!("parsing NUMAFLOW_SERVING_SOURCE: {e:?}")))?; - - // Update tid_header from source_spec - if let Some(msg_id_header_key) = source_spec.msg_id_header_key { - settings.tid_header = msg_id_header_key; - } - - // Update redis.addr from source_spec, currently we only support redis as callback storage - settings.redis.addr = source_spec.callback_storage.url; - - Ok(settings) - } -} - -#[cfg(test)] -mod tests { - use crate::pipeline::{Edge, Vertex}; - - use super::*; - - #[test] - fn test_default_config() { - let settings = Settings::default(); - - assert_eq!(settings.tid_header, "ID"); - assert_eq!(settings.app_listen_port, 3000); - assert_eq!(settings.metrics_server_listen_port, 3001); - assert_eq!(settings.upstream_addr, "localhost:8888"); - assert_eq!(settings.drain_timeout_secs, 10); - assert_eq!(settings.redis.addr, "redis://127.0.0.1:6379"); - assert_eq!(settings.redis.max_tasks, 50); - assert_eq!(settings.redis.retries, 5); - assert_eq!(settings.redis.retries_duration_millis, 100); - } - - #[test] - fn test_config_parse() { - // Set up the environment variables - let env_vars = [ - (ENV_NUMAFLOW_SERVING_HOST_IP, "10.2.3.5"), - (ENV_NUMAFLOW_SERVING_AUTH_TOKEN, "api-auth-token"), - (ENV_NUMAFLOW_SERVING_APP_PORT, "8443"), - (ENV_NUMAFLOW_SERVING_STORE_TTL, "86400"), - (ENV_NUMAFLOW_SERVING_SOURCE_OBJECT, "eyJhdXRoIjpudWxsLCJzZXJ2aWNlIjp0cnVlLCJtc2dJREhlYWRlcktleSI6IlgtTnVtYWZsb3ctSWQiLCJzdG9yZSI6eyJ1cmwiOiJyZWRpczovL3JlZGlzOjYzNzkifX0="), - (ENV_MIN_PIPELINE_SPEC, "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6InNlcnZpbmctaW4iLCJzb3VyY2UiOnsic2VydmluZyI6eyJhdXRoIjpudWxsLCJzZXJ2aWNlIjp0cnVlLCJtc2dJREhlYWRlcktleSI6IlgtTnVtYWZsb3ctSWQiLCJzdG9yZSI6eyJ1cmwiOiJyZWRpczovL3JlZGlzOjYzNzkifX19LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciIsImVudiI6W3sibmFtZSI6IlJVU1RfTE9HIiwidmFsdWUiOiJpbmZvIn1dfSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2aW5nLXNpbmsiLCJzaW5rIjp7InVkc2luayI6eyJjb250YWluZXIiOnsiaW1hZ2UiOiJxdWF5LmlvL251bWFpby9udW1hZmxvdy1ycy9zaW5rLWxvZzpzdGFibGUiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9fX0sInJldHJ5U3RyYXRlZ3kiOnt9fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fV0sImVkZ2VzIjpbeyJmcm9tIjoic2VydmluZy1pbiIsInRvIjoic2VydmluZy1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH1dLCJsaWZlY3ljbGUiOnt9LCJ3YXRlcm1hcmsiOnt9fQ==") - ]; - - // Call the config method - let settings: Settings = env_vars - .into_iter() - .map(|(key, val)| (key.to_owned(), val.to_owned())) - .collect::>() - .try_into() - .unwrap(); - - let expected_config = Settings { - tid_header: "X-Numaflow-Id".into(), - app_listen_port: 8443, - metrics_server_listen_port: 3001, - upstream_addr: "localhost:8888".into(), - drain_timeout_secs: 10, - redis: RedisConfig { - addr: "redis://redis:6379".into(), - max_tasks: 50, - retries: 5, - retries_duration_millis: 100, - ttl_secs: Some(86400), - }, - host_ip: "10.2.3.5".into(), - api_auth_token: Some("api-auth-token".into()), - pipeline_spec: PipelineDCG { - vertices: vec![ - Vertex { - name: "serving-in".into(), - }, - Vertex { - name: "serving-sink".into(), - }, - ], - edges: vec![Edge { - from: "serving-in".into(), - to: "serving-sink".into(), - conditions: None, - }], - }, - }; - assert_eq!(settings, expected_config); - } -} diff --git a/rust/extns/numaflow-serving/src/errors.rs b/rust/extns/numaflow-serving/src/errors.rs deleted file mode 100644 index 7446faee71..0000000000 --- a/rust/extns/numaflow-serving/src/errors.rs +++ /dev/null @@ -1,50 +0,0 @@ -use tokio::sync::oneshot; - -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("Initialization error - {0}")] - InitError(String), - - #[error("Failed to parse configuration - {0}")] - ParseConfig(String), - // - // callback errors - // TODO: store the ID too? - #[error("IDNotFound Error - {0}")] - IDNotFound(&'static str), - - #[error("SubGraphGenerator Error - {0}")] - // subgraph generator errors - SubGraphGenerator(String), - - #[error("StoreWrite Error - {0}")] - // Store write errors - StoreWrite(String), - - #[error("SubGraphNotFound Error - {0}")] - // Sub Graph Not Found Error - SubGraphNotFound(&'static str), - - #[error("SubGraphInvalidInput Error - {0}")] - // Sub Graph Invalid Input Error - SubGraphInvalidInput(String), - - #[error("StoreRead Error - {0}")] - // Store read errors - StoreRead(String), - - #[error("Metrics Error - {0}")] - // Metrics errors - MetricsServer(String), - - #[error("Connection Error - {0}")] - Connection(String), - - #[error("Failed to receive message from channel. Actor task is terminated: {0:?}")] - ActorTaskTerminated(oneshot::error::RecvError), - - #[error("{0}")] - Other(String), -} - -pub type Result = std::result::Result; diff --git a/rust/extns/numaflow-serving/src/lib.rs b/rust/extns/numaflow-serving/src/lib.rs deleted file mode 100644 index a6a446203e..0000000000 --- a/rust/extns/numaflow-serving/src/lib.rs +++ /dev/null @@ -1,12 +0,0 @@ -mod source; -pub use source::{Message, MessageWrapper, ServingSource}; - -mod app; -pub mod config; - -mod errors; -pub use errors::{Error, Result}; - -pub(crate) mod pipeline; - -pub type Settings = config::Settings; diff --git a/rust/extns/numaflow-serving/src/pipeline.rs b/rust/extns/numaflow-serving/src/pipeline.rs deleted file mode 100644 index 50236f5381..0000000000 --- a/rust/extns/numaflow-serving/src/pipeline.rs +++ /dev/null @@ -1,154 +0,0 @@ -use std::str::FromStr; - -use base64::prelude::BASE64_STANDARD; -use base64::Engine; -use numaflow_models::models::PipelineSpec; -use serde::{Deserialize, Serialize}; - -use crate::Error::ParseConfig; - -// OperatorType is an enum that contains the types of operators -// that can be used in the conditions for the edge. -#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] -pub enum OperatorType { - #[serde(rename = "and")] - And, - #[serde(rename = "or")] - Or, - #[serde(rename = "not")] - Not, -} - -#[allow(dead_code)] -impl OperatorType { - fn as_str(&self) -> &'static str { - match self { - OperatorType::And => "and", - OperatorType::Or => "or", - OperatorType::Not => "not", - } - } -} - -impl From for OperatorType { - fn from(s: String) -> Self { - match s.as_str() { - "and" => OperatorType::And, - "or" => OperatorType::Or, - "not" => OperatorType::Not, - _ => panic!("Invalid operator type: {}", s), - } - } -} - -// Tag is a struct that contains the information about the tags for the edge -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Tag { - pub operator: Option, - pub values: Vec, -} - -// Conditions is a struct that contains the information about the conditions for the edge -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Conditions { - pub tags: Option, -} - -// Edge is a struct that contains the information about the edge in the pipeline. -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Edge { - pub from: String, - pub to: String, - pub conditions: Option, -} - -/// DCG (directed compute graph) of the pipeline with minimal information build using vertices and edges -/// from the pipeline spec -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] -pub struct PipelineDCG { - pub vertices: Vec, - pub edges: Vec, -} - -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Vertex { - pub name: String, -} - -impl FromStr for PipelineDCG { - type Err = crate::Error; - - fn from_str(pipeline_spec_encoded: &str) -> Result { - let full_pipeline_spec_decoded = BASE64_STANDARD - .decode(pipeline_spec_encoded) - .map_err(|e| ParseConfig(format!("Decoding pipeline from env: {e:?}")))?; - - let full_pipeline_spec = - serde_json::from_slice::(&full_pipeline_spec_decoded) - .map_err(|e| ParseConfig(format!("parsing pipeline from env: {e:?}")))?; - - let vertices: Vec = full_pipeline_spec - .vertices - .ok_or(ParseConfig("missing vertices in pipeline spec".to_string()))? - .iter() - .map(|v| Vertex { - name: v.name.clone(), - }) - .collect(); - - let edges: Vec = full_pipeline_spec - .edges - .ok_or(ParseConfig("missing edges in pipeline spec".to_string()))? - .iter() - .map(|e| { - let conditions = e.conditions.clone().map(|c| Conditions { - tags: Some(Tag { - operator: c.tags.operator.map(|o| o.into()), - values: c.tags.values.clone(), - }), - }); - - Edge { - from: e.from.clone(), - to: e.to.clone(), - conditions, - } - }) - .collect(); - - Ok(PipelineDCG { vertices, edges }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_pipeline_load() { - let pipeline: PipelineDCG = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0=".parse().unwrap(); - - assert_eq!(pipeline.vertices.len(), 8); - assert_eq!(pipeline.edges.len(), 10); - assert_eq!(pipeline.vertices[0].name, "in"); - assert_eq!(pipeline.edges[0].from, "in"); - assert_eq!(pipeline.edges[0].to, "planner"); - assert!(pipeline.edges[0].conditions.is_none()); - - assert_eq!(pipeline.vertices[1].name, "planner"); - assert_eq!(pipeline.edges[1].from, "planner"); - assert_eq!(pipeline.edges[1].to, "asciiart"); - assert_eq!( - pipeline.edges[1].conditions, - Some(Conditions { - tags: Some(Tag { - operator: Some(OperatorType::Or), - values: vec!["asciiart".to_owned()] - }) - }) - ); - - assert_eq!(pipeline.vertices[2].name, "tiger"); - assert_eq!(pipeline.vertices[3].name, "dog"); - } -} diff --git a/rust/extns/numaflow-serving/src/source.rs b/rust/extns/numaflow-serving/src/source.rs deleted file mode 100644 index af7d8aaa8f..0000000000 --- a/rust/extns/numaflow-serving/src/source.rs +++ /dev/null @@ -1,194 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - -use bytes::Bytes; -use tokio::sync::{mpsc, oneshot}; -use tokio::time::Instant; - -use crate::app::callback::state::State as CallbackState; -use crate::app::callback::store::redisstore::RedisConnection; -use crate::app::tracker::MessageGraph; -use crate::{app, Settings}; -use crate::{Error, Result}; - -pub struct MessageWrapper { - pub confirm_save: oneshot::Sender<()>, - pub message: Message, -} - -pub struct Message { - pub value: Bytes, - pub id: String, - pub headers: HashMap, -} - -enum ActorMessage { - Read { - batch_size: usize, - timeout_at: Instant, - reply_to: oneshot::Sender>, - }, - Ack { - offsets: Vec, - reply_to: oneshot::Sender<()>, - }, -} - -struct ServingSourceActor { - messages: mpsc::Receiver, - handler_rx: mpsc::Receiver, - tracker: HashMap>, -} - -impl ServingSourceActor { - async fn start( - settings: Arc, - handler_rx: mpsc::Receiver, - ) -> Result<()> { - let (messages_tx, messages_rx) = mpsc::channel(10000); - // Create a redis store to store the callbacks and the custom responses - let redis_store = RedisConnection::new(settings.redis.clone()).await?; - // Create the message graph from the pipeline spec and the redis store - let msg_graph = MessageGraph::from_pipeline(&settings.pipeline_spec).map_err(|e| { - Error::InitError(format!( - "Creating message graph from pipeline spec: {:?}", - e - )) - })?; - let callback_state = CallbackState::new(msg_graph, redis_store).await?; - - tokio::spawn(async move { - let mut serving_actor = ServingSourceActor { - messages: messages_rx, - handler_rx, - tracker: HashMap::new(), - }; - serving_actor.run().await; - }); - let app = app::AppState { - message: messages_tx, - settings, - callback_state, - }; - app::serve(app).await.unwrap(); - Ok(()) - } - - async fn run(&mut self) { - while let Some(msg) = self.handler_rx.recv().await { - self.handle_message(msg).await; - } - } - - async fn handle_message(&mut self, actor_msg: ActorMessage) { - match actor_msg { - ActorMessage::Read { - batch_size, - timeout_at, - reply_to, - } => { - let messages = self.read(batch_size, timeout_at).await; - let _ = reply_to.send(messages); - } - ActorMessage::Ack { offsets, reply_to } => { - self.ack(offsets).await; - let _ = reply_to.send(()); - } - } - } - - async fn read(&mut self, count: usize, timeout_at: Instant) -> Vec { - let mut messages = vec![]; - loop { - if messages.len() >= count || Instant::now() >= timeout_at { - break; - } - let message = match self.messages.try_recv() { - Ok(msg) => msg, - Err(mpsc::error::TryRecvError::Empty) => break, - Err(e) => { - tracing::error!(?e, "Receiving messages from the serving channel"); // FIXME: - return messages; - } - }; - let MessageWrapper { - confirm_save, - message, - } = message; - - self.tracker.insert(message.id.clone(), confirm_save); - messages.push(message); - } - messages - } - - async fn ack(&mut self, offsets: Vec) { - for offset in offsets { - let offset = offset - .strip_suffix("-0") - .expect("offset does not end with '-0'"); // FIXME: we hardcode 0 as the partition index when constructing offset - let confirm_save_tx = self - .tracker - .remove(offset) - .expect("offset was not found in the tracker"); - confirm_save_tx - .send(()) - .expect("Sending on confirm_save channel"); - } - } -} - -#[derive(Clone)] -pub struct ServingSource { - batch_size: usize, - // timeout for each batch read request - timeout: Duration, - actor_tx: mpsc::Sender, -} - -impl ServingSource { - pub async fn new( - settings: Arc, - batch_size: usize, - timeout: Duration, - ) -> Result { - let (actor_tx, actor_rx) = mpsc::channel(10); - ServingSourceActor::start(settings, actor_rx).await?; - Ok(Self { - batch_size, - timeout, - actor_tx, - }) - } - - pub async fn read_messages(&self) -> Result> { - let start = Instant::now(); - let (tx, rx) = oneshot::channel(); - let actor_msg = ActorMessage::Read { - reply_to: tx, - batch_size: self.batch_size, - timeout_at: Instant::now() + self.timeout, - }; - let _ = self.actor_tx.send(actor_msg).await; - let messages = rx.await.map_err(Error::ActorTaskTerminated)?; - tracing::debug!( - count = messages.len(), - requested_count = self.batch_size, - time_taken_ms = start.elapsed().as_millis(), - "Got messages from Serving source" - ); - Ok(messages) - } - - pub async fn ack_messages(&self, offsets: Vec) -> Result<()> { - let (tx, rx) = oneshot::channel(); - let actor_msg = ActorMessage::Ack { - offsets, - reply_to: tx, - }; - let _ = self.actor_tx.send(actor_msg).await; - rx.await.map_err(Error::ActorTaskTerminated)?; - Ok(()) - } -} diff --git a/rust/numaflow-core/Cargo.toml b/rust/numaflow-core/Cargo.toml index dc0f6582da..ea7e22a04e 100644 --- a/rust/numaflow-core/Cargo.toml +++ b/rust/numaflow-core/Cargo.toml @@ -19,7 +19,6 @@ numaflow-models.workspace = true numaflow-pb.workspace = true serving.workspace = true backoff.workspace = true -numaflow-serving.workspace = true axum.workspace = true axum-server.workspace = true bytes.workspace = true diff --git a/rust/numaflow-core/src/config/components.rs b/rust/numaflow-core/src/config/components.rs index 29734c98a1..833ad8950b 100644 --- a/rust/numaflow-core/src/config/components.rs +++ b/rust/numaflow-core/src/config/components.rs @@ -38,7 +38,7 @@ pub(crate) mod source { Generator(GeneratorConfig), UserDefined(UserDefinedConfig), Pulsar(PulsarSourceConfig), - Serving(Arc), + Serving(Arc), } impl From> for SourceType { @@ -111,7 +111,7 @@ pub(crate) mod source { // There should be only one option (user-defined) to define the settings. fn try_from(cfg: Box) -> Result { let env_vars = env::vars().collect::>(); - let mut settings: numaflow_serving::Settings = env_vars.try_into()?; + let mut settings: serving::Settings = env_vars.try_into()?; settings.tid_header = cfg.msg_id_header_key; diff --git a/rust/numaflow-core/src/metrics.rs b/rust/numaflow-core/src/metrics.rs index fa79e457b8..2a672ec31d 100644 --- a/rust/numaflow-core/src/metrics.rs +++ b/rust/numaflow-core/src/metrics.rs @@ -600,8 +600,6 @@ pub(crate) async fn start_metrics_https_server( addr: SocketAddr, metrics_state: UserDefinedContainerState, ) -> crate::Result<()> { - let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); - // Generate a self-signed certificate let CertifiedKey { cert, key_pair } = generate_simple_self_signed(vec!["localhost".into()]) .map_err(|e| Error::Metrics(format!("Generating self-signed certificate: {}", e)))?; diff --git a/rust/numaflow-core/src/shared/create_components.rs b/rust/numaflow-core/src/shared/create_components.rs index eeae1136e8..c077a1f44f 100644 --- a/rust/numaflow-core/src/shared/create_components.rs +++ b/rust/numaflow-core/src/shared/create_components.rs @@ -5,7 +5,7 @@ use numaflow_pb::clients::map::map_client::MapClient; use numaflow_pb::clients::sink::sink_client::SinkClient; use numaflow_pb::clients::source::source_client::SourceClient; use numaflow_pb::clients::sourcetransformer::source_transform_client::SourceTransformClient; -use numaflow_serving::ServingSource; +use serving::ServingSource; use tokio_util::sync::CancellationToken; use tonic::transport::Channel; diff --git a/rust/numaflow-core/src/source.rs b/rust/numaflow-core/src/source.rs index 73bc69b1ec..a30fc9777e 100644 --- a/rust/numaflow-core/src/source.rs +++ b/rust/numaflow-core/src/source.rs @@ -1,5 +1,4 @@ use numaflow_pulsar::source::PulsarSource; -use numaflow_serving::ServingSource; use std::sync::Arc; use tokio::sync::OwnedSemaphorePermit; use tokio::sync::Semaphore; @@ -39,6 +38,7 @@ pub(crate) mod generator; pub(crate) mod pulsar; pub(crate) mod serving; +use serving::ServingSource; /// Set of Read related items that has to be implemented to become a Source. pub(crate) trait SourceReader { diff --git a/rust/numaflow-core/src/source/serving.rs b/rust/numaflow-core/src/source/serving.rs index a06f418d66..02074904eb 100644 --- a/rust/numaflow-core/src/source/serving.rs +++ b/rust/numaflow-core/src/source/serving.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use numaflow_serving::ServingSource; +pub(crate) use serving::ServingSource; use crate::message::{MessageID, StringOffset}; use crate::Error; @@ -8,10 +8,10 @@ use crate::Result; use super::{get_vertex_name, Message, Offset}; -impl TryFrom for Message { +impl TryFrom for Message { type Error = Error; - fn try_from(message: numaflow_serving::Message) -> Result { + fn try_from(message: serving::Message) -> Result { let offset = Offset::String(StringOffset::new(message.id.clone(), 0)); Ok(Message { @@ -30,8 +30,8 @@ impl TryFrom for Message { } } -impl From for Error { - fn from(value: numaflow_serving::Error) -> Self { +impl From for Error { + fn from(value: serving::Error) -> Self { Error::Source(value.to_string()) } } diff --git a/rust/serving/Cargo.toml b/rust/serving/Cargo.toml index 0dbe0818c5..427bc84ce3 100644 --- a/rust/serving/Cargo.toml +++ b/rust/serving/Cargo.toml @@ -19,7 +19,6 @@ backoff.workspace = true axum.workspace = true axum-server.workspace = true bytes.workspace = true -async-nats = "0.35.1" axum-macros = "0.4.1" hyper-util = { version = "0.1.6", features = ["client-legacy"] } serde = { version = "1.0.204", features = ["derive"] } diff --git a/rust/serving/src/app.rs b/rust/serving/src/app.rs index b8bbc4c76e..5161d67ac0 100644 --- a/rust/serving/src/app.rs +++ b/rust/serving/src/app.rs @@ -246,16 +246,15 @@ async fn routes( mod tests { use std::sync::Arc; - use async_nats::jetstream::stream; use axum::http::StatusCode; - use tokio::sync::mpsc; - use tokio::time::{sleep, Duration}; use tower::ServiceExt; use super::*; use crate::app::callback::store::memstore::InMemoryStore; - use crate::config::generate_certs; use crate::Settings; + use callback::state::State as CallbackState; + use tokio::sync::mpsc; + use tracker::MessageGraph; const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; @@ -266,141 +265,78 @@ mod tests { #[tokio::test] async fn test_setup_app() -> Result<()> { let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await?; - let context = jetstream::new(client); - let stream_name = &settings.jetstream.stream; - - let stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await; - - assert!(stream.is_ok()); let mem_store = InMemoryStore::new(); let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; let callback_state = CallbackState::new(msg_graph, mem_store).await?; + let (tx, _) = mpsc::channel(10); + let app = AppState { + message: tx, + settings, + callback_state, + }; - let result = setup_app(settings, context, callback_state).await; + let result = setup_app(app).await; assert!(result.is_ok()); Ok(()) } #[cfg(feature = "all-tests")] #[tokio::test] - async fn test_livez() -> Result<()> { + async fn test_health_check_endpoints() -> Result<()> { let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await?; - let context = jetstream::new(client); - let stream_name = &settings.jetstream.stream; - - let stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await; - - assert!(stream.is_ok()); let mem_store = InMemoryStore::new(); - let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; - + let msg_graph = MessageGraph::from_pipeline(&settings.pipeline_spec)?; let callback_state = CallbackState::new(msg_graph, mem_store).await?; - let result = setup_app(settings, context, callback_state).await; + let (messages_tx, _messages_rx) = mpsc::channel(10); + let app = AppState { + message: messages_tx, + settings, + callback_state, + }; + + let router = setup_app(app).await.unwrap(); let request = Request::builder().uri("/livez").body(Body::empty())?; - - let response = result?.oneshot(request).await?; + let response = router.clone().oneshot(request).await?; assert_eq!(response.status(), StatusCode::NO_CONTENT); - Ok(()) - } - - #[cfg(feature = "all-tests")] - #[tokio::test] - async fn test_readyz() -> Result<()> { - let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await?; - let context = jetstream::new(client); - let stream_name = &settings.jetstream.stream; - - let stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await; - - assert!(stream.is_ok()); - - let mem_store = InMemoryStore::new(); - let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; - - let callback_state = CallbackState::new(msg_graph, mem_store).await?; - - let result = setup_app(settings, context, callback_state).await; let request = Request::builder().uri("/readyz").body(Body::empty())?; - - let response = result.unwrap().oneshot(request).await?; + let response = router.clone().oneshot(request).await?; assert_eq!(response.status(), StatusCode::NO_CONTENT); - Ok(()) - } - #[tokio::test] - async fn test_health_check() { - let response = health_check().await; - let response = response.into_response(); + let request = Request::builder().uri("/health").body(Body::empty())?; + let response = router.clone().oneshot(request).await?; assert_eq!(response.status(), StatusCode::OK); + Ok(()) } #[cfg(feature = "all-tests")] #[tokio::test] async fn test_auth_middleware() -> Result<()> { - let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await?; - let context = jetstream::new(client); - let stream_name = &settings.jetstream.stream; - - let stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await; - - assert!(stream.is_ok()); + let settings = Settings { + api_auth_token: Some("test-token".into()), + ..Default::default() + }; let mem_store = InMemoryStore::new(); let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; let callback_state = CallbackState::new(msg_graph, mem_store).await?; + let (messages_tx, _messages_rx) = mpsc::channel(10); let app_state = AppState { - settings, + message: messages_tx, + settings: Arc::new(settings), callback_state, - context, }; - let app = Router::new() - .nest("/v1/process", routes(app_state).await.unwrap()) - .layer(middleware::from_fn_with_state( - Some("test_token".to_owned()), - auth_middleware, - )); - - let res = app + let router = setup_app(app_state).await.unwrap(); + let res = router .oneshot( axum::extract::Request::builder() .uri("/v1/process/sync") diff --git a/rust/serving/src/app/jetstream_proxy.rs b/rust/serving/src/app/jetstream_proxy.rs index e5a7a661e1..30fb8e1ef9 100644 --- a/rust/serving/src/app/jetstream_proxy.rs +++ b/rust/serving/src/app/jetstream_proxy.rs @@ -256,8 +256,6 @@ fn extract_id_from_headers(tid_header: &str, headers: &HeaderMap) -> String { mod tests { use std::sync::Arc; - use async_nats::jetstream; - use async_nats::jetstream::stream; use axum::body::{to_bytes, Body}; use axum::extract::Request; use axum::http::header::{CONTENT_LENGTH, CONTENT_TYPE}; @@ -298,46 +296,38 @@ mod tests { #[tokio::test] async fn test_async_publish() -> Result<(), Box> { - let settings = Settings::default(); - let settings = Arc::new(settings); - let client = async_nats::connect(&settings.jetstream.url) - .await - .map_err(|e| format!("Connecting to Jetstream: {:?}", e))?; - - let context = jetstream::new(client); - let id = "foobar"; - let stream_name = "default"; - - let _stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await - .map_err(|e| format!("creating stream {}: {}", &settings.jetstream.url, e))?; + const ID_HEADER: &str = "X-Numaflow-ID"; + const ID_VALUE: &str = "foobar"; + let settings = Settings { + tid_header: ID_HEADER.into(), + ..Default::default() + }; let mock_store = MockStore {}; - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec) - .map_err(|e| format!("Failed to create message graph from pipeline spec: {:?}", e))?; - + let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); + let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; let callback_state = CallbackState::new(msg_graph, mock_store).await?; + + let (messages_tx, mut messages_rx) = mpsc::channel(10); let app_state = AppState { + message: messages_tx, + settings: Arc::new(settings), callback_state, - context, - settings, }; + let app = jetstream_proxy(app_state).await?; let res = Request::builder() .method("POST") .uri("/async") .header(CONTENT_TYPE, "text/plain") - .header("id", id) + .header(ID_HEADER, ID_VALUE) .body(Body::from("Test Message")) .unwrap(); let response = app.oneshot(res).await.unwrap(); + let message = messages_rx.recv().await.unwrap(); + assert_eq!(message.message.id, ID_VALUE); + message.confirm_save.send(()).unwrap(); assert_eq!(response.status(), StatusCode::OK); let result = extract_response_from_body(response.into_body()).await; @@ -345,7 +335,7 @@ mod tests { result, json!({ "message": "Successfully published message", - "id": id, + "id": ID_HEADER, "code": 200 }) ); @@ -387,20 +377,12 @@ mod tests { #[tokio::test] async fn test_sync_publish() { - let settings = Settings::default(); - let client = async_nats::connect(&settings.jetstream.url).await.unwrap(); - let context = jetstream::new(client); - let id = "foobar"; - let stream_name = "sync_pub"; - - let _stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await - .map_err(|e| format!("creating stream {}: {}", &settings.jetstream.url, e)); + const ID_HEADER: &str = "X-Numaflow-ID"; + const ID_VALUE: &str = "foobar"; + let settings = Settings { + tid_header: ID_HEADER.into(), + ..Default::default() + }; let mem_store = InMemoryStore::new(); let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); @@ -408,16 +390,17 @@ mod tests { let mut callback_state = CallbackState::new(msg_graph, mem_store).await.unwrap(); - let settings = Arc::new(settings); + let (messages_tx, mut messages_rx) = mpsc::channel(10); let app_state = AppState { - settings, + message: messages_tx, + settings: Arc::new(settings), callback_state: callback_state.clone(), - context, }; + let app = jetstream_proxy(app_state).await.unwrap(); tokio::spawn(async move { - let cbs = create_default_callbacks(id); + let cbs = create_default_callbacks(ID_VALUE); let mut retries = 0; loop { match callback_state.insert_callback_requests(cbs.clone()).await { @@ -437,11 +420,13 @@ mod tests { .method("POST") .uri("/sync") .header("Content-Type", "text/plain") - .header("id", id) + .header(ID_HEADER, ID_VALUE) .body(Body::from("Test Message")) .unwrap(); let response = app.clone().oneshot(res).await.unwrap(); + let message = messages_rx.recv().await.unwrap(); + message.confirm_save.send(()).unwrap(); assert_eq!(response.status(), StatusCode::OK); let result = extract_response_from_body(response.into_body()).await; @@ -449,7 +434,7 @@ mod tests { result, json!({ "message": "Successfully processed the message", - "id": id, + "id": ID_VALUE, "code": 200 }) ); @@ -457,20 +442,8 @@ mod tests { #[tokio::test] async fn test_sync_publish_serve() { + const ID_VALUE: &str = "foobar"; let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await.unwrap(); - let context = jetstream::new(client); - let id = "foobar"; - let stream_name = "sync_serve_pub"; - - let _stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await - .map_err(|e| format!("creating stream {}: {}", &settings.jetstream.url, e)); let mem_store = InMemoryStore::new(); let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); @@ -478,16 +451,17 @@ mod tests { let mut callback_state = CallbackState::new(msg_graph, mem_store).await.unwrap(); + let (messages_tx, mut messages_rx) = mpsc::channel(10); let app_state = AppState { + message: messages_tx, settings, callback_state: callback_state.clone(), - context, }; let app = jetstream_proxy(app_state).await.unwrap(); // pipeline is in -> cat -> out, so we will have 3 callback requests - let cbs = create_default_callbacks(id); + let cbs = create_default_callbacks(ID_VALUE); // spawn a tokio task which will insert the callback requests to the callback state // if it fails, sleep for 10ms and retry @@ -526,11 +500,13 @@ mod tests { .method("POST") .uri("/sync_serve") .header("Content-Type", "text/plain") - .header("id", id) + .header("ID", ID_VALUE) .body(Body::from("Test Message")) .unwrap(); let response = app.oneshot(res).await.unwrap(); + let message = messages_rx.recv().await.unwrap(); + message.confirm_save.send(()).unwrap(); assert_eq!(response.status(), StatusCode::OK); let content_len = response.headers().get(CONTENT_LENGTH).unwrap(); diff --git a/rust/serving/src/config.rs b/rust/serving/src/config.rs index 03fb2beb17..16c2ee125c 100644 --- a/rust/serving/src/config.rs +++ b/rust/serving/src/config.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use std::fmt::Debug; -use async_nats::rustls; use base64::prelude::BASE64_STANDARD; use base64::Engine; use rcgen::{generate_simple_self_signed, Certificate, CertifiedKey, KeyPair}; @@ -20,7 +19,6 @@ const ENV_NUMAFLOW_SERVING_AUTH_TOKEN: &str = "NUMAFLOW_SERVING_AUTH_TOKEN"; const ENV_MIN_PIPELINE_SPEC: &str = "NUMAFLOW_SERVING_MIN_PIPELINE_SPEC"; pub fn generate_certs() -> std::result::Result<(Certificate, KeyPair), String> { - let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); let CertifiedKey { cert, key_pair } = generate_simple_self_signed(vec!["localhost".into()]) .map_err(|e| format!("Failed to generate cert {:?}", e))?; Ok((cert, key_pair))