From 3a45e5f0a9d605dee9f0b210be2606fc5d51c101 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Mon, 23 Dec 2024 14:08:12 +0530 Subject: [PATCH 01/15] /v1/process/async publish works Signed-off-by: Sreekanth --- rust/Cargo.lock | 43 +- rust/Cargo.toml | 8 +- rust/extns/numaflow-serving/Cargo.toml | 27 + rust/extns/numaflow-serving/src/app.rs | 278 ++++++ .../numaflow-serving/src/app/callback.rs | 219 ++++ .../src/app/callback/state.rs | 378 +++++++ .../src/app/callback/store.rs | 35 + .../src/app/callback/store/memstore.rs | 217 ++++ .../src/app/callback/store/redisstore.rs | 198 ++++ .../numaflow-serving/src/app/response.rs | 60 ++ .../extns/numaflow-serving/src/app/tracker.rs | 933 ++++++++++++++++++ rust/extns/numaflow-serving/src/config.rs | 224 +++++ rust/extns/numaflow-serving/src/lib.rs | 248 +++++ rust/extns/numaflow-serving/src/pipeline.rs | 164 +++ rust/numaflow-core/Cargo.toml | 5 +- rust/numaflow-core/src/config/components.rs | 14 +- rust/numaflow-core/src/lib.rs | 1 + .../src/shared/create_components.rs | 14 +- rust/numaflow-core/src/source.rs | 11 + rust/numaflow-core/src/source/serving.rs | 77 ++ 20 files changed, 3133 insertions(+), 21 deletions(-) create mode 100644 rust/extns/numaflow-serving/Cargo.toml create mode 100644 rust/extns/numaflow-serving/src/app.rs create mode 100644 rust/extns/numaflow-serving/src/app/callback.rs create mode 100644 rust/extns/numaflow-serving/src/app/callback/state.rs create mode 100644 rust/extns/numaflow-serving/src/app/callback/store.rs create mode 100644 rust/extns/numaflow-serving/src/app/callback/store/memstore.rs create mode 100644 rust/extns/numaflow-serving/src/app/callback/store/redisstore.rs create mode 100644 rust/extns/numaflow-serving/src/app/response.rs create mode 100644 rust/extns/numaflow-serving/src/app/tracker.rs create mode 100644 rust/extns/numaflow-serving/src/config.rs create mode 100644 rust/extns/numaflow-serving/src/lib.rs create mode 100644 rust/extns/numaflow-serving/src/pipeline.rs create mode 100644 rust/numaflow-core/src/source/serving.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index a210284fcd..1090d81b9c 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1740,6 +1740,7 @@ dependencies = [ "numaflow-models", "numaflow-pb", "numaflow-pulsar", + "numaflow-serving", "parking_lot", "pep440_rs", "pin-project", @@ -1755,7 +1756,7 @@ dependencies = [ "serde_json", "serving", "tempfile", - "thiserror 2.0.3", + "thiserror 2.0.8", "tokio", "tokio-stream", "tokio-util", @@ -1799,12 +1800,36 @@ dependencies = [ "prost 0.11.9", "pulsar", "serde", - "thiserror 2.0.3", + "thiserror 2.0.8", "tokio", "tonic", "tracing", ] +[[package]] +name = "numaflow-serving" +version = "0.1.0" +dependencies = [ + "axum", + "axum-server", + "backoff", + "base64 0.22.1", + "bytes", + "chrono", + "numaflow-models", + "rcgen", + "redis", + "serde", + "serde_json", + "thiserror 2.0.8", + "tokio", + "tower 0.4.13", + "tower-http", + "tracing", + "trait-variant", + "uuid", +] + [[package]] name = "object" version = "0.36.5" @@ -2227,7 +2252,7 @@ dependencies = [ "rustc-hash 2.1.0", "rustls 0.23.19", "socket2", - "thiserror 2.0.3", + "thiserror 2.0.8", "tokio", "tracing", ] @@ -2246,7 +2271,7 @@ dependencies = [ "rustls 0.23.19", "rustls-pki-types", "slab", - "thiserror 2.0.3", + "thiserror 2.0.8", "tinyvec", "tracing", "web-time", @@ -3082,11 +3107,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.3" +version = "2.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" +checksum = "08f5383f3e0071702bf93ab5ee99b52d26936be9dedd9413067cbdcddcb6141a" dependencies = [ - "thiserror-impl 2.0.3", + "thiserror-impl 2.0.8", ] [[package]] @@ -3102,9 +3127,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.3" +version = "2.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" +checksum = "f2f357fcec90b3caef6623a099691be676d033b40a058ac95d2a6ade6fa0c943" dependencies = [ "proc-macro2", "quote", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 8a6b41a1a4..b986f29316 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -9,6 +9,7 @@ members = [ "numaflow-pb", "extns/numaflow-pulsar", "numaflow", + "extns/numaflow-serving", ] [workspace.lints.rust] @@ -40,8 +41,8 @@ verbose_file_reads = "warn" # This profile optimizes for runtime performance and small binary size at the expense of longer build times. # Compared to default release profile, this profile reduced binary size from 29MB to 21MB # and increased build time (with only one line change in code) from 12 seconds to 133 seconds (tested on Mac M2 Max). -[profile.release] -lto = "fat" +# [profile.release] +# lto = "fat" # This profile optimizes for short build times at the expense of larger binary size and slower runtime performance. # If you have to rebuild image often, in Dockerfile you may replace `--release` passed to cargo command with `--profile quick-release` @@ -59,7 +60,10 @@ numaflow-models = { path = "numaflow-models" } backoff = { path = "backoff" } numaflow-pb = { path = "numaflow-pb" } numaflow-pulsar = {path = "extns/numaflow-pulsar"} +numaflow-serving = {path = "extns/numaflow-serving"} tokio = "1.41.1" +bytes = "1.7.1" tracing = "0.1.40" axum = "0.7.5" axum-server = { version = "0.7.1", features = ["tls-rustls"] } +serde = { version = "1.0.204", features = ["derive"] } diff --git a/rust/extns/numaflow-serving/Cargo.toml b/rust/extns/numaflow-serving/Cargo.toml new file mode 100644 index 0000000000..9a2abf9959 --- /dev/null +++ b/rust/extns/numaflow-serving/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "numaflow-serving" +version = "0.1.0" +edition = "2021" + +[lints] +workspace = true + +[dependencies] +axum.workspace = true +axum-server.workspace = true +tokio.workspace = true +bytes.workspace = true +tracing.workspace = true +serde = { workspace = true } +numaflow-models.workspace = true +backoff.workspace = true +rcgen = "0.13.1" +tower = "0.4.13" +tower-http = { version = "0.5.2", features = ["trace", "timeout"] } +uuid = { version = "1.10.0", features = ["v4"] } +thiserror = "2.0.8" +base64 = "0.22.1" +serde_json = "1.0.120" +trait-variant = "0.1.2" +redis = { version = "0.26.0", features = ["tokio-comp", "aio", "connection-manager"] } +chrono = { version = "0.4", features = ["serde"] } diff --git a/rust/extns/numaflow-serving/src/app.rs b/rust/extns/numaflow-serving/src/app.rs new file mode 100644 index 0000000000..f2433e928a --- /dev/null +++ b/rust/extns/numaflow-serving/src/app.rs @@ -0,0 +1,278 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use axum::body::Body; +use axum::extract::{MatchedPath, Request, State}; +use axum::http::{HeaderMap, StatusCode}; +use axum::middleware::{self, Next}; +use axum::response::{IntoResponse, Response}; +use axum::routing::{get, post}; +use axum::{Json, Router}; +use axum_server::tls_rustls::RustlsConfig; +use axum_server::Handle; +use bytes::Bytes; +use rcgen::{generate_simple_self_signed, Certificate, CertifiedKey, KeyPair}; +use tokio::sync::{mpsc, oneshot}; +use tower::ServiceBuilder; +use tower_http::timeout::TimeoutLayer; +use tower_http::trace::{DefaultOnResponse, TraceLayer}; +use uuid::Uuid; + +use crate::{Error, Message, MessageWrapper, Settings}; +/// +/// manage callbacks +pub(crate) mod callback; + +mod response; + +use crate::app::callback::state::State as CallbackState; + +pub(crate) mod tracker; + +use self::callback::store::Store; +use self::response::{ApiError, ServeResponse}; + +pub fn generate_certs() -> crate::Result<(Certificate, KeyPair)> { + let CertifiedKey { cert, key_pair } = generate_simple_self_signed(vec!["localhost".into()]) + .map_err(|e| Error::InitError(format!("Failed to generate cert {:?}", e)))?; + Ok((cert, key_pair)) +} + +pub(crate) async fn serve(app: AppState) -> crate::Result<()> +where + T: Clone + Send + Sync + Store + 'static, +{ + let (cert, key) = generate_certs()?; + + let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) + .await + .map_err(|e| Error::InitError(format!("Failed to create tls config {:?}", e)))?; + + // TODO: Move all env variables into one place. Some env variables are loaded when Settings is initialized + + tracing::info!(config = ?app.settings, "Starting server with config and pipeline spec"); + // Start the main server, which serves the application. + tokio::spawn(start_main_server(app, tls_config)); + + Ok(()) +} + +#[derive(Clone)] +pub struct AppState { + pub message: mpsc::Sender, + pub settings: Arc, + pub callback_state: CallbackState, +} + +const PUBLISH_ENDPOINTS: [&str; 3] = [ + "/v1/process/sync", + "/v1/process/sync_serve", + "/v1/process/async", +]; +// auth middleware to do token based authentication for all user facing routes +// if auth is enabled. +async fn auth_middleware( + State(api_auth_token): State>, + request: axum::extract::Request, + next: Next, +) -> Response { + let path = request.uri().path(); + + // we only need to check for the presence of the auth token in the request headers for the publish endpoints + if !PUBLISH_ENDPOINTS.contains(&path) { + return next.run(request).await; + } + + match api_auth_token { + Some(token) => { + // Check for the presence of the auth token in the request headers + let auth_token = match request.headers().get("Authorization") { + Some(token) => token, + None => { + return Response::builder() + .status(401) + .body(Body::empty()) + .expect("failed to build response") + } + }; + if auth_token.to_str().expect("auth token should be a string") + != format!("Bearer {}", token) + { + Response::builder() + .status(401) + .body(Body::empty()) + .expect("failed to build response") + } else { + next.run(request).await + } + } + None => { + // If the auth token is not set, we don't need to check for the presence of the auth token in the request headers + next.run(request).await + } + } +} + +async fn start_main_server(app: AppState, tls_config: RustlsConfig) -> crate::Result<()> +where + T: Clone + Send + Sync + Store + 'static, +{ + let app_addr: SocketAddr = format!("0.0.0.0:{}", &app.settings.app_listen_port) + .parse() + .map_err(|e| Error::InitError(format!("{e:?}")))?; + + let tid_header = app.settings.tid_header.clone(); + let layers = ServiceBuilder::new() + // Add tracing to all requests + .layer( + TraceLayer::new_for_http() + .make_span_with(move |req: &Request| { + let tid = req + .headers() + .get(&tid_header) + .and_then(|v| v.to_str().ok()) + .map(|v| v.to_string()) + .unwrap_or_else(|| Uuid::new_v4().to_string()); + + let matched_path = req + .extensions() + .get::() + .map(MatchedPath::as_str); + + tracing::info_span!("request", tid, method=?req.method(), matched_path) + }) + .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)), + ) + .layer( + // Graceful shutdown will wait for outstanding requests to complete. Add a timeout so + // requests don't hang forever. + TimeoutLayer::new(Duration::from_secs(app.settings.drain_timeout_secs)), + ) + // Add auth middleware to all user facing routes + .layer(middleware::from_fn_with_state( + app.settings.api_auth_token.clone(), + auth_middleware, + )); + + let handle = Handle::new(); + + let router = setup_app(app).await.layer(layers); + + tracing::info!(?app_addr, "Starting application server"); + + axum_server::bind_rustls(app_addr, tls_config) + .handle(handle) + .serve(router.into_make_service()) + .await + .map_err(|e| Error::InitError(format!("Starting web server for metrics: {}", e)))?; + + Ok(()) +} + +async fn setup_app(state: AppState) -> Router { + let router = Router::new() + .route("/health", get(health_check)) + .route("/livez", get(livez)) // Liveliness check + .route("/readyz", get(readyz)) + .with_state(state.clone()); + + let router = router.nest("/v1/process", routes(state.clone())); + router +} + +async fn health_check() -> impl IntoResponse { + "ok" +} + +async fn livez() -> impl IntoResponse { + StatusCode::NO_CONTENT +} + +async fn readyz( + State(app): State>, +) -> impl IntoResponse { + if app.callback_state.clone().ready().await { + StatusCode::NO_CONTENT + } else { + StatusCode::INTERNAL_SERVER_ERROR + } +} + +// extracts the ID from the headers, if not found, generates a new UUID +fn extract_id_from_headers(tid_header: &str, headers: &HeaderMap) -> String { + headers.get(tid_header).map_or_else( + || Uuid::new_v4().to_string(), + |v| String::from_utf8_lossy(v.as_bytes()).to_string(), + ) +} + +fn routes(app_state: AppState) -> Router { + let jetstream_proxy = jetstream_proxy(app_state); + jetstream_proxy +} + +const CALLBACK_URL_KEY: &str = "X-Numaflow-Callback-Url"; +const NUMAFLOW_RESP_ARRAY_LEN: &str = "Numaflow-Array-Len"; +const NUMAFLOW_RESP_ARRAY_IDX_LEN: &str = "Numaflow-Array-Index-Len"; + +struct ProxyState { + tid_header: String, + callback: CallbackState, + stream: String, + callback_url: String, + messages: mpsc::Sender, +} + +pub(crate) fn jetstream_proxy( + state: AppState, +) -> Router { + let proxy_state = Arc::new(ProxyState { + tid_header: state.settings.tid_header.clone(), + callback: state.callback_state.clone(), + stream: state.settings.jetstream.stream.clone(), + messages: state.message.clone(), + callback_url: format!( + "https://{}:{}/v1/process/callback", + state.settings.host_ip, state.settings.app_listen_port + ), + }); + + let router = Router::new() + .route("/async", post(async_publish)) + .with_state(proxy_state); + router +} + +async fn async_publish( + State(proxy_state): State>>, + headers: HeaderMap, + body: Bytes, +) -> Result, ApiError> { + let id = extract_id_from_headers(&proxy_state.tid_header, &headers); + let mut msg_headers: HashMap = HashMap::new(); + for (k, v) in headers.iter() { + msg_headers.insert( + k.to_string(), + String::from_utf8_lossy(v.as_bytes()).to_string(), + ); + } + let (tx, rx) = oneshot::channel(); + let message = MessageWrapper { + confirm_save: tx, + message: Message { + value: body, + id: id.clone(), + headers: msg_headers, + }, + }; + proxy_state.messages.send(message).await.unwrap(); + rx.await.unwrap(); + + Ok(Json(ServeResponse::new( + "Successfully published message".to_string(), + id, + StatusCode::OK, + ))) +} diff --git a/rust/extns/numaflow-serving/src/app/callback.rs b/rust/extns/numaflow-serving/src/app/callback.rs new file mode 100644 index 0000000000..b4d43868ee --- /dev/null +++ b/rust/extns/numaflow-serving/src/app/callback.rs @@ -0,0 +1,219 @@ +use axum::{body::Bytes, extract::State, http::HeaderMap, routing, Json, Router}; +use serde::{Deserialize, Serialize}; +use tracing::error; + +use self::store::Store; +use crate::app::response::ApiError; + +/// in-memory state store including connection tracking +pub(crate) mod state; +use state::State as CallbackState; + +/// store for storing the state +pub(crate) mod store; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub(crate) struct CallbackRequest { + pub(crate) id: String, + pub(crate) vertex: String, + pub(crate) cb_time: u64, + pub(crate) from_vertex: String, + pub(crate) tags: Option>, +} + +#[derive(Clone)] +struct CallbackAppState { + tid_header: String, + callback_state: CallbackState, +} + +pub fn callback_handler( + tid_header: String, + callback_state: CallbackState, +) -> Router { + let app_state = CallbackAppState { + tid_header, + callback_state, + }; + Router::new() + .route("/callback", routing::post(callback)) + .route("/callback_save", routing::post(callback_save)) + .with_state(app_state) +} + +async fn callback_save( + State(app_state): State>, + headers: HeaderMap, + body: Bytes, +) -> Result<(), ApiError> { + let id = headers + .get(&app_state.tid_header) + .map(|id| String::from_utf8_lossy(id.as_bytes()).to_string()) + .ok_or_else(|| ApiError::BadRequest("Missing id header".to_string()))?; + + app_state + .callback_state + .clone() + .save_response(id, body) + .await + .map_err(|e| { + error!(error=?e, "Saving body from callback save request"); + ApiError::InternalServerError( + "Failed to save body from callback save request".to_string(), + ) + })?; + + Ok(()) +} + +async fn callback( + State(app_state): State>, + Json(payload): Json>, +) -> Result<(), ApiError> { + app_state + .callback_state + .clone() + .insert_callback_requests(payload) + .await + .map_err(|e| { + error!(error=?e, "Inserting callback requests"); + ApiError::InternalServerError("Failed to insert callback requests".to_string()) + })?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use axum::body::Body; + use axum::extract::Request; + use axum::http::header::CONTENT_TYPE; + use axum::http::StatusCode; + use tower::ServiceExt; + + use super::*; + use crate::app::callback::state::State as CallbackState; + use crate::app::callback::store::memstore::InMemoryStore; + use crate::app::tracker::MessageGraph; + use crate::pipeline::PipelineDCG; + + const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; + + #[tokio::test] + async fn test_callback_failure() { + let store = InMemoryStore::new(); + let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); + let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); + let state = CallbackState::new(msg_graph, store).await.unwrap(); + let app = callback_handler("ID".to_owned(), state); + + let payload = vec![CallbackRequest { + id: "test_id".to_string(), + vertex: "in".to_string(), + cb_time: 12345, + from_vertex: "in".to_string(), + tags: None, + }]; + + let res = Request::builder() + .method("POST") + .uri("/callback") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_vec(&payload).unwrap())) + .unwrap(); + + let resp = app.oneshot(res).await.unwrap(); + // id is not registered, so it should return 500 + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + #[tokio::test] + async fn test_callback_success() { + let store = InMemoryStore::new(); + let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); + let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); + let mut state = CallbackState::new(msg_graph, store).await.unwrap(); + + let x = state.register("test_id".to_string()); + // spawn a task which will be awaited later + let handle = tokio::spawn(async move { + let _ = x.await.unwrap(); + }); + + let app = callback_handler("ID".to_owned(), state); + + let payload = vec![ + CallbackRequest { + id: "test_id".to_string(), + vertex: "in".to_string(), + cb_time: 12345, + from_vertex: "in".to_string(), + tags: None, + }, + CallbackRequest { + id: "test_id".to_string(), + vertex: "cat".to_string(), + cb_time: 12345, + from_vertex: "in".to_string(), + tags: None, + }, + CallbackRequest { + id: "test_id".to_string(), + vertex: "out".to_string(), + cb_time: 12345, + from_vertex: "cat".to_string(), + tags: None, + }, + ]; + + let res = Request::builder() + .method("POST") + .uri("/callback") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_vec(&payload).unwrap())) + .unwrap(); + + let resp = app.oneshot(res).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + handle.await.unwrap(); + } + + #[tokio::test] + async fn test_callback_save() { + let store = InMemoryStore::new(); + let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); + let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); + let state = CallbackState::new(msg_graph, store).await.unwrap(); + let app = callback_handler("ID".to_owned(), state); + + let res = Request::builder() + .method("POST") + .uri("/callback_save") + .header(CONTENT_TYPE, "application/json") + .header("id", "test_id") + .body(Body::from("test_body")) + .unwrap(); + + let resp = app.oneshot(res).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[tokio::test] + async fn test_without_id_header() { + let store = InMemoryStore::new(); + let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); + let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); + let state = CallbackState::new(msg_graph, store).await.unwrap(); + let app = callback_handler("ID".to_owned(), state); + + let res = Request::builder() + .method("POST") + .uri("/callback_save") + .body(Body::from("test_body")) + .unwrap(); + + let resp = app.oneshot(res).await.unwrap(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } +} diff --git a/rust/extns/numaflow-serving/src/app/callback/state.rs b/rust/extns/numaflow-serving/src/app/callback/state.rs new file mode 100644 index 0000000000..293478ead2 --- /dev/null +++ b/rust/extns/numaflow-serving/src/app/callback/state.rs @@ -0,0 +1,378 @@ +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use tokio::sync::oneshot; + +use super::store::Store; +use crate::app::callback::{store::PayloadToSave, CallbackRequest}; +use crate::app::tracker::MessageGraph; +use crate::Error; + +struct RequestState { + // Channel to notify when all callbacks for a message is received + tx: oneshot::Sender>, + // CallbackRequest is immutable, while vtx_visited can grow. + vtx_visited: Vec>, +} + +#[derive(Clone)] +pub(crate) struct State { + // hashmap of vertex infos keyed by ID + // it also contains tx to trigger to response to the syncHTTP call + callbacks: Arc>>, + // generator to generate subgraph + msg_graph_generator: Arc, + // conn is to be used while reading and writing to redis. + store: T, +} + +impl State +where + T: Store, +{ + /// Create a new State to track connections and callback data + pub(crate) async fn new(msg_graph: MessageGraph, store: T) -> crate::Result { + Ok(Self { + callbacks: Arc::new(Mutex::new(HashMap::new())), + msg_graph_generator: Arc::new(msg_graph), + store, + }) + } + + /// register a new connection + /// The oneshot receiver will be notified when all callbacks for this connection is received from the numaflow pipeline + pub(crate) fn register(&mut self, id: String) -> oneshot::Receiver> { + // TODO: add an entry in Redis to note that the entry has been registered + + let (tx, rx) = oneshot::channel(); + let mut guard = self.callbacks.lock().expect("Getting lock on State"); + guard.insert( + id.clone(), + RequestState { + tx, + vtx_visited: Vec::new(), + }, + ); + rx + } + + /// Retrieves the output of the numaflow pipeline + pub(crate) async fn retrieve_saved(&mut self, id: &str) -> Result>, Error> { + self.store.retrieve_datum(id).await + } + + pub(crate) async fn save_response( + &mut self, + id: String, + body: axum::body::Bytes, + ) -> crate::Result<()> { + // we have to differentiate between the saved responses and the callback requests + // saved responses are stored in "id_SAVED", callback requests are stored in "id" + self.store + .save(vec![PayloadToSave::DatumFromPipeline { + key: id, + value: body, + }]) + .await + } + + /// insert_callback_requests is used to insert the callback requests. + pub(crate) async fn insert_callback_requests( + &mut self, + cb_requests: Vec, + ) -> Result<(), Error> { + /* + TODO: should we consider batching the requests and then processing them? + that way algorithm can be invoked only once for a batch of requests + instead of invoking it for each request. + */ + let cb_requests: Vec> = + cb_requests.into_iter().map(Arc::new).collect(); + let redis_payloads: Vec = cb_requests + .iter() + .cloned() + .map(|cbr| PayloadToSave::Callback { + key: cbr.id.clone(), + value: Arc::clone(&cbr), + }) + .collect(); + + self.store.save(redis_payloads).await?; + + for cbr in cb_requests { + let id = cbr.id.clone(); + { + let mut guard = self.callbacks.lock().expect("Getting lock on State"); + guard + .get_mut(&cbr.id) + .ok_or(Error::IDNotFound( + "Connection for the received callback is not present in the in-memory store", + ))? + .vtx_visited + .push(cbr); + } + + // check if the sub graph can be generated + match self.get_subgraph_from_memory(&id) { + Ok(_) => { + // if the sub graph is generated, then we can send the response + self.deregister(&id).await? + } + Err(e) => { + match e { + Error::SubGraphNotFound(_) => { + // if the sub graph is not generated, then we can continue + continue; + } + _ => { + // if there is an error, deregister with the error + self.deregister(&id).await? + } + } + } + } + } + Ok(()) + } + + /// Get the subgraph for the given ID from in-memory. + fn get_subgraph_from_memory(&self, id: &str) -> Result { + let callbacks = self.get_callbacks_from_memory(id).ok_or(Error::IDNotFound( + "Connection for the received callback is not present in the in-memory store", + ))?; + + self.get_subgraph(id.to_string(), callbacks) + } + + /// Get the subgraph for the given ID from persistent store. This is used querying for the status from the service endpoint even after the + /// request has been completed. + pub(crate) async fn retrieve_subgraph_from_storage( + &mut self, + id: &str, + ) -> Result { + // If the id is not found in the in-memory store, fetch from Redis + let callbacks: Vec> = + match self.retrieve_callbacks_from_storage(id).await { + Ok(callbacks) => callbacks, + Err(e) => { + return Err(e); + } + }; + // check if the sub graph can be generated + self.get_subgraph(id.to_string(), callbacks) + } + + // Generate subgraph from the given callbacks + fn get_subgraph( + &self, + id: String, + callbacks: Vec>, + ) -> Result { + match self + .msg_graph_generator + .generate_subgraph_from_callbacks(id, callbacks) + { + Ok(Some(sub_graph)) => Ok(sub_graph), + Ok(None) => Err(Error::SubGraphNotFound( + "Subgraph could not be generated for the given ID", + )), + Err(e) => Err(e), + } + } + + /// deregister is called to trigger response and delete all the data persisted for that ID + pub(crate) async fn deregister(&mut self, id: &str) -> Result<(), Error> { + let state = { + let mut guard = self.callbacks.lock().expect("Getting lock on State"); + // we do not require the data stored in HashMap anymore + guard.remove(id) + }; + + let Some(state) = state else { + return Err(Error::IDNotFound( + "Connection for the received callback is not present in the in-memory store", + )); + }; + + state + .tx + .send(Ok(id.to_string())) + .map_err(|_| Error::Other("Application bug - Receiver is already dropped".to_string())) + } + + // Get the Callback value for the given ID + // TODO: Generate json serialized data here itself to avoid cloning. + fn get_callbacks_from_memory(&self, id: &str) -> Option>> { + let guard = self.callbacks.lock().expect("Getting lock on State"); + guard.get(id).map(|state| state.vtx_visited.clone()) + } + + // Get the Callback value for the given ID from persistent store + async fn retrieve_callbacks_from_storage( + &mut self, + id: &str, + ) -> Result>, Error> { + // If the id is not found in the in-memory store, fetch from Redis + let callbacks: Vec> = match self.store.retrieve_callbacks(id).await { + Ok(response) => response.into_iter().collect(), + Err(e) => { + return Err(e); + } + }; + Ok(callbacks) + } + + // Check if the store is ready + pub(crate) async fn ready(&mut self) -> bool { + self.store.ready().await + } +} + +#[cfg(test)] +mod tests { + use axum::body::Bytes; + + use super::*; + use crate::app::callback::store::memstore::InMemoryStore; + use crate::pipeline::PipelineDCG; + + const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; + + #[tokio::test] + async fn test_state() { + let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); + let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); + let store = InMemoryStore::new(); + let mut state = State::new(msg_graph, store).await.unwrap(); + + // Test register + let id = "test_id".to_string(); + let rx = state.register(id.clone()); + + let xid = id.clone(); + + // spawn a task to listen on the receiver, once we have received all the callbacks for the message + // we will get a response from the receiver with the message id + let handle = tokio::spawn(async move { + let result = rx.await.unwrap(); + // Tests deregister, and fetching the subgraph from the memory + assert_eq!(result.unwrap(), xid); + }); + + // Test save_response + let body = Bytes::from("Test Message"); + state.save_response(id.clone(), body).await.unwrap(); + + // Test retrieve_saved + let saved = state.retrieve_saved(&id).await.unwrap(); + assert_eq!(saved, vec!["Test Message".as_bytes()]); + + // Test insert_callback_requests + let cbs = vec![ + CallbackRequest { + id: id.clone(), + vertex: "in".to_string(), + cb_time: 12345, + from_vertex: "in".to_string(), + tags: None, + }, + CallbackRequest { + id: id.clone(), + vertex: "planner".to_string(), + cb_time: 12345, + from_vertex: "in".to_string(), + tags: Some(vec!["tiger".to_owned(), "asciiart".to_owned()]), + }, + CallbackRequest { + id: id.clone(), + vertex: "tiger".to_string(), + cb_time: 12345, + from_vertex: "planner".to_string(), + tags: None, + }, + CallbackRequest { + id: id.clone(), + vertex: "asciiart".to_string(), + cb_time: 12345, + from_vertex: "planner".to_string(), + tags: None, + }, + CallbackRequest { + id: id.clone(), + vertex: "serve-sink".to_string(), + cb_time: 12345, + from_vertex: "tiger".to_string(), + tags: None, + }, + CallbackRequest { + id: id.clone(), + vertex: "serve-sink".to_string(), + cb_time: 12345, + from_vertex: "asciiart".to_string(), + tags: None, + }, + ]; + state.insert_callback_requests(cbs).await.unwrap(); + + let sub_graph = state.retrieve_subgraph_from_storage(&id).await; + assert!(sub_graph.is_ok()); + + handle.await.unwrap(); + } + + #[tokio::test] + async fn test_retrieve_saved_no_entry() { + let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); + let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); + let store = InMemoryStore::new(); + let mut state = State::new(msg_graph, store).await.unwrap(); + + let id = "nonexistent_id".to_string(); + + // Try to retrieve saved data for an ID that doesn't exist + let result = state.retrieve_saved(&id).await; + + // Check that an error is returned + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_insert_callback_requests_invalid_id() { + let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); + let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); + let store = InMemoryStore::new(); + let mut state = State::new(msg_graph, store).await.unwrap(); + + let cbs = vec![CallbackRequest { + id: "nonexistent_id".to_string(), + vertex: "in".to_string(), + cb_time: 12345, + from_vertex: "in".to_string(), + tags: None, + }]; + + // Try to insert callback requests for an ID that hasn't been registered + let result = state.insert_callback_requests(cbs).await; + + // Check that an error is returned + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_retrieve_subgraph_from_storage_no_entry() { + let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); + let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); + let store = InMemoryStore::new(); + let mut state = State::new(msg_graph, store).await.unwrap(); + + let id = "nonexistent_id".to_string(); + + // Try to retrieve a subgraph for an ID that doesn't exist + let result = state.retrieve_subgraph_from_storage(&id).await; + + // Check that an error is returned + assert!(result.is_err()); + } +} diff --git a/rust/extns/numaflow-serving/src/app/callback/store.rs b/rust/extns/numaflow-serving/src/app/callback/store.rs new file mode 100644 index 0000000000..af5f3c4368 --- /dev/null +++ b/rust/extns/numaflow-serving/src/app/callback/store.rs @@ -0,0 +1,35 @@ +use std::sync::Arc; + +use crate::app::callback::CallbackRequest; + +// in-memory store +pub(crate) mod memstore; +// redis as the store +pub(crate) mod redisstore; + +pub(crate) enum PayloadToSave { + /// Callback as sent by Numaflow to track the progression + Callback { + key: String, + value: Arc, + }, + /// Data sent by the Numaflow pipeline which is to be delivered as the response + DatumFromPipeline { + key: String, + value: axum::body::Bytes, + }, +} + +/// Store trait to store the callback information. +#[trait_variant::make(Store: Send)] +#[allow(dead_code)] +pub(crate) trait LocalStore { + async fn save(&mut self, messages: Vec) -> crate::Result<()>; + /// retrieve the callback payloads + async fn retrieve_callbacks( + &mut self, + id: &str, + ) -> Result>, crate::Error>; + async fn retrieve_datum(&mut self, id: &str) -> Result>, crate::Error>; + async fn ready(&mut self) -> bool; +} diff --git a/rust/extns/numaflow-serving/src/app/callback/store/memstore.rs b/rust/extns/numaflow-serving/src/app/callback/store/memstore.rs new file mode 100644 index 0000000000..59355c76c6 --- /dev/null +++ b/rust/extns/numaflow-serving/src/app/callback/store/memstore.rs @@ -0,0 +1,217 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use super::PayloadToSave; +use crate::app::callback::CallbackRequest; +use crate::config::SAVED; +use crate::Error; + +/// `InMemoryStore` is an in-memory implementation of the `Store` trait. +/// It uses a `HashMap` to store data in memory. +#[derive(Clone)] +pub(crate) struct InMemoryStore { + /// The data field is a `HashMap` where the key is a `String` and the value is a `Vec>`. + /// It is wrapped in an `Arc>` to allow shared ownership and thread safety. + data: Arc>>>>, +} + +impl InMemoryStore { + /// Creates a new `InMemoryStore` with an empty `HashMap`. + #[allow(dead_code)] + pub(crate) fn new() -> Self { + Self { + data: Arc::new(std::sync::Mutex::new(HashMap::new())), + } + } +} + +impl super::Store for InMemoryStore { + /// Saves a vector of `PayloadToSave` into the `HashMap`. + /// Each `PayloadToSave` is serialized into bytes and stored in the `HashMap` under its key. + async fn save(&mut self, messages: Vec) -> crate::Result<()> { + let mut data = self.data.lock().unwrap(); + for msg in messages { + match msg { + PayloadToSave::Callback { key, value } => { + if key.is_empty() { + return Err(Error::StoreWrite("Key cannot be empty".to_string())); + } + let bytes = serde_json::to_vec(&*value) + .map_err(|e| Error::StoreWrite(format!("Serializing to bytes - {}", e)))?; + data.entry(key).or_default().push(bytes); + } + PayloadToSave::DatumFromPipeline { key, value } => { + if key.is_empty() { + return Err(Error::StoreWrite("Key cannot be empty".to_string())); + } + data.entry(format!("{}_{}", key, SAVED)) + .or_default() + .push(value.into()); + } + } + } + Ok(()) + } + + /// Retrieves callbacks for a given id from the `HashMap`. + /// Each callback is deserialized from bytes into a `CallbackRequest`. + async fn retrieve_callbacks(&mut self, id: &str) -> Result>, Error> { + let data = self.data.lock().unwrap(); + match data.get(id) { + Some(result) => { + let messages: Result, _> = result + .iter() + .map(|msg| { + let cbr: CallbackRequest = serde_json::from_slice(msg).map_err(|_| { + Error::StoreRead( + "Failed to parse CallbackRequest from bytes".to_string(), + ) + })?; + Ok(Arc::new(cbr)) + }) + .collect(); + messages + } + None => Err(Error::StoreRead(format!("No entry found for id: {}", id))), + } + } + + /// Retrieves data for a given id from the `HashMap`. + /// Each piece of data is deserialized from bytes into a `String`. + async fn retrieve_datum(&mut self, id: &str) -> Result>, Error> { + let id = format!("{}_{}", id, SAVED); + let data = self.data.lock().unwrap(); + match data.get(&id) { + Some(result) => Ok(result.to_vec()), + None => Err(Error::StoreRead(format!("No entry found for id: {}", id))), + } + } + + async fn ready(&mut self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use crate::app::callback::store::{PayloadToSave, Store}; + use crate::app::callback::CallbackRequest; + + #[tokio::test] + async fn test_save_and_retrieve_callbacks() { + let mut store = InMemoryStore::new(); + let key = "test_key".to_string(); + let value = Arc::new(CallbackRequest { + id: "test_id".to_string(), + vertex: "in".to_string(), + cb_time: 12345, + from_vertex: "in".to_string(), + tags: None, + }); + + // Save a callback + store + .save(vec![PayloadToSave::Callback { + key: key.clone(), + value: Arc::clone(&value), + }]) + .await + .unwrap(); + + // Retrieve the callback + let retrieved = store.retrieve_callbacks(&key).await.unwrap(); + + // Check that the retrieved callback is the same as the one we saved + assert_eq!(retrieved.len(), 1); + assert_eq!(retrieved[0].id, "test_id".to_string()) + } + + #[tokio::test] + async fn test_save_and_retrieve_datum() { + let mut store = InMemoryStore::new(); + let key = "test_key".to_string(); + let value = "test_value".to_string(); + + // Save a datum + store + .save(vec![PayloadToSave::DatumFromPipeline { + key: key.clone(), + value: value.clone().into(), + }]) + .await + .unwrap(); + + // Retrieve the datum + let retrieved = store.retrieve_datum(&key).await.unwrap(); + + // Check that the retrieved datum is the same as the one we saved + assert_eq!(retrieved.len(), 1); + assert_eq!(retrieved[0], value.as_bytes()); + } + + #[tokio::test] + async fn test_retrieve_callbacks_no_entry() { + let mut store = InMemoryStore::new(); + let key = "nonexistent_key".to_string(); + + // Try to retrieve a callback for a key that doesn't exist + let result = store.retrieve_callbacks(&key).await; + + // Check that an error is returned + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_retrieve_datum_no_entry() { + let mut store = InMemoryStore::new(); + let key = "nonexistent_key".to_string(); + + // Try to retrieve a datum for a key that doesn't exist + let result = store.retrieve_datum(&key).await; + + // Check that an error is returned + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_save_invalid_callback() { + let mut store = InMemoryStore::new(); + let value = Arc::new(CallbackRequest { + id: "test_id".to_string(), + vertex: "in".to_string(), + cb_time: 12345, + from_vertex: "in".to_string(), + tags: None, + }); + + // Try to save a callback with an invalid key + let result = store + .save(vec![PayloadToSave::Callback { + key: "".to_string(), + value: Arc::clone(&value), + }]) + .await; + + // Check that an error is returned + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_save_invalid_datum() { + let mut store = InMemoryStore::new(); + + // Try to save a datum with an invalid key + let result = store + .save(vec![PayloadToSave::DatumFromPipeline { + key: "".to_string(), + value: "test_value".into(), + }]) + .await; + + // Check that an error is returned + assert!(result.is_err()); + } +} diff --git a/rust/extns/numaflow-serving/src/app/callback/store/redisstore.rs b/rust/extns/numaflow-serving/src/app/callback/store/redisstore.rs new file mode 100644 index 0000000000..deae8b42cf --- /dev/null +++ b/rust/extns/numaflow-serving/src/app/callback/store/redisstore.rs @@ -0,0 +1,198 @@ +use std::sync::Arc; + +use backoff::retry::Retry; +use backoff::strategy::fixed; +use redis::aio::ConnectionManager; +use redis::RedisError; +use tokio::sync::Semaphore; + +use super::PayloadToSave; +use crate::app::callback::CallbackRequest; +use crate::config::RedisConfig; +use crate::config::SAVED; +use crate::Error; + +const LPUSH: &str = "LPUSH"; +const LRANGE: &str = "LRANGE"; +const EXPIRE: &str = "EXPIRE"; + +// Handle to the Redis actor. +#[derive(Clone)] +pub(crate) struct RedisConnection { + conn_manager: ConnectionManager, + config: RedisConfig, +} + +impl RedisConnection { + /// Creates a new RedisConnection with concurrent operations on Redis set by max_tasks. + pub(crate) async fn new(config: RedisConfig) -> crate::Result { + let client = redis::Client::open(config.addr.as_str()) + .map_err(|e| Error::Connection(format!("Creating Redis client: {e:?}")))?; + let conn = client + .get_connection_manager() + .await + .map_err(|e| Error::Connection(format!("Connecting to Redis server: {e:?}")))?; + Ok(Self { + conn_manager: conn, + config, + }) + } + + async fn execute_redis_cmd( + conn_manager: &mut ConnectionManager, + ttl_secs: Option, + key: &str, + val: &Vec, + ) -> Result<(), RedisError> { + let mut pipe = redis::pipe(); + pipe.cmd(LPUSH).arg(key).arg(val); + + // if the ttl is configured, add the EXPIRE command to the pipeline + if let Some(ttl) = ttl_secs { + pipe.cmd(EXPIRE).arg(key).arg(ttl); + } + + // Execute the pipeline + pipe.query_async(conn_manager).await.map(|_: ()| ()) + } + + // write to Redis with retries + async fn write_to_redis(&self, key: &str, value: &Vec) -> crate::Result<()> { + let interval = fixed::Interval::from_millis(self.config.retries_duration_millis.into()) + .take(self.config.retries); + + Retry::retry( + interval, + || async { + // https://hackmd.io/@compiler-errors/async-closures + Self::execute_redis_cmd( + &mut self.conn_manager.clone(), + self.config.ttl_secs, + key, + value, + ) + .await + }, + |e: &RedisError| !e.is_unrecoverable_error(), + ) + .await + .map_err(|err| Error::StoreWrite(format!("Saving to redis: {}", err).to_string())) + } +} + +async fn handle_write_requests( + redis_conn: RedisConnection, + msg: PayloadToSave, +) -> crate::Result<()> { + match msg { + PayloadToSave::Callback { key, value } => { + // Convert the CallbackRequest to a byte array + let value = serde_json::to_vec(&*value) + .map_err(|e| Error::StoreWrite(format!("Serializing payload - {}", e)))?; + + redis_conn.write_to_redis(&key, &value).await + } + + // Write the byte array to Redis + PayloadToSave::DatumFromPipeline { key, value } => { + // we have to differentiate between the saved responses and the callback requests + // saved responses are stored in "id_SAVED", callback requests are stored in "id" + let key = format!("{}_{}", key, SAVED); + let value: Vec = value.into(); + + redis_conn.write_to_redis(&key, &value).await + } + } +} + +// It is possible to move the methods defined here to be methods on the Redis actor and communicate through channels. +// With that, all public APIs defined on RedisConnection can be on &self (immutable). +impl super::Store for RedisConnection { + // Attempt to save all payloads. Returns error if we fail to save at least one message. + async fn save(&mut self, messages: Vec) -> crate::Result<()> { + let mut tasks = vec![]; + // This is put in place not to overload Redis and also way some kind of + // flow control. + let sem = Arc::new(Semaphore::new(self.config.max_tasks)); + for msg in messages { + let permit = Arc::clone(&sem).acquire_owned().await; + let redis_conn = self.clone(); + let task = tokio::spawn(async move { + let _permit = permit; + handle_write_requests(redis_conn, msg).await + }); + tasks.push(task); + } + for task in tasks { + task.await.unwrap()?; + } + Ok(()) + } + + async fn retrieve_callbacks(&mut self, id: &str) -> Result>, Error> { + let result: Result>, RedisError> = redis::cmd(LRANGE) + .arg(id) + .arg(0) + .arg(-1) + .query_async(&mut self.conn_manager) + .await; + + match result { + Ok(result) => { + if result.is_empty() { + return Err(Error::StoreRead(format!("No entry found for id: {}", id))); + } + + let messages: Result, _> = result + .into_iter() + .map(|msg| { + let cbr: CallbackRequest = serde_json::from_slice(&msg).map_err(|e| { + Error::StoreRead(format!("Parsing payload from bytes - {}", e)) + })?; + Ok(Arc::new(cbr)) + }) + .collect(); + + messages + } + Err(e) => Err(Error::StoreRead(format!( + "Failed to read from redis: {:?}", + e + ))), + } + } + + async fn retrieve_datum(&mut self, id: &str) -> Result>, Error> { + // saved responses are stored in "id_SAVED" + let key = format!("{}_{}", id, SAVED); + let result: Result>, RedisError> = redis::cmd(LRANGE) + .arg(key) + .arg(0) + .arg(-1) + .query_async(&mut self.conn_manager) + .await; + + match result { + Ok(result) => { + if result.is_empty() { + return Err(Error::StoreRead(format!("No entry found for id: {}", id))); + } + + Ok(result) + } + Err(e) => Err(Error::StoreRead(format!( + "Failed to read from redis: {:?}", + e + ))), + } + } + + // Check if the Redis connection is healthy + async fn ready(&mut self) -> bool { + let mut conn = self.conn_manager.clone(); + match redis::cmd("PING").query_async::(&mut conn).await { + Ok(response) => response == "PONG", + Err(_) => false, + } + } +} diff --git a/rust/extns/numaflow-serving/src/app/response.rs b/rust/extns/numaflow-serving/src/app/response.rs new file mode 100644 index 0000000000..40064a1f78 --- /dev/null +++ b/rust/extns/numaflow-serving/src/app/response.rs @@ -0,0 +1,60 @@ +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use axum::Json; +use chrono::{DateTime, Utc}; +use serde::Serialize; + +// Response sent by the serve handler sync/async to the client(user). +#[derive(Serialize)] +pub(crate) struct ServeResponse { + pub(crate) message: String, + pub(crate) id: String, + pub(crate) code: u16, + pub(crate) timestamp: DateTime, +} + +impl ServeResponse { + pub(crate) fn new(message: String, id: String, status: StatusCode) -> Self { + Self { + code: status.as_u16(), + message, + id, + timestamp: Utc::now(), + } + } +} + +// Error response sent by all the handlers to the client(user). +#[derive(Debug, Serialize)] +pub enum ApiError { + BadRequest(String), + InternalServerError(String), + BadGateway(String), +} + +impl IntoResponse for ApiError { + fn into_response(self) -> Response { + #[derive(Serialize)] + struct ErrorBody { + message: String, + code: u16, + timestamp: DateTime, + } + + let (status, message) = match self { + ApiError::BadRequest(message) => (StatusCode::BAD_REQUEST, message), + ApiError::InternalServerError(message) => (StatusCode::INTERNAL_SERVER_ERROR, message), + ApiError::BadGateway(message) => (StatusCode::BAD_GATEWAY, message), + }; + + ( + status, + Json(ErrorBody { + code: status.as_u16(), + message, + timestamp: Utc::now(), + }), + ) + .into_response() + } +} diff --git a/rust/extns/numaflow-serving/src/app/tracker.rs b/rust/extns/numaflow-serving/src/app/tracker.rs new file mode 100644 index 0000000000..33137f45db --- /dev/null +++ b/rust/extns/numaflow-serving/src/app/tracker.rs @@ -0,0 +1,933 @@ +use std::collections::HashMap; +use std::string::ToString; +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; + +use crate::app::callback::CallbackRequest; +use crate::pipeline::{Edge, OperatorType, PipelineDCG}; +use crate::Error; + +fn compare_slice(operator: &OperatorType, a: &[String], b: &[String]) -> bool { + match operator { + OperatorType::And => a.iter().all(|val| b.contains(val)), + OperatorType::Or => a.iter().any(|val| b.contains(val)), + OperatorType::Not => !a.iter().any(|val| b.contains(val)), + } +} + +type Graph = HashMap>; + +#[derive(Serialize, Deserialize, Debug)] +struct Subgraph { + id: String, + blocks: Vec, +} + +const DROP: &str = "U+005C__DROP__"; + +/// MessageGraph is a struct that generates the graph from the source vertex to the downstream vertices +/// for a message using the given callbacks. +pub(crate) struct MessageGraph { + dag: Graph, +} + +/// Block is a struct that contains the information about the block in the subgraph. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub(crate) struct Block { + from: String, + to: String, + cb_time: u64, +} + +// CallbackRequestWrapper is a struct that contains the information about the callback request and +// whether it has been visited or not. It is used to keep track of the visited callbacks. +#[derive(Debug)] +struct CallbackRequestWrapper { + callback_request: Arc, + visited: bool, +} + +impl MessageGraph { + /// This function generates a sub graph from a list of callbacks. + /// It first creates a HashMap to map each vertex to its corresponding callbacks. + /// Then it finds the source vertex by checking if the vertex and from_vertex fields are the same. + /// Finally, it calls the `generate_subgraph` function to generate the subgraph from the source vertex. + pub(crate) fn generate_subgraph_from_callbacks( + &self, + id: String, + callbacks: Vec>, + ) -> Result, Error> { + // Create a HashMap to map each vertex to its corresponding callbacks + let mut callback_map: HashMap> = HashMap::new(); + let mut source_vertex = None; + + // Find the source vertex, source vertex is the vertex that has the same vertex and from_vertex + for callback in callbacks { + // Check if the vertex is present in the graph + if !self.dag.contains_key(&callback.from_vertex) { + return Err(Error::SubGraphInvalidInput(format!( + "Invalid callback: {}, vertex: {}", + callback.id, callback.from_vertex + ))); + } + + if callback.vertex == callback.from_vertex { + source_vertex = Some(callback.vertex.clone()); + } + callback_map + .entry(callback.vertex.clone()) + .or_default() + .push(CallbackRequestWrapper { + callback_request: Arc::clone(&callback), + visited: false, + }); + } + + // If there is no source vertex, return None + let source_vertex = match source_vertex { + Some(vertex) => vertex, + None => return Ok(None), + }; + + // Create a new subgraph. + let mut subgraph: Subgraph = Subgraph { + id, + blocks: Vec::new(), + }; + // Call the `generate_subgraph` function to generate the subgraph from the source vertex + let result = self.generate_subgraph( + source_vertex.clone(), + source_vertex, + &mut callback_map, + &mut subgraph, + ); + + // If the subgraph is generated successfully, serialize it into a JSON string and return it. + // Otherwise, return None + if result { + match serde_json::to_string(&subgraph) { + Ok(json) => Ok(Some(json)), + Err(e) => Err(Error::SubGraphGenerator(e.to_string())), + } + } else { + Ok(None) + } + } + + // generate_subgraph function is a recursive function that generates the sub graph from the source vertex for + // the given list of callbacks. The function returns true if the subgraph is generated successfully(if we are + // able to find a subgraph for the message using the given callbacks), it + // updates the subgraph with the path from the source vertex to the downstream vertices. + fn generate_subgraph( + &self, + current: String, + from: String, + callback_map: &mut HashMap>, + subgraph: &mut Subgraph, + ) -> bool { + let mut current_callback: Option> = None; + + // we need to borrow the callback_map as mutable to update the visited flag of the callback + // so that next time when we visit the same callback, we can skip it. Because there can be cases + // where there will be multiple callbacks for the same vertex. + // If there are no callbacks for the current vertex, we should not continue + // because it means we have not received all the callbacks for the given message. + let Some(callbacks) = callback_map.get_mut(¤t) else { + return false; + }; + + // iterate over the callbacks for the current vertex and find the one that has not been visited, + // and it is coming from the same vertex as the current vertex + for callback in callbacks { + if callback.callback_request.from_vertex == from && !callback.visited { + callback.visited = true; + current_callback = Some(Arc::clone(&callback.callback_request)); + break; + } + } + + // If there is no callback which is not visited and its parent is the "from" vertex, then we should + // return false because we have not received all the callbacks for the given message. + let Some(current_callback) = current_callback else { + return false; + }; + + // add the current block to the subgraph + subgraph.blocks.push(Block { + from: current_callback.from_vertex.clone(), + to: current_callback.vertex.clone(), + cb_time: current_callback.cb_time, + }); + + // if the current vertex has a DROP tag, then we should not proceed further + // and return true + if current_callback + .tags + .as_ref() + .map_or(false, |tags| tags.contains(&DROP.to_string())) + { + return true; + } + + // recursively invoke the downstream vertices of the current vertex, if any + if let Some(edges) = self.dag.get(¤t) { + for edge in edges { + // check if the edge should proceed based on the conditions + // if there are no conditions, we should proceed with the edge + // if there are conditions, we should check the tags of the current callback + // with the tags of the edge and the operator of the tags to decide if we should + // proceed with the edge + let should_proceed = edge + .conditions + .as_ref() + // If the edge has conditions, get the tags + .and_then(|conditions| conditions.tags.as_ref()) + // If there are no conditions or tags, default to true (i.e., proceed with the edge) + // If there are tags, compare the tags with the current callback's tags and the operator + // to decide if we should proceed with the edge. + .map_or(true, |tags| { + current_callback + .tags + .as_ref() + // If the current callback has no tags we should not proceed with the edge for "and" and "or" operators + // because we expect the current callback to have tags specified in the edge. + // If the current callback has no tags we should proceed with the edge for "not" operator. + // because we don't expect the current callback to have tags specified in the edge. + .map_or( + tags.operator.as_ref() == Some(&OperatorType::Not), + |callback_tags| { + tags.operator.as_ref().map_or(false, |operator| { + // If there is no operator, default to false (i.e., do not proceed with the edge) + // If there is an operator, compare the current callback's tags with the edge's tags + compare_slice(operator, callback_tags, &tags.values) + }) + }, + ) + }); + + // if the conditions are not met, then proceed to the next edge + if !should_proceed { + continue; + } + + // proceed to the downstream vertex + // if any of the downstream vertex returns false, then we should return false. + if !self.generate_subgraph(edge.to.clone(), current.clone(), callback_map, subgraph) + { + return false; + } + } + } + // if there are no downstream vertices, or all the downstream vertices returned true, + // we can return true + true + } + + // from_env reads the pipeline stored in the environment variable and creates a MessageGraph from it. + pub(crate) fn from_pipeline(pipeline_spec: &PipelineDCG) -> Result { + let mut dag = Graph::with_capacity(pipeline_spec.edges.len()); + for edge in &pipeline_spec.edges { + dag.entry(edge.from.clone()).or_default().push(edge.clone()); + } + + Ok(MessageGraph { dag }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::pipeline::{Conditions, Tag, Vertex}; + + #[test] + fn test_no_subgraph() { + let mut dag: Graph = HashMap::new(); + dag.insert( + "a".to_string(), + vec![ + Edge { + from: "a".to_string(), + to: "b".to_string(), + conditions: None, + }, + Edge { + from: "a".to_string(), + to: "c".to_string(), + conditions: None, + }, + ], + ); + let message_graph = MessageGraph { dag }; + + let mut callback_map: HashMap> = HashMap::new(); + callback_map.insert( + "a".to_string(), + vec![CallbackRequestWrapper { + callback_request: Arc::new(CallbackRequest { + id: "uuid1".to_string(), + vertex: "a".to_string(), + cb_time: 1, + from_vertex: "a".to_string(), + tags: None, + }), + visited: false, + }], + ); + + let mut subgraph: Subgraph = Subgraph { + id: "uuid1".to_string(), + blocks: Vec::new(), + }; + let result = message_graph.generate_subgraph( + "a".to_string(), + "a".to_string(), + &mut callback_map, + &mut subgraph, + ); + + assert!(!result); + } + + #[test] + fn test_generate_subgraph() { + let mut dag: Graph = HashMap::new(); + dag.insert( + "a".to_string(), + vec![ + Edge { + from: "a".to_string(), + to: "b".to_string(), + conditions: None, + }, + Edge { + from: "a".to_string(), + to: "c".to_string(), + conditions: None, + }, + ], + ); + let message_graph = MessageGraph { dag }; + + let mut callback_map: HashMap> = HashMap::new(); + callback_map.insert( + "a".to_string(), + vec![CallbackRequestWrapper { + callback_request: Arc::new(CallbackRequest { + id: "uuid1".to_string(), + vertex: "a".to_string(), + cb_time: 1, + from_vertex: "a".to_string(), + tags: None, + }), + visited: false, + }], + ); + + callback_map.insert( + "b".to_string(), + vec![CallbackRequestWrapper { + callback_request: Arc::new(CallbackRequest { + id: "uuid1".to_string(), + vertex: "b".to_string(), + cb_time: 1, + from_vertex: "a".to_string(), + tags: None, + }), + visited: false, + }], + ); + + callback_map.insert( + "c".to_string(), + vec![CallbackRequestWrapper { + callback_request: Arc::new(CallbackRequest { + id: "uuid1".to_string(), + vertex: "c".to_string(), + cb_time: 1, + from_vertex: "a".to_string(), + tags: None, + }), + visited: false, + }], + ); + + let mut subgraph: Subgraph = Subgraph { + id: "uuid1".to_string(), + blocks: Vec::new(), + }; + let result = message_graph.generate_subgraph( + "a".to_string(), + "a".to_string(), + &mut callback_map, + &mut subgraph, + ); + + assert!(result); + } + + #[test] + fn test_generate_subgraph_complex() { + let pipeline = PipelineDCG { + vertices: vec![ + Vertex { + name: "a".to_string(), + }, + Vertex { + name: "b".to_string(), + }, + Vertex { + name: "c".to_string(), + }, + Vertex { + name: "d".to_string(), + }, + Vertex { + name: "e".to_string(), + }, + Vertex { + name: "f".to_string(), + }, + Vertex { + name: "g".to_string(), + }, + Vertex { + name: "h".to_string(), + }, + Vertex { + name: "i".to_string(), + }, + ], + edges: vec![ + Edge { + from: "a".to_string(), + to: "b".to_string(), + conditions: None, + }, + Edge { + from: "a".to_string(), + to: "c".to_string(), + conditions: None, + }, + Edge { + from: "b".to_string(), + to: "d".to_string(), + conditions: None, + }, + Edge { + from: "c".to_string(), + to: "e".to_string(), + conditions: None, + }, + Edge { + from: "d".to_string(), + to: "f".to_string(), + conditions: None, + }, + Edge { + from: "e".to_string(), + to: "f".to_string(), + conditions: None, + }, + Edge { + from: "f".to_string(), + to: "g".to_string(), + conditions: None, + }, + Edge { + from: "g".to_string(), + to: "h".to_string(), + conditions: Some(Conditions { + tags: Some(Tag { + operator: Some(OperatorType::And), + values: vec!["even".to_string()], + }), + }), + }, + Edge { + from: "g".to_string(), + to: "i".to_string(), + conditions: Some(Conditions { + tags: Some(Tag { + operator: Some(OperatorType::Or), + values: vec!["odd".to_string()], + }), + }), + }, + ], + }; + + let message_graph = MessageGraph::from_pipeline(&pipeline).unwrap(); + let source_vertex = "a".to_string(); + + let raw_callback = r#"[ + { + "id": "xxxx", + "vertex": "a", + "from_vertex": "a", + "cb_time": 123456789 + }, + { + "id": "xxxx", + "vertex": "b", + "from_vertex": "a", + "cb_time": 123456867 + }, + { + "id": "xxxx", + "vertex": "c", + "from_vertex": "a", + "cb_time": 123456819 + }, + { + "id": "xxxx", + "vertex": "d", + "from_vertex": "b", + "cb_time": 123456840 + }, + { + "id": "xxxx", + "vertex": "e", + "from_vertex": "c", + "cb_time": 123456843 + }, + { + "id": "xxxx", + "vertex": "f", + "from_vertex": "d", + "cb_time": 123456854 + }, + { + "id": "xxxx", + "vertex": "f", + "from_vertex": "e", + "cb_time": 123456886 + }, + { + "id": "xxxx", + "vertex": "g", + "from_vertex": "f", + "tags": ["even"], + "cb_time": 123456885 + }, + { + "id": "xxxx", + "vertex": "g", + "from_vertex": "f", + "tags": ["even"], + "cb_time": 123456888 + }, + { + "id": "xxxx", + "vertex": "h", + "from_vertex": "g", + "cb_time": 123456889 + }, + { + "id": "xxxx", + "vertex": "h", + "from_vertex": "g", + "cb_time": 123456890 + } + ]"#; + + let callbacks: Vec = serde_json::from_str(raw_callback).unwrap(); + let mut callback_map: HashMap> = HashMap::new(); + + for callback in callbacks { + callback_map + .entry(callback.vertex.clone()) + .or_default() + .push(CallbackRequestWrapper { + callback_request: Arc::new(callback), + visited: false, + }); + } + + let mut subgraph: Subgraph = Subgraph { + id: "xxxx".to_string(), + blocks: Vec::new(), + }; + let result = message_graph.generate_subgraph( + source_vertex.clone(), + source_vertex, + &mut callback_map, + &mut subgraph, + ); + + assert!(result); + } + + #[test] + fn test_simple_dropped_message() { + let pipeline = PipelineDCG { + vertices: vec![ + Vertex { + name: "a".to_string(), + }, + Vertex { + name: "b".to_string(), + }, + Vertex { + name: "c".to_string(), + }, + ], + edges: vec![ + Edge { + from: "a".to_string(), + to: "b".to_string(), + conditions: None, + }, + Edge { + from: "b".to_string(), + to: "c".to_string(), + conditions: None, + }, + ], + }; + + let message_graph = MessageGraph::from_pipeline(&pipeline).unwrap(); + let source_vertex = "a".to_string(); + + let raw_callback = r#" + [ + { + "id": "xxxx", + "vertex": "a", + "from_vertex": "a", + "cb_time": 123456789 + }, + { + "id": "xxxx", + "vertex": "b", + "from_vertex": "a", + "cb_time": 123456867, + "tags": ["U+005C__DROP__"] + } + ]"#; + + let callbacks: Vec = serde_json::from_str(raw_callback).unwrap(); + let mut callback_map: HashMap> = HashMap::new(); + + for callback in callbacks { + callback_map + .entry(callback.vertex.clone()) + .or_default() + .push(CallbackRequestWrapper { + callback_request: Arc::new(callback), + visited: false, + }); + } + + let mut subgraph: Subgraph = Subgraph { + id: "xxxx".to_string(), + blocks: Vec::new(), + }; + let result = message_graph.generate_subgraph( + source_vertex.clone(), + source_vertex, + &mut callback_map, + &mut subgraph, + ); + + assert!(result); + } + + #[test] + fn test_complex_dropped_message() { + let pipeline = PipelineDCG { + vertices: vec![ + Vertex { + name: "a".to_string(), + }, + Vertex { + name: "b".to_string(), + }, + Vertex { + name: "c".to_string(), + }, + Vertex { + name: "d".to_string(), + }, + Vertex { + name: "e".to_string(), + }, + Vertex { + name: "f".to_string(), + }, + Vertex { + name: "g".to_string(), + }, + Vertex { + name: "h".to_string(), + }, + Vertex { + name: "i".to_string(), + }, + ], + edges: vec![ + Edge { + from: "a".to_string(), + to: "b".to_string(), + conditions: None, + }, + Edge { + from: "a".to_string(), + to: "c".to_string(), + conditions: None, + }, + Edge { + from: "b".to_string(), + to: "d".to_string(), + conditions: None, + }, + Edge { + from: "c".to_string(), + to: "e".to_string(), + conditions: None, + }, + Edge { + from: "d".to_string(), + to: "f".to_string(), + conditions: None, + }, + Edge { + from: "e".to_string(), + to: "f".to_string(), + conditions: None, + }, + Edge { + from: "f".to_string(), + to: "g".to_string(), + conditions: None, + }, + Edge { + from: "g".to_string(), + to: "h".to_string(), + conditions: Some(Conditions { + tags: Some(Tag { + operator: Some(OperatorType::And), + values: vec!["even".to_string()], + }), + }), + }, + Edge { + from: "g".to_string(), + to: "i".to_string(), + conditions: Some(Conditions { + tags: Some(Tag { + operator: Some(OperatorType::Or), + values: vec!["odd".to_string()], + }), + }), + }, + ], + }; + + let message_graph = MessageGraph::from_pipeline(&pipeline).unwrap(); + let source_vertex = "a".to_string(); + + let raw_callback = r#" + [ + { + "id": "xxxx", + "vertex": "a", + "from_vertex": "a", + "cb_time": 123456789 + }, + { + "id": "xxxx", + "vertex": "b", + "from_vertex": "a", + "cb_time": 123456867 + }, + { + "id": "xxxx", + "vertex": "c", + "from_vertex": "a", + "cb_time": 123456819, + "tags": ["U+005C__DROP__"] + }, + { + "id": "xxxx", + "vertex": "d", + "from_vertex": "b", + "cb_time": 123456840 + }, + { + "id": "xxxx", + "vertex": "f", + "from_vertex": "d", + "cb_time": 123456854 + }, + { + "id": "xxxx", + "vertex": "g", + "from_vertex": "f", + "tags": ["even"], + "cb_time": 123456885 + }, + { + "id": "xxxx", + "vertex": "h", + "from_vertex": "g", + "cb_time": 123456889 + } + ]"#; + + let callbacks: Vec = serde_json::from_str(raw_callback).unwrap(); + let mut callback_map: HashMap> = HashMap::new(); + + for callback in callbacks { + callback_map + .entry(callback.vertex.clone()) + .or_default() + .push(CallbackRequestWrapper { + callback_request: Arc::new(callback), + visited: false, + }); + } + + let mut subgraph: Subgraph = Subgraph { + id: "xxxx".to_string(), + blocks: Vec::new(), + }; + let result = message_graph.generate_subgraph( + source_vertex.clone(), + source_vertex, + &mut callback_map, + &mut subgraph, + ); + + assert!(result); + } + + #[test] + fn test_simple_cycle_pipeline() { + let pipeline = PipelineDCG { + vertices: vec![ + Vertex { + name: "a".to_string(), + }, + Vertex { + name: "b".to_string(), + }, + Vertex { + name: "c".to_string(), + }, + ], + edges: vec![ + Edge { + from: "a".to_string(), + to: "b".to_string(), + conditions: None, + }, + Edge { + from: "b".to_string(), + to: "a".to_string(), + conditions: Some(Conditions { + tags: Some(Tag { + operator: Some(OperatorType::Not), + values: vec!["failed".to_string()], + }), + }), + }, + ], + }; + + let message_graph = MessageGraph::from_pipeline(&pipeline).unwrap(); + let source_vertex = "a".to_string(); + + let raw_callback = r#" + [ + { + "id": "xxxx", + "vertex": "a", + "from_vertex": "a", + "cb_time": 123456789 + }, + { + "id": "xxxx", + "vertex": "b", + "from_vertex": "a", + "cb_time": 123456867, + "tags": ["failed"] + }, + { + "id": "xxxx", + "vertex": "a", + "from_vertex": "b", + "cb_time": 123456819 + }, + { + "id": "xxxx", + "vertex": "b", + "from_vertex": "a", + "cb_time": 123456819 + }, + { + "id": "xxxx", + "vertex": "c", + "from_vertex": "b", + "cb_time": 123456819 + } + ]"#; + + let callbacks: Vec = serde_json::from_str(raw_callback).unwrap(); + let mut callback_map: HashMap> = HashMap::new(); + + for callback in callbacks { + callback_map + .entry(callback.vertex.clone()) + .or_default() + .push(CallbackRequestWrapper { + callback_request: Arc::new(callback), + visited: false, + }); + } + + let mut subgraph: Subgraph = Subgraph { + id: "xxxx".to_string(), + blocks: Vec::new(), + }; + let result = message_graph.generate_subgraph( + source_vertex.clone(), + source_vertex, + &mut callback_map, + &mut subgraph, + ); + + assert!(result); + } + + #[test] + fn test_generate_subgraph_from_callbacks_with_invalid_vertex() { + // Create a simple graph + let mut dag: Graph = HashMap::new(); + dag.insert( + "a".to_string(), + vec![Edge { + from: "a".to_string(), + to: "b".to_string(), + conditions: None, + }], + ); + let message_graph = MessageGraph { dag }; + + // Create a callback with an invalid vertex + let callbacks = vec![Arc::new(CallbackRequest { + id: "test".to_string(), + vertex: "invalid_vertex".to_string(), + from_vertex: "invalid_vertex".to_string(), + cb_time: 1, + tags: None, + })]; + + // Call the function with the invalid callback + let result = message_graph.generate_subgraph_from_callbacks("test".to_string(), callbacks); + + // Check that the function returned an error + assert!(result.is_err()); + assert!(matches!(result, Err(Error::SubGraphInvalidInput(_)))); + } +} diff --git a/rust/extns/numaflow-serving/src/config.rs b/rust/extns/numaflow-serving/src/config.rs new file mode 100644 index 0000000000..d8ace5716a --- /dev/null +++ b/rust/extns/numaflow-serving/src/config.rs @@ -0,0 +1,224 @@ +use std::collections::HashMap; + +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use serde::{Deserialize, Serialize}; + +use crate::pipeline::PipelineDCG; +use crate::Error; + +const ENV_NUMAFLOW_SERVING_SOURCE_OBJECT: &str = "NUMAFLOW_SERVING_SOURCE_OBJECT"; +const ENV_NUMAFLOW_SERVING_JETSTREAM_URL: &str = "NUMAFLOW_ISBSVC_JETSTREAM_URL"; +const ENV_NUMAFLOW_SERVING_JETSTREAM_STREAM: &str = "NUMAFLOW_SERVING_JETSTREAM_STREAM"; +const ENV_NUMAFLOW_SERVING_STORE_TTL: &str = "NUMAFLOW_SERVING_STORE_TTL"; +const ENV_NUMAFLOW_SERVING_HOST_IP: &str = "NUMAFLOW_SERVING_HOST_IP"; +const ENV_NUMAFLOW_SERVING_APP_PORT: &str = "NUMAFLOW_SERVING_APP_LISTEN_PORT"; +const ENV_NUMAFLOW_SERVING_JETSTREAM_USER: &str = "NUMAFLOW_ISBSVC_JETSTREAM_USER"; +const ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD: &str = "NUMAFLOW_ISBSVC_JETSTREAM_PASSWORD"; +const ENV_NUMAFLOW_SERVING_AUTH_TOKEN: &str = "NUMAFLOW_SERVING_AUTH_TOKEN"; +const ENV_MIN_PIPELINE_SPEC: &str = "NUMAFLOW_SERVING_MIN_PIPELINE_SPEC"; + +pub(crate) const SAVED: &str = "SAVED"; + +#[derive(Deserialize, Clone, PartialEq)] +pub struct BasicAuth { + pub username: String, + pub password: String, +} + +impl std::fmt::Debug for BasicAuth { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let passwd_printable = if self.password.len() > 4 { + let passwd: String = self + .password + .chars() + .skip(self.password.len() - 2) + .take(2) + .collect(); + format!("***{}", passwd) + } else { + "*****".to_owned() + }; + write!(f, "{}:{}", self.username, passwd_printable) + } +} + +#[derive(Debug, Deserialize, Clone, PartialEq)] +pub struct JetStreamConfig { + pub stream: String, + pub url: String, + pub auth: Option, +} + +impl Default for JetStreamConfig { + fn default() -> Self { + Self { + stream: "default".to_owned(), + url: "localhost:4222".to_owned(), + auth: None, + } + } +} + +#[derive(Debug, Deserialize, Clone, PartialEq)] +pub struct RedisConfig { + pub addr: String, + pub max_tasks: usize, + pub retries: usize, + pub retries_duration_millis: u16, + pub ttl_secs: Option, +} + +impl Default for RedisConfig { + fn default() -> Self { + Self { + addr: "redis://127.0.0.1:6379".to_owned(), + max_tasks: 50, + retries: 5, + retries_duration_millis: 100, + // TODO: we might need an option type here. Zero value of u32 can be used instead of None + ttl_secs: Some(1), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Settings { + pub tid_header: String, + pub app_listen_port: u16, + pub metrics_server_listen_port: u16, + pub upstream_addr: String, + pub drain_timeout_secs: u64, + /// The IP address of the numaserve pod. This will be used to construct the value for X-Numaflow-Callback-Url header + pub host_ip: String, + pub api_auth_token: Option, + pub redis: RedisConfig, + pub jetstream: JetStreamConfig, + pub pipeline_spec: PipelineDCG, +} + +impl Default for Settings { + fn default() -> Self { + Self { + tid_header: "ID".to_owned(), + app_listen_port: 3000, + metrics_server_listen_port: 3001, + upstream_addr: "localhost:8888".to_owned(), + drain_timeout_secs: 10, + host_ip: "127.0.0.1".to_owned(), + api_auth_token: None, + redis: RedisConfig::default(), + jetstream: JetStreamConfig::default(), + pipeline_spec: Default::default(), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Serving { + #[serde(rename = "msgIDHeaderKey")] + pub msg_id_header_key: Option, + #[serde(rename = "store")] + pub callback_storage: CallbackStorageConfig, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct CallbackStorageConfig { + pub url: String, +} + +/// This implementation is to load settings from env variables +impl TryFrom> for Settings { + type Error = Error; + fn try_from(env_vars: HashMap) -> std::result::Result { + let host_ip = env_vars + .get(ENV_NUMAFLOW_SERVING_HOST_IP) + .ok_or_else(|| { + Error::ParseConfig(format!( + "Environment variable {ENV_NUMAFLOW_SERVING_HOST_IP} is not set" + )) + })? + .to_owned(); + + let pipeline_spec: PipelineDCG = env_vars + .get(ENV_MIN_PIPELINE_SPEC) + .ok_or_else(|| { + Error::ParseConfig(format!( + "Pipeline spec is not set using environment variable {ENV_MIN_PIPELINE_SPEC}" + )) + })? + .parse() + .map_err(|e| { + Error::ParseConfig(format!( + "Parsing pipeline spec: {}: error={e:?}", + env_vars.get(ENV_MIN_PIPELINE_SPEC).unwrap() + )) + })?; + + let mut settings = Settings { + host_ip, + pipeline_spec, + ..Default::default() + }; + + if let Some(jetstream_url) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_URL) { + settings.jetstream.url = jetstream_url.to_owned(); + } + + if let Some(jetstream_stream) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_STREAM) { + settings.jetstream.stream = jetstream_stream.to_owned(); + } + + if let Some(api_auth_token) = env_vars.get(ENV_NUMAFLOW_SERVING_AUTH_TOKEN) { + settings.api_auth_token = Some(api_auth_token.to_owned()); + } + + if let Some(app_port) = env_vars.get(ENV_NUMAFLOW_SERVING_APP_PORT) { + settings.app_listen_port = app_port.parse().map_err(|e| { + Error::ParseConfig(format!( + "Parsing {ENV_NUMAFLOW_SERVING_APP_PORT}(set to '{app_port}'): {e:?}" + )) + })?; + } + + // If username is set, the password also must be set + if let Some(username) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_USER) { + let Some(password) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD) else { + return Err(Error::ParseConfig(format!("Env variable {ENV_NUMAFLOW_SERVING_JETSTREAM_USER} is set, but {ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD} is not set"))); + }; + settings.jetstream.auth = Some(BasicAuth { + username: username.to_owned(), + password: password.to_owned(), + }); + } + + // Update redis.ttl_secs from environment variable + if let Some(ttl_secs) = env_vars.get(ENV_NUMAFLOW_SERVING_STORE_TTL) { + let ttl_secs: u32 = ttl_secs.parse().map_err(|e| { + Error::ParseConfig(format!("parsing {ENV_NUMAFLOW_SERVING_STORE_TTL}: {e:?}")) + })?; + settings.redis.ttl_secs = Some(ttl_secs); + } + + let Some(source_spec_encoded) = env_vars.get(ENV_NUMAFLOW_SERVING_SOURCE_OBJECT) else { + return Ok(settings); + }; + + let source_spec_decoded = BASE64_STANDARD + .decode(source_spec_encoded.as_bytes()) + .map_err(|e| Error::ParseConfig(format!("decoding NUMAFLOW_SERVING_SOURCE: {e:?}")))?; + + let source_spec = serde_json::from_slice::(&source_spec_decoded) + .map_err(|e| Error::ParseConfig(format!("parsing NUMAFLOW_SERVING_SOURCE: {e:?}")))?; + + // Update tid_header from source_spec + if let Some(msg_id_header_key) = source_spec.msg_id_header_key { + settings.tid_header = msg_id_header_key; + } + + // Update redis.addr from source_spec, currently we only support redis as callback storage + settings.redis.addr = source_spec.callback_storage.url; + + Ok(settings) + } +} diff --git a/rust/extns/numaflow-serving/src/lib.rs b/rust/extns/numaflow-serving/src/lib.rs new file mode 100644 index 0000000000..1da6d0c6c3 --- /dev/null +++ b/rust/extns/numaflow-serving/src/lib.rs @@ -0,0 +1,248 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use bytes::Bytes; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::time::Instant; + +use crate::app::callback::state::State as CallbackState; +use crate::app::callback::store::redisstore::RedisConnection; +use crate::app::tracker::MessageGraph; + +mod app; +pub mod config; + +pub(crate) mod pipeline; + +struct MessageWrapper { + pub confirm_save: oneshot::Sender<()>, + message: Message, +} + +pub struct Message { + pub value: Bytes, + pub id: String, + pub headers: HashMap, +} + +enum ActorMessage { + Read { + batch_size: usize, + timeout_at: Instant, + reply_to: oneshot::Sender>, + }, + Ack { + offsets: Vec, + reply_to: oneshot::Sender<()>, + }, +} + +struct ServingSourceActor { + messages: mpsc::Receiver, + handler_rx: mpsc::Receiver, + tracker: HashMap>, +} + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Initialization error - {0}")] + InitError(String), + + #[error("Failed to parse configuration - {0}")] + ParseConfig(String), + // + // callback errors + // TODO: store the ID too? + #[error("IDNotFound Error - {0}")] + IDNotFound(&'static str), + + #[error("SubGraphGenerator Error - {0}")] + // subgraph generator errors + SubGraphGenerator(String), + + #[error("StoreWrite Error - {0}")] + // Store write errors + StoreWrite(String), + + #[error("SubGraphNotFound Error - {0}")] + // Sub Graph Not Found Error + SubGraphNotFound(&'static str), + + #[error("SubGraphInvalidInput Error - {0}")] + // Sub Graph Invalid Input Error + SubGraphInvalidInput(String), + + #[error("StoreRead Error - {0}")] + // Store read errors + StoreRead(String), + + #[error("Metrics Error - {0}")] + // Metrics errors + MetricsServer(String), + + #[error("Connection Error - {0}")] + Connection(String), + + #[error("Failed to receive message from channel. Actor task is terminated: {0:?}")] + ActorTaskTerminated(oneshot::error::RecvError), + + #[error("{0}")] + Other(String), +} + +type Result = std::result::Result; +pub type Settings = config::Settings; + +impl ServingSourceActor { + async fn start( + settings: Arc, + handler_rx: mpsc::Receiver, + ) -> Result<()> { + let (messages_tx, messages_rx) = mpsc::channel(10000); + // Create a redis store to store the callbacks and the custom responses + let redis_store = RedisConnection::new(settings.redis.clone()).await?; + // Create the message graph from the pipeline spec and the redis store + let msg_graph = MessageGraph::from_pipeline(&settings.pipeline_spec).map_err(|e| { + Error::InitError(format!( + "Creating message graph from pipeline spec: {:?}", + e + )) + })?; + let callback_state = CallbackState::new(msg_graph, redis_store).await?; + + tokio::spawn(async move { + let mut serving_actor = ServingSourceActor { + messages: messages_rx, + handler_rx, + tracker: HashMap::new(), + }; + serving_actor.run().await; + }); + let app = app::AppState { + message: messages_tx, + settings, + callback_state, + }; + app::serve(app).await.unwrap(); + Ok(()) + } + + async fn run(&mut self) { + while let Some(msg) = self.handler_rx.recv().await { + self.handle_message(msg).await; + } + } + + async fn handle_message(&mut self, actor_msg: ActorMessage) { + match actor_msg { + ActorMessage::Read { + batch_size, + timeout_at, + reply_to, + } => { + let messages = self.read(batch_size, timeout_at).await; + let _ = reply_to.send(messages); + } + ActorMessage::Ack { offsets, reply_to } => { + self.ack(offsets).await; + let _ = reply_to.send(()); + } + } + } + + async fn read(&mut self, count: usize, timeout_at: Instant) -> Vec { + let mut messages = vec![]; + loop { + if messages.len() >= count || Instant::now() >= timeout_at { + break; + } + let message = match self.messages.try_recv() { + Ok(msg) => msg, + Err(mpsc::error::TryRecvError::Empty) => break, + Err(e) => { + tracing::error!(?e, "Receiving messages from the serving channel"); // FIXME: + return messages; + } + }; + let MessageWrapper { + confirm_save, + message, + } = message; + + self.tracker.insert(message.id.clone(), confirm_save); + messages.push(message); + } + messages + } + + async fn ack(&mut self, offsets: Vec) { + for offset in offsets { + let offset = offset + .strip_suffix("-0") + .expect("offset does not end with '-0'"); // FIXME: we hardcode 0 as the partition index when constructing offset + let confirm_save_tx = self + .tracker + .remove(offset) + .expect("offset was not found in the tracker"); + confirm_save_tx + .send(()) + .expect("Sending on confirm_save channel"); + } + } +} + +#[derive(Clone)] +pub struct ServingSource { + batch_size: usize, + // timeout for each batch read request + timeout: Duration, + actor_tx: mpsc::Sender, +} + +impl ServingSource { + pub async fn new( + settings: Arc, + batch_size: usize, + timeout: Duration, + ) -> Result { + let (actor_tx, actor_rx) = mpsc::channel(10); + ServingSourceActor::start(settings, actor_rx).await?; + Ok(Self { + batch_size, + timeout, + actor_tx, + }) + } + + pub async fn read_messages(&self) -> Result> { + let start = Instant::now(); + let (tx, rx) = oneshot::channel(); + let actor_msg = ActorMessage::Read { + reply_to: tx, + batch_size: self.batch_size, + timeout_at: Instant::now() + self.timeout, + }; + let _ = self.actor_tx.send(actor_msg).await; + let messages = rx.await.map_err(Error::ActorTaskTerminated)?; + tracing::debug!( + count = messages.len(), + requested_count = self.batch_size, + time_taken_ms = start.elapsed().as_millis(), + "Got messages from Serving source" + ); + Ok(messages) + } + + pub async fn ack_messages(&self, offsets: Vec) -> Result<()> { + let (tx, rx) = oneshot::channel(); + let actor_msg = ActorMessage::Ack { + offsets, + reply_to: tx, + }; + let _ = self.actor_tx.send(actor_msg).await; + rx.await.map_err(Error::ActorTaskTerminated)?; + Ok(()) + } +} diff --git a/rust/extns/numaflow-serving/src/pipeline.rs b/rust/extns/numaflow-serving/src/pipeline.rs new file mode 100644 index 0000000000..05768568ab --- /dev/null +++ b/rust/extns/numaflow-serving/src/pipeline.rs @@ -0,0 +1,164 @@ +use std::str::FromStr; + +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use numaflow_models::models::PipelineSpec; +use serde::{Deserialize, Serialize}; + +use crate::Error::ParseConfig; + +// OperatorType is an enum that contains the types of operators +// that can be used in the conditions for the edge. +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] +pub enum OperatorType { + #[serde(rename = "and")] + And, + #[serde(rename = "or")] + Or, + #[serde(rename = "not")] + Not, +} + +#[allow(dead_code)] +impl OperatorType { + fn as_str(&self) -> &'static str { + match self { + OperatorType::And => "and", + OperatorType::Or => "or", + OperatorType::Not => "not", + } + } +} + +impl From for OperatorType { + fn from(s: String) -> Self { + match s.as_str() { + "and" => OperatorType::And, + "or" => OperatorType::Or, + "not" => OperatorType::Not, + _ => panic!("Invalid operator type: {}", s), + } + } +} + +// Tag is a struct that contains the information about the tags for the edge +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Tag { + pub operator: Option, + pub values: Vec, +} + +// Conditions is a struct that contains the information about the conditions for the edge +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Conditions { + pub tags: Option, +} + +// Edge is a struct that contains the information about the edge in the pipeline. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Edge { + pub from: String, + pub to: String, + pub conditions: Option, +} + +/// DCG (directed compute graph) of the pipeline with minimal information build using vertices and edges +/// from the pipeline spec +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[serde()] +pub struct PipelineDCG { + pub vertices: Vec, + pub edges: Vec, +} + +impl Default for PipelineDCG { + fn default() -> Self { + Self { + vertices: vec![], + edges: vec![], + } + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Vertex { + pub name: String, +} + +impl FromStr for PipelineDCG { + type Err = crate::Error; + + fn from_str(pipeline_spec_encoded: &str) -> Result { + let full_pipeline_spec_decoded = BASE64_STANDARD + .decode(pipeline_spec_encoded) + .map_err(|e| ParseConfig(format!("Decoding pipeline from env: {e:?}")))?; + + let full_pipeline_spec = + serde_json::from_slice::(&full_pipeline_spec_decoded) + .map_err(|e| ParseConfig(format!("parsing pipeline from env: {e:?}")))?; + + let vertices: Vec = full_pipeline_spec + .vertices + .ok_or(ParseConfig("missing vertices in pipeline spec".to_string()))? + .iter() + .map(|v| Vertex { + name: v.name.clone(), + }) + .collect(); + + let edges: Vec = full_pipeline_spec + .edges + .ok_or(ParseConfig("missing edges in pipeline spec".to_string()))? + .iter() + .map(|e| { + let conditions = e.conditions.clone().map(|c| Conditions { + tags: Some(Tag { + operator: c.tags.operator.map(|o| o.into()), + values: c.tags.values.clone(), + }), + }); + + Edge { + from: e.from.clone(), + to: e.to.clone(), + conditions, + } + }) + .collect(); + + Ok(PipelineDCG { vertices, edges }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pipeline_load() { + let pipeline: PipelineDCG = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0=".parse().unwrap(); + + assert_eq!(pipeline.vertices.len(), 8); + assert_eq!(pipeline.edges.len(), 10); + assert_eq!(pipeline.vertices[0].name, "in"); + assert_eq!(pipeline.edges[0].from, "in"); + assert_eq!(pipeline.edges[0].to, "planner"); + assert!(pipeline.edges[0].conditions.is_none()); + + assert_eq!(pipeline.vertices[1].name, "planner"); + assert_eq!(pipeline.edges[1].from, "planner"); + assert_eq!(pipeline.edges[1].to, "asciiart"); + assert_eq!( + pipeline.edges[1].conditions, + Some(Conditions { + tags: Some(Tag { + operator: Some(OperatorType::Or), + values: vec!["asciiart".to_owned()] + }) + }) + ); + + assert_eq!(pipeline.vertices[2].name, "tiger"); + assert_eq!(pipeline.vertices[3].name, "dog"); + } +} diff --git a/rust/numaflow-core/Cargo.toml b/rust/numaflow-core/Cargo.toml index b4688a135b..53cf4136fc 100644 --- a/rust/numaflow-core/Cargo.toml +++ b/rust/numaflow-core/Cargo.toml @@ -19,10 +19,12 @@ numaflow-models.workspace = true numaflow-pb.workspace = true serving.workspace = true backoff.workspace = true +numaflow-serving.workspace = true axum.workspace = true axum-server.workspace = true +bytes.workspace = true +serde.workspace = true tonic = "0.12.3" -bytes = "1.7.1" thiserror = "2.0.3" tokio-util = "0.7.11" tokio-stream = "0.1.15" @@ -36,7 +38,6 @@ serde_json = "1.0.122" trait-variant = "0.1.2" rcgen = "0.13.1" rustls = { version = "0.23.12", features = ["aws_lc_rs"] } -serde = { version = "1.0.204", features = ["derive"] } semver = "1.0" pep440_rs = "0.6.6" parking_lot = "0.12.3" diff --git a/rust/numaflow-core/src/config/components.rs b/rust/numaflow-core/src/config/components.rs index f17331ddaa..f9f696de18 100644 --- a/rust/numaflow-core/src/config/components.rs +++ b/rust/numaflow-core/src/config/components.rs @@ -5,6 +5,7 @@ pub(crate) mod source { use std::collections::HashMap; use std::env; + use std::sync::Arc; use std::{fmt::Debug, time::Duration}; use bytes::Bytes; @@ -33,7 +34,7 @@ pub(crate) mod source { Generator(GeneratorConfig), UserDefined(UserDefinedConfig), Pulsar(PulsarSourceConfig), - Serving(serving::Settings), + Serving(Arc), } impl From> for SourceType { @@ -106,10 +107,7 @@ pub(crate) mod source { // There should be only one option (user-defined) to define the settings. fn try_from(cfg: Box) -> Result { let env_vars = env::vars().collect::>(); - - let mut settings: serving::Settings = env_vars - .try_into() - .map_err(|e: serving::Error| Error::Config(e.to_string()))?; + let mut settings: numaflow_serving::Settings = env_vars.try_into()?; settings.tid_header = cfg.msg_id_header_key; @@ -144,7 +142,7 @@ pub(crate) mod source { } settings.redis.addr = cfg.store.url; - Ok(SourceType::Serving(settings)) + Ok(SourceType::Serving(Arc::new(settings))) } } @@ -164,6 +162,10 @@ pub(crate) mod source { return pulsar.try_into(); } + if let Some(serving) = source.serving.take() { + return serving.try_into(); + } + Err(Error::Config(format!("Invalid source type: {source:?}"))) } } diff --git a/rust/numaflow-core/src/lib.rs b/rust/numaflow-core/src/lib.rs index 727a119f1b..2237c5577a 100644 --- a/rust/numaflow-core/src/lib.rs +++ b/rust/numaflow-core/src/lib.rs @@ -52,6 +52,7 @@ mod pipeline; mod tracker; pub async fn run() -> Result<()> { + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); let cln_token = CancellationToken::new(); let shutdown_cln_token = cln_token.clone(); diff --git a/rust/numaflow-core/src/shared/create_components.rs b/rust/numaflow-core/src/shared/create_components.rs index 0e5fade691..b41eb045d0 100644 --- a/rust/numaflow-core/src/shared/create_components.rs +++ b/rust/numaflow-core/src/shared/create_components.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; use std::time::Duration; use numaflow_pb::clients::sink::sink_client::SinkClient; use numaflow_pb::clients::source::source_client::SourceClient; use numaflow_pb::clients::sourcetransformer::source_transform_client::SourceTransformClient; +use numaflow_serving::ServingSource; use tokio_util::sync::CancellationToken; use tonic::transport::Channel; @@ -266,8 +268,16 @@ pub async fn create_source( None, )) } - SourceType::Serving(_) => { - unimplemented!("Serving as built-in source is not yet implemented") + SourceType::Serving(config) => { + let serving = ServingSource::new(Arc::clone(config), batch_size, read_timeout).await?; + Ok(( + Source::new( + batch_size, + source::SourceType::Serving(serving), + tracker_handle, + ), + None, + )) } } } diff --git a/rust/numaflow-core/src/source.rs b/rust/numaflow-core/src/source.rs index 66361d84ac..9e0f0f0b56 100644 --- a/rust/numaflow-core/src/source.rs +++ b/rust/numaflow-core/src/source.rs @@ -1,4 +1,5 @@ use numaflow_pulsar::source::PulsarSource; +use numaflow_serving::ServingSource; use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinHandle; use tokio::time; @@ -34,6 +35,8 @@ pub(crate) mod generator; /// [Pulsar]: https://numaflow.numaproj.io/user-guide/sources/pulsar/ pub(crate) mod pulsar; +pub(crate) mod serving; + /// Set of Read related items that has to be implemented to become a Source. pub(crate) trait SourceReader { #[allow(dead_code)] @@ -65,6 +68,7 @@ pub(crate) enum SourceType { generator::GeneratorLagReader, ), Pulsar(PulsarSource), + Serving(ServingSource), } enum ActorMessage { @@ -177,6 +181,13 @@ impl Source { actor.run().await; }); } + SourceType::Serving(serving) => { + tokio::spawn(async move { + let actor = + SourceActor::new(receiver, serving.clone(), serving.clone(), serving); + actor.run().await; + }); + } }; Self { read_batch_size: batch_size, diff --git a/rust/numaflow-core/src/source/serving.rs b/rust/numaflow-core/src/source/serving.rs new file mode 100644 index 0000000000..a06f418d66 --- /dev/null +++ b/rust/numaflow-core/src/source/serving.rs @@ -0,0 +1,77 @@ +use std::sync::Arc; + +use numaflow_serving::ServingSource; + +use crate::message::{MessageID, StringOffset}; +use crate::Error; +use crate::Result; + +use super::{get_vertex_name, Message, Offset}; + +impl TryFrom for Message { + type Error = Error; + + fn try_from(message: numaflow_serving::Message) -> Result { + let offset = Offset::String(StringOffset::new(message.id.clone(), 0)); + + Ok(Message { + keys: Arc::from(vec![]), + tags: None, + value: message.value, + offset: Some(offset.clone()), + event_time: Default::default(), + id: MessageID { + vertex_name: get_vertex_name().to_string().into(), + offset: offset.to_string().into(), + index: 0, + }, + headers: message.headers, + }) + } +} + +impl From for Error { + fn from(value: numaflow_serving::Error) -> Self { + Error::Source(value.to_string()) + } +} + +impl super::SourceReader for ServingSource { + fn name(&self) -> &'static str { + "serving" + } + + async fn read(&mut self) -> Result> { + self.read_messages() + .await? + .into_iter() + .map(|msg| msg.try_into()) + .collect() + } + + fn partitions(&self) -> Vec { + vec![] + } +} + +impl super::SourceAcker for ServingSource { + async fn ack(&mut self, offsets: Vec) -> Result<()> { + let mut serving_offsets = vec![]; + for offset in offsets { + let Offset::String(offset) = offset else { + return Err(Error::Source(format!( + "Expected string offset for Serving source. Got {offset:?}" + ))); + }; + serving_offsets.push(offset.to_string()); + } + self.ack_messages(serving_offsets).await?; + Ok(()) + } +} + +impl super::LagReader for ServingSource { + async fn pending(&mut self) -> Result> { + Ok(None) + } +} From 198fe0095f6ed21df722c66a4f4c8a931c1a9cd7 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Mon, 23 Dec 2024 15:17:02 +0530 Subject: [PATCH 02/15] Unit tests for configurations Signed-off-by: Sreekanth --- rust/extns/numaflow-serving/src/app.rs | 2 +- rust/extns/numaflow-serving/src/config.rs | 106 ++++++++++ rust/extns/numaflow-serving/src/errors.rs | 50 +++++ rust/extns/numaflow-serving/src/lib.rs | 246 +--------------------- rust/extns/numaflow-serving/src/source.rs | 194 +++++++++++++++++ 5 files changed, 356 insertions(+), 242 deletions(-) create mode 100644 rust/extns/numaflow-serving/src/errors.rs create mode 100644 rust/extns/numaflow-serving/src/source.rs diff --git a/rust/extns/numaflow-serving/src/app.rs b/rust/extns/numaflow-serving/src/app.rs index f2433e928a..6290aac84c 100644 --- a/rust/extns/numaflow-serving/src/app.rs +++ b/rust/extns/numaflow-serving/src/app.rs @@ -34,7 +34,7 @@ pub(crate) mod tracker; use self::callback::store::Store; use self::response::{ApiError, ServeResponse}; -pub fn generate_certs() -> crate::Result<(Certificate, KeyPair)> { +fn generate_certs() -> crate::Result<(Certificate, KeyPair)> { let CertifiedKey { cert, key_pair } = generate_simple_self_signed(vec!["localhost".into()]) .map_err(|e| Error::InitError(format!("Failed to generate cert {:?}", e)))?; Ok((cert, key_pair)) diff --git a/rust/extns/numaflow-serving/src/config.rs b/rust/extns/numaflow-serving/src/config.rs index d8ace5716a..9a577636c5 100644 --- a/rust/extns/numaflow-serving/src/config.rs +++ b/rust/extns/numaflow-serving/src/config.rs @@ -222,3 +222,109 @@ impl TryFrom> for Settings { Ok(settings) } } + +#[cfg(test)] +mod tests { + use crate::pipeline::{Edge, Vertex}; + + use super::*; + + #[test] + fn test_basic_auth_debug_print() { + let auth = BasicAuth { + username: "js-auth-user".into(), + password: "js-auth-password".into(), + }; + let auth_debug = format!("{auth:?}"); + assert_eq!(auth_debug, "js-auth-user:***rd"); + } + + #[test] + fn test_default_config() { + let settings = Settings::default(); + + assert_eq!(settings.tid_header, "ID"); + assert_eq!(settings.app_listen_port, 3000); + assert_eq!(settings.metrics_server_listen_port, 3001); + assert_eq!(settings.upstream_addr, "localhost:8888"); + assert_eq!(settings.drain_timeout_secs, 10); + assert_eq!(settings.jetstream.stream, "default"); + assert_eq!(settings.jetstream.url, "localhost:4222"); + assert_eq!(settings.redis.addr, "redis://127.0.0.1:6379"); + assert_eq!(settings.redis.max_tasks, 50); + assert_eq!(settings.redis.retries, 5); + assert_eq!(settings.redis.retries_duration_millis, 100); + } + + #[test] + fn test_config_parse() { + // Set up the environment variables + let env_vars = [ + ( + ENV_NUMAFLOW_SERVING_JETSTREAM_URL, + "nats://isbsvc-default-js-svc.default.svc:4222", + ), + ( + ENV_NUMAFLOW_SERVING_JETSTREAM_STREAM, + "ascii-art-pipeline-in-serving-source", + ), + (ENV_NUMAFLOW_SERVING_JETSTREAM_USER, "js-auth-user"), + (ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD, "js-user-password"), + (ENV_NUMAFLOW_SERVING_HOST_IP, "10.2.3.5"), + (ENV_NUMAFLOW_SERVING_AUTH_TOKEN, "api-auth-token"), + (ENV_NUMAFLOW_SERVING_APP_PORT, "8443"), + (ENV_NUMAFLOW_SERVING_STORE_TTL, "86400"), + (ENV_NUMAFLOW_SERVING_SOURCE_OBJECT, "eyJhdXRoIjpudWxsLCJzZXJ2aWNlIjp0cnVlLCJtc2dJREhlYWRlcktleSI6IlgtTnVtYWZsb3ctSWQiLCJzdG9yZSI6eyJ1cmwiOiJyZWRpczovL3JlZGlzOjYzNzkifX0="), + (ENV_MIN_PIPELINE_SPEC, "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6InNlcnZpbmctaW4iLCJzb3VyY2UiOnsic2VydmluZyI6eyJhdXRoIjpudWxsLCJzZXJ2aWNlIjp0cnVlLCJtc2dJREhlYWRlcktleSI6IlgtTnVtYWZsb3ctSWQiLCJzdG9yZSI6eyJ1cmwiOiJyZWRpczovL3JlZGlzOjYzNzkifX19LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciIsImVudiI6W3sibmFtZSI6IlJVU1RfTE9HIiwidmFsdWUiOiJpbmZvIn1dfSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2aW5nLXNpbmsiLCJzaW5rIjp7InVkc2luayI6eyJjb250YWluZXIiOnsiaW1hZ2UiOiJxdWF5LmlvL251bWFpby9udW1hZmxvdy1ycy9zaW5rLWxvZzpzdGFibGUiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9fX0sInJldHJ5U3RyYXRlZ3kiOnt9fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fV0sImVkZ2VzIjpbeyJmcm9tIjoic2VydmluZy1pbiIsInRvIjoic2VydmluZy1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH1dLCJsaWZlY3ljbGUiOnt9LCJ3YXRlcm1hcmsiOnt9fQ==") + ]; + + // Call the config method + let settings: Settings = env_vars + .into_iter() + .map(|(key, val)| (key.to_owned(), val.to_owned())) + .collect::>() + .try_into() + .unwrap(); + + let expected_config = Settings { + tid_header: "X-Numaflow-Id".into(), + app_listen_port: 8443, + metrics_server_listen_port: 3001, + upstream_addr: "localhost:8888".into(), + drain_timeout_secs: 10, + jetstream: JetStreamConfig { + stream: "ascii-art-pipeline-in-serving-source".into(), + url: "nats://isbsvc-default-js-svc.default.svc:4222".into(), + auth: Some(BasicAuth { + username: "js-auth-user".into(), + password: "js-user-password".into(), + }), + }, + redis: RedisConfig { + addr: "redis://redis:6379".into(), + max_tasks: 50, + retries: 5, + retries_duration_millis: 100, + ttl_secs: Some(86400), + }, + host_ip: "10.2.3.5".into(), + api_auth_token: Some("api-auth-token".into()), + pipeline_spec: PipelineDCG { + vertices: vec![ + Vertex { + name: "serving-in".into(), + }, + Vertex { + name: "serving-sink".into(), + }, + ], + edges: vec![Edge { + from: "serving-in".into(), + to: "serving-sink".into(), + conditions: None, + }], + }, + }; + assert_eq!(settings, expected_config); + } +} diff --git a/rust/extns/numaflow-serving/src/errors.rs b/rust/extns/numaflow-serving/src/errors.rs new file mode 100644 index 0000000000..7446faee71 --- /dev/null +++ b/rust/extns/numaflow-serving/src/errors.rs @@ -0,0 +1,50 @@ +use tokio::sync::oneshot; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Initialization error - {0}")] + InitError(String), + + #[error("Failed to parse configuration - {0}")] + ParseConfig(String), + // + // callback errors + // TODO: store the ID too? + #[error("IDNotFound Error - {0}")] + IDNotFound(&'static str), + + #[error("SubGraphGenerator Error - {0}")] + // subgraph generator errors + SubGraphGenerator(String), + + #[error("StoreWrite Error - {0}")] + // Store write errors + StoreWrite(String), + + #[error("SubGraphNotFound Error - {0}")] + // Sub Graph Not Found Error + SubGraphNotFound(&'static str), + + #[error("SubGraphInvalidInput Error - {0}")] + // Sub Graph Invalid Input Error + SubGraphInvalidInput(String), + + #[error("StoreRead Error - {0}")] + // Store read errors + StoreRead(String), + + #[error("Metrics Error - {0}")] + // Metrics errors + MetricsServer(String), + + #[error("Connection Error - {0}")] + Connection(String), + + #[error("Failed to receive message from channel. Actor task is terminated: {0:?}")] + ActorTaskTerminated(oneshot::error::RecvError), + + #[error("{0}")] + Other(String), +} + +pub type Result = std::result::Result; diff --git a/rust/extns/numaflow-serving/src/lib.rs b/rust/extns/numaflow-serving/src/lib.rs index 1da6d0c6c3..a6a446203e 100644 --- a/rust/extns/numaflow-serving/src/lib.rs +++ b/rust/extns/numaflow-serving/src/lib.rs @@ -1,248 +1,12 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - -use bytes::Bytes; -use tokio::sync::mpsc; -use tokio::sync::oneshot; -use tokio::time::Instant; - -use crate::app::callback::state::State as CallbackState; -use crate::app::callback::store::redisstore::RedisConnection; -use crate::app::tracker::MessageGraph; +mod source; +pub use source::{Message, MessageWrapper, ServingSource}; mod app; pub mod config; -pub(crate) mod pipeline; - -struct MessageWrapper { - pub confirm_save: oneshot::Sender<()>, - message: Message, -} - -pub struct Message { - pub value: Bytes, - pub id: String, - pub headers: HashMap, -} - -enum ActorMessage { - Read { - batch_size: usize, - timeout_at: Instant, - reply_to: oneshot::Sender>, - }, - Ack { - offsets: Vec, - reply_to: oneshot::Sender<()>, - }, -} - -struct ServingSourceActor { - messages: mpsc::Receiver, - handler_rx: mpsc::Receiver, - tracker: HashMap>, -} - -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("Initialization error - {0}")] - InitError(String), - - #[error("Failed to parse configuration - {0}")] - ParseConfig(String), - // - // callback errors - // TODO: store the ID too? - #[error("IDNotFound Error - {0}")] - IDNotFound(&'static str), - - #[error("SubGraphGenerator Error - {0}")] - // subgraph generator errors - SubGraphGenerator(String), - - #[error("StoreWrite Error - {0}")] - // Store write errors - StoreWrite(String), - - #[error("SubGraphNotFound Error - {0}")] - // Sub Graph Not Found Error - SubGraphNotFound(&'static str), - - #[error("SubGraphInvalidInput Error - {0}")] - // Sub Graph Invalid Input Error - SubGraphInvalidInput(String), - - #[error("StoreRead Error - {0}")] - // Store read errors - StoreRead(String), - - #[error("Metrics Error - {0}")] - // Metrics errors - MetricsServer(String), +mod errors; +pub use errors::{Error, Result}; - #[error("Connection Error - {0}")] - Connection(String), - - #[error("Failed to receive message from channel. Actor task is terminated: {0:?}")] - ActorTaskTerminated(oneshot::error::RecvError), - - #[error("{0}")] - Other(String), -} +pub(crate) mod pipeline; -type Result = std::result::Result; pub type Settings = config::Settings; - -impl ServingSourceActor { - async fn start( - settings: Arc, - handler_rx: mpsc::Receiver, - ) -> Result<()> { - let (messages_tx, messages_rx) = mpsc::channel(10000); - // Create a redis store to store the callbacks and the custom responses - let redis_store = RedisConnection::new(settings.redis.clone()).await?; - // Create the message graph from the pipeline spec and the redis store - let msg_graph = MessageGraph::from_pipeline(&settings.pipeline_spec).map_err(|e| { - Error::InitError(format!( - "Creating message graph from pipeline spec: {:?}", - e - )) - })?; - let callback_state = CallbackState::new(msg_graph, redis_store).await?; - - tokio::spawn(async move { - let mut serving_actor = ServingSourceActor { - messages: messages_rx, - handler_rx, - tracker: HashMap::new(), - }; - serving_actor.run().await; - }); - let app = app::AppState { - message: messages_tx, - settings, - callback_state, - }; - app::serve(app).await.unwrap(); - Ok(()) - } - - async fn run(&mut self) { - while let Some(msg) = self.handler_rx.recv().await { - self.handle_message(msg).await; - } - } - - async fn handle_message(&mut self, actor_msg: ActorMessage) { - match actor_msg { - ActorMessage::Read { - batch_size, - timeout_at, - reply_to, - } => { - let messages = self.read(batch_size, timeout_at).await; - let _ = reply_to.send(messages); - } - ActorMessage::Ack { offsets, reply_to } => { - self.ack(offsets).await; - let _ = reply_to.send(()); - } - } - } - - async fn read(&mut self, count: usize, timeout_at: Instant) -> Vec { - let mut messages = vec![]; - loop { - if messages.len() >= count || Instant::now() >= timeout_at { - break; - } - let message = match self.messages.try_recv() { - Ok(msg) => msg, - Err(mpsc::error::TryRecvError::Empty) => break, - Err(e) => { - tracing::error!(?e, "Receiving messages from the serving channel"); // FIXME: - return messages; - } - }; - let MessageWrapper { - confirm_save, - message, - } = message; - - self.tracker.insert(message.id.clone(), confirm_save); - messages.push(message); - } - messages - } - - async fn ack(&mut self, offsets: Vec) { - for offset in offsets { - let offset = offset - .strip_suffix("-0") - .expect("offset does not end with '-0'"); // FIXME: we hardcode 0 as the partition index when constructing offset - let confirm_save_tx = self - .tracker - .remove(offset) - .expect("offset was not found in the tracker"); - confirm_save_tx - .send(()) - .expect("Sending on confirm_save channel"); - } - } -} - -#[derive(Clone)] -pub struct ServingSource { - batch_size: usize, - // timeout for each batch read request - timeout: Duration, - actor_tx: mpsc::Sender, -} - -impl ServingSource { - pub async fn new( - settings: Arc, - batch_size: usize, - timeout: Duration, - ) -> Result { - let (actor_tx, actor_rx) = mpsc::channel(10); - ServingSourceActor::start(settings, actor_rx).await?; - Ok(Self { - batch_size, - timeout, - actor_tx, - }) - } - - pub async fn read_messages(&self) -> Result> { - let start = Instant::now(); - let (tx, rx) = oneshot::channel(); - let actor_msg = ActorMessage::Read { - reply_to: tx, - batch_size: self.batch_size, - timeout_at: Instant::now() + self.timeout, - }; - let _ = self.actor_tx.send(actor_msg).await; - let messages = rx.await.map_err(Error::ActorTaskTerminated)?; - tracing::debug!( - count = messages.len(), - requested_count = self.batch_size, - time_taken_ms = start.elapsed().as_millis(), - "Got messages from Serving source" - ); - Ok(messages) - } - - pub async fn ack_messages(&self, offsets: Vec) -> Result<()> { - let (tx, rx) = oneshot::channel(); - let actor_msg = ActorMessage::Ack { - offsets, - reply_to: tx, - }; - let _ = self.actor_tx.send(actor_msg).await; - rx.await.map_err(Error::ActorTaskTerminated)?; - Ok(()) - } -} diff --git a/rust/extns/numaflow-serving/src/source.rs b/rust/extns/numaflow-serving/src/source.rs new file mode 100644 index 0000000000..af7d8aaa8f --- /dev/null +++ b/rust/extns/numaflow-serving/src/source.rs @@ -0,0 +1,194 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use bytes::Bytes; +use tokio::sync::{mpsc, oneshot}; +use tokio::time::Instant; + +use crate::app::callback::state::State as CallbackState; +use crate::app::callback::store::redisstore::RedisConnection; +use crate::app::tracker::MessageGraph; +use crate::{app, Settings}; +use crate::{Error, Result}; + +pub struct MessageWrapper { + pub confirm_save: oneshot::Sender<()>, + pub message: Message, +} + +pub struct Message { + pub value: Bytes, + pub id: String, + pub headers: HashMap, +} + +enum ActorMessage { + Read { + batch_size: usize, + timeout_at: Instant, + reply_to: oneshot::Sender>, + }, + Ack { + offsets: Vec, + reply_to: oneshot::Sender<()>, + }, +} + +struct ServingSourceActor { + messages: mpsc::Receiver, + handler_rx: mpsc::Receiver, + tracker: HashMap>, +} + +impl ServingSourceActor { + async fn start( + settings: Arc, + handler_rx: mpsc::Receiver, + ) -> Result<()> { + let (messages_tx, messages_rx) = mpsc::channel(10000); + // Create a redis store to store the callbacks and the custom responses + let redis_store = RedisConnection::new(settings.redis.clone()).await?; + // Create the message graph from the pipeline spec and the redis store + let msg_graph = MessageGraph::from_pipeline(&settings.pipeline_spec).map_err(|e| { + Error::InitError(format!( + "Creating message graph from pipeline spec: {:?}", + e + )) + })?; + let callback_state = CallbackState::new(msg_graph, redis_store).await?; + + tokio::spawn(async move { + let mut serving_actor = ServingSourceActor { + messages: messages_rx, + handler_rx, + tracker: HashMap::new(), + }; + serving_actor.run().await; + }); + let app = app::AppState { + message: messages_tx, + settings, + callback_state, + }; + app::serve(app).await.unwrap(); + Ok(()) + } + + async fn run(&mut self) { + while let Some(msg) = self.handler_rx.recv().await { + self.handle_message(msg).await; + } + } + + async fn handle_message(&mut self, actor_msg: ActorMessage) { + match actor_msg { + ActorMessage::Read { + batch_size, + timeout_at, + reply_to, + } => { + let messages = self.read(batch_size, timeout_at).await; + let _ = reply_to.send(messages); + } + ActorMessage::Ack { offsets, reply_to } => { + self.ack(offsets).await; + let _ = reply_to.send(()); + } + } + } + + async fn read(&mut self, count: usize, timeout_at: Instant) -> Vec { + let mut messages = vec![]; + loop { + if messages.len() >= count || Instant::now() >= timeout_at { + break; + } + let message = match self.messages.try_recv() { + Ok(msg) => msg, + Err(mpsc::error::TryRecvError::Empty) => break, + Err(e) => { + tracing::error!(?e, "Receiving messages from the serving channel"); // FIXME: + return messages; + } + }; + let MessageWrapper { + confirm_save, + message, + } = message; + + self.tracker.insert(message.id.clone(), confirm_save); + messages.push(message); + } + messages + } + + async fn ack(&mut self, offsets: Vec) { + for offset in offsets { + let offset = offset + .strip_suffix("-0") + .expect("offset does not end with '-0'"); // FIXME: we hardcode 0 as the partition index when constructing offset + let confirm_save_tx = self + .tracker + .remove(offset) + .expect("offset was not found in the tracker"); + confirm_save_tx + .send(()) + .expect("Sending on confirm_save channel"); + } + } +} + +#[derive(Clone)] +pub struct ServingSource { + batch_size: usize, + // timeout for each batch read request + timeout: Duration, + actor_tx: mpsc::Sender, +} + +impl ServingSource { + pub async fn new( + settings: Arc, + batch_size: usize, + timeout: Duration, + ) -> Result { + let (actor_tx, actor_rx) = mpsc::channel(10); + ServingSourceActor::start(settings, actor_rx).await?; + Ok(Self { + batch_size, + timeout, + actor_tx, + }) + } + + pub async fn read_messages(&self) -> Result> { + let start = Instant::now(); + let (tx, rx) = oneshot::channel(); + let actor_msg = ActorMessage::Read { + reply_to: tx, + batch_size: self.batch_size, + timeout_at: Instant::now() + self.timeout, + }; + let _ = self.actor_tx.send(actor_msg).await; + let messages = rx.await.map_err(Error::ActorTaskTerminated)?; + tracing::debug!( + count = messages.len(), + requested_count = self.batch_size, + time_taken_ms = start.elapsed().as_millis(), + "Got messages from Serving source" + ); + Ok(messages) + } + + pub async fn ack_messages(&self, offsets: Vec) -> Result<()> { + let (tx, rx) = oneshot::channel(); + let actor_msg = ActorMessage::Ack { + offsets, + reply_to: tx, + }; + let _ = self.actor_tx.send(actor_msg).await; + rx.await.map_err(Error::ActorTaskTerminated)?; + Ok(()) + } +} From 94dcddd1b868757fab345ff1f1bba01d3d66896a Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Mon, 23 Dec 2024 15:40:43 +0530 Subject: [PATCH 03/15] Remove jetstream related config options Signed-off-by: Sreekanth --- rust/extns/numaflow-serving/src/app.rs | 2 - rust/extns/numaflow-serving/src/config.rs | 81 ----------------------- 2 files changed, 83 deletions(-) diff --git a/rust/extns/numaflow-serving/src/app.rs b/rust/extns/numaflow-serving/src/app.rs index 6290aac84c..1d491bdace 100644 --- a/rust/extns/numaflow-serving/src/app.rs +++ b/rust/extns/numaflow-serving/src/app.rs @@ -220,7 +220,6 @@ const NUMAFLOW_RESP_ARRAY_IDX_LEN: &str = "Numaflow-Array-Index-Len"; struct ProxyState { tid_header: String, callback: CallbackState, - stream: String, callback_url: String, messages: mpsc::Sender, } @@ -231,7 +230,6 @@ pub(crate) fn jetstream_proxy( let proxy_state = Arc::new(ProxyState { tid_header: state.settings.tid_header.clone(), callback: state.callback_state.clone(), - stream: state.settings.jetstream.stream.clone(), messages: state.message.clone(), callback_url: format!( "https://{}:{}/v1/process/callback", diff --git a/rust/extns/numaflow-serving/src/config.rs b/rust/extns/numaflow-serving/src/config.rs index 9a577636c5..27b0eb4f65 100644 --- a/rust/extns/numaflow-serving/src/config.rs +++ b/rust/extns/numaflow-serving/src/config.rs @@ -20,46 +20,6 @@ const ENV_MIN_PIPELINE_SPEC: &str = "NUMAFLOW_SERVING_MIN_PIPELINE_SPEC"; pub(crate) const SAVED: &str = "SAVED"; -#[derive(Deserialize, Clone, PartialEq)] -pub struct BasicAuth { - pub username: String, - pub password: String, -} - -impl std::fmt::Debug for BasicAuth { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let passwd_printable = if self.password.len() > 4 { - let passwd: String = self - .password - .chars() - .skip(self.password.len() - 2) - .take(2) - .collect(); - format!("***{}", passwd) - } else { - "*****".to_owned() - }; - write!(f, "{}:{}", self.username, passwd_printable) - } -} - -#[derive(Debug, Deserialize, Clone, PartialEq)] -pub struct JetStreamConfig { - pub stream: String, - pub url: String, - pub auth: Option, -} - -impl Default for JetStreamConfig { - fn default() -> Self { - Self { - stream: "default".to_owned(), - url: "localhost:4222".to_owned(), - auth: None, - } - } -} - #[derive(Debug, Deserialize, Clone, PartialEq)] pub struct RedisConfig { pub addr: String, @@ -93,7 +53,6 @@ pub struct Settings { pub host_ip: String, pub api_auth_token: Option, pub redis: RedisConfig, - pub jetstream: JetStreamConfig, pub pipeline_spec: PipelineDCG, } @@ -108,7 +67,6 @@ impl Default for Settings { host_ip: "127.0.0.1".to_owned(), api_auth_token: None, redis: RedisConfig::default(), - jetstream: JetStreamConfig::default(), pipeline_spec: Default::default(), } } @@ -161,14 +119,6 @@ impl TryFrom> for Settings { ..Default::default() }; - if let Some(jetstream_url) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_URL) { - settings.jetstream.url = jetstream_url.to_owned(); - } - - if let Some(jetstream_stream) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_STREAM) { - settings.jetstream.stream = jetstream_stream.to_owned(); - } - if let Some(api_auth_token) = env_vars.get(ENV_NUMAFLOW_SERVING_AUTH_TOKEN) { settings.api_auth_token = Some(api_auth_token.to_owned()); } @@ -181,17 +131,6 @@ impl TryFrom> for Settings { })?; } - // If username is set, the password also must be set - if let Some(username) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_USER) { - let Some(password) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD) else { - return Err(Error::ParseConfig(format!("Env variable {ENV_NUMAFLOW_SERVING_JETSTREAM_USER} is set, but {ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD} is not set"))); - }; - settings.jetstream.auth = Some(BasicAuth { - username: username.to_owned(), - password: password.to_owned(), - }); - } - // Update redis.ttl_secs from environment variable if let Some(ttl_secs) = env_vars.get(ENV_NUMAFLOW_SERVING_STORE_TTL) { let ttl_secs: u32 = ttl_secs.parse().map_err(|e| { @@ -229,16 +168,6 @@ mod tests { use super::*; - #[test] - fn test_basic_auth_debug_print() { - let auth = BasicAuth { - username: "js-auth-user".into(), - password: "js-auth-password".into(), - }; - let auth_debug = format!("{auth:?}"); - assert_eq!(auth_debug, "js-auth-user:***rd"); - } - #[test] fn test_default_config() { let settings = Settings::default(); @@ -248,8 +177,6 @@ mod tests { assert_eq!(settings.metrics_server_listen_port, 3001); assert_eq!(settings.upstream_addr, "localhost:8888"); assert_eq!(settings.drain_timeout_secs, 10); - assert_eq!(settings.jetstream.stream, "default"); - assert_eq!(settings.jetstream.url, "localhost:4222"); assert_eq!(settings.redis.addr, "redis://127.0.0.1:6379"); assert_eq!(settings.redis.max_tasks, 50); assert_eq!(settings.redis.retries, 5); @@ -292,14 +219,6 @@ mod tests { metrics_server_listen_port: 3001, upstream_addr: "localhost:8888".into(), drain_timeout_secs: 10, - jetstream: JetStreamConfig { - stream: "ascii-art-pipeline-in-serving-source".into(), - url: "nats://isbsvc-default-js-svc.default.svc:4222".into(), - auth: Some(BasicAuth { - username: "js-auth-user".into(), - password: "js-user-password".into(), - }), - }, redis: RedisConfig { addr: "redis://redis:6379".into(), max_tasks: 50, From ae42c294d2806a3c2be53434e9c429bee74f25bc Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 24 Dec 2024 10:12:36 +0530 Subject: [PATCH 04/15] Remove unused constants Signed-off-by: Sreekanth --- rust/extns/numaflow-serving/src/config.rs | 14 -------------- rust/extns/numaflow-serving/src/pipeline.rs | 12 +----------- 2 files changed, 1 insertion(+), 25 deletions(-) diff --git a/rust/extns/numaflow-serving/src/config.rs b/rust/extns/numaflow-serving/src/config.rs index 27b0eb4f65..d306b47d70 100644 --- a/rust/extns/numaflow-serving/src/config.rs +++ b/rust/extns/numaflow-serving/src/config.rs @@ -8,13 +8,9 @@ use crate::pipeline::PipelineDCG; use crate::Error; const ENV_NUMAFLOW_SERVING_SOURCE_OBJECT: &str = "NUMAFLOW_SERVING_SOURCE_OBJECT"; -const ENV_NUMAFLOW_SERVING_JETSTREAM_URL: &str = "NUMAFLOW_ISBSVC_JETSTREAM_URL"; -const ENV_NUMAFLOW_SERVING_JETSTREAM_STREAM: &str = "NUMAFLOW_SERVING_JETSTREAM_STREAM"; const ENV_NUMAFLOW_SERVING_STORE_TTL: &str = "NUMAFLOW_SERVING_STORE_TTL"; const ENV_NUMAFLOW_SERVING_HOST_IP: &str = "NUMAFLOW_SERVING_HOST_IP"; const ENV_NUMAFLOW_SERVING_APP_PORT: &str = "NUMAFLOW_SERVING_APP_LISTEN_PORT"; -const ENV_NUMAFLOW_SERVING_JETSTREAM_USER: &str = "NUMAFLOW_ISBSVC_JETSTREAM_USER"; -const ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD: &str = "NUMAFLOW_ISBSVC_JETSTREAM_PASSWORD"; const ENV_NUMAFLOW_SERVING_AUTH_TOKEN: &str = "NUMAFLOW_SERVING_AUTH_TOKEN"; const ENV_MIN_PIPELINE_SPEC: &str = "NUMAFLOW_SERVING_MIN_PIPELINE_SPEC"; @@ -187,16 +183,6 @@ mod tests { fn test_config_parse() { // Set up the environment variables let env_vars = [ - ( - ENV_NUMAFLOW_SERVING_JETSTREAM_URL, - "nats://isbsvc-default-js-svc.default.svc:4222", - ), - ( - ENV_NUMAFLOW_SERVING_JETSTREAM_STREAM, - "ascii-art-pipeline-in-serving-source", - ), - (ENV_NUMAFLOW_SERVING_JETSTREAM_USER, "js-auth-user"), - (ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD, "js-user-password"), (ENV_NUMAFLOW_SERVING_HOST_IP, "10.2.3.5"), (ENV_NUMAFLOW_SERVING_AUTH_TOKEN, "api-auth-token"), (ENV_NUMAFLOW_SERVING_APP_PORT, "8443"), diff --git a/rust/extns/numaflow-serving/src/pipeline.rs b/rust/extns/numaflow-serving/src/pipeline.rs index 05768568ab..50236f5381 100644 --- a/rust/extns/numaflow-serving/src/pipeline.rs +++ b/rust/extns/numaflow-serving/src/pipeline.rs @@ -64,22 +64,12 @@ pub struct Edge { /// DCG (directed compute graph) of the pipeline with minimal information build using vertices and edges /// from the pipeline spec -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -#[serde()] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] pub struct PipelineDCG { pub vertices: Vec, pub edges: Vec, } -impl Default for PipelineDCG { - fn default() -> Self { - Self { - vertices: vec![], - edges: vec![], - } - } -} - #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Vertex { pub name: String, From 0bc6ca91c38c513aea202efafab566b1a3737e13 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 24 Dec 2024 10:29:20 +0530 Subject: [PATCH 05/15] Fix compilation errors Signed-off-by: Sreekanth --- rust/numaflow-core/src/shared/create_components.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rust/numaflow-core/src/shared/create_components.rs b/rust/numaflow-core/src/shared/create_components.rs index feac461e23..eeae1136e8 100644 --- a/rust/numaflow-core/src/shared/create_components.rs +++ b/rust/numaflow-core/src/shared/create_components.rs @@ -343,6 +343,7 @@ pub async fn create_source( batch_size, source::SourceType::Serving(serving), tracker_handle, + source_config.read_ahead, ), None, )) From e20b022054e7504fd6186743e8a0d7159e6b0937 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 24 Dec 2024 12:51:28 +0530 Subject: [PATCH 06/15] Modify serving to be a built-in source Signed-off-by: Sreekanth --- rust/Cargo.lock | 1 + rust/numaflow/src/main.rs | 13 +- rust/serving/Cargo.toml | 1 + rust/serving/src/app.rs | 133 +++------------- rust/serving/src/app/jetstream_proxy.rs | 143 +++++++++-------- rust/serving/src/config.rs | 142 ++++++----------- rust/serving/src/error.rs | 4 + rust/serving/src/lib.rs | 43 +++--- rust/serving/src/pipeline.rs | 13 +- rust/serving/src/source.rs | 194 ++++++++++++++++++++++++ 10 files changed, 363 insertions(+), 324 deletions(-) create mode 100644 rust/serving/src/source.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index aaac1c520b..21a5f97246 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2891,6 +2891,7 @@ dependencies = [ "axum-server", "backoff", "base64 0.22.1", + "bytes", "chrono", "hyper-util", "numaflow-models", diff --git a/rust/numaflow/src/main.rs b/rust/numaflow/src/main.rs index 60e26ef850..9a5ab6fe82 100644 --- a/rust/numaflow/src/main.rs +++ b/rust/numaflow/src/main.rs @@ -1,7 +1,5 @@ -use std::collections::HashMap; use std::env; use std::error::Error; -use std::sync::Arc; use tracing::error; use tracing_subscriber::layer::SubscriberExt; @@ -31,14 +29,7 @@ async fn main() -> Result<(), Box> { async fn run() -> Result<(), Box> { let args: Vec = env::args().collect(); // Based on the argument, run the appropriate component. - if args.contains(&"--serving".to_string()) { - let env_vars: HashMap = env::vars().collect(); - let settings: serving::Settings = env_vars.try_into()?; - let settings = Arc::new(settings); - serving::serve(settings) - .await - .map_err(|e| format!("Error running serving: {e:?}"))?; - } else if args.contains(&"--servesink".to_string()) { + if args.contains(&"--servesink".to_string()) { servesink::servesink() .await .map_err(|e| format!("Error running servesink: {e:?}"))?; @@ -47,5 +38,5 @@ async fn run() -> Result<(), Box> { .await .map_err(|e| format!("Error running rust binary: {e:?}"))? } - Err("Invalid argument. Use --serving, --servesink, or --rust".into()) + Err("Invalid argument. Use --servesink, or --rust".into()) } diff --git a/rust/serving/Cargo.toml b/rust/serving/Cargo.toml index de2f8bb820..0dbe0818c5 100644 --- a/rust/serving/Cargo.toml +++ b/rust/serving/Cargo.toml @@ -18,6 +18,7 @@ numaflow-models.workspace = true backoff.workspace = true axum.workspace = true axum-server.workspace = true +bytes.workspace = true async-nats = "0.35.1" axum-macros = "0.4.1" hyper-util = { version = "0.1.6", features = ["client-legacy"] } diff --git a/rust/serving/src/app.rs b/rust/serving/src/app.rs index 56d4a33cb3..b8bbc4c76e 100644 --- a/rust/serving/src/app.rs +++ b/rust/serving/src/app.rs @@ -1,9 +1,6 @@ use std::net::SocketAddr; -use std::sync::Arc; use std::time::Duration; -use async_nats::jetstream; -use async_nats::jetstream::Context; use axum::extract::{MatchedPath, State}; use axum::http::StatusCode; use axum::middleware::Next; @@ -25,12 +22,9 @@ use self::{ message_path::get_message_path, }; use crate::app::callback::store::Store; -use crate::app::tracker::MessageGraph; -use crate::config::JetStreamConfig; -use crate::pipeline::PipelineDCG; +use crate::metrics::capture_metrics; +use crate::AppState; use crate::Error::InitError; -use crate::Settings; -use crate::{app::callback::state::State as CallbackState, metrics::capture_metrics}; /// manage callbacks pub(crate) mod callback; @@ -41,7 +35,7 @@ mod jetstream_proxy; /// Return message path in response to UI requests mod message_path; // TODO: merge message_path and tracker mod response; -mod tracker; +pub(crate) mod tracker; /// Everything for numaserve starts here. The routing, middlewares, proxying, etc. // TODO @@ -49,16 +43,18 @@ mod tracker; // - [ ] outer fallback for /v1/direct /// Start the main application Router and the axum server. -pub(crate) async fn start_main_server( - settings: Arc, +pub(crate) async fn start_main_server( + app: AppState, tls_config: RustlsConfig, - pipeline_spec: PipelineDCG, -) -> crate::Result<()> { - let app_addr: SocketAddr = format!("0.0.0.0:{}", &settings.app_listen_port) +) -> crate::Result<()> +where + T: Clone + Send + Sync + Store + 'static, +{ + let app_addr: SocketAddr = format!("0.0.0.0:{}", &app.settings.app_listen_port) .parse() .map_err(|e| InitError(format!("{e:?}")))?; - let tid_header = settings.tid_header.clone(); + let tid_header = app.settings.tid_header.clone(); let layers = ServiceBuilder::new() // Add tracing to all requests .layer( @@ -85,35 +81,19 @@ pub(crate) async fn start_main_server( .layer( // Graceful shutdown will wait for outstanding requests to complete. Add a timeout so // requests don't hang forever. - TimeoutLayer::new(Duration::from_secs(settings.drain_timeout_secs)), + TimeoutLayer::new(Duration::from_secs(app.settings.drain_timeout_secs)), ) // Add auth middleware to all user facing routes .layer(middleware::from_fn_with_state( - settings.api_auth_token.clone(), + app.settings.api_auth_token.clone(), auth_middleware, )); - // Create the message graph from the pipeline spec and the redis store - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).map_err(|e| { - InitError(format!( - "Creating message graph from pipeline spec: {:?}", - e - )) - })?; - - // Create a redis store to store the callbacks and the custom responses - let redis_store = - callback::store::redisstore::RedisConnection::new(settings.redis.clone()).await?; - let state = CallbackState::new(msg_graph, redis_store).await?; - let handle = Handle::new(); // Spawn a task to gracefully shutdown server. tokio::spawn(graceful_shutdown(handle.clone())); - // Create a Jetstream context - let js_context = create_js_context(&settings.jetstream).await?; - - let router = setup_app(settings, js_context, state).await?.layer(layers); + let router = setup_app(app).await?.layer(layers); info!(?app_addr, "Starting application server"); @@ -154,30 +134,6 @@ async fn graceful_shutdown(handle: Handle) { handle.graceful_shutdown(Some(Duration::from_secs(30))); } -async fn create_js_context(js_config: &JetStreamConfig) -> crate::Result { - // Connect to Jetstream with user and password if they are set - let js_client = match js_config.auth.as_ref() { - Some(auth) => { - async_nats::connect_with_options( - &js_config.url, - async_nats::ConnectOptions::with_user_and_password( - auth.username.clone(), - auth.password.clone(), - ), - ) - .await - } - _ => async_nats::connect(&js_config.url).await, - } - .map_err(|e| { - InitError(format!( - "Connecting to jetstream server {}: {}", - &js_config.url, e - )) - })?; - Ok(jetstream::new(js_client)) -} - const PUBLISH_ENDPOINTS: [&str; 3] = [ "/v1/process/sync", "/v1/process/sync_serve", @@ -228,28 +184,14 @@ async fn auth_middleware( } } -#[derive(Clone)] -pub(crate) struct AppState { - pub(crate) settings: Arc, - pub(crate) callback_state: CallbackState, - pub(crate) context: Context, -} - async fn setup_app( - settings: Arc, - context: Context, - callback_state: CallbackState, + app: AppState, ) -> crate::Result { - let app_state = AppState { - settings, - callback_state: callback_state.clone(), - context: context.clone(), - }; let parent = Router::new() .route("/health", get(health_check)) .route("/livez", get(livez)) // Liveliness check .route("/readyz", get(readyz)) - .with_state(app_state.clone()); // Readiness check + .with_state(app.clone()); // Readiness check // a pool based client implementation for direct proxy, this client is cloneable. let client: direct_proxy::Client = @@ -260,9 +202,9 @@ async fn setup_app( let app = parent .nest( "/v1/direct", - direct_proxy(client, app_state.settings.upstream_addr.clone()), + direct_proxy(client, app.settings.upstream_addr.clone()), ) - .nest("/v1/process", routes(app_state).await?); + .nest("/v1/process", routes(app).await?); Ok(app) } @@ -278,13 +220,7 @@ async fn livez() -> impl IntoResponse { async fn readyz( State(app): State>, ) -> impl IntoResponse { - if app.callback_state.clone().ready().await - && app - .context - .get_stream(&app.settings.jetstream.stream) - .await - .is_ok() - { + if app.callback_state.clone().ready().await { StatusCode::NO_CONTENT } else { StatusCode::INTERNAL_SERVER_ERROR @@ -308,47 +244,24 @@ async fn routes( #[cfg(test)] mod tests { + use std::sync::Arc; + use async_nats::jetstream::stream; use axum::http::StatusCode; + use tokio::sync::mpsc; use tokio::time::{sleep, Duration}; use tower::ServiceExt; use super::*; use crate::app::callback::store::memstore::InMemoryStore; use crate::config::generate_certs; + use crate::Settings; const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; type Result = core::result::Result; type Error = Box; - #[tokio::test] - async fn test_start_main_server() -> Result<()> { - let (cert, key) = generate_certs()?; - - let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) - .await - .unwrap(); - - let settings = Arc::new(Settings { - app_listen_port: 0, - ..Settings::default() - }); - - let server = tokio::spawn(async move { - let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let result = start_main_server(settings, tls_config, pipeline_spec).await; - assert!(result.is_ok()) - }); - - // Give the server a little bit of time to start - sleep(Duration::from_millis(50)).await; - - // Stop the server - server.abort(); - Ok(()) - } - #[cfg(feature = "all-tests")] #[tokio::test] async fn test_setup_app() -> Result<()> { diff --git a/rust/serving/src/app/jetstream_proxy.rs b/rust/serving/src/app/jetstream_proxy.rs index af7d3917ff..e5a7a661e1 100644 --- a/rust/serving/src/app/jetstream_proxy.rs +++ b/rust/serving/src/app/jetstream_proxy.rs @@ -1,6 +1,5 @@ -use std::{borrow::Borrow, sync::Arc}; +use std::{collections::HashMap, sync::Arc}; -use async_nats::{jetstream::Context, HeaderMap as JSHeaderMap}; use axum::{ body::Bytes, extract::State, @@ -9,12 +8,13 @@ use axum::{ routing::post, Json, Router, }; +use tokio::sync::{mpsc, oneshot}; use tracing::error; use uuid::Uuid; use super::{callback::store::Store, AppState}; -use crate::app::callback::state; use crate::app::response::{ApiError, ServeResponse}; +use crate::{app::callback::state, Message, MessageWrapper}; // TODO: // - [ ] better health check @@ -37,10 +37,9 @@ const NUMAFLOW_RESP_ARRAY_LEN: &str = "Numaflow-Array-Len"; const NUMAFLOW_RESP_ARRAY_IDX_LEN: &str = "Numaflow-Array-Index-Len"; struct ProxyState { + message: mpsc::Sender, tid_header: String, - context: Context, callback: state::State, - stream: String, callback_url: String, } @@ -48,10 +47,9 @@ pub(crate) async fn jetstream_proxy( state: AppState, ) -> crate::Result { let proxy_state = Arc::new(ProxyState { + message: state.message.clone(), tid_header: state.settings.tid_header.clone(), - context: state.context.clone(), callback: state.callback_state.clone(), - stream: state.settings.jetstream.stream.clone(), callback_url: format!( "https://{}:{}/v1/process/callback", state.settings.host_ip, state.settings.app_listen_port @@ -76,20 +74,30 @@ async fn sync_publish_serve( // Register the ID in the callback proxy state let notify = proxy_state.callback.clone().register(id.clone()); - if let Err(e) = publish_to_jetstream( - proxy_state.stream.clone(), - &proxy_state.callback_url, - headers, - body, - proxy_state.context.clone(), - proxy_state.tid_header.as_str(), - id.as_str(), - ) - .await - { + let mut msg_headers: HashMap = HashMap::new(); + for (key, value) in headers.iter() { + msg_headers.insert( + key.to_string(), + String::from_utf8_lossy(value.as_bytes()).to_string(), + ); + } + + let (tx, rx) = oneshot::channel(); + let message = MessageWrapper { + confirm_save: tx, + message: Message { + value: body, + id: id.clone(), + headers: msg_headers, + }, + }; + + proxy_state.message.send(message).await.unwrap(); // FIXME: + + if let Err(e) = rx.await { // Deregister the ID in the callback proxy state if writing to Jetstream fails let _ = proxy_state.callback.clone().deregister(&id).await; - error!(error = ?e, "Publishing message to Jetstream for sync serve request"); + error!(error = ?e, "Waiting for acknowledgement for message"); return Err(ApiError::BadGateway( "Failed to write message to Jetstream".to_string(), )); @@ -143,21 +151,30 @@ async fn sync_publish( ) -> Result, ApiError> { let id = extract_id_from_headers(&proxy_state.tid_header, &headers); + let mut msg_headers: HashMap = HashMap::new(); + for (key, value) in headers.iter() { + msg_headers.insert( + key.to_string(), + String::from_utf8_lossy(value.as_bytes()).to_string(), + ); + } + + let (tx, rx) = oneshot::channel(); + let message = MessageWrapper { + confirm_save: tx, + message: Message { + value: body, + id: id.clone(), + headers: msg_headers, + }, + }; + // Register the ID in the callback proxy state let notify = proxy_state.callback.clone().register(id.clone()); + proxy_state.message.send(message).await.unwrap(); // FIXME: - if let Err(e) = publish_to_jetstream( - proxy_state.stream.clone(), - &proxy_state.callback_url, - headers, - body, - proxy_state.context.clone(), - &proxy_state.tid_header, - id.as_str(), - ) - .await - { - // Deregister the ID in the callback proxy state if writing to Jetstream fails + if let Err(e) = rx.await { + // Deregister the ID in the callback proxy state if waiting for ack fails let _ = proxy_state.callback.clone().deregister(&id).await; error!(error = ?e, "Publishing message to Jetstream for sync request"); return Err(ApiError::BadGateway( @@ -192,62 +209,40 @@ async fn async_publish( body: Bytes, ) -> Result, ApiError> { let id = extract_id_from_headers(&proxy_state.tid_header, &headers); - let result = publish_to_jetstream( - proxy_state.stream.clone(), - &proxy_state.callback_url, - headers, - body, - proxy_state.context.clone(), - &proxy_state.tid_header, - id.as_str(), - ) - .await; + let mut msg_headers: HashMap = HashMap::new(); + for (key, value) in headers.iter() { + msg_headers.insert( + key.to_string(), + String::from_utf8_lossy(value.as_bytes()).to_string(), + ); + } - match result { + let (tx, rx) = oneshot::channel(); + let message = MessageWrapper { + confirm_save: tx, + message: Message { + value: body, + id: id.clone(), + headers: msg_headers, + }, + }; + + proxy_state.message.send(message).await.unwrap(); // FIXME: + match rx.await { Ok(_) => Ok(Json(ServeResponse::new( "Successfully published message".to_string(), id, StatusCode::OK, ))), Err(e) => { - error!(error = ?e, "Publishing message to Jetstream"); + error!(error = ?e, "Waiting for message save confirmation"); Err(ApiError::InternalServerError( - "Failed to publish message to Jetstream".to_string(), + "Failed to save message".to_string(), )) } } } -/// Write to JetStream and return the metadata. It is responsible for getting the ID from the header. -async fn publish_to_jetstream( - stream: String, - callback_url: &str, - headers: HeaderMap, - body: Bytes, - js_context: Context, - id_header: &str, - id_header_value: &str, -) -> Result<(), async_nats::Error> { - let mut js_headers = JSHeaderMap::new(); - - // pass in the HTTP headers as jetstream headers - for (k, v) in headers.iter() { - js_headers.append(k.as_ref(), String::from_utf8_lossy(v.as_bytes()).borrow()) - } - - js_headers.append(id_header, id_header_value); // Use the passed ID - js_headers.append(CALLBACK_URL_KEY, callback_url); - - js_context - .publish_with_headers(stream, js_headers, body) - .await - .map_err(|e| format!("Publishing message to stream: {e:?}"))? - .await - .map_err(|e| format!("Waiting for acknowledgement of published message: {e:?}"))?; - - Ok(()) -} - // extracts the ID from the headers, if not found, generates a new UUID fn extract_id_from_headers(tid_header: &str, headers: &HeaderMap) -> String { headers.get(tid_header).map_or_else( diff --git a/rust/serving/src/config.rs b/rust/serving/src/config.rs index 7ba3778d00..03fb2beb17 100644 --- a/rust/serving/src/config.rs +++ b/rust/serving/src/config.rs @@ -7,17 +7,17 @@ use base64::Engine; use rcgen::{generate_simple_self_signed, Certificate, CertifiedKey, KeyPair}; use serde::{Deserialize, Serialize}; -use crate::Error::ParseConfig; +use crate::{ + pipeline::PipelineDCG, + Error::{self, ParseConfig}, +}; const ENV_NUMAFLOW_SERVING_SOURCE_OBJECT: &str = "NUMAFLOW_SERVING_SOURCE_OBJECT"; -const ENV_NUMAFLOW_SERVING_JETSTREAM_URL: &str = "NUMAFLOW_ISBSVC_JETSTREAM_URL"; -const ENV_NUMAFLOW_SERVING_JETSTREAM_STREAM: &str = "NUMAFLOW_SERVING_JETSTREAM_STREAM"; const ENV_NUMAFLOW_SERVING_STORE_TTL: &str = "NUMAFLOW_SERVING_STORE_TTL"; const ENV_NUMAFLOW_SERVING_HOST_IP: &str = "NUMAFLOW_SERVING_HOST_IP"; const ENV_NUMAFLOW_SERVING_APP_PORT: &str = "NUMAFLOW_SERVING_APP_LISTEN_PORT"; -const ENV_NUMAFLOW_SERVING_JETSTREAM_USER: &str = "NUMAFLOW_ISBSVC_JETSTREAM_USER"; -const ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD: &str = "NUMAFLOW_ISBSVC_JETSTREAM_PASSWORD"; const ENV_NUMAFLOW_SERVING_AUTH_TOKEN: &str = "NUMAFLOW_SERVING_AUTH_TOKEN"; +const ENV_MIN_PIPELINE_SPEC: &str = "NUMAFLOW_SERVING_MIN_PIPELINE_SPEC"; pub fn generate_certs() -> std::result::Result<(Certificate, KeyPair), String> { let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); @@ -26,46 +26,6 @@ pub fn generate_certs() -> std::result::Result<(Certificate, KeyPair), String> { Ok((cert, key_pair)) } -#[derive(Deserialize, Clone, PartialEq)] -pub struct BasicAuth { - pub username: String, - pub password: String, -} - -impl Debug for BasicAuth { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let passwd_printable = if self.password.len() > 4 { - let passwd: String = self - .password - .chars() - .skip(self.password.len() - 2) - .take(2) - .collect(); - format!("***{}", passwd) - } else { - "*****".to_owned() - }; - write!(f, "{}:{}", self.username, passwd_printable) - } -} - -#[derive(Debug, Deserialize, Clone, PartialEq)] -pub struct JetStreamConfig { - pub stream: String, - pub url: String, - pub auth: Option, -} - -impl Default for JetStreamConfig { - fn default() -> Self { - Self { - stream: "default".to_owned(), - url: "localhost:4222".to_owned(), - auth: None, - } - } -} - #[derive(Debug, Deserialize, Clone, PartialEq)] pub struct RedisConfig { pub addr: String, @@ -95,11 +55,11 @@ pub struct Settings { pub metrics_server_listen_port: u16, pub upstream_addr: String, pub drain_timeout_secs: u64, - pub jetstream: JetStreamConfig, pub redis: RedisConfig, /// The IP address of the numaserve pod. This will be used to construct the value for X-Numaflow-Callback-Url header pub host_ip: String, pub api_auth_token: Option, + pub pipeline_spec: PipelineDCG, } impl Default for Settings { @@ -110,10 +70,10 @@ impl Default for Settings { metrics_server_listen_port: 3001, upstream_addr: "localhost:8888".to_owned(), drain_timeout_secs: 10, - jetstream: JetStreamConfig::default(), redis: RedisConfig::default(), host_ip: "127.0.0.1".to_owned(), api_auth_token: None, + pipeline_spec: Default::default(), } } } @@ -133,7 +93,7 @@ pub struct CallbackStorageConfig { /// This implementation is to load settings from env variables impl TryFrom> for Settings { - type Error = crate::Error; + type Error = Error; fn try_from(env_vars: HashMap) -> std::result::Result { let host_ip = env_vars .get(ENV_NUMAFLOW_SERVING_HOST_IP) @@ -144,19 +104,27 @@ impl TryFrom> for Settings { })? .to_owned(); + let pipeline_spec: PipelineDCG = env_vars + .get(ENV_MIN_PIPELINE_SPEC) + .ok_or_else(|| { + Error::ParseConfig(format!( + "Pipeline spec is not set using environment variable {ENV_MIN_PIPELINE_SPEC}" + )) + })? + .parse() + .map_err(|e| { + Error::ParseConfig(format!( + "Parsing pipeline spec: {}: error={e:?}", + env_vars.get(ENV_MIN_PIPELINE_SPEC).unwrap() + )) + })?; + let mut settings = Settings { host_ip, + pipeline_spec, ..Default::default() }; - if let Some(jetstream_url) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_URL) { - settings.jetstream.url = jetstream_url.to_owned(); - } - - if let Some(jetstream_stream) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_STREAM) { - settings.jetstream.stream = jetstream_stream.to_owned(); - } - if let Some(api_auth_token) = env_vars.get(ENV_NUMAFLOW_SERVING_AUTH_TOKEN) { settings.api_auth_token = Some(api_auth_token.to_owned()); } @@ -169,17 +137,6 @@ impl TryFrom> for Settings { })?; } - // If username is set, the password also must be set - if let Some(username) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_USER) { - let Some(password) = env_vars.get(ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD) else { - return Err(ParseConfig(format!("Env variable {ENV_NUMAFLOW_SERVING_JETSTREAM_USER} is set, but {ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD} is not set"))); - }; - settings.jetstream.auth = Some(BasicAuth { - username: username.to_owned(), - password: password.to_owned(), - }); - } - // Update redis.ttl_secs from environment variable if let Some(ttl_secs) = env_vars.get(ENV_NUMAFLOW_SERVING_STORE_TTL) { let ttl_secs: u32 = ttl_secs.parse().map_err(|e| { @@ -213,17 +170,9 @@ impl TryFrom> for Settings { #[cfg(test)] mod tests { - use super::*; + use crate::pipeline::{Edge, Vertex}; - #[test] - fn test_basic_auth_debug_print() { - let auth = BasicAuth { - username: "js-auth-user".into(), - password: "js-auth-password".into(), - }; - let auth_debug = format!("{auth:?}"); - assert_eq!(auth_debug, "js-auth-user:***rd"); - } + use super::*; #[test] fn test_default_config() { @@ -234,8 +183,6 @@ mod tests { assert_eq!(settings.metrics_server_listen_port, 3001); assert_eq!(settings.upstream_addr, "localhost:8888"); assert_eq!(settings.drain_timeout_secs, 10); - assert_eq!(settings.jetstream.stream, "default"); - assert_eq!(settings.jetstream.url, "localhost:4222"); assert_eq!(settings.redis.addr, "redis://127.0.0.1:6379"); assert_eq!(settings.redis.max_tasks, 50); assert_eq!(settings.redis.retries, 5); @@ -246,21 +193,12 @@ mod tests { fn test_config_parse() { // Set up the environment variables let env_vars = [ - ( - ENV_NUMAFLOW_SERVING_JETSTREAM_URL, - "nats://isbsvc-default-js-svc.default.svc:4222", - ), - ( - ENV_NUMAFLOW_SERVING_JETSTREAM_STREAM, - "ascii-art-pipeline-in-serving-source", - ), - (ENV_NUMAFLOW_SERVING_JETSTREAM_USER, "js-auth-user"), - (ENV_NUMAFLOW_SERVING_JETSTREAM_PASSWORD, "js-user-password"), (ENV_NUMAFLOW_SERVING_HOST_IP, "10.2.3.5"), (ENV_NUMAFLOW_SERVING_AUTH_TOKEN, "api-auth-token"), (ENV_NUMAFLOW_SERVING_APP_PORT, "8443"), (ENV_NUMAFLOW_SERVING_STORE_TTL, "86400"), - (ENV_NUMAFLOW_SERVING_SOURCE_OBJECT, "eyJhdXRoIjpudWxsLCJzZXJ2aWNlIjp0cnVlLCJtc2dJREhlYWRlcktleSI6IlgtTnVtYWZsb3ctSWQiLCJzdG9yZSI6eyJ1cmwiOiJyZWRpczovL3JlZGlzOjYzNzkifX0=") + (ENV_NUMAFLOW_SERVING_SOURCE_OBJECT, "eyJhdXRoIjpudWxsLCJzZXJ2aWNlIjp0cnVlLCJtc2dJREhlYWRlcktleSI6IlgtTnVtYWZsb3ctSWQiLCJzdG9yZSI6eyJ1cmwiOiJyZWRpczovL3JlZGlzOjYzNzkifX0="), + (ENV_MIN_PIPELINE_SPEC, "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6InNlcnZpbmctaW4iLCJzb3VyY2UiOnsic2VydmluZyI6eyJhdXRoIjpudWxsLCJzZXJ2aWNlIjp0cnVlLCJtc2dJREhlYWRlcktleSI6IlgtTnVtYWZsb3ctSWQiLCJzdG9yZSI6eyJ1cmwiOiJyZWRpczovL3JlZGlzOjYzNzkifX19LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciIsImVudiI6W3sibmFtZSI6IlJVU1RfTE9HIiwidmFsdWUiOiJpbmZvIn1dfSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2aW5nLXNpbmsiLCJzaW5rIjp7InVkc2luayI6eyJjb250YWluZXIiOnsiaW1hZ2UiOiJxdWF5LmlvL251bWFpby9udW1hZmxvdy1ycy9zaW5rLWxvZzpzdGFibGUiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9fX0sInJldHJ5U3RyYXRlZ3kiOnt9fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fV0sImVkZ2VzIjpbeyJmcm9tIjoic2VydmluZy1pbiIsInRvIjoic2VydmluZy1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH1dLCJsaWZlY3ljbGUiOnt9LCJ3YXRlcm1hcmsiOnt9fQ==") ]; // Call the config method @@ -277,14 +215,6 @@ mod tests { metrics_server_listen_port: 3001, upstream_addr: "localhost:8888".into(), drain_timeout_secs: 10, - jetstream: JetStreamConfig { - stream: "ascii-art-pipeline-in-serving-source".into(), - url: "nats://isbsvc-default-js-svc.default.svc:4222".into(), - auth: Some(BasicAuth { - username: "js-auth-user".into(), - password: "js-user-password".into(), - }), - }, redis: RedisConfig { addr: "redis://redis:6379".into(), max_tasks: 50, @@ -294,8 +224,22 @@ mod tests { }, host_ip: "10.2.3.5".into(), api_auth_token: Some("api-auth-token".into()), + pipeline_spec: PipelineDCG { + vertices: vec![ + Vertex { + name: "serving-in".into(), + }, + Vertex { + name: "serving-sink".into(), + }, + ], + edges: vec![Edge { + from: "serving-in".into(), + to: "serving-sink".into(), + conditions: None, + }], + }, }; - assert_eq!(settings, expected_config); } } diff --git a/rust/serving/src/error.rs b/rust/serving/src/error.rs index d53509c939..cfa252daad 100644 --- a/rust/serving/src/error.rs +++ b/rust/serving/src/error.rs @@ -1,4 +1,5 @@ use thiserror::Error; +use tokio::sync::oneshot; // TODO: introduce module level error handling @@ -44,6 +45,9 @@ pub enum Error { #[error("Init Error - {0}")] InitError(String), + #[error("Failed to receive message from channel. Actor task is terminated: {0:?}")] + ActorTaskTerminated(oneshot::error::RecvError), + #[error("Other Error - {0}")] // catch-all variant for now Other(String), diff --git a/rust/serving/src/lib.rs b/rust/serving/src/lib.rs index 796313bdb2..e275132386 100644 --- a/rust/serving/src/lib.rs +++ b/rust/serving/src/lib.rs @@ -1,12 +1,13 @@ -use std::env; use std::net::SocketAddr; use std::sync::Arc; +use crate::app::callback::state::State as CallbackState; +use app::callback::store::Store; use axum_server::tls_rustls::RustlsConfig; +use tokio::sync::mpsc; use tracing::info; pub use self::error::{Error, Result}; -use self::pipeline::PipelineDCG; use crate::app::start_main_server; use crate::config::generate_certs; use crate::metrics::start_https_metrics_server; @@ -21,41 +22,39 @@ mod error; mod metrics; mod pipeline; -const ENV_MIN_PIPELINE_SPEC: &str = "NUMAFLOW_SERVING_MIN_PIPELINE_SPEC"; +pub mod source; +pub use source::{Message, MessageWrapper, ServingSource}; -pub async fn serve( - settings: Arc, -) -> std::result::Result<(), Box> { +#[derive(Clone)] +pub(crate) struct AppState { + pub message: mpsc::Sender, + pub settings: Arc, + pub callback_state: CallbackState, +} + +pub(crate) async fn serve( + app: AppState, +) -> std::result::Result<(), Box> +where + T: Clone + Send + Sync + Store + 'static, +{ let (cert, key) = generate_certs()?; let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) .await .map_err(|e| format!("Failed to create tls config {:?}", e))?; - // TODO: Move all env variables into one place. Some env variables are loaded when Settings is initialized - let pipeline_spec: PipelineDCG = env::var(ENV_MIN_PIPELINE_SPEC) - .map_err(|_| { - format!("Pipeline spec is not set using environment variable {ENV_MIN_PIPELINE_SPEC}") - })? - .parse() - .map_err(|e| { - format!( - "Parsing pipeline spec: {}: error={e:?}", - env::var(ENV_MIN_PIPELINE_SPEC).unwrap() - ) - })?; - - info!(config = ?settings, ?pipeline_spec, "Starting server with config and pipeline spec"); + info!(config = ?app.settings, "Starting server with config and pipeline spec"); // Start the metrics server, which serves the prometheus metrics. let metrics_addr: SocketAddr = - format!("0.0.0.0:{}", &settings.metrics_server_listen_port).parse()?; + format!("0.0.0.0:{}", &app.settings.metrics_server_listen_port).parse()?; let metrics_server_handle = tokio::spawn(start_https_metrics_server(metrics_addr, tls_config.clone())); // Start the main server, which serves the application. - let app_server_handle = tokio::spawn(start_main_server(settings, tls_config, pipeline_spec)); + let app_server_handle = tokio::spawn(start_main_server(app, tls_config)); // TODO: is try_join the best? we need to short-circuit at the first failure tokio::try_join!(flatten(app_server_handle), flatten(metrics_server_handle))?; diff --git a/rust/serving/src/pipeline.rs b/rust/serving/src/pipeline.rs index d782e3d73a..2adc169433 100644 --- a/rust/serving/src/pipeline.rs +++ b/rust/serving/src/pipeline.rs @@ -42,22 +42,20 @@ impl From for OperatorType { } // Tag is a struct that contains the information about the tags for the edge -#[cfg_attr(test, derive(PartialEq))] -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Tag { pub operator: Option, pub values: Vec, } // Conditions is a struct that contains the information about the conditions for the edge -#[cfg_attr(test, derive(PartialEq))] -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Conditions { pub tags: Option, } // Edge is a struct that contains the information about the edge in the pipeline. -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Edge { pub from: String, pub to: String, @@ -66,14 +64,13 @@ pub struct Edge { /// DCG (directed compute graph) of the pipeline with minimal information build using vertices and edges /// from the pipeline spec -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde()] +#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq)] pub struct PipelineDCG { pub vertices: Vec, pub edges: Vec, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Vertex { pub name: String, } diff --git a/rust/serving/src/source.rs b/rust/serving/src/source.rs new file mode 100644 index 0000000000..b3f4c2e1ab --- /dev/null +++ b/rust/serving/src/source.rs @@ -0,0 +1,194 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use bytes::Bytes; +use tokio::sync::{mpsc, oneshot}; +use tokio::time::Instant; + +use crate::app::callback::state::State as CallbackState; +use crate::app::callback::store::redisstore::RedisConnection; +use crate::app::tracker::MessageGraph; +use crate::Settings; +use crate::{Error, Result}; + +pub struct MessageWrapper { + pub confirm_save: oneshot::Sender<()>, + pub message: Message, +} + +pub struct Message { + pub value: Bytes, + pub id: String, + pub headers: HashMap, +} + +enum ActorMessage { + Read { + batch_size: usize, + timeout_at: Instant, + reply_to: oneshot::Sender>, + }, + Ack { + offsets: Vec, + reply_to: oneshot::Sender<()>, + }, +} + +struct ServingSourceActor { + messages: mpsc::Receiver, + handler_rx: mpsc::Receiver, + tracker: HashMap>, +} + +impl ServingSourceActor { + async fn start( + settings: Arc, + handler_rx: mpsc::Receiver, + ) -> Result<()> { + let (messages_tx, messages_rx) = mpsc::channel(10000); + // Create a redis store to store the callbacks and the custom responses + let redis_store = RedisConnection::new(settings.redis.clone()).await?; + // Create the message graph from the pipeline spec and the redis store + let msg_graph = MessageGraph::from_pipeline(&settings.pipeline_spec).map_err(|e| { + Error::InitError(format!( + "Creating message graph from pipeline spec: {:?}", + e + )) + })?; + let callback_state = CallbackState::new(msg_graph, redis_store).await?; + + tokio::spawn(async move { + let mut serving_actor = ServingSourceActor { + messages: messages_rx, + handler_rx, + tracker: HashMap::new(), + }; + serving_actor.run().await; + }); + let app = crate::AppState { + message: messages_tx, + settings, + callback_state, + }; + crate::serve(app).await.unwrap(); + Ok(()) + } + + async fn run(&mut self) { + while let Some(msg) = self.handler_rx.recv().await { + self.handle_message(msg).await; + } + } + + async fn handle_message(&mut self, actor_msg: ActorMessage) { + match actor_msg { + ActorMessage::Read { + batch_size, + timeout_at, + reply_to, + } => { + let messages = self.read(batch_size, timeout_at).await; + let _ = reply_to.send(messages); + } + ActorMessage::Ack { offsets, reply_to } => { + self.ack(offsets).await; + let _ = reply_to.send(()); + } + } + } + + async fn read(&mut self, count: usize, timeout_at: Instant) -> Vec { + let mut messages = vec![]; + loop { + if messages.len() >= count || Instant::now() >= timeout_at { + break; + } + let message = match self.messages.try_recv() { + Ok(msg) => msg, + Err(mpsc::error::TryRecvError::Empty) => break, + Err(e) => { + tracing::error!(?e, "Receiving messages from the serving channel"); // FIXME: + return messages; + } + }; + let MessageWrapper { + confirm_save, + message, + } = message; + + self.tracker.insert(message.id.clone(), confirm_save); + messages.push(message); + } + messages + } + + async fn ack(&mut self, offsets: Vec) { + for offset in offsets { + let offset = offset + .strip_suffix("-0") + .expect("offset does not end with '-0'"); // FIXME: we hardcode 0 as the partition index when constructing offset + let confirm_save_tx = self + .tracker + .remove(offset) + .expect("offset was not found in the tracker"); + confirm_save_tx + .send(()) + .expect("Sending on confirm_save channel"); + } + } +} + +#[derive(Clone)] +pub struct ServingSource { + batch_size: usize, + // timeout for each batch read request + timeout: Duration, + actor_tx: mpsc::Sender, +} + +impl ServingSource { + pub async fn new( + settings: Arc, + batch_size: usize, + timeout: Duration, + ) -> Result { + let (actor_tx, actor_rx) = mpsc::channel(10); + ServingSourceActor::start(settings, actor_rx).await?; + Ok(Self { + batch_size, + timeout, + actor_tx, + }) + } + + pub async fn read_messages(&self) -> Result> { + let start = Instant::now(); + let (tx, rx) = oneshot::channel(); + let actor_msg = ActorMessage::Read { + reply_to: tx, + batch_size: self.batch_size, + timeout_at: Instant::now() + self.timeout, + }; + let _ = self.actor_tx.send(actor_msg).await; + let messages = rx.await.map_err(Error::ActorTaskTerminated)?; + tracing::debug!( + count = messages.len(), + requested_count = self.batch_size, + time_taken_ms = start.elapsed().as_millis(), + "Got messages from Serving source" + ); + Ok(messages) + } + + pub async fn ack_messages(&self, offsets: Vec) -> Result<()> { + let (tx, rx) = oneshot::channel(); + let actor_msg = ActorMessage::Ack { + offsets, + reply_to: tx, + }; + let _ = self.actor_tx.send(actor_msg).await; + rx.await.map_err(Error::ActorTaskTerminated)?; + Ok(()) + } +} From f5486092529a2b5661a0c3939b040c5d3a849bcb Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 24 Dec 2024 15:50:52 +0530 Subject: [PATCH 07/15] Remove duplicate code, make serving a built-in source Signed-off-by: Sreekanth --- rust/Cargo.lock | 61 +- rust/Cargo.toml | 2 - rust/extns/numaflow-serving/Cargo.toml | 27 - rust/extns/numaflow-serving/src/app.rs | 276 ------ .../numaflow-serving/src/app/callback.rs | 219 ---- .../src/app/callback/state.rs | 378 ------- .../src/app/callback/store.rs | 35 - .../src/app/callback/store/memstore.rs | 217 ---- .../src/app/callback/store/redisstore.rs | 198 ---- .../numaflow-serving/src/app/response.rs | 60 -- .../extns/numaflow-serving/src/app/tracker.rs | 933 ------------------ rust/extns/numaflow-serving/src/config.rs | 235 ----- rust/extns/numaflow-serving/src/errors.rs | 50 - rust/extns/numaflow-serving/src/lib.rs | 12 - rust/extns/numaflow-serving/src/pipeline.rs | 154 --- rust/extns/numaflow-serving/src/source.rs | 194 ---- rust/numaflow-core/Cargo.toml | 1 - rust/numaflow-core/src/config/components.rs | 4 +- rust/numaflow-core/src/metrics.rs | 2 - .../src/shared/create_components.rs | 2 +- rust/numaflow-core/src/source.rs | 2 +- rust/numaflow-core/src/source/serving.rs | 10 +- rust/serving/Cargo.toml | 1 - rust/serving/src/app.rs | 132 +-- rust/serving/src/app/jetstream_proxy.rs | 104 +- rust/serving/src/config.rs | 2 - 26 files changed, 84 insertions(+), 3227 deletions(-) delete mode 100644 rust/extns/numaflow-serving/Cargo.toml delete mode 100644 rust/extns/numaflow-serving/src/app.rs delete mode 100644 rust/extns/numaflow-serving/src/app/callback.rs delete mode 100644 rust/extns/numaflow-serving/src/app/callback/state.rs delete mode 100644 rust/extns/numaflow-serving/src/app/callback/store.rs delete mode 100644 rust/extns/numaflow-serving/src/app/callback/store/memstore.rs delete mode 100644 rust/extns/numaflow-serving/src/app/callback/store/redisstore.rs delete mode 100644 rust/extns/numaflow-serving/src/app/response.rs delete mode 100644 rust/extns/numaflow-serving/src/app/tracker.rs delete mode 100644 rust/extns/numaflow-serving/src/config.rs delete mode 100644 rust/extns/numaflow-serving/src/errors.rs delete mode 100644 rust/extns/numaflow-serving/src/lib.rs delete mode 100644 rust/extns/numaflow-serving/src/pipeline.rs delete mode 100644 rust/extns/numaflow-serving/src/source.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 21a5f97246..8c9e480319 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -53,39 +53,6 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" -[[package]] -name = "async-nats" -version = "0.35.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab8df97cb8fc4a884af29ab383e9292ea0939cfcdd7d2a17179086dc6c427e7f" -dependencies = [ - "base64 0.22.1", - "bytes", - "futures", - "memchr", - "nkeys", - "nuid", - "once_cell", - "portable-atomic", - "rand", - "regex", - "ring", - "rustls-native-certs 0.7.3", - "rustls-pemfile 2.2.0", - "rustls-webpki 0.102.8", - "serde", - "serde_json", - "serde_nanos", - "serde_repr", - "thiserror 1.0.69", - "time", - "tokio", - "tokio-rustls 0.26.0", - "tracing", - "tryhard", - "url", -] - [[package]] name = "async-nats" version = "0.38.0" @@ -1748,7 +1715,7 @@ dependencies = [ name = "numaflow-core" version = "0.1.0" dependencies = [ - "async-nats 0.38.0", + "async-nats", "axum", "axum-server", "backoff", @@ -1762,7 +1729,6 @@ dependencies = [ "numaflow-models", "numaflow-pb", "numaflow-pulsar", - "numaflow-serving", "parking_lot", "pep440_rs", "pin-project", @@ -1828,30 +1794,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "numaflow-serving" -version = "0.1.0" -dependencies = [ - "axum", - "axum-server", - "backoff", - "base64 0.22.1", - "bytes", - "chrono", - "numaflow-models", - "rcgen", - "redis", - "serde", - "serde_json", - "thiserror 2.0.8", - "tokio", - "tower 0.4.13", - "tower-http", - "tracing", - "trait-variant", - "uuid", -] - [[package]] name = "object" version = "0.36.5" @@ -2885,7 +2827,6 @@ dependencies = [ name = "serving" version = "0.1.0" dependencies = [ - "async-nats 0.35.1", "axum", "axum-macros", "axum-server", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index b986f29316..db7deddb61 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -9,7 +9,6 @@ members = [ "numaflow-pb", "extns/numaflow-pulsar", "numaflow", - "extns/numaflow-serving", ] [workspace.lints.rust] @@ -60,7 +59,6 @@ numaflow-models = { path = "numaflow-models" } backoff = { path = "backoff" } numaflow-pb = { path = "numaflow-pb" } numaflow-pulsar = {path = "extns/numaflow-pulsar"} -numaflow-serving = {path = "extns/numaflow-serving"} tokio = "1.41.1" bytes = "1.7.1" tracing = "0.1.40" diff --git a/rust/extns/numaflow-serving/Cargo.toml b/rust/extns/numaflow-serving/Cargo.toml deleted file mode 100644 index 9a2abf9959..0000000000 --- a/rust/extns/numaflow-serving/Cargo.toml +++ /dev/null @@ -1,27 +0,0 @@ -[package] -name = "numaflow-serving" -version = "0.1.0" -edition = "2021" - -[lints] -workspace = true - -[dependencies] -axum.workspace = true -axum-server.workspace = true -tokio.workspace = true -bytes.workspace = true -tracing.workspace = true -serde = { workspace = true } -numaflow-models.workspace = true -backoff.workspace = true -rcgen = "0.13.1" -tower = "0.4.13" -tower-http = { version = "0.5.2", features = ["trace", "timeout"] } -uuid = { version = "1.10.0", features = ["v4"] } -thiserror = "2.0.8" -base64 = "0.22.1" -serde_json = "1.0.120" -trait-variant = "0.1.2" -redis = { version = "0.26.0", features = ["tokio-comp", "aio", "connection-manager"] } -chrono = { version = "0.4", features = ["serde"] } diff --git a/rust/extns/numaflow-serving/src/app.rs b/rust/extns/numaflow-serving/src/app.rs deleted file mode 100644 index 1d491bdace..0000000000 --- a/rust/extns/numaflow-serving/src/app.rs +++ /dev/null @@ -1,276 +0,0 @@ -use std::collections::HashMap; -use std::net::SocketAddr; -use std::sync::Arc; -use std::time::Duration; - -use axum::body::Body; -use axum::extract::{MatchedPath, Request, State}; -use axum::http::{HeaderMap, StatusCode}; -use axum::middleware::{self, Next}; -use axum::response::{IntoResponse, Response}; -use axum::routing::{get, post}; -use axum::{Json, Router}; -use axum_server::tls_rustls::RustlsConfig; -use axum_server::Handle; -use bytes::Bytes; -use rcgen::{generate_simple_self_signed, Certificate, CertifiedKey, KeyPair}; -use tokio::sync::{mpsc, oneshot}; -use tower::ServiceBuilder; -use tower_http::timeout::TimeoutLayer; -use tower_http::trace::{DefaultOnResponse, TraceLayer}; -use uuid::Uuid; - -use crate::{Error, Message, MessageWrapper, Settings}; -/// -/// manage callbacks -pub(crate) mod callback; - -mod response; - -use crate::app::callback::state::State as CallbackState; - -pub(crate) mod tracker; - -use self::callback::store::Store; -use self::response::{ApiError, ServeResponse}; - -fn generate_certs() -> crate::Result<(Certificate, KeyPair)> { - let CertifiedKey { cert, key_pair } = generate_simple_self_signed(vec!["localhost".into()]) - .map_err(|e| Error::InitError(format!("Failed to generate cert {:?}", e)))?; - Ok((cert, key_pair)) -} - -pub(crate) async fn serve(app: AppState) -> crate::Result<()> -where - T: Clone + Send + Sync + Store + 'static, -{ - let (cert, key) = generate_certs()?; - - let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) - .await - .map_err(|e| Error::InitError(format!("Failed to create tls config {:?}", e)))?; - - // TODO: Move all env variables into one place. Some env variables are loaded when Settings is initialized - - tracing::info!(config = ?app.settings, "Starting server with config and pipeline spec"); - // Start the main server, which serves the application. - tokio::spawn(start_main_server(app, tls_config)); - - Ok(()) -} - -#[derive(Clone)] -pub struct AppState { - pub message: mpsc::Sender, - pub settings: Arc, - pub callback_state: CallbackState, -} - -const PUBLISH_ENDPOINTS: [&str; 3] = [ - "/v1/process/sync", - "/v1/process/sync_serve", - "/v1/process/async", -]; -// auth middleware to do token based authentication for all user facing routes -// if auth is enabled. -async fn auth_middleware( - State(api_auth_token): State>, - request: axum::extract::Request, - next: Next, -) -> Response { - let path = request.uri().path(); - - // we only need to check for the presence of the auth token in the request headers for the publish endpoints - if !PUBLISH_ENDPOINTS.contains(&path) { - return next.run(request).await; - } - - match api_auth_token { - Some(token) => { - // Check for the presence of the auth token in the request headers - let auth_token = match request.headers().get("Authorization") { - Some(token) => token, - None => { - return Response::builder() - .status(401) - .body(Body::empty()) - .expect("failed to build response") - } - }; - if auth_token.to_str().expect("auth token should be a string") - != format!("Bearer {}", token) - { - Response::builder() - .status(401) - .body(Body::empty()) - .expect("failed to build response") - } else { - next.run(request).await - } - } - None => { - // If the auth token is not set, we don't need to check for the presence of the auth token in the request headers - next.run(request).await - } - } -} - -async fn start_main_server(app: AppState, tls_config: RustlsConfig) -> crate::Result<()> -where - T: Clone + Send + Sync + Store + 'static, -{ - let app_addr: SocketAddr = format!("0.0.0.0:{}", &app.settings.app_listen_port) - .parse() - .map_err(|e| Error::InitError(format!("{e:?}")))?; - - let tid_header = app.settings.tid_header.clone(); - let layers = ServiceBuilder::new() - // Add tracing to all requests - .layer( - TraceLayer::new_for_http() - .make_span_with(move |req: &Request| { - let tid = req - .headers() - .get(&tid_header) - .and_then(|v| v.to_str().ok()) - .map(|v| v.to_string()) - .unwrap_or_else(|| Uuid::new_v4().to_string()); - - let matched_path = req - .extensions() - .get::() - .map(MatchedPath::as_str); - - tracing::info_span!("request", tid, method=?req.method(), matched_path) - }) - .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)), - ) - .layer( - // Graceful shutdown will wait for outstanding requests to complete. Add a timeout so - // requests don't hang forever. - TimeoutLayer::new(Duration::from_secs(app.settings.drain_timeout_secs)), - ) - // Add auth middleware to all user facing routes - .layer(middleware::from_fn_with_state( - app.settings.api_auth_token.clone(), - auth_middleware, - )); - - let handle = Handle::new(); - - let router = setup_app(app).await.layer(layers); - - tracing::info!(?app_addr, "Starting application server"); - - axum_server::bind_rustls(app_addr, tls_config) - .handle(handle) - .serve(router.into_make_service()) - .await - .map_err(|e| Error::InitError(format!("Starting web server for metrics: {}", e)))?; - - Ok(()) -} - -async fn setup_app(state: AppState) -> Router { - let router = Router::new() - .route("/health", get(health_check)) - .route("/livez", get(livez)) // Liveliness check - .route("/readyz", get(readyz)) - .with_state(state.clone()); - - let router = router.nest("/v1/process", routes(state.clone())); - router -} - -async fn health_check() -> impl IntoResponse { - "ok" -} - -async fn livez() -> impl IntoResponse { - StatusCode::NO_CONTENT -} - -async fn readyz( - State(app): State>, -) -> impl IntoResponse { - if app.callback_state.clone().ready().await { - StatusCode::NO_CONTENT - } else { - StatusCode::INTERNAL_SERVER_ERROR - } -} - -// extracts the ID from the headers, if not found, generates a new UUID -fn extract_id_from_headers(tid_header: &str, headers: &HeaderMap) -> String { - headers.get(tid_header).map_or_else( - || Uuid::new_v4().to_string(), - |v| String::from_utf8_lossy(v.as_bytes()).to_string(), - ) -} - -fn routes(app_state: AppState) -> Router { - let jetstream_proxy = jetstream_proxy(app_state); - jetstream_proxy -} - -const CALLBACK_URL_KEY: &str = "X-Numaflow-Callback-Url"; -const NUMAFLOW_RESP_ARRAY_LEN: &str = "Numaflow-Array-Len"; -const NUMAFLOW_RESP_ARRAY_IDX_LEN: &str = "Numaflow-Array-Index-Len"; - -struct ProxyState { - tid_header: String, - callback: CallbackState, - callback_url: String, - messages: mpsc::Sender, -} - -pub(crate) fn jetstream_proxy( - state: AppState, -) -> Router { - let proxy_state = Arc::new(ProxyState { - tid_header: state.settings.tid_header.clone(), - callback: state.callback_state.clone(), - messages: state.message.clone(), - callback_url: format!( - "https://{}:{}/v1/process/callback", - state.settings.host_ip, state.settings.app_listen_port - ), - }); - - let router = Router::new() - .route("/async", post(async_publish)) - .with_state(proxy_state); - router -} - -async fn async_publish( - State(proxy_state): State>>, - headers: HeaderMap, - body: Bytes, -) -> Result, ApiError> { - let id = extract_id_from_headers(&proxy_state.tid_header, &headers); - let mut msg_headers: HashMap = HashMap::new(); - for (k, v) in headers.iter() { - msg_headers.insert( - k.to_string(), - String::from_utf8_lossy(v.as_bytes()).to_string(), - ); - } - let (tx, rx) = oneshot::channel(); - let message = MessageWrapper { - confirm_save: tx, - message: Message { - value: body, - id: id.clone(), - headers: msg_headers, - }, - }; - proxy_state.messages.send(message).await.unwrap(); - rx.await.unwrap(); - - Ok(Json(ServeResponse::new( - "Successfully published message".to_string(), - id, - StatusCode::OK, - ))) -} diff --git a/rust/extns/numaflow-serving/src/app/callback.rs b/rust/extns/numaflow-serving/src/app/callback.rs deleted file mode 100644 index b4d43868ee..0000000000 --- a/rust/extns/numaflow-serving/src/app/callback.rs +++ /dev/null @@ -1,219 +0,0 @@ -use axum::{body::Bytes, extract::State, http::HeaderMap, routing, Json, Router}; -use serde::{Deserialize, Serialize}; -use tracing::error; - -use self::store::Store; -use crate::app::response::ApiError; - -/// in-memory state store including connection tracking -pub(crate) mod state; -use state::State as CallbackState; - -/// store for storing the state -pub(crate) mod store; - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub(crate) struct CallbackRequest { - pub(crate) id: String, - pub(crate) vertex: String, - pub(crate) cb_time: u64, - pub(crate) from_vertex: String, - pub(crate) tags: Option>, -} - -#[derive(Clone)] -struct CallbackAppState { - tid_header: String, - callback_state: CallbackState, -} - -pub fn callback_handler( - tid_header: String, - callback_state: CallbackState, -) -> Router { - let app_state = CallbackAppState { - tid_header, - callback_state, - }; - Router::new() - .route("/callback", routing::post(callback)) - .route("/callback_save", routing::post(callback_save)) - .with_state(app_state) -} - -async fn callback_save( - State(app_state): State>, - headers: HeaderMap, - body: Bytes, -) -> Result<(), ApiError> { - let id = headers - .get(&app_state.tid_header) - .map(|id| String::from_utf8_lossy(id.as_bytes()).to_string()) - .ok_or_else(|| ApiError::BadRequest("Missing id header".to_string()))?; - - app_state - .callback_state - .clone() - .save_response(id, body) - .await - .map_err(|e| { - error!(error=?e, "Saving body from callback save request"); - ApiError::InternalServerError( - "Failed to save body from callback save request".to_string(), - ) - })?; - - Ok(()) -} - -async fn callback( - State(app_state): State>, - Json(payload): Json>, -) -> Result<(), ApiError> { - app_state - .callback_state - .clone() - .insert_callback_requests(payload) - .await - .map_err(|e| { - error!(error=?e, "Inserting callback requests"); - ApiError::InternalServerError("Failed to insert callback requests".to_string()) - })?; - - Ok(()) -} - -#[cfg(test)] -mod tests { - use axum::body::Body; - use axum::extract::Request; - use axum::http::header::CONTENT_TYPE; - use axum::http::StatusCode; - use tower::ServiceExt; - - use super::*; - use crate::app::callback::state::State as CallbackState; - use crate::app::callback::store::memstore::InMemoryStore; - use crate::app::tracker::MessageGraph; - use crate::pipeline::PipelineDCG; - - const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; - - #[tokio::test] - async fn test_callback_failure() { - let store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let state = CallbackState::new(msg_graph, store).await.unwrap(); - let app = callback_handler("ID".to_owned(), state); - - let payload = vec![CallbackRequest { - id: "test_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: None, - }]; - - let res = Request::builder() - .method("POST") - .uri("/callback") - .header(CONTENT_TYPE, "application/json") - .body(Body::from(serde_json::to_vec(&payload).unwrap())) - .unwrap(); - - let resp = app.oneshot(res).await.unwrap(); - // id is not registered, so it should return 500 - assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); - } - - #[tokio::test] - async fn test_callback_success() { - let store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let mut state = CallbackState::new(msg_graph, store).await.unwrap(); - - let x = state.register("test_id".to_string()); - // spawn a task which will be awaited later - let handle = tokio::spawn(async move { - let _ = x.await.unwrap(); - }); - - let app = callback_handler("ID".to_owned(), state); - - let payload = vec![ - CallbackRequest { - id: "test_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: None, - }, - CallbackRequest { - id: "test_id".to_string(), - vertex: "cat".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: None, - }, - CallbackRequest { - id: "test_id".to_string(), - vertex: "out".to_string(), - cb_time: 12345, - from_vertex: "cat".to_string(), - tags: None, - }, - ]; - - let res = Request::builder() - .method("POST") - .uri("/callback") - .header(CONTENT_TYPE, "application/json") - .body(Body::from(serde_json::to_vec(&payload).unwrap())) - .unwrap(); - - let resp = app.oneshot(res).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - - handle.await.unwrap(); - } - - #[tokio::test] - async fn test_callback_save() { - let store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let state = CallbackState::new(msg_graph, store).await.unwrap(); - let app = callback_handler("ID".to_owned(), state); - - let res = Request::builder() - .method("POST") - .uri("/callback_save") - .header(CONTENT_TYPE, "application/json") - .header("id", "test_id") - .body(Body::from("test_body")) - .unwrap(); - - let resp = app.oneshot(res).await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - } - - #[tokio::test] - async fn test_without_id_header() { - let store = InMemoryStore::new(); - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let state = CallbackState::new(msg_graph, store).await.unwrap(); - let app = callback_handler("ID".to_owned(), state); - - let res = Request::builder() - .method("POST") - .uri("/callback_save") - .body(Body::from("test_body")) - .unwrap(); - - let resp = app.oneshot(res).await.unwrap(); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - } -} diff --git a/rust/extns/numaflow-serving/src/app/callback/state.rs b/rust/extns/numaflow-serving/src/app/callback/state.rs deleted file mode 100644 index 293478ead2..0000000000 --- a/rust/extns/numaflow-serving/src/app/callback/state.rs +++ /dev/null @@ -1,378 +0,0 @@ -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, -}; - -use tokio::sync::oneshot; - -use super::store::Store; -use crate::app::callback::{store::PayloadToSave, CallbackRequest}; -use crate::app::tracker::MessageGraph; -use crate::Error; - -struct RequestState { - // Channel to notify when all callbacks for a message is received - tx: oneshot::Sender>, - // CallbackRequest is immutable, while vtx_visited can grow. - vtx_visited: Vec>, -} - -#[derive(Clone)] -pub(crate) struct State { - // hashmap of vertex infos keyed by ID - // it also contains tx to trigger to response to the syncHTTP call - callbacks: Arc>>, - // generator to generate subgraph - msg_graph_generator: Arc, - // conn is to be used while reading and writing to redis. - store: T, -} - -impl State -where - T: Store, -{ - /// Create a new State to track connections and callback data - pub(crate) async fn new(msg_graph: MessageGraph, store: T) -> crate::Result { - Ok(Self { - callbacks: Arc::new(Mutex::new(HashMap::new())), - msg_graph_generator: Arc::new(msg_graph), - store, - }) - } - - /// register a new connection - /// The oneshot receiver will be notified when all callbacks for this connection is received from the numaflow pipeline - pub(crate) fn register(&mut self, id: String) -> oneshot::Receiver> { - // TODO: add an entry in Redis to note that the entry has been registered - - let (tx, rx) = oneshot::channel(); - let mut guard = self.callbacks.lock().expect("Getting lock on State"); - guard.insert( - id.clone(), - RequestState { - tx, - vtx_visited: Vec::new(), - }, - ); - rx - } - - /// Retrieves the output of the numaflow pipeline - pub(crate) async fn retrieve_saved(&mut self, id: &str) -> Result>, Error> { - self.store.retrieve_datum(id).await - } - - pub(crate) async fn save_response( - &mut self, - id: String, - body: axum::body::Bytes, - ) -> crate::Result<()> { - // we have to differentiate between the saved responses and the callback requests - // saved responses are stored in "id_SAVED", callback requests are stored in "id" - self.store - .save(vec![PayloadToSave::DatumFromPipeline { - key: id, - value: body, - }]) - .await - } - - /// insert_callback_requests is used to insert the callback requests. - pub(crate) async fn insert_callback_requests( - &mut self, - cb_requests: Vec, - ) -> Result<(), Error> { - /* - TODO: should we consider batching the requests and then processing them? - that way algorithm can be invoked only once for a batch of requests - instead of invoking it for each request. - */ - let cb_requests: Vec> = - cb_requests.into_iter().map(Arc::new).collect(); - let redis_payloads: Vec = cb_requests - .iter() - .cloned() - .map(|cbr| PayloadToSave::Callback { - key: cbr.id.clone(), - value: Arc::clone(&cbr), - }) - .collect(); - - self.store.save(redis_payloads).await?; - - for cbr in cb_requests { - let id = cbr.id.clone(); - { - let mut guard = self.callbacks.lock().expect("Getting lock on State"); - guard - .get_mut(&cbr.id) - .ok_or(Error::IDNotFound( - "Connection for the received callback is not present in the in-memory store", - ))? - .vtx_visited - .push(cbr); - } - - // check if the sub graph can be generated - match self.get_subgraph_from_memory(&id) { - Ok(_) => { - // if the sub graph is generated, then we can send the response - self.deregister(&id).await? - } - Err(e) => { - match e { - Error::SubGraphNotFound(_) => { - // if the sub graph is not generated, then we can continue - continue; - } - _ => { - // if there is an error, deregister with the error - self.deregister(&id).await? - } - } - } - } - } - Ok(()) - } - - /// Get the subgraph for the given ID from in-memory. - fn get_subgraph_from_memory(&self, id: &str) -> Result { - let callbacks = self.get_callbacks_from_memory(id).ok_or(Error::IDNotFound( - "Connection for the received callback is not present in the in-memory store", - ))?; - - self.get_subgraph(id.to_string(), callbacks) - } - - /// Get the subgraph for the given ID from persistent store. This is used querying for the status from the service endpoint even after the - /// request has been completed. - pub(crate) async fn retrieve_subgraph_from_storage( - &mut self, - id: &str, - ) -> Result { - // If the id is not found in the in-memory store, fetch from Redis - let callbacks: Vec> = - match self.retrieve_callbacks_from_storage(id).await { - Ok(callbacks) => callbacks, - Err(e) => { - return Err(e); - } - }; - // check if the sub graph can be generated - self.get_subgraph(id.to_string(), callbacks) - } - - // Generate subgraph from the given callbacks - fn get_subgraph( - &self, - id: String, - callbacks: Vec>, - ) -> Result { - match self - .msg_graph_generator - .generate_subgraph_from_callbacks(id, callbacks) - { - Ok(Some(sub_graph)) => Ok(sub_graph), - Ok(None) => Err(Error::SubGraphNotFound( - "Subgraph could not be generated for the given ID", - )), - Err(e) => Err(e), - } - } - - /// deregister is called to trigger response and delete all the data persisted for that ID - pub(crate) async fn deregister(&mut self, id: &str) -> Result<(), Error> { - let state = { - let mut guard = self.callbacks.lock().expect("Getting lock on State"); - // we do not require the data stored in HashMap anymore - guard.remove(id) - }; - - let Some(state) = state else { - return Err(Error::IDNotFound( - "Connection for the received callback is not present in the in-memory store", - )); - }; - - state - .tx - .send(Ok(id.to_string())) - .map_err(|_| Error::Other("Application bug - Receiver is already dropped".to_string())) - } - - // Get the Callback value for the given ID - // TODO: Generate json serialized data here itself to avoid cloning. - fn get_callbacks_from_memory(&self, id: &str) -> Option>> { - let guard = self.callbacks.lock().expect("Getting lock on State"); - guard.get(id).map(|state| state.vtx_visited.clone()) - } - - // Get the Callback value for the given ID from persistent store - async fn retrieve_callbacks_from_storage( - &mut self, - id: &str, - ) -> Result>, Error> { - // If the id is not found in the in-memory store, fetch from Redis - let callbacks: Vec> = match self.store.retrieve_callbacks(id).await { - Ok(response) => response.into_iter().collect(), - Err(e) => { - return Err(e); - } - }; - Ok(callbacks) - } - - // Check if the store is ready - pub(crate) async fn ready(&mut self) -> bool { - self.store.ready().await - } -} - -#[cfg(test)] -mod tests { - use axum::body::Bytes; - - use super::*; - use crate::app::callback::store::memstore::InMemoryStore; - use crate::pipeline::PipelineDCG; - - const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; - - #[tokio::test] - async fn test_state() { - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let store = InMemoryStore::new(); - let mut state = State::new(msg_graph, store).await.unwrap(); - - // Test register - let id = "test_id".to_string(); - let rx = state.register(id.clone()); - - let xid = id.clone(); - - // spawn a task to listen on the receiver, once we have received all the callbacks for the message - // we will get a response from the receiver with the message id - let handle = tokio::spawn(async move { - let result = rx.await.unwrap(); - // Tests deregister, and fetching the subgraph from the memory - assert_eq!(result.unwrap(), xid); - }); - - // Test save_response - let body = Bytes::from("Test Message"); - state.save_response(id.clone(), body).await.unwrap(); - - // Test retrieve_saved - let saved = state.retrieve_saved(&id).await.unwrap(); - assert_eq!(saved, vec!["Test Message".as_bytes()]); - - // Test insert_callback_requests - let cbs = vec![ - CallbackRequest { - id: id.clone(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: None, - }, - CallbackRequest { - id: id.clone(), - vertex: "planner".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: Some(vec!["tiger".to_owned(), "asciiart".to_owned()]), - }, - CallbackRequest { - id: id.clone(), - vertex: "tiger".to_string(), - cb_time: 12345, - from_vertex: "planner".to_string(), - tags: None, - }, - CallbackRequest { - id: id.clone(), - vertex: "asciiart".to_string(), - cb_time: 12345, - from_vertex: "planner".to_string(), - tags: None, - }, - CallbackRequest { - id: id.clone(), - vertex: "serve-sink".to_string(), - cb_time: 12345, - from_vertex: "tiger".to_string(), - tags: None, - }, - CallbackRequest { - id: id.clone(), - vertex: "serve-sink".to_string(), - cb_time: 12345, - from_vertex: "asciiart".to_string(), - tags: None, - }, - ]; - state.insert_callback_requests(cbs).await.unwrap(); - - let sub_graph = state.retrieve_subgraph_from_storage(&id).await; - assert!(sub_graph.is_ok()); - - handle.await.unwrap(); - } - - #[tokio::test] - async fn test_retrieve_saved_no_entry() { - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let store = InMemoryStore::new(); - let mut state = State::new(msg_graph, store).await.unwrap(); - - let id = "nonexistent_id".to_string(); - - // Try to retrieve saved data for an ID that doesn't exist - let result = state.retrieve_saved(&id).await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_insert_callback_requests_invalid_id() { - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let store = InMemoryStore::new(); - let mut state = State::new(msg_graph, store).await.unwrap(); - - let cbs = vec![CallbackRequest { - id: "nonexistent_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: None, - }]; - - // Try to insert callback requests for an ID that hasn't been registered - let result = state.insert_callback_requests(cbs).await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_retrieve_subgraph_from_storage_no_entry() { - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec).unwrap(); - let store = InMemoryStore::new(); - let mut state = State::new(msg_graph, store).await.unwrap(); - - let id = "nonexistent_id".to_string(); - - // Try to retrieve a subgraph for an ID that doesn't exist - let result = state.retrieve_subgraph_from_storage(&id).await; - - // Check that an error is returned - assert!(result.is_err()); - } -} diff --git a/rust/extns/numaflow-serving/src/app/callback/store.rs b/rust/extns/numaflow-serving/src/app/callback/store.rs deleted file mode 100644 index af5f3c4368..0000000000 --- a/rust/extns/numaflow-serving/src/app/callback/store.rs +++ /dev/null @@ -1,35 +0,0 @@ -use std::sync::Arc; - -use crate::app::callback::CallbackRequest; - -// in-memory store -pub(crate) mod memstore; -// redis as the store -pub(crate) mod redisstore; - -pub(crate) enum PayloadToSave { - /// Callback as sent by Numaflow to track the progression - Callback { - key: String, - value: Arc, - }, - /// Data sent by the Numaflow pipeline which is to be delivered as the response - DatumFromPipeline { - key: String, - value: axum::body::Bytes, - }, -} - -/// Store trait to store the callback information. -#[trait_variant::make(Store: Send)] -#[allow(dead_code)] -pub(crate) trait LocalStore { - async fn save(&mut self, messages: Vec) -> crate::Result<()>; - /// retrieve the callback payloads - async fn retrieve_callbacks( - &mut self, - id: &str, - ) -> Result>, crate::Error>; - async fn retrieve_datum(&mut self, id: &str) -> Result>, crate::Error>; - async fn ready(&mut self) -> bool; -} diff --git a/rust/extns/numaflow-serving/src/app/callback/store/memstore.rs b/rust/extns/numaflow-serving/src/app/callback/store/memstore.rs deleted file mode 100644 index 59355c76c6..0000000000 --- a/rust/extns/numaflow-serving/src/app/callback/store/memstore.rs +++ /dev/null @@ -1,217 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use super::PayloadToSave; -use crate::app::callback::CallbackRequest; -use crate::config::SAVED; -use crate::Error; - -/// `InMemoryStore` is an in-memory implementation of the `Store` trait. -/// It uses a `HashMap` to store data in memory. -#[derive(Clone)] -pub(crate) struct InMemoryStore { - /// The data field is a `HashMap` where the key is a `String` and the value is a `Vec>`. - /// It is wrapped in an `Arc>` to allow shared ownership and thread safety. - data: Arc>>>>, -} - -impl InMemoryStore { - /// Creates a new `InMemoryStore` with an empty `HashMap`. - #[allow(dead_code)] - pub(crate) fn new() -> Self { - Self { - data: Arc::new(std::sync::Mutex::new(HashMap::new())), - } - } -} - -impl super::Store for InMemoryStore { - /// Saves a vector of `PayloadToSave` into the `HashMap`. - /// Each `PayloadToSave` is serialized into bytes and stored in the `HashMap` under its key. - async fn save(&mut self, messages: Vec) -> crate::Result<()> { - let mut data = self.data.lock().unwrap(); - for msg in messages { - match msg { - PayloadToSave::Callback { key, value } => { - if key.is_empty() { - return Err(Error::StoreWrite("Key cannot be empty".to_string())); - } - let bytes = serde_json::to_vec(&*value) - .map_err(|e| Error::StoreWrite(format!("Serializing to bytes - {}", e)))?; - data.entry(key).or_default().push(bytes); - } - PayloadToSave::DatumFromPipeline { key, value } => { - if key.is_empty() { - return Err(Error::StoreWrite("Key cannot be empty".to_string())); - } - data.entry(format!("{}_{}", key, SAVED)) - .or_default() - .push(value.into()); - } - } - } - Ok(()) - } - - /// Retrieves callbacks for a given id from the `HashMap`. - /// Each callback is deserialized from bytes into a `CallbackRequest`. - async fn retrieve_callbacks(&mut self, id: &str) -> Result>, Error> { - let data = self.data.lock().unwrap(); - match data.get(id) { - Some(result) => { - let messages: Result, _> = result - .iter() - .map(|msg| { - let cbr: CallbackRequest = serde_json::from_slice(msg).map_err(|_| { - Error::StoreRead( - "Failed to parse CallbackRequest from bytes".to_string(), - ) - })?; - Ok(Arc::new(cbr)) - }) - .collect(); - messages - } - None => Err(Error::StoreRead(format!("No entry found for id: {}", id))), - } - } - - /// Retrieves data for a given id from the `HashMap`. - /// Each piece of data is deserialized from bytes into a `String`. - async fn retrieve_datum(&mut self, id: &str) -> Result>, Error> { - let id = format!("{}_{}", id, SAVED); - let data = self.data.lock().unwrap(); - match data.get(&id) { - Some(result) => Ok(result.to_vec()), - None => Err(Error::StoreRead(format!("No entry found for id: {}", id))), - } - } - - async fn ready(&mut self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use super::*; - use crate::app::callback::store::{PayloadToSave, Store}; - use crate::app::callback::CallbackRequest; - - #[tokio::test] - async fn test_save_and_retrieve_callbacks() { - let mut store = InMemoryStore::new(); - let key = "test_key".to_string(); - let value = Arc::new(CallbackRequest { - id: "test_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: None, - }); - - // Save a callback - store - .save(vec![PayloadToSave::Callback { - key: key.clone(), - value: Arc::clone(&value), - }]) - .await - .unwrap(); - - // Retrieve the callback - let retrieved = store.retrieve_callbacks(&key).await.unwrap(); - - // Check that the retrieved callback is the same as the one we saved - assert_eq!(retrieved.len(), 1); - assert_eq!(retrieved[0].id, "test_id".to_string()) - } - - #[tokio::test] - async fn test_save_and_retrieve_datum() { - let mut store = InMemoryStore::new(); - let key = "test_key".to_string(); - let value = "test_value".to_string(); - - // Save a datum - store - .save(vec![PayloadToSave::DatumFromPipeline { - key: key.clone(), - value: value.clone().into(), - }]) - .await - .unwrap(); - - // Retrieve the datum - let retrieved = store.retrieve_datum(&key).await.unwrap(); - - // Check that the retrieved datum is the same as the one we saved - assert_eq!(retrieved.len(), 1); - assert_eq!(retrieved[0], value.as_bytes()); - } - - #[tokio::test] - async fn test_retrieve_callbacks_no_entry() { - let mut store = InMemoryStore::new(); - let key = "nonexistent_key".to_string(); - - // Try to retrieve a callback for a key that doesn't exist - let result = store.retrieve_callbacks(&key).await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_retrieve_datum_no_entry() { - let mut store = InMemoryStore::new(); - let key = "nonexistent_key".to_string(); - - // Try to retrieve a datum for a key that doesn't exist - let result = store.retrieve_datum(&key).await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_save_invalid_callback() { - let mut store = InMemoryStore::new(); - let value = Arc::new(CallbackRequest { - id: "test_id".to_string(), - vertex: "in".to_string(), - cb_time: 12345, - from_vertex: "in".to_string(), - tags: None, - }); - - // Try to save a callback with an invalid key - let result = store - .save(vec![PayloadToSave::Callback { - key: "".to_string(), - value: Arc::clone(&value), - }]) - .await; - - // Check that an error is returned - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_save_invalid_datum() { - let mut store = InMemoryStore::new(); - - // Try to save a datum with an invalid key - let result = store - .save(vec![PayloadToSave::DatumFromPipeline { - key: "".to_string(), - value: "test_value".into(), - }]) - .await; - - // Check that an error is returned - assert!(result.is_err()); - } -} diff --git a/rust/extns/numaflow-serving/src/app/callback/store/redisstore.rs b/rust/extns/numaflow-serving/src/app/callback/store/redisstore.rs deleted file mode 100644 index deae8b42cf..0000000000 --- a/rust/extns/numaflow-serving/src/app/callback/store/redisstore.rs +++ /dev/null @@ -1,198 +0,0 @@ -use std::sync::Arc; - -use backoff::retry::Retry; -use backoff::strategy::fixed; -use redis::aio::ConnectionManager; -use redis::RedisError; -use tokio::sync::Semaphore; - -use super::PayloadToSave; -use crate::app::callback::CallbackRequest; -use crate::config::RedisConfig; -use crate::config::SAVED; -use crate::Error; - -const LPUSH: &str = "LPUSH"; -const LRANGE: &str = "LRANGE"; -const EXPIRE: &str = "EXPIRE"; - -// Handle to the Redis actor. -#[derive(Clone)] -pub(crate) struct RedisConnection { - conn_manager: ConnectionManager, - config: RedisConfig, -} - -impl RedisConnection { - /// Creates a new RedisConnection with concurrent operations on Redis set by max_tasks. - pub(crate) async fn new(config: RedisConfig) -> crate::Result { - let client = redis::Client::open(config.addr.as_str()) - .map_err(|e| Error::Connection(format!("Creating Redis client: {e:?}")))?; - let conn = client - .get_connection_manager() - .await - .map_err(|e| Error::Connection(format!("Connecting to Redis server: {e:?}")))?; - Ok(Self { - conn_manager: conn, - config, - }) - } - - async fn execute_redis_cmd( - conn_manager: &mut ConnectionManager, - ttl_secs: Option, - key: &str, - val: &Vec, - ) -> Result<(), RedisError> { - let mut pipe = redis::pipe(); - pipe.cmd(LPUSH).arg(key).arg(val); - - // if the ttl is configured, add the EXPIRE command to the pipeline - if let Some(ttl) = ttl_secs { - pipe.cmd(EXPIRE).arg(key).arg(ttl); - } - - // Execute the pipeline - pipe.query_async(conn_manager).await.map(|_: ()| ()) - } - - // write to Redis with retries - async fn write_to_redis(&self, key: &str, value: &Vec) -> crate::Result<()> { - let interval = fixed::Interval::from_millis(self.config.retries_duration_millis.into()) - .take(self.config.retries); - - Retry::retry( - interval, - || async { - // https://hackmd.io/@compiler-errors/async-closures - Self::execute_redis_cmd( - &mut self.conn_manager.clone(), - self.config.ttl_secs, - key, - value, - ) - .await - }, - |e: &RedisError| !e.is_unrecoverable_error(), - ) - .await - .map_err(|err| Error::StoreWrite(format!("Saving to redis: {}", err).to_string())) - } -} - -async fn handle_write_requests( - redis_conn: RedisConnection, - msg: PayloadToSave, -) -> crate::Result<()> { - match msg { - PayloadToSave::Callback { key, value } => { - // Convert the CallbackRequest to a byte array - let value = serde_json::to_vec(&*value) - .map_err(|e| Error::StoreWrite(format!("Serializing payload - {}", e)))?; - - redis_conn.write_to_redis(&key, &value).await - } - - // Write the byte array to Redis - PayloadToSave::DatumFromPipeline { key, value } => { - // we have to differentiate between the saved responses and the callback requests - // saved responses are stored in "id_SAVED", callback requests are stored in "id" - let key = format!("{}_{}", key, SAVED); - let value: Vec = value.into(); - - redis_conn.write_to_redis(&key, &value).await - } - } -} - -// It is possible to move the methods defined here to be methods on the Redis actor and communicate through channels. -// With that, all public APIs defined on RedisConnection can be on &self (immutable). -impl super::Store for RedisConnection { - // Attempt to save all payloads. Returns error if we fail to save at least one message. - async fn save(&mut self, messages: Vec) -> crate::Result<()> { - let mut tasks = vec![]; - // This is put in place not to overload Redis and also way some kind of - // flow control. - let sem = Arc::new(Semaphore::new(self.config.max_tasks)); - for msg in messages { - let permit = Arc::clone(&sem).acquire_owned().await; - let redis_conn = self.clone(); - let task = tokio::spawn(async move { - let _permit = permit; - handle_write_requests(redis_conn, msg).await - }); - tasks.push(task); - } - for task in tasks { - task.await.unwrap()?; - } - Ok(()) - } - - async fn retrieve_callbacks(&mut self, id: &str) -> Result>, Error> { - let result: Result>, RedisError> = redis::cmd(LRANGE) - .arg(id) - .arg(0) - .arg(-1) - .query_async(&mut self.conn_manager) - .await; - - match result { - Ok(result) => { - if result.is_empty() { - return Err(Error::StoreRead(format!("No entry found for id: {}", id))); - } - - let messages: Result, _> = result - .into_iter() - .map(|msg| { - let cbr: CallbackRequest = serde_json::from_slice(&msg).map_err(|e| { - Error::StoreRead(format!("Parsing payload from bytes - {}", e)) - })?; - Ok(Arc::new(cbr)) - }) - .collect(); - - messages - } - Err(e) => Err(Error::StoreRead(format!( - "Failed to read from redis: {:?}", - e - ))), - } - } - - async fn retrieve_datum(&mut self, id: &str) -> Result>, Error> { - // saved responses are stored in "id_SAVED" - let key = format!("{}_{}", id, SAVED); - let result: Result>, RedisError> = redis::cmd(LRANGE) - .arg(key) - .arg(0) - .arg(-1) - .query_async(&mut self.conn_manager) - .await; - - match result { - Ok(result) => { - if result.is_empty() { - return Err(Error::StoreRead(format!("No entry found for id: {}", id))); - } - - Ok(result) - } - Err(e) => Err(Error::StoreRead(format!( - "Failed to read from redis: {:?}", - e - ))), - } - } - - // Check if the Redis connection is healthy - async fn ready(&mut self) -> bool { - let mut conn = self.conn_manager.clone(); - match redis::cmd("PING").query_async::(&mut conn).await { - Ok(response) => response == "PONG", - Err(_) => false, - } - } -} diff --git a/rust/extns/numaflow-serving/src/app/response.rs b/rust/extns/numaflow-serving/src/app/response.rs deleted file mode 100644 index 40064a1f78..0000000000 --- a/rust/extns/numaflow-serving/src/app/response.rs +++ /dev/null @@ -1,60 +0,0 @@ -use axum::http::StatusCode; -use axum::response::{IntoResponse, Response}; -use axum::Json; -use chrono::{DateTime, Utc}; -use serde::Serialize; - -// Response sent by the serve handler sync/async to the client(user). -#[derive(Serialize)] -pub(crate) struct ServeResponse { - pub(crate) message: String, - pub(crate) id: String, - pub(crate) code: u16, - pub(crate) timestamp: DateTime, -} - -impl ServeResponse { - pub(crate) fn new(message: String, id: String, status: StatusCode) -> Self { - Self { - code: status.as_u16(), - message, - id, - timestamp: Utc::now(), - } - } -} - -// Error response sent by all the handlers to the client(user). -#[derive(Debug, Serialize)] -pub enum ApiError { - BadRequest(String), - InternalServerError(String), - BadGateway(String), -} - -impl IntoResponse for ApiError { - fn into_response(self) -> Response { - #[derive(Serialize)] - struct ErrorBody { - message: String, - code: u16, - timestamp: DateTime, - } - - let (status, message) = match self { - ApiError::BadRequest(message) => (StatusCode::BAD_REQUEST, message), - ApiError::InternalServerError(message) => (StatusCode::INTERNAL_SERVER_ERROR, message), - ApiError::BadGateway(message) => (StatusCode::BAD_GATEWAY, message), - }; - - ( - status, - Json(ErrorBody { - code: status.as_u16(), - message, - timestamp: Utc::now(), - }), - ) - .into_response() - } -} diff --git a/rust/extns/numaflow-serving/src/app/tracker.rs b/rust/extns/numaflow-serving/src/app/tracker.rs deleted file mode 100644 index 33137f45db..0000000000 --- a/rust/extns/numaflow-serving/src/app/tracker.rs +++ /dev/null @@ -1,933 +0,0 @@ -use std::collections::HashMap; -use std::string::ToString; -use std::sync::Arc; - -use serde::{Deserialize, Serialize}; - -use crate::app::callback::CallbackRequest; -use crate::pipeline::{Edge, OperatorType, PipelineDCG}; -use crate::Error; - -fn compare_slice(operator: &OperatorType, a: &[String], b: &[String]) -> bool { - match operator { - OperatorType::And => a.iter().all(|val| b.contains(val)), - OperatorType::Or => a.iter().any(|val| b.contains(val)), - OperatorType::Not => !a.iter().any(|val| b.contains(val)), - } -} - -type Graph = HashMap>; - -#[derive(Serialize, Deserialize, Debug)] -struct Subgraph { - id: String, - blocks: Vec, -} - -const DROP: &str = "U+005C__DROP__"; - -/// MessageGraph is a struct that generates the graph from the source vertex to the downstream vertices -/// for a message using the given callbacks. -pub(crate) struct MessageGraph { - dag: Graph, -} - -/// Block is a struct that contains the information about the block in the subgraph. -#[derive(Clone, Debug, Deserialize, Serialize)] -pub(crate) struct Block { - from: String, - to: String, - cb_time: u64, -} - -// CallbackRequestWrapper is a struct that contains the information about the callback request and -// whether it has been visited or not. It is used to keep track of the visited callbacks. -#[derive(Debug)] -struct CallbackRequestWrapper { - callback_request: Arc, - visited: bool, -} - -impl MessageGraph { - /// This function generates a sub graph from a list of callbacks. - /// It first creates a HashMap to map each vertex to its corresponding callbacks. - /// Then it finds the source vertex by checking if the vertex and from_vertex fields are the same. - /// Finally, it calls the `generate_subgraph` function to generate the subgraph from the source vertex. - pub(crate) fn generate_subgraph_from_callbacks( - &self, - id: String, - callbacks: Vec>, - ) -> Result, Error> { - // Create a HashMap to map each vertex to its corresponding callbacks - let mut callback_map: HashMap> = HashMap::new(); - let mut source_vertex = None; - - // Find the source vertex, source vertex is the vertex that has the same vertex and from_vertex - for callback in callbacks { - // Check if the vertex is present in the graph - if !self.dag.contains_key(&callback.from_vertex) { - return Err(Error::SubGraphInvalidInput(format!( - "Invalid callback: {}, vertex: {}", - callback.id, callback.from_vertex - ))); - } - - if callback.vertex == callback.from_vertex { - source_vertex = Some(callback.vertex.clone()); - } - callback_map - .entry(callback.vertex.clone()) - .or_default() - .push(CallbackRequestWrapper { - callback_request: Arc::clone(&callback), - visited: false, - }); - } - - // If there is no source vertex, return None - let source_vertex = match source_vertex { - Some(vertex) => vertex, - None => return Ok(None), - }; - - // Create a new subgraph. - let mut subgraph: Subgraph = Subgraph { - id, - blocks: Vec::new(), - }; - // Call the `generate_subgraph` function to generate the subgraph from the source vertex - let result = self.generate_subgraph( - source_vertex.clone(), - source_vertex, - &mut callback_map, - &mut subgraph, - ); - - // If the subgraph is generated successfully, serialize it into a JSON string and return it. - // Otherwise, return None - if result { - match serde_json::to_string(&subgraph) { - Ok(json) => Ok(Some(json)), - Err(e) => Err(Error::SubGraphGenerator(e.to_string())), - } - } else { - Ok(None) - } - } - - // generate_subgraph function is a recursive function that generates the sub graph from the source vertex for - // the given list of callbacks. The function returns true if the subgraph is generated successfully(if we are - // able to find a subgraph for the message using the given callbacks), it - // updates the subgraph with the path from the source vertex to the downstream vertices. - fn generate_subgraph( - &self, - current: String, - from: String, - callback_map: &mut HashMap>, - subgraph: &mut Subgraph, - ) -> bool { - let mut current_callback: Option> = None; - - // we need to borrow the callback_map as mutable to update the visited flag of the callback - // so that next time when we visit the same callback, we can skip it. Because there can be cases - // where there will be multiple callbacks for the same vertex. - // If there are no callbacks for the current vertex, we should not continue - // because it means we have not received all the callbacks for the given message. - let Some(callbacks) = callback_map.get_mut(¤t) else { - return false; - }; - - // iterate over the callbacks for the current vertex and find the one that has not been visited, - // and it is coming from the same vertex as the current vertex - for callback in callbacks { - if callback.callback_request.from_vertex == from && !callback.visited { - callback.visited = true; - current_callback = Some(Arc::clone(&callback.callback_request)); - break; - } - } - - // If there is no callback which is not visited and its parent is the "from" vertex, then we should - // return false because we have not received all the callbacks for the given message. - let Some(current_callback) = current_callback else { - return false; - }; - - // add the current block to the subgraph - subgraph.blocks.push(Block { - from: current_callback.from_vertex.clone(), - to: current_callback.vertex.clone(), - cb_time: current_callback.cb_time, - }); - - // if the current vertex has a DROP tag, then we should not proceed further - // and return true - if current_callback - .tags - .as_ref() - .map_or(false, |tags| tags.contains(&DROP.to_string())) - { - return true; - } - - // recursively invoke the downstream vertices of the current vertex, if any - if let Some(edges) = self.dag.get(¤t) { - for edge in edges { - // check if the edge should proceed based on the conditions - // if there are no conditions, we should proceed with the edge - // if there are conditions, we should check the tags of the current callback - // with the tags of the edge and the operator of the tags to decide if we should - // proceed with the edge - let should_proceed = edge - .conditions - .as_ref() - // If the edge has conditions, get the tags - .and_then(|conditions| conditions.tags.as_ref()) - // If there are no conditions or tags, default to true (i.e., proceed with the edge) - // If there are tags, compare the tags with the current callback's tags and the operator - // to decide if we should proceed with the edge. - .map_or(true, |tags| { - current_callback - .tags - .as_ref() - // If the current callback has no tags we should not proceed with the edge for "and" and "or" operators - // because we expect the current callback to have tags specified in the edge. - // If the current callback has no tags we should proceed with the edge for "not" operator. - // because we don't expect the current callback to have tags specified in the edge. - .map_or( - tags.operator.as_ref() == Some(&OperatorType::Not), - |callback_tags| { - tags.operator.as_ref().map_or(false, |operator| { - // If there is no operator, default to false (i.e., do not proceed with the edge) - // If there is an operator, compare the current callback's tags with the edge's tags - compare_slice(operator, callback_tags, &tags.values) - }) - }, - ) - }); - - // if the conditions are not met, then proceed to the next edge - if !should_proceed { - continue; - } - - // proceed to the downstream vertex - // if any of the downstream vertex returns false, then we should return false. - if !self.generate_subgraph(edge.to.clone(), current.clone(), callback_map, subgraph) - { - return false; - } - } - } - // if there are no downstream vertices, or all the downstream vertices returned true, - // we can return true - true - } - - // from_env reads the pipeline stored in the environment variable and creates a MessageGraph from it. - pub(crate) fn from_pipeline(pipeline_spec: &PipelineDCG) -> Result { - let mut dag = Graph::with_capacity(pipeline_spec.edges.len()); - for edge in &pipeline_spec.edges { - dag.entry(edge.from.clone()).or_default().push(edge.clone()); - } - - Ok(MessageGraph { dag }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::pipeline::{Conditions, Tag, Vertex}; - - #[test] - fn test_no_subgraph() { - let mut dag: Graph = HashMap::new(); - dag.insert( - "a".to_string(), - vec![ - Edge { - from: "a".to_string(), - to: "b".to_string(), - conditions: None, - }, - Edge { - from: "a".to_string(), - to: "c".to_string(), - conditions: None, - }, - ], - ); - let message_graph = MessageGraph { dag }; - - let mut callback_map: HashMap> = HashMap::new(); - callback_map.insert( - "a".to_string(), - vec![CallbackRequestWrapper { - callback_request: Arc::new(CallbackRequest { - id: "uuid1".to_string(), - vertex: "a".to_string(), - cb_time: 1, - from_vertex: "a".to_string(), - tags: None, - }), - visited: false, - }], - ); - - let mut subgraph: Subgraph = Subgraph { - id: "uuid1".to_string(), - blocks: Vec::new(), - }; - let result = message_graph.generate_subgraph( - "a".to_string(), - "a".to_string(), - &mut callback_map, - &mut subgraph, - ); - - assert!(!result); - } - - #[test] - fn test_generate_subgraph() { - let mut dag: Graph = HashMap::new(); - dag.insert( - "a".to_string(), - vec![ - Edge { - from: "a".to_string(), - to: "b".to_string(), - conditions: None, - }, - Edge { - from: "a".to_string(), - to: "c".to_string(), - conditions: None, - }, - ], - ); - let message_graph = MessageGraph { dag }; - - let mut callback_map: HashMap> = HashMap::new(); - callback_map.insert( - "a".to_string(), - vec![CallbackRequestWrapper { - callback_request: Arc::new(CallbackRequest { - id: "uuid1".to_string(), - vertex: "a".to_string(), - cb_time: 1, - from_vertex: "a".to_string(), - tags: None, - }), - visited: false, - }], - ); - - callback_map.insert( - "b".to_string(), - vec![CallbackRequestWrapper { - callback_request: Arc::new(CallbackRequest { - id: "uuid1".to_string(), - vertex: "b".to_string(), - cb_time: 1, - from_vertex: "a".to_string(), - tags: None, - }), - visited: false, - }], - ); - - callback_map.insert( - "c".to_string(), - vec![CallbackRequestWrapper { - callback_request: Arc::new(CallbackRequest { - id: "uuid1".to_string(), - vertex: "c".to_string(), - cb_time: 1, - from_vertex: "a".to_string(), - tags: None, - }), - visited: false, - }], - ); - - let mut subgraph: Subgraph = Subgraph { - id: "uuid1".to_string(), - blocks: Vec::new(), - }; - let result = message_graph.generate_subgraph( - "a".to_string(), - "a".to_string(), - &mut callback_map, - &mut subgraph, - ); - - assert!(result); - } - - #[test] - fn test_generate_subgraph_complex() { - let pipeline = PipelineDCG { - vertices: vec![ - Vertex { - name: "a".to_string(), - }, - Vertex { - name: "b".to_string(), - }, - Vertex { - name: "c".to_string(), - }, - Vertex { - name: "d".to_string(), - }, - Vertex { - name: "e".to_string(), - }, - Vertex { - name: "f".to_string(), - }, - Vertex { - name: "g".to_string(), - }, - Vertex { - name: "h".to_string(), - }, - Vertex { - name: "i".to_string(), - }, - ], - edges: vec![ - Edge { - from: "a".to_string(), - to: "b".to_string(), - conditions: None, - }, - Edge { - from: "a".to_string(), - to: "c".to_string(), - conditions: None, - }, - Edge { - from: "b".to_string(), - to: "d".to_string(), - conditions: None, - }, - Edge { - from: "c".to_string(), - to: "e".to_string(), - conditions: None, - }, - Edge { - from: "d".to_string(), - to: "f".to_string(), - conditions: None, - }, - Edge { - from: "e".to_string(), - to: "f".to_string(), - conditions: None, - }, - Edge { - from: "f".to_string(), - to: "g".to_string(), - conditions: None, - }, - Edge { - from: "g".to_string(), - to: "h".to_string(), - conditions: Some(Conditions { - tags: Some(Tag { - operator: Some(OperatorType::And), - values: vec!["even".to_string()], - }), - }), - }, - Edge { - from: "g".to_string(), - to: "i".to_string(), - conditions: Some(Conditions { - tags: Some(Tag { - operator: Some(OperatorType::Or), - values: vec!["odd".to_string()], - }), - }), - }, - ], - }; - - let message_graph = MessageGraph::from_pipeline(&pipeline).unwrap(); - let source_vertex = "a".to_string(); - - let raw_callback = r#"[ - { - "id": "xxxx", - "vertex": "a", - "from_vertex": "a", - "cb_time": 123456789 - }, - { - "id": "xxxx", - "vertex": "b", - "from_vertex": "a", - "cb_time": 123456867 - }, - { - "id": "xxxx", - "vertex": "c", - "from_vertex": "a", - "cb_time": 123456819 - }, - { - "id": "xxxx", - "vertex": "d", - "from_vertex": "b", - "cb_time": 123456840 - }, - { - "id": "xxxx", - "vertex": "e", - "from_vertex": "c", - "cb_time": 123456843 - }, - { - "id": "xxxx", - "vertex": "f", - "from_vertex": "d", - "cb_time": 123456854 - }, - { - "id": "xxxx", - "vertex": "f", - "from_vertex": "e", - "cb_time": 123456886 - }, - { - "id": "xxxx", - "vertex": "g", - "from_vertex": "f", - "tags": ["even"], - "cb_time": 123456885 - }, - { - "id": "xxxx", - "vertex": "g", - "from_vertex": "f", - "tags": ["even"], - "cb_time": 123456888 - }, - { - "id": "xxxx", - "vertex": "h", - "from_vertex": "g", - "cb_time": 123456889 - }, - { - "id": "xxxx", - "vertex": "h", - "from_vertex": "g", - "cb_time": 123456890 - } - ]"#; - - let callbacks: Vec = serde_json::from_str(raw_callback).unwrap(); - let mut callback_map: HashMap> = HashMap::new(); - - for callback in callbacks { - callback_map - .entry(callback.vertex.clone()) - .or_default() - .push(CallbackRequestWrapper { - callback_request: Arc::new(callback), - visited: false, - }); - } - - let mut subgraph: Subgraph = Subgraph { - id: "xxxx".to_string(), - blocks: Vec::new(), - }; - let result = message_graph.generate_subgraph( - source_vertex.clone(), - source_vertex, - &mut callback_map, - &mut subgraph, - ); - - assert!(result); - } - - #[test] - fn test_simple_dropped_message() { - let pipeline = PipelineDCG { - vertices: vec![ - Vertex { - name: "a".to_string(), - }, - Vertex { - name: "b".to_string(), - }, - Vertex { - name: "c".to_string(), - }, - ], - edges: vec![ - Edge { - from: "a".to_string(), - to: "b".to_string(), - conditions: None, - }, - Edge { - from: "b".to_string(), - to: "c".to_string(), - conditions: None, - }, - ], - }; - - let message_graph = MessageGraph::from_pipeline(&pipeline).unwrap(); - let source_vertex = "a".to_string(); - - let raw_callback = r#" - [ - { - "id": "xxxx", - "vertex": "a", - "from_vertex": "a", - "cb_time": 123456789 - }, - { - "id": "xxxx", - "vertex": "b", - "from_vertex": "a", - "cb_time": 123456867, - "tags": ["U+005C__DROP__"] - } - ]"#; - - let callbacks: Vec = serde_json::from_str(raw_callback).unwrap(); - let mut callback_map: HashMap> = HashMap::new(); - - for callback in callbacks { - callback_map - .entry(callback.vertex.clone()) - .or_default() - .push(CallbackRequestWrapper { - callback_request: Arc::new(callback), - visited: false, - }); - } - - let mut subgraph: Subgraph = Subgraph { - id: "xxxx".to_string(), - blocks: Vec::new(), - }; - let result = message_graph.generate_subgraph( - source_vertex.clone(), - source_vertex, - &mut callback_map, - &mut subgraph, - ); - - assert!(result); - } - - #[test] - fn test_complex_dropped_message() { - let pipeline = PipelineDCG { - vertices: vec![ - Vertex { - name: "a".to_string(), - }, - Vertex { - name: "b".to_string(), - }, - Vertex { - name: "c".to_string(), - }, - Vertex { - name: "d".to_string(), - }, - Vertex { - name: "e".to_string(), - }, - Vertex { - name: "f".to_string(), - }, - Vertex { - name: "g".to_string(), - }, - Vertex { - name: "h".to_string(), - }, - Vertex { - name: "i".to_string(), - }, - ], - edges: vec![ - Edge { - from: "a".to_string(), - to: "b".to_string(), - conditions: None, - }, - Edge { - from: "a".to_string(), - to: "c".to_string(), - conditions: None, - }, - Edge { - from: "b".to_string(), - to: "d".to_string(), - conditions: None, - }, - Edge { - from: "c".to_string(), - to: "e".to_string(), - conditions: None, - }, - Edge { - from: "d".to_string(), - to: "f".to_string(), - conditions: None, - }, - Edge { - from: "e".to_string(), - to: "f".to_string(), - conditions: None, - }, - Edge { - from: "f".to_string(), - to: "g".to_string(), - conditions: None, - }, - Edge { - from: "g".to_string(), - to: "h".to_string(), - conditions: Some(Conditions { - tags: Some(Tag { - operator: Some(OperatorType::And), - values: vec!["even".to_string()], - }), - }), - }, - Edge { - from: "g".to_string(), - to: "i".to_string(), - conditions: Some(Conditions { - tags: Some(Tag { - operator: Some(OperatorType::Or), - values: vec!["odd".to_string()], - }), - }), - }, - ], - }; - - let message_graph = MessageGraph::from_pipeline(&pipeline).unwrap(); - let source_vertex = "a".to_string(); - - let raw_callback = r#" - [ - { - "id": "xxxx", - "vertex": "a", - "from_vertex": "a", - "cb_time": 123456789 - }, - { - "id": "xxxx", - "vertex": "b", - "from_vertex": "a", - "cb_time": 123456867 - }, - { - "id": "xxxx", - "vertex": "c", - "from_vertex": "a", - "cb_time": 123456819, - "tags": ["U+005C__DROP__"] - }, - { - "id": "xxxx", - "vertex": "d", - "from_vertex": "b", - "cb_time": 123456840 - }, - { - "id": "xxxx", - "vertex": "f", - "from_vertex": "d", - "cb_time": 123456854 - }, - { - "id": "xxxx", - "vertex": "g", - "from_vertex": "f", - "tags": ["even"], - "cb_time": 123456885 - }, - { - "id": "xxxx", - "vertex": "h", - "from_vertex": "g", - "cb_time": 123456889 - } - ]"#; - - let callbacks: Vec = serde_json::from_str(raw_callback).unwrap(); - let mut callback_map: HashMap> = HashMap::new(); - - for callback in callbacks { - callback_map - .entry(callback.vertex.clone()) - .or_default() - .push(CallbackRequestWrapper { - callback_request: Arc::new(callback), - visited: false, - }); - } - - let mut subgraph: Subgraph = Subgraph { - id: "xxxx".to_string(), - blocks: Vec::new(), - }; - let result = message_graph.generate_subgraph( - source_vertex.clone(), - source_vertex, - &mut callback_map, - &mut subgraph, - ); - - assert!(result); - } - - #[test] - fn test_simple_cycle_pipeline() { - let pipeline = PipelineDCG { - vertices: vec![ - Vertex { - name: "a".to_string(), - }, - Vertex { - name: "b".to_string(), - }, - Vertex { - name: "c".to_string(), - }, - ], - edges: vec![ - Edge { - from: "a".to_string(), - to: "b".to_string(), - conditions: None, - }, - Edge { - from: "b".to_string(), - to: "a".to_string(), - conditions: Some(Conditions { - tags: Some(Tag { - operator: Some(OperatorType::Not), - values: vec!["failed".to_string()], - }), - }), - }, - ], - }; - - let message_graph = MessageGraph::from_pipeline(&pipeline).unwrap(); - let source_vertex = "a".to_string(); - - let raw_callback = r#" - [ - { - "id": "xxxx", - "vertex": "a", - "from_vertex": "a", - "cb_time": 123456789 - }, - { - "id": "xxxx", - "vertex": "b", - "from_vertex": "a", - "cb_time": 123456867, - "tags": ["failed"] - }, - { - "id": "xxxx", - "vertex": "a", - "from_vertex": "b", - "cb_time": 123456819 - }, - { - "id": "xxxx", - "vertex": "b", - "from_vertex": "a", - "cb_time": 123456819 - }, - { - "id": "xxxx", - "vertex": "c", - "from_vertex": "b", - "cb_time": 123456819 - } - ]"#; - - let callbacks: Vec = serde_json::from_str(raw_callback).unwrap(); - let mut callback_map: HashMap> = HashMap::new(); - - for callback in callbacks { - callback_map - .entry(callback.vertex.clone()) - .or_default() - .push(CallbackRequestWrapper { - callback_request: Arc::new(callback), - visited: false, - }); - } - - let mut subgraph: Subgraph = Subgraph { - id: "xxxx".to_string(), - blocks: Vec::new(), - }; - let result = message_graph.generate_subgraph( - source_vertex.clone(), - source_vertex, - &mut callback_map, - &mut subgraph, - ); - - assert!(result); - } - - #[test] - fn test_generate_subgraph_from_callbacks_with_invalid_vertex() { - // Create a simple graph - let mut dag: Graph = HashMap::new(); - dag.insert( - "a".to_string(), - vec![Edge { - from: "a".to_string(), - to: "b".to_string(), - conditions: None, - }], - ); - let message_graph = MessageGraph { dag }; - - // Create a callback with an invalid vertex - let callbacks = vec![Arc::new(CallbackRequest { - id: "test".to_string(), - vertex: "invalid_vertex".to_string(), - from_vertex: "invalid_vertex".to_string(), - cb_time: 1, - tags: None, - })]; - - // Call the function with the invalid callback - let result = message_graph.generate_subgraph_from_callbacks("test".to_string(), callbacks); - - // Check that the function returned an error - assert!(result.is_err()); - assert!(matches!(result, Err(Error::SubGraphInvalidInput(_)))); - } -} diff --git a/rust/extns/numaflow-serving/src/config.rs b/rust/extns/numaflow-serving/src/config.rs deleted file mode 100644 index d306b47d70..0000000000 --- a/rust/extns/numaflow-serving/src/config.rs +++ /dev/null @@ -1,235 +0,0 @@ -use std::collections::HashMap; - -use base64::prelude::BASE64_STANDARD; -use base64::Engine; -use serde::{Deserialize, Serialize}; - -use crate::pipeline::PipelineDCG; -use crate::Error; - -const ENV_NUMAFLOW_SERVING_SOURCE_OBJECT: &str = "NUMAFLOW_SERVING_SOURCE_OBJECT"; -const ENV_NUMAFLOW_SERVING_STORE_TTL: &str = "NUMAFLOW_SERVING_STORE_TTL"; -const ENV_NUMAFLOW_SERVING_HOST_IP: &str = "NUMAFLOW_SERVING_HOST_IP"; -const ENV_NUMAFLOW_SERVING_APP_PORT: &str = "NUMAFLOW_SERVING_APP_LISTEN_PORT"; -const ENV_NUMAFLOW_SERVING_AUTH_TOKEN: &str = "NUMAFLOW_SERVING_AUTH_TOKEN"; -const ENV_MIN_PIPELINE_SPEC: &str = "NUMAFLOW_SERVING_MIN_PIPELINE_SPEC"; - -pub(crate) const SAVED: &str = "SAVED"; - -#[derive(Debug, Deserialize, Clone, PartialEq)] -pub struct RedisConfig { - pub addr: String, - pub max_tasks: usize, - pub retries: usize, - pub retries_duration_millis: u16, - pub ttl_secs: Option, -} - -impl Default for RedisConfig { - fn default() -> Self { - Self { - addr: "redis://127.0.0.1:6379".to_owned(), - max_tasks: 50, - retries: 5, - retries_duration_millis: 100, - // TODO: we might need an option type here. Zero value of u32 can be used instead of None - ttl_secs: Some(1), - } - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct Settings { - pub tid_header: String, - pub app_listen_port: u16, - pub metrics_server_listen_port: u16, - pub upstream_addr: String, - pub drain_timeout_secs: u64, - /// The IP address of the numaserve pod. This will be used to construct the value for X-Numaflow-Callback-Url header - pub host_ip: String, - pub api_auth_token: Option, - pub redis: RedisConfig, - pub pipeline_spec: PipelineDCG, -} - -impl Default for Settings { - fn default() -> Self { - Self { - tid_header: "ID".to_owned(), - app_listen_port: 3000, - metrics_server_listen_port: 3001, - upstream_addr: "localhost:8888".to_owned(), - drain_timeout_secs: 10, - host_ip: "127.0.0.1".to_owned(), - api_auth_token: None, - redis: RedisConfig::default(), - pipeline_spec: Default::default(), - } - } -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Serving { - #[serde(rename = "msgIDHeaderKey")] - pub msg_id_header_key: Option, - #[serde(rename = "store")] - pub callback_storage: CallbackStorageConfig, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct CallbackStorageConfig { - pub url: String, -} - -/// This implementation is to load settings from env variables -impl TryFrom> for Settings { - type Error = Error; - fn try_from(env_vars: HashMap) -> std::result::Result { - let host_ip = env_vars - .get(ENV_NUMAFLOW_SERVING_HOST_IP) - .ok_or_else(|| { - Error::ParseConfig(format!( - "Environment variable {ENV_NUMAFLOW_SERVING_HOST_IP} is not set" - )) - })? - .to_owned(); - - let pipeline_spec: PipelineDCG = env_vars - .get(ENV_MIN_PIPELINE_SPEC) - .ok_or_else(|| { - Error::ParseConfig(format!( - "Pipeline spec is not set using environment variable {ENV_MIN_PIPELINE_SPEC}" - )) - })? - .parse() - .map_err(|e| { - Error::ParseConfig(format!( - "Parsing pipeline spec: {}: error={e:?}", - env_vars.get(ENV_MIN_PIPELINE_SPEC).unwrap() - )) - })?; - - let mut settings = Settings { - host_ip, - pipeline_spec, - ..Default::default() - }; - - if let Some(api_auth_token) = env_vars.get(ENV_NUMAFLOW_SERVING_AUTH_TOKEN) { - settings.api_auth_token = Some(api_auth_token.to_owned()); - } - - if let Some(app_port) = env_vars.get(ENV_NUMAFLOW_SERVING_APP_PORT) { - settings.app_listen_port = app_port.parse().map_err(|e| { - Error::ParseConfig(format!( - "Parsing {ENV_NUMAFLOW_SERVING_APP_PORT}(set to '{app_port}'): {e:?}" - )) - })?; - } - - // Update redis.ttl_secs from environment variable - if let Some(ttl_secs) = env_vars.get(ENV_NUMAFLOW_SERVING_STORE_TTL) { - let ttl_secs: u32 = ttl_secs.parse().map_err(|e| { - Error::ParseConfig(format!("parsing {ENV_NUMAFLOW_SERVING_STORE_TTL}: {e:?}")) - })?; - settings.redis.ttl_secs = Some(ttl_secs); - } - - let Some(source_spec_encoded) = env_vars.get(ENV_NUMAFLOW_SERVING_SOURCE_OBJECT) else { - return Ok(settings); - }; - - let source_spec_decoded = BASE64_STANDARD - .decode(source_spec_encoded.as_bytes()) - .map_err(|e| Error::ParseConfig(format!("decoding NUMAFLOW_SERVING_SOURCE: {e:?}")))?; - - let source_spec = serde_json::from_slice::(&source_spec_decoded) - .map_err(|e| Error::ParseConfig(format!("parsing NUMAFLOW_SERVING_SOURCE: {e:?}")))?; - - // Update tid_header from source_spec - if let Some(msg_id_header_key) = source_spec.msg_id_header_key { - settings.tid_header = msg_id_header_key; - } - - // Update redis.addr from source_spec, currently we only support redis as callback storage - settings.redis.addr = source_spec.callback_storage.url; - - Ok(settings) - } -} - -#[cfg(test)] -mod tests { - use crate::pipeline::{Edge, Vertex}; - - use super::*; - - #[test] - fn test_default_config() { - let settings = Settings::default(); - - assert_eq!(settings.tid_header, "ID"); - assert_eq!(settings.app_listen_port, 3000); - assert_eq!(settings.metrics_server_listen_port, 3001); - assert_eq!(settings.upstream_addr, "localhost:8888"); - assert_eq!(settings.drain_timeout_secs, 10); - assert_eq!(settings.redis.addr, "redis://127.0.0.1:6379"); - assert_eq!(settings.redis.max_tasks, 50); - assert_eq!(settings.redis.retries, 5); - assert_eq!(settings.redis.retries_duration_millis, 100); - } - - #[test] - fn test_config_parse() { - // Set up the environment variables - let env_vars = [ - (ENV_NUMAFLOW_SERVING_HOST_IP, "10.2.3.5"), - (ENV_NUMAFLOW_SERVING_AUTH_TOKEN, "api-auth-token"), - (ENV_NUMAFLOW_SERVING_APP_PORT, "8443"), - (ENV_NUMAFLOW_SERVING_STORE_TTL, "86400"), - (ENV_NUMAFLOW_SERVING_SOURCE_OBJECT, "eyJhdXRoIjpudWxsLCJzZXJ2aWNlIjp0cnVlLCJtc2dJREhlYWRlcktleSI6IlgtTnVtYWZsb3ctSWQiLCJzdG9yZSI6eyJ1cmwiOiJyZWRpczovL3JlZGlzOjYzNzkifX0="), - (ENV_MIN_PIPELINE_SPEC, "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6InNlcnZpbmctaW4iLCJzb3VyY2UiOnsic2VydmluZyI6eyJhdXRoIjpudWxsLCJzZXJ2aWNlIjp0cnVlLCJtc2dJREhlYWRlcktleSI6IlgtTnVtYWZsb3ctSWQiLCJzdG9yZSI6eyJ1cmwiOiJyZWRpczovL3JlZGlzOjYzNzkifX19LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciIsImVudiI6W3sibmFtZSI6IlJVU1RfTE9HIiwidmFsdWUiOiJpbmZvIn1dfSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2aW5nLXNpbmsiLCJzaW5rIjp7InVkc2luayI6eyJjb250YWluZXIiOnsiaW1hZ2UiOiJxdWF5LmlvL251bWFpby9udW1hZmxvdy1ycy9zaW5rLWxvZzpzdGFibGUiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9fX0sInJldHJ5U3RyYXRlZ3kiOnt9fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fV0sImVkZ2VzIjpbeyJmcm9tIjoic2VydmluZy1pbiIsInRvIjoic2VydmluZy1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH1dLCJsaWZlY3ljbGUiOnt9LCJ3YXRlcm1hcmsiOnt9fQ==") - ]; - - // Call the config method - let settings: Settings = env_vars - .into_iter() - .map(|(key, val)| (key.to_owned(), val.to_owned())) - .collect::>() - .try_into() - .unwrap(); - - let expected_config = Settings { - tid_header: "X-Numaflow-Id".into(), - app_listen_port: 8443, - metrics_server_listen_port: 3001, - upstream_addr: "localhost:8888".into(), - drain_timeout_secs: 10, - redis: RedisConfig { - addr: "redis://redis:6379".into(), - max_tasks: 50, - retries: 5, - retries_duration_millis: 100, - ttl_secs: Some(86400), - }, - host_ip: "10.2.3.5".into(), - api_auth_token: Some("api-auth-token".into()), - pipeline_spec: PipelineDCG { - vertices: vec![ - Vertex { - name: "serving-in".into(), - }, - Vertex { - name: "serving-sink".into(), - }, - ], - edges: vec![Edge { - from: "serving-in".into(), - to: "serving-sink".into(), - conditions: None, - }], - }, - }; - assert_eq!(settings, expected_config); - } -} diff --git a/rust/extns/numaflow-serving/src/errors.rs b/rust/extns/numaflow-serving/src/errors.rs deleted file mode 100644 index 7446faee71..0000000000 --- a/rust/extns/numaflow-serving/src/errors.rs +++ /dev/null @@ -1,50 +0,0 @@ -use tokio::sync::oneshot; - -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("Initialization error - {0}")] - InitError(String), - - #[error("Failed to parse configuration - {0}")] - ParseConfig(String), - // - // callback errors - // TODO: store the ID too? - #[error("IDNotFound Error - {0}")] - IDNotFound(&'static str), - - #[error("SubGraphGenerator Error - {0}")] - // subgraph generator errors - SubGraphGenerator(String), - - #[error("StoreWrite Error - {0}")] - // Store write errors - StoreWrite(String), - - #[error("SubGraphNotFound Error - {0}")] - // Sub Graph Not Found Error - SubGraphNotFound(&'static str), - - #[error("SubGraphInvalidInput Error - {0}")] - // Sub Graph Invalid Input Error - SubGraphInvalidInput(String), - - #[error("StoreRead Error - {0}")] - // Store read errors - StoreRead(String), - - #[error("Metrics Error - {0}")] - // Metrics errors - MetricsServer(String), - - #[error("Connection Error - {0}")] - Connection(String), - - #[error("Failed to receive message from channel. Actor task is terminated: {0:?}")] - ActorTaskTerminated(oneshot::error::RecvError), - - #[error("{0}")] - Other(String), -} - -pub type Result = std::result::Result; diff --git a/rust/extns/numaflow-serving/src/lib.rs b/rust/extns/numaflow-serving/src/lib.rs deleted file mode 100644 index a6a446203e..0000000000 --- a/rust/extns/numaflow-serving/src/lib.rs +++ /dev/null @@ -1,12 +0,0 @@ -mod source; -pub use source::{Message, MessageWrapper, ServingSource}; - -mod app; -pub mod config; - -mod errors; -pub use errors::{Error, Result}; - -pub(crate) mod pipeline; - -pub type Settings = config::Settings; diff --git a/rust/extns/numaflow-serving/src/pipeline.rs b/rust/extns/numaflow-serving/src/pipeline.rs deleted file mode 100644 index 50236f5381..0000000000 --- a/rust/extns/numaflow-serving/src/pipeline.rs +++ /dev/null @@ -1,154 +0,0 @@ -use std::str::FromStr; - -use base64::prelude::BASE64_STANDARD; -use base64::Engine; -use numaflow_models::models::PipelineSpec; -use serde::{Deserialize, Serialize}; - -use crate::Error::ParseConfig; - -// OperatorType is an enum that contains the types of operators -// that can be used in the conditions for the edge. -#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] -pub enum OperatorType { - #[serde(rename = "and")] - And, - #[serde(rename = "or")] - Or, - #[serde(rename = "not")] - Not, -} - -#[allow(dead_code)] -impl OperatorType { - fn as_str(&self) -> &'static str { - match self { - OperatorType::And => "and", - OperatorType::Or => "or", - OperatorType::Not => "not", - } - } -} - -impl From for OperatorType { - fn from(s: String) -> Self { - match s.as_str() { - "and" => OperatorType::And, - "or" => OperatorType::Or, - "not" => OperatorType::Not, - _ => panic!("Invalid operator type: {}", s), - } - } -} - -// Tag is a struct that contains the information about the tags for the edge -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Tag { - pub operator: Option, - pub values: Vec, -} - -// Conditions is a struct that contains the information about the conditions for the edge -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Conditions { - pub tags: Option, -} - -// Edge is a struct that contains the information about the edge in the pipeline. -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Edge { - pub from: String, - pub to: String, - pub conditions: Option, -} - -/// DCG (directed compute graph) of the pipeline with minimal information build using vertices and edges -/// from the pipeline spec -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] -pub struct PipelineDCG { - pub vertices: Vec, - pub edges: Vec, -} - -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Vertex { - pub name: String, -} - -impl FromStr for PipelineDCG { - type Err = crate::Error; - - fn from_str(pipeline_spec_encoded: &str) -> Result { - let full_pipeline_spec_decoded = BASE64_STANDARD - .decode(pipeline_spec_encoded) - .map_err(|e| ParseConfig(format!("Decoding pipeline from env: {e:?}")))?; - - let full_pipeline_spec = - serde_json::from_slice::(&full_pipeline_spec_decoded) - .map_err(|e| ParseConfig(format!("parsing pipeline from env: {e:?}")))?; - - let vertices: Vec = full_pipeline_spec - .vertices - .ok_or(ParseConfig("missing vertices in pipeline spec".to_string()))? - .iter() - .map(|v| Vertex { - name: v.name.clone(), - }) - .collect(); - - let edges: Vec = full_pipeline_spec - .edges - .ok_or(ParseConfig("missing edges in pipeline spec".to_string()))? - .iter() - .map(|e| { - let conditions = e.conditions.clone().map(|c| Conditions { - tags: Some(Tag { - operator: c.tags.operator.map(|o| o.into()), - values: c.tags.values.clone(), - }), - }); - - Edge { - from: e.from.clone(), - to: e.to.clone(), - conditions, - } - }) - .collect(); - - Ok(PipelineDCG { vertices, edges }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_pipeline_load() { - let pipeline: PipelineDCG = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0=".parse().unwrap(); - - assert_eq!(pipeline.vertices.len(), 8); - assert_eq!(pipeline.edges.len(), 10); - assert_eq!(pipeline.vertices[0].name, "in"); - assert_eq!(pipeline.edges[0].from, "in"); - assert_eq!(pipeline.edges[0].to, "planner"); - assert!(pipeline.edges[0].conditions.is_none()); - - assert_eq!(pipeline.vertices[1].name, "planner"); - assert_eq!(pipeline.edges[1].from, "planner"); - assert_eq!(pipeline.edges[1].to, "asciiart"); - assert_eq!( - pipeline.edges[1].conditions, - Some(Conditions { - tags: Some(Tag { - operator: Some(OperatorType::Or), - values: vec!["asciiart".to_owned()] - }) - }) - ); - - assert_eq!(pipeline.vertices[2].name, "tiger"); - assert_eq!(pipeline.vertices[3].name, "dog"); - } -} diff --git a/rust/extns/numaflow-serving/src/source.rs b/rust/extns/numaflow-serving/src/source.rs deleted file mode 100644 index af7d8aaa8f..0000000000 --- a/rust/extns/numaflow-serving/src/source.rs +++ /dev/null @@ -1,194 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - -use bytes::Bytes; -use tokio::sync::{mpsc, oneshot}; -use tokio::time::Instant; - -use crate::app::callback::state::State as CallbackState; -use crate::app::callback::store::redisstore::RedisConnection; -use crate::app::tracker::MessageGraph; -use crate::{app, Settings}; -use crate::{Error, Result}; - -pub struct MessageWrapper { - pub confirm_save: oneshot::Sender<()>, - pub message: Message, -} - -pub struct Message { - pub value: Bytes, - pub id: String, - pub headers: HashMap, -} - -enum ActorMessage { - Read { - batch_size: usize, - timeout_at: Instant, - reply_to: oneshot::Sender>, - }, - Ack { - offsets: Vec, - reply_to: oneshot::Sender<()>, - }, -} - -struct ServingSourceActor { - messages: mpsc::Receiver, - handler_rx: mpsc::Receiver, - tracker: HashMap>, -} - -impl ServingSourceActor { - async fn start( - settings: Arc, - handler_rx: mpsc::Receiver, - ) -> Result<()> { - let (messages_tx, messages_rx) = mpsc::channel(10000); - // Create a redis store to store the callbacks and the custom responses - let redis_store = RedisConnection::new(settings.redis.clone()).await?; - // Create the message graph from the pipeline spec and the redis store - let msg_graph = MessageGraph::from_pipeline(&settings.pipeline_spec).map_err(|e| { - Error::InitError(format!( - "Creating message graph from pipeline spec: {:?}", - e - )) - })?; - let callback_state = CallbackState::new(msg_graph, redis_store).await?; - - tokio::spawn(async move { - let mut serving_actor = ServingSourceActor { - messages: messages_rx, - handler_rx, - tracker: HashMap::new(), - }; - serving_actor.run().await; - }); - let app = app::AppState { - message: messages_tx, - settings, - callback_state, - }; - app::serve(app).await.unwrap(); - Ok(()) - } - - async fn run(&mut self) { - while let Some(msg) = self.handler_rx.recv().await { - self.handle_message(msg).await; - } - } - - async fn handle_message(&mut self, actor_msg: ActorMessage) { - match actor_msg { - ActorMessage::Read { - batch_size, - timeout_at, - reply_to, - } => { - let messages = self.read(batch_size, timeout_at).await; - let _ = reply_to.send(messages); - } - ActorMessage::Ack { offsets, reply_to } => { - self.ack(offsets).await; - let _ = reply_to.send(()); - } - } - } - - async fn read(&mut self, count: usize, timeout_at: Instant) -> Vec { - let mut messages = vec![]; - loop { - if messages.len() >= count || Instant::now() >= timeout_at { - break; - } - let message = match self.messages.try_recv() { - Ok(msg) => msg, - Err(mpsc::error::TryRecvError::Empty) => break, - Err(e) => { - tracing::error!(?e, "Receiving messages from the serving channel"); // FIXME: - return messages; - } - }; - let MessageWrapper { - confirm_save, - message, - } = message; - - self.tracker.insert(message.id.clone(), confirm_save); - messages.push(message); - } - messages - } - - async fn ack(&mut self, offsets: Vec) { - for offset in offsets { - let offset = offset - .strip_suffix("-0") - .expect("offset does not end with '-0'"); // FIXME: we hardcode 0 as the partition index when constructing offset - let confirm_save_tx = self - .tracker - .remove(offset) - .expect("offset was not found in the tracker"); - confirm_save_tx - .send(()) - .expect("Sending on confirm_save channel"); - } - } -} - -#[derive(Clone)] -pub struct ServingSource { - batch_size: usize, - // timeout for each batch read request - timeout: Duration, - actor_tx: mpsc::Sender, -} - -impl ServingSource { - pub async fn new( - settings: Arc, - batch_size: usize, - timeout: Duration, - ) -> Result { - let (actor_tx, actor_rx) = mpsc::channel(10); - ServingSourceActor::start(settings, actor_rx).await?; - Ok(Self { - batch_size, - timeout, - actor_tx, - }) - } - - pub async fn read_messages(&self) -> Result> { - let start = Instant::now(); - let (tx, rx) = oneshot::channel(); - let actor_msg = ActorMessage::Read { - reply_to: tx, - batch_size: self.batch_size, - timeout_at: Instant::now() + self.timeout, - }; - let _ = self.actor_tx.send(actor_msg).await; - let messages = rx.await.map_err(Error::ActorTaskTerminated)?; - tracing::debug!( - count = messages.len(), - requested_count = self.batch_size, - time_taken_ms = start.elapsed().as_millis(), - "Got messages from Serving source" - ); - Ok(messages) - } - - pub async fn ack_messages(&self, offsets: Vec) -> Result<()> { - let (tx, rx) = oneshot::channel(); - let actor_msg = ActorMessage::Ack { - offsets, - reply_to: tx, - }; - let _ = self.actor_tx.send(actor_msg).await; - rx.await.map_err(Error::ActorTaskTerminated)?; - Ok(()) - } -} diff --git a/rust/numaflow-core/Cargo.toml b/rust/numaflow-core/Cargo.toml index dc0f6582da..ea7e22a04e 100644 --- a/rust/numaflow-core/Cargo.toml +++ b/rust/numaflow-core/Cargo.toml @@ -19,7 +19,6 @@ numaflow-models.workspace = true numaflow-pb.workspace = true serving.workspace = true backoff.workspace = true -numaflow-serving.workspace = true axum.workspace = true axum-server.workspace = true bytes.workspace = true diff --git a/rust/numaflow-core/src/config/components.rs b/rust/numaflow-core/src/config/components.rs index 29734c98a1..833ad8950b 100644 --- a/rust/numaflow-core/src/config/components.rs +++ b/rust/numaflow-core/src/config/components.rs @@ -38,7 +38,7 @@ pub(crate) mod source { Generator(GeneratorConfig), UserDefined(UserDefinedConfig), Pulsar(PulsarSourceConfig), - Serving(Arc), + Serving(Arc), } impl From> for SourceType { @@ -111,7 +111,7 @@ pub(crate) mod source { // There should be only one option (user-defined) to define the settings. fn try_from(cfg: Box) -> Result { let env_vars = env::vars().collect::>(); - let mut settings: numaflow_serving::Settings = env_vars.try_into()?; + let mut settings: serving::Settings = env_vars.try_into()?; settings.tid_header = cfg.msg_id_header_key; diff --git a/rust/numaflow-core/src/metrics.rs b/rust/numaflow-core/src/metrics.rs index fa79e457b8..2a672ec31d 100644 --- a/rust/numaflow-core/src/metrics.rs +++ b/rust/numaflow-core/src/metrics.rs @@ -600,8 +600,6 @@ pub(crate) async fn start_metrics_https_server( addr: SocketAddr, metrics_state: UserDefinedContainerState, ) -> crate::Result<()> { - let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); - // Generate a self-signed certificate let CertifiedKey { cert, key_pair } = generate_simple_self_signed(vec!["localhost".into()]) .map_err(|e| Error::Metrics(format!("Generating self-signed certificate: {}", e)))?; diff --git a/rust/numaflow-core/src/shared/create_components.rs b/rust/numaflow-core/src/shared/create_components.rs index eeae1136e8..c077a1f44f 100644 --- a/rust/numaflow-core/src/shared/create_components.rs +++ b/rust/numaflow-core/src/shared/create_components.rs @@ -5,7 +5,7 @@ use numaflow_pb::clients::map::map_client::MapClient; use numaflow_pb::clients::sink::sink_client::SinkClient; use numaflow_pb::clients::source::source_client::SourceClient; use numaflow_pb::clients::sourcetransformer::source_transform_client::SourceTransformClient; -use numaflow_serving::ServingSource; +use serving::ServingSource; use tokio_util::sync::CancellationToken; use tonic::transport::Channel; diff --git a/rust/numaflow-core/src/source.rs b/rust/numaflow-core/src/source.rs index 73bc69b1ec..a30fc9777e 100644 --- a/rust/numaflow-core/src/source.rs +++ b/rust/numaflow-core/src/source.rs @@ -1,5 +1,4 @@ use numaflow_pulsar::source::PulsarSource; -use numaflow_serving::ServingSource; use std::sync::Arc; use tokio::sync::OwnedSemaphorePermit; use tokio::sync::Semaphore; @@ -39,6 +38,7 @@ pub(crate) mod generator; pub(crate) mod pulsar; pub(crate) mod serving; +use serving::ServingSource; /// Set of Read related items that has to be implemented to become a Source. pub(crate) trait SourceReader { diff --git a/rust/numaflow-core/src/source/serving.rs b/rust/numaflow-core/src/source/serving.rs index a06f418d66..02074904eb 100644 --- a/rust/numaflow-core/src/source/serving.rs +++ b/rust/numaflow-core/src/source/serving.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use numaflow_serving::ServingSource; +pub(crate) use serving::ServingSource; use crate::message::{MessageID, StringOffset}; use crate::Error; @@ -8,10 +8,10 @@ use crate::Result; use super::{get_vertex_name, Message, Offset}; -impl TryFrom for Message { +impl TryFrom for Message { type Error = Error; - fn try_from(message: numaflow_serving::Message) -> Result { + fn try_from(message: serving::Message) -> Result { let offset = Offset::String(StringOffset::new(message.id.clone(), 0)); Ok(Message { @@ -30,8 +30,8 @@ impl TryFrom for Message { } } -impl From for Error { - fn from(value: numaflow_serving::Error) -> Self { +impl From for Error { + fn from(value: serving::Error) -> Self { Error::Source(value.to_string()) } } diff --git a/rust/serving/Cargo.toml b/rust/serving/Cargo.toml index 0dbe0818c5..427bc84ce3 100644 --- a/rust/serving/Cargo.toml +++ b/rust/serving/Cargo.toml @@ -19,7 +19,6 @@ backoff.workspace = true axum.workspace = true axum-server.workspace = true bytes.workspace = true -async-nats = "0.35.1" axum-macros = "0.4.1" hyper-util = { version = "0.1.6", features = ["client-legacy"] } serde = { version = "1.0.204", features = ["derive"] } diff --git a/rust/serving/src/app.rs b/rust/serving/src/app.rs index b8bbc4c76e..5161d67ac0 100644 --- a/rust/serving/src/app.rs +++ b/rust/serving/src/app.rs @@ -246,16 +246,15 @@ async fn routes( mod tests { use std::sync::Arc; - use async_nats::jetstream::stream; use axum::http::StatusCode; - use tokio::sync::mpsc; - use tokio::time::{sleep, Duration}; use tower::ServiceExt; use super::*; use crate::app::callback::store::memstore::InMemoryStore; - use crate::config::generate_certs; use crate::Settings; + use callback::state::State as CallbackState; + use tokio::sync::mpsc; + use tracker::MessageGraph; const PIPELINE_SPEC_ENCODED: &str = "eyJ2ZXJ0aWNlcyI6W3sibmFtZSI6ImluIiwic291cmNlIjp7InNlcnZpbmciOnsiYXV0aCI6bnVsbCwic2VydmljZSI6dHJ1ZSwibXNnSURIZWFkZXJLZXkiOiJYLU51bWFmbG93LUlkIiwic3RvcmUiOnsidXJsIjoicmVkaXM6Ly9yZWRpczo2Mzc5In19fSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIiLCJlbnYiOlt7Im5hbWUiOiJSVVNUX0xPRyIsInZhbHVlIjoiZGVidWcifV19LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InBsYW5uZXIiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJwbGFubmVyIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6InRpZ2VyIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsidGlnZXIiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZG9nIiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZG9nIl0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sImJ1aWx0aW4iOm51bGwsImdyb3VwQnkiOm51bGx9LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19LHsibmFtZSI6ImVsZXBoYW50IiwidWRmIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6ImFzY2lpOjAuMSIsImFyZ3MiOlsiZWxlcGhhbnQiXSwicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwiYnVpbHRpbiI6bnVsbCwiZ3JvdXBCeSI6bnVsbH0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiYXNjaWlhcnQiLCJ1ZGYiOnsiY29udGFpbmVyIjp7ImltYWdlIjoiYXNjaWk6MC4xIiwiYXJncyI6WyJhc2NpaWFydCJdLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJidWlsdGluIjpudWxsLCJncm91cEJ5IjpudWxsfSwiY29udGFpbmVyVGVtcGxhdGUiOnsicmVzb3VyY2VzIjp7fSwiaW1hZ2VQdWxsUG9saWN5IjoiTmV2ZXIifSwic2NhbGUiOnsibWluIjoxfSwidXBkYXRlU3RyYXRlZ3kiOnsidHlwZSI6IlJvbGxpbmdVcGRhdGUiLCJyb2xsaW5nVXBkYXRlIjp7Im1heFVuYXZhaWxhYmxlIjoiMjUlIn19fSx7Im5hbWUiOiJzZXJ2ZS1zaW5rIiwic2luayI6eyJ1ZHNpbmsiOnsiY29udGFpbmVyIjp7ImltYWdlIjoic2VydmVzaW5rOjAuMSIsImVudiI6W3sibmFtZSI6Ik5VTUFGTE9XX0NBTExCQUNLX1VSTF9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctQ2FsbGJhY2stVXJsIn0seyJuYW1lIjoiTlVNQUZMT1dfTVNHX0lEX0hFQURFUl9LRVkiLCJ2YWx1ZSI6IlgtTnVtYWZsb3ctSWQifV0sInJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn19LCJyZXRyeVN0cmF0ZWd5Ijp7fX0sImNvbnRhaW5lclRlbXBsYXRlIjp7InJlc291cmNlcyI6e30sImltYWdlUHVsbFBvbGljeSI6Ik5ldmVyIn0sInNjYWxlIjp7Im1pbiI6MX0sInVwZGF0ZVN0cmF0ZWd5Ijp7InR5cGUiOiJSb2xsaW5nVXBkYXRlIiwicm9sbGluZ1VwZGF0ZSI6eyJtYXhVbmF2YWlsYWJsZSI6IjI1JSJ9fX0seyJuYW1lIjoiZXJyb3Itc2luayIsInNpbmsiOnsidWRzaW5rIjp7ImNvbnRhaW5lciI6eyJpbWFnZSI6InNlcnZlc2luazowLjEiLCJlbnYiOlt7Im5hbWUiOiJOVU1BRkxPV19DQUxMQkFDS19VUkxfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUNhbGxiYWNrLVVybCJ9LHsibmFtZSI6Ik5VTUFGTE9XX01TR19JRF9IRUFERVJfS0VZIiwidmFsdWUiOiJYLU51bWFmbG93LUlkIn1dLCJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9fSwicmV0cnlTdHJhdGVneSI6e319LCJjb250YWluZXJUZW1wbGF0ZSI6eyJyZXNvdXJjZXMiOnt9LCJpbWFnZVB1bGxQb2xpY3kiOiJOZXZlciJ9LCJzY2FsZSI6eyJtaW4iOjF9LCJ1cGRhdGVTdHJhdGVneSI6eyJ0eXBlIjoiUm9sbGluZ1VwZGF0ZSIsInJvbGxpbmdVcGRhdGUiOnsibWF4VW5hdmFpbGFibGUiOiIyNSUifX19XSwiZWRnZXMiOlt7ImZyb20iOiJpbiIsInRvIjoicGxhbm5lciIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImFzY2lpYXJ0IiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiYXNjaWlhcnQiXX19fSx7ImZyb20iOiJwbGFubmVyIiwidG8iOiJ0aWdlciIsImNvbmRpdGlvbnMiOnsidGFncyI6eyJvcGVyYXRvciI6Im9yIiwidmFsdWVzIjpbInRpZ2VyIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZG9nIiwiY29uZGl0aW9ucyI6eyJ0YWdzIjp7Im9wZXJhdG9yIjoib3IiLCJ2YWx1ZXMiOlsiZG9nIl19fX0seyJmcm9tIjoicGxhbm5lciIsInRvIjoiZWxlcGhhbnQiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlbGVwaGFudCJdfX19LHsiZnJvbSI6InRpZ2VyIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZG9nIiwidG8iOiJzZXJ2ZS1zaW5rIiwiY29uZGl0aW9ucyI6bnVsbH0seyJmcm9tIjoiZWxlcGhhbnQiLCJ0byI6InNlcnZlLXNpbmsiLCJjb25kaXRpb25zIjpudWxsfSx7ImZyb20iOiJhc2NpaWFydCIsInRvIjoic2VydmUtc2luayIsImNvbmRpdGlvbnMiOm51bGx9LHsiZnJvbSI6InBsYW5uZXIiLCJ0byI6ImVycm9yLXNpbmsiLCJjb25kaXRpb25zIjp7InRhZ3MiOnsib3BlcmF0b3IiOiJvciIsInZhbHVlcyI6WyJlcnJvciJdfX19XSwibGlmZWN5Y2xlIjp7fSwid2F0ZXJtYXJrIjp7fX0="; @@ -266,141 +265,78 @@ mod tests { #[tokio::test] async fn test_setup_app() -> Result<()> { let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await?; - let context = jetstream::new(client); - let stream_name = &settings.jetstream.stream; - - let stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await; - - assert!(stream.is_ok()); let mem_store = InMemoryStore::new(); let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; let callback_state = CallbackState::new(msg_graph, mem_store).await?; + let (tx, _) = mpsc::channel(10); + let app = AppState { + message: tx, + settings, + callback_state, + }; - let result = setup_app(settings, context, callback_state).await; + let result = setup_app(app).await; assert!(result.is_ok()); Ok(()) } #[cfg(feature = "all-tests")] #[tokio::test] - async fn test_livez() -> Result<()> { + async fn test_health_check_endpoints() -> Result<()> { let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await?; - let context = jetstream::new(client); - let stream_name = &settings.jetstream.stream; - - let stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await; - - assert!(stream.is_ok()); let mem_store = InMemoryStore::new(); - let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; - + let msg_graph = MessageGraph::from_pipeline(&settings.pipeline_spec)?; let callback_state = CallbackState::new(msg_graph, mem_store).await?; - let result = setup_app(settings, context, callback_state).await; + let (messages_tx, _messages_rx) = mpsc::channel(10); + let app = AppState { + message: messages_tx, + settings, + callback_state, + }; + + let router = setup_app(app).await.unwrap(); let request = Request::builder().uri("/livez").body(Body::empty())?; - - let response = result?.oneshot(request).await?; + let response = router.clone().oneshot(request).await?; assert_eq!(response.status(), StatusCode::NO_CONTENT); - Ok(()) - } - - #[cfg(feature = "all-tests")] - #[tokio::test] - async fn test_readyz() -> Result<()> { - let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await?; - let context = jetstream::new(client); - let stream_name = &settings.jetstream.stream; - - let stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await; - - assert!(stream.is_ok()); - - let mem_store = InMemoryStore::new(); - let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; - - let callback_state = CallbackState::new(msg_graph, mem_store).await?; - - let result = setup_app(settings, context, callback_state).await; let request = Request::builder().uri("/readyz").body(Body::empty())?; - - let response = result.unwrap().oneshot(request).await?; + let response = router.clone().oneshot(request).await?; assert_eq!(response.status(), StatusCode::NO_CONTENT); - Ok(()) - } - #[tokio::test] - async fn test_health_check() { - let response = health_check().await; - let response = response.into_response(); + let request = Request::builder().uri("/health").body(Body::empty())?; + let response = router.clone().oneshot(request).await?; assert_eq!(response.status(), StatusCode::OK); + Ok(()) } #[cfg(feature = "all-tests")] #[tokio::test] async fn test_auth_middleware() -> Result<()> { - let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await?; - let context = jetstream::new(client); - let stream_name = &settings.jetstream.stream; - - let stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await; - - assert!(stream.is_ok()); + let settings = Settings { + api_auth_token: Some("test-token".into()), + ..Default::default() + }; let mem_store = InMemoryStore::new(); let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; let callback_state = CallbackState::new(msg_graph, mem_store).await?; + let (messages_tx, _messages_rx) = mpsc::channel(10); let app_state = AppState { - settings, + message: messages_tx, + settings: Arc::new(settings), callback_state, - context, }; - let app = Router::new() - .nest("/v1/process", routes(app_state).await.unwrap()) - .layer(middleware::from_fn_with_state( - Some("test_token".to_owned()), - auth_middleware, - )); - - let res = app + let router = setup_app(app_state).await.unwrap(); + let res = router .oneshot( axum::extract::Request::builder() .uri("/v1/process/sync") diff --git a/rust/serving/src/app/jetstream_proxy.rs b/rust/serving/src/app/jetstream_proxy.rs index e5a7a661e1..30fb8e1ef9 100644 --- a/rust/serving/src/app/jetstream_proxy.rs +++ b/rust/serving/src/app/jetstream_proxy.rs @@ -256,8 +256,6 @@ fn extract_id_from_headers(tid_header: &str, headers: &HeaderMap) -> String { mod tests { use std::sync::Arc; - use async_nats::jetstream; - use async_nats::jetstream::stream; use axum::body::{to_bytes, Body}; use axum::extract::Request; use axum::http::header::{CONTENT_LENGTH, CONTENT_TYPE}; @@ -298,46 +296,38 @@ mod tests { #[tokio::test] async fn test_async_publish() -> Result<(), Box> { - let settings = Settings::default(); - let settings = Arc::new(settings); - let client = async_nats::connect(&settings.jetstream.url) - .await - .map_err(|e| format!("Connecting to Jetstream: {:?}", e))?; - - let context = jetstream::new(client); - let id = "foobar"; - let stream_name = "default"; - - let _stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await - .map_err(|e| format!("creating stream {}: {}", &settings.jetstream.url, e))?; + const ID_HEADER: &str = "X-Numaflow-ID"; + const ID_VALUE: &str = "foobar"; + let settings = Settings { + tid_header: ID_HEADER.into(), + ..Default::default() + }; let mock_store = MockStore {}; - let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); - let msg_graph = MessageGraph::from_pipeline(&pipeline_spec) - .map_err(|e| format!("Failed to create message graph from pipeline spec: {:?}", e))?; - + let pipeline_spec = PIPELINE_SPEC_ENCODED.parse().unwrap(); + let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; let callback_state = CallbackState::new(msg_graph, mock_store).await?; + + let (messages_tx, mut messages_rx) = mpsc::channel(10); let app_state = AppState { + message: messages_tx, + settings: Arc::new(settings), callback_state, - context, - settings, }; + let app = jetstream_proxy(app_state).await?; let res = Request::builder() .method("POST") .uri("/async") .header(CONTENT_TYPE, "text/plain") - .header("id", id) + .header(ID_HEADER, ID_VALUE) .body(Body::from("Test Message")) .unwrap(); let response = app.oneshot(res).await.unwrap(); + let message = messages_rx.recv().await.unwrap(); + assert_eq!(message.message.id, ID_VALUE); + message.confirm_save.send(()).unwrap(); assert_eq!(response.status(), StatusCode::OK); let result = extract_response_from_body(response.into_body()).await; @@ -345,7 +335,7 @@ mod tests { result, json!({ "message": "Successfully published message", - "id": id, + "id": ID_HEADER, "code": 200 }) ); @@ -387,20 +377,12 @@ mod tests { #[tokio::test] async fn test_sync_publish() { - let settings = Settings::default(); - let client = async_nats::connect(&settings.jetstream.url).await.unwrap(); - let context = jetstream::new(client); - let id = "foobar"; - let stream_name = "sync_pub"; - - let _stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await - .map_err(|e| format!("creating stream {}: {}", &settings.jetstream.url, e)); + const ID_HEADER: &str = "X-Numaflow-ID"; + const ID_VALUE: &str = "foobar"; + let settings = Settings { + tid_header: ID_HEADER.into(), + ..Default::default() + }; let mem_store = InMemoryStore::new(); let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); @@ -408,16 +390,17 @@ mod tests { let mut callback_state = CallbackState::new(msg_graph, mem_store).await.unwrap(); - let settings = Arc::new(settings); + let (messages_tx, mut messages_rx) = mpsc::channel(10); let app_state = AppState { - settings, + message: messages_tx, + settings: Arc::new(settings), callback_state: callback_state.clone(), - context, }; + let app = jetstream_proxy(app_state).await.unwrap(); tokio::spawn(async move { - let cbs = create_default_callbacks(id); + let cbs = create_default_callbacks(ID_VALUE); let mut retries = 0; loop { match callback_state.insert_callback_requests(cbs.clone()).await { @@ -437,11 +420,13 @@ mod tests { .method("POST") .uri("/sync") .header("Content-Type", "text/plain") - .header("id", id) + .header(ID_HEADER, ID_VALUE) .body(Body::from("Test Message")) .unwrap(); let response = app.clone().oneshot(res).await.unwrap(); + let message = messages_rx.recv().await.unwrap(); + message.confirm_save.send(()).unwrap(); assert_eq!(response.status(), StatusCode::OK); let result = extract_response_from_body(response.into_body()).await; @@ -449,7 +434,7 @@ mod tests { result, json!({ "message": "Successfully processed the message", - "id": id, + "id": ID_VALUE, "code": 200 }) ); @@ -457,20 +442,8 @@ mod tests { #[tokio::test] async fn test_sync_publish_serve() { + const ID_VALUE: &str = "foobar"; let settings = Arc::new(Settings::default()); - let client = async_nats::connect(&settings.jetstream.url).await.unwrap(); - let context = jetstream::new(client); - let id = "foobar"; - let stream_name = "sync_serve_pub"; - - let _stream = context - .get_or_create_stream(stream::Config { - name: stream_name.into(), - subjects: vec![stream_name.into()], - ..Default::default() - }) - .await - .map_err(|e| format!("creating stream {}: {}", &settings.jetstream.url, e)); let mem_store = InMemoryStore::new(); let pipeline_spec: PipelineDCG = PIPELINE_SPEC_ENCODED.parse().unwrap(); @@ -478,16 +451,17 @@ mod tests { let mut callback_state = CallbackState::new(msg_graph, mem_store).await.unwrap(); + let (messages_tx, mut messages_rx) = mpsc::channel(10); let app_state = AppState { + message: messages_tx, settings, callback_state: callback_state.clone(), - context, }; let app = jetstream_proxy(app_state).await.unwrap(); // pipeline is in -> cat -> out, so we will have 3 callback requests - let cbs = create_default_callbacks(id); + let cbs = create_default_callbacks(ID_VALUE); // spawn a tokio task which will insert the callback requests to the callback state // if it fails, sleep for 10ms and retry @@ -526,11 +500,13 @@ mod tests { .method("POST") .uri("/sync_serve") .header("Content-Type", "text/plain") - .header("id", id) + .header("ID", ID_VALUE) .body(Body::from("Test Message")) .unwrap(); let response = app.oneshot(res).await.unwrap(); + let message = messages_rx.recv().await.unwrap(); + message.confirm_save.send(()).unwrap(); assert_eq!(response.status(), StatusCode::OK); let content_len = response.headers().get(CONTENT_LENGTH).unwrap(); diff --git a/rust/serving/src/config.rs b/rust/serving/src/config.rs index 03fb2beb17..16c2ee125c 100644 --- a/rust/serving/src/config.rs +++ b/rust/serving/src/config.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use std::fmt::Debug; -use async_nats::rustls; use base64::prelude::BASE64_STANDARD; use base64::Engine; use rcgen::{generate_simple_self_signed, Certificate, CertifiedKey, KeyPair}; @@ -20,7 +19,6 @@ const ENV_NUMAFLOW_SERVING_AUTH_TOKEN: &str = "NUMAFLOW_SERVING_AUTH_TOKEN"; const ENV_MIN_PIPELINE_SPEC: &str = "NUMAFLOW_SERVING_MIN_PIPELINE_SPEC"; pub fn generate_certs() -> std::result::Result<(Certificate, KeyPair), String> { - let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); let CertifiedKey { cert, key_pair } = generate_simple_self_signed(vec!["localhost".into()]) .map_err(|e| format!("Failed to generate cert {:?}", e))?; Ok((cert, key_pair)) From ca8cdebb69896fd003df31c7d374801d70a44ffe Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 24 Dec 2024 18:53:27 +0530 Subject: [PATCH 08/15] Unit tests for Serving endpoints Signed-off-by: Sreekanth --- rust/Cargo.lock | 1 + rust/Cargo.toml | 1 + rust/numaflow-core/Cargo.toml | 2 +- rust/serving/Cargo.toml | 5 ++- rust/serving/src/app.rs | 45 ++++++++++++---------- rust/serving/src/app/jetstream_proxy.rs | 51 ++++++++++++++++++++----- rust/serving/src/metrics.rs | 3 ++ 7 files changed, 75 insertions(+), 33 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 8c9e480319..7d70eea7b5 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2840,6 +2840,7 @@ dependencies = [ "prometheus-client", "rcgen", "redis", + "rustls 0.23.19", "serde", "serde_json", "thiserror 1.0.69", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index db7deddb61..91e0cc7361 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -65,3 +65,4 @@ tracing = "0.1.40" axum = "0.7.5" axum-server = { version = "0.7.1", features = ["tls-rustls"] } serde = { version = "1.0.204", features = ["derive"] } +rustls = { version = "0.23.12", features = ["aws_lc_rs"] } \ No newline at end of file diff --git a/rust/numaflow-core/Cargo.toml b/rust/numaflow-core/Cargo.toml index ea7e22a04e..c7a33fe275 100644 --- a/rust/numaflow-core/Cargo.toml +++ b/rust/numaflow-core/Cargo.toml @@ -23,6 +23,7 @@ axum.workspace = true axum-server.workspace = true bytes.workspace = true serde.workspace = true +rustls.workspace = true tonic = "0.12.3" thiserror = "2.0.3" tokio-util = "0.7.11" @@ -36,7 +37,6 @@ tower = "0.4.13" serde_json = "1.0.122" trait-variant = "0.1.2" rcgen = "0.13.1" -rustls = { version = "0.23.12", features = ["aws_lc_rs"] } semver = "1.0" pep440_rs = "0.6.6" parking_lot = "0.12.3" diff --git a/rust/serving/Cargo.toml b/rust/serving/Cargo.toml index 427bc84ce3..8f8b86ff3b 100644 --- a/rust/serving/Cargo.toml +++ b/rust/serving/Cargo.toml @@ -5,8 +5,7 @@ edition = "2021" [features] redis-tests = [] -nats-tests = [] -all-tests = ["redis-tests", "nats-tests"] +all-tests = ["redis-tests"] [lints] workspace = true @@ -35,3 +34,5 @@ parking_lot = "0.12.3" prometheus-client = "0.22.3" thiserror = "1.0.63" +[dev-dependencies] +rustls.workspace = true diff --git a/rust/serving/src/app.rs b/rust/serving/src/app.rs index 5161d67ac0..82ef1ef62e 100644 --- a/rust/serving/src/app.rs +++ b/rust/serving/src/app.rs @@ -54,6 +54,27 @@ where .parse() .map_err(|e| InitError(format!("{e:?}")))?; + let handle = Handle::new(); + // Spawn a task to gracefully shutdown server. + tokio::spawn(graceful_shutdown(handle.clone())); + + info!(?app_addr, "Starting application server"); + + let router = router_with_auth(app).await?; + + axum_server::bind_rustls(app_addr, tls_config) + .handle(handle) + .serve(router.into_make_service()) + .await + .map_err(|e| InitError(format!("Starting web server for metrics: {}", e)))?; + + Ok(()) +} + +pub(crate) async fn router_with_auth(app: AppState) -> crate::Result +where + T: Clone + Send + Sync + Store + 'static, +{ let tid_header = app.settings.tid_header.clone(); let layers = ServiceBuilder::new() // Add tracing to all requests @@ -88,22 +109,7 @@ where app.settings.api_auth_token.clone(), auth_middleware, )); - - let handle = Handle::new(); - // Spawn a task to gracefully shutdown server. - tokio::spawn(graceful_shutdown(handle.clone())); - - let router = setup_app(app).await?.layer(layers); - - info!(?app_addr, "Starting application server"); - - axum_server::bind_rustls(app_addr, tls_config) - .handle(handle) - .serve(router.into_make_service()) - .await - .map_err(|e| InitError(format!("Starting web server for metrics: {}", e)))?; - - Ok(()) + Ok(setup_app(app).await?.layer(layers)) } // Gracefully shutdown the server on receiving SIGINT or SIGTERM @@ -261,7 +267,6 @@ mod tests { type Result = core::result::Result; type Error = Box; - #[cfg(feature = "all-tests")] #[tokio::test] async fn test_setup_app() -> Result<()> { let settings = Arc::new(Settings::default()); @@ -283,7 +288,6 @@ mod tests { Ok(()) } - #[cfg(feature = "all-tests")] #[tokio::test] async fn test_health_check_endpoints() -> Result<()> { let settings = Arc::new(Settings::default()); @@ -315,7 +319,6 @@ mod tests { Ok(()) } - #[cfg(feature = "all-tests")] #[tokio::test] async fn test_auth_middleware() -> Result<()> { let settings = Settings { @@ -329,16 +332,18 @@ mod tests { let callback_state = CallbackState::new(msg_graph, mem_store).await?; let (messages_tx, _messages_rx) = mpsc::channel(10); + let app_state = AppState { message: messages_tx, settings: Arc::new(settings), callback_state, }; - let router = setup_app(app_state).await.unwrap(); + let router = router_with_auth(app_state).await.unwrap(); let res = router .oneshot( axum::extract::Request::builder() + .method("POST") .uri("/v1/process/sync") .body(Body::empty()) .unwrap(), diff --git a/rust/serving/src/app/jetstream_proxy.rs b/rust/serving/src/app/jetstream_proxy.rs index 30fb8e1ef9..41d0c0df3e 100644 --- a/rust/serving/src/app/jetstream_proxy.rs +++ b/rust/serving/src/app/jetstream_proxy.rs @@ -251,7 +251,6 @@ fn extract_id_from_headers(tid_header: &str, headers: &HeaderMap) -> String { ) } -#[cfg(feature = "nats-tests")] #[cfg(test)] mod tests { use std::sync::Arc; @@ -308,7 +307,17 @@ mod tests { let msg_graph = MessageGraph::from_pipeline(&pipeline_spec)?; let callback_state = CallbackState::new(msg_graph, mock_store).await?; - let (messages_tx, mut messages_rx) = mpsc::channel(10); + let (messages_tx, mut messages_rx) = mpsc::channel::(10); + let response_collector = tokio::spawn(async move { + let message = messages_rx.recv().await.unwrap(); + let MessageWrapper { + confirm_save, + message, + } = message; + confirm_save.send(()).unwrap(); + message + }); + let app_state = AppState { message: messages_tx, settings: Arc::new(settings), @@ -325,9 +334,8 @@ mod tests { .unwrap(); let response = app.oneshot(res).await.unwrap(); - let message = messages_rx.recv().await.unwrap(); - assert_eq!(message.message.id, ID_VALUE); - message.confirm_save.send(()).unwrap(); + let message = response_collector.await.unwrap(); + assert_eq!(message.id, ID_VALUE); assert_eq!(response.status(), StatusCode::OK); let result = extract_response_from_body(response.into_body()).await; @@ -335,7 +343,7 @@ mod tests { result, json!({ "message": "Successfully published message", - "id": ID_HEADER, + "id": ID_VALUE, "code": 200 }) ); @@ -391,6 +399,17 @@ mod tests { let mut callback_state = CallbackState::new(msg_graph, mem_store).await.unwrap(); let (messages_tx, mut messages_rx) = mpsc::channel(10); + + let response_collector = tokio::spawn(async move { + let message = messages_rx.recv().await.unwrap(); + let MessageWrapper { + confirm_save, + message, + } = message; + confirm_save.send(()).unwrap(); + message + }); + let app_state = AppState { message: messages_tx, settings: Arc::new(settings), @@ -425,8 +444,8 @@ mod tests { .unwrap(); let response = app.clone().oneshot(res).await.unwrap(); - let message = messages_rx.recv().await.unwrap(); - message.confirm_save.send(()).unwrap(); + let message = response_collector.await.unwrap(); + assert_eq!(message.id, ID_VALUE); assert_eq!(response.status(), StatusCode::OK); let result = extract_response_from_body(response.into_body()).await; @@ -452,6 +471,17 @@ mod tests { let mut callback_state = CallbackState::new(msg_graph, mem_store).await.unwrap(); let (messages_tx, mut messages_rx) = mpsc::channel(10); + + let response_collector = tokio::spawn(async move { + let message = messages_rx.recv().await.unwrap(); + let MessageWrapper { + confirm_save, + message, + } = message; + confirm_save.send(()).unwrap(); + message + }); + let app_state = AppState { message: messages_tx, settings, @@ -505,8 +535,9 @@ mod tests { .unwrap(); let response = app.oneshot(res).await.unwrap(); - let message = messages_rx.recv().await.unwrap(); - message.confirm_save.send(()).unwrap(); + let message = response_collector.await.unwrap(); + assert_eq!(message.id, ID_VALUE); + assert_eq!(response.status(), StatusCode::OK); let content_len = response.headers().get(CONTENT_LENGTH).unwrap(); diff --git a/rust/serving/src/metrics.rs b/rust/serving/src/metrics.rs index 4c64760d4d..ff96513774 100644 --- a/rust/serving/src/metrics.rs +++ b/rust/serving/src/metrics.rs @@ -175,6 +175,9 @@ mod tests { #[tokio::test] async fn test_start_metrics_server() -> Result<()> { + rustls::crypto::aws_lc_rs::default_provider() + .install_default() + .unwrap(); let (cert, key) = generate_certs()?; let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) From e8939643b478429ed2f332ec9bf54e996cc212a1 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Fri, 3 Jan 2025 06:30:04 +0530 Subject: [PATCH 09/15] Unit tests for serving source Signed-off-by: Sreekanth --- rust/Cargo.lock | 155 ++++++++++++++++++++++++++++++++++--- rust/serving/Cargo.toml | 2 + rust/serving/src/lib.rs | 3 + rust/serving/src/source.rs | 76 +++++++++++++++++- 4 files changed, 225 insertions(+), 11 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 7d70eea7b5..6fc31dc61e 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -188,7 +188,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.2", "tokio", - "tower 0.5.1", + "tower 0.5.2", "tower-layer", "tower-service", "tracing", @@ -672,6 +672,21 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -1092,6 +1107,22 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.5.1", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.10" @@ -1575,6 +1606,23 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nkeys" version = "0.4.4" @@ -1809,12 +1857,50 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "openssl" +version = "0.10.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "openssl-probe" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "openssl-sys" +version = "0.9.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "ordered-float" version = "2.10.1" @@ -1992,6 +2078,12 @@ dependencies = [ "spki", ] +[[package]] +name = "pkg-config" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" + [[package]] name = "portable-atomic" version = "1.10.0" @@ -2415,7 +2507,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper 0.1.2", - "system-configuration", + "system-configuration 0.5.1", "tokio", "tokio-rustls 0.24.1", "tower-service", @@ -2429,24 +2521,28 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.9" +version = "0.12.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" +checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" dependencies = [ "base64 0.22.1", "bytes", + "encoding_rs", "futures-core", "futures-util", + "h2 0.4.7", "http 1.1.0", "http-body 1.0.1", "http-body-util", "hyper 1.5.1", "hyper-rustls 0.27.3", + "hyper-tls", "hyper-util", "ipnet", "js-sys", "log", "mime", + "native-tls", "once_cell", "percent-encoding", "pin-project-lite", @@ -2458,8 +2554,11 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper 1.0.2", + "system-configuration 0.6.1", "tokio", + "tokio-native-tls", "tokio-rustls 0.26.0", + "tower 0.5.2", "tower-service", "url", "wasm-bindgen", @@ -2817,7 +2916,7 @@ name = "servesink" version = "0.1.0" dependencies = [ "numaflow 0.1.1", - "reqwest 0.12.9", + "reqwest 0.12.12", "tokio", "tonic", "tracing", @@ -2840,6 +2939,7 @@ dependencies = [ "prometheus-client", "rcgen", "redis", + "reqwest 0.12.12", "rustls 0.23.19", "serde", "serde_json", @@ -3035,7 +3135,18 @@ checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", "core-foundation 0.9.4", - "system-configuration-sys", + "system-configuration-sys 0.5.0", +] + +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.6.0", + "core-foundation 0.9.4", + "system-configuration-sys 0.6.0", ] [[package]] @@ -3048,6 +3159,16 @@ dependencies = [ "libc", ] +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tempfile" version = "3.14.0" @@ -3195,6 +3316,16 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-retry" version = "0.3.0" @@ -3338,14 +3469,14 @@ dependencies = [ [[package]] name = "tower" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", "pin-project-lite", - "sync_wrapper 0.1.2", + "sync_wrapper 1.0.2", "tokio", "tower-layer", "tower-service", @@ -3562,6 +3693,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" diff --git a/rust/serving/Cargo.toml b/rust/serving/Cargo.toml index 8f8b86ff3b..5d541c27e0 100644 --- a/rust/serving/Cargo.toml +++ b/rust/serving/Cargo.toml @@ -18,6 +18,7 @@ backoff.workspace = true axum.workspace = true axum-server.workspace = true bytes.workspace = true +rustls.workspace = true axum-macros = "0.4.1" hyper-util = { version = "0.1.6", features = ["client-legacy"] } serde = { version = "1.0.204", features = ["derive"] } @@ -35,4 +36,5 @@ prometheus-client = "0.22.3" thiserror = "1.0.63" [dev-dependencies] +reqwest = {version= "0.12.12", features = ["json"]} rustls.workspace = true diff --git a/rust/serving/src/lib.rs b/rust/serving/src/lib.rs index e275132386..20fa2c9587 100644 --- a/rust/serving/src/lib.rs +++ b/rust/serving/src/lib.rs @@ -38,6 +38,9 @@ pub(crate) async fn serve( where T: Clone + Send + Sync + Store + 'static, { + rustls::crypto::aws_lc_rs::default_provider() + .install_default() + .expect("Failed to set crypto provider"); let (cert, key) = generate_certs()?; let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) diff --git a/rust/serving/src/source.rs b/rust/serving/src/source.rs index b3f4c2e1ab..4b1b2ce3f7 100644 --- a/rust/serving/src/source.rs +++ b/rust/serving/src/source.rs @@ -17,6 +17,7 @@ pub struct MessageWrapper { pub message: Message, } +#[derive(Debug)] pub struct Message { pub value: Bytes, pub id: String, @@ -71,7 +72,9 @@ impl ServingSourceActor { settings, callback_state, }; - crate::serve(app).await.unwrap(); + tokio::spawn(async move { + crate::serve(app).await.unwrap(); + }); Ok(()) } @@ -153,7 +156,7 @@ impl ServingSource { batch_size: usize, timeout: Duration, ) -> Result { - let (actor_tx, actor_rx) = mpsc::channel(10); + let (actor_tx, actor_rx) = mpsc::channel(1000); ServingSourceActor::start(settings, actor_rx).await?; Ok(Self { batch_size, @@ -192,3 +195,72 @@ impl ServingSource { Ok(()) } } + +#[cfg(feature = "redis-tests")] +#[cfg(test)] +mod tests { + use std::{sync::Arc, time::Duration}; + + use crate::Settings; + + use super::ServingSource; + + type Result = std::result::Result>; + #[tokio::test] + async fn test_serving_source() -> Result<()> { + let settings = Arc::new(Settings::default()); + let serving_source = + ServingSource::new(Arc::clone(&settings), 10, Duration::from_millis(1)).await?; + + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(2)) + .danger_accept_invalid_certs(true) + .build() + .unwrap(); + + // Wait for the server + for _ in 0..10 { + let resp = client + .get(format!( + "https://localhost:{}/livez", + settings.app_listen_port + )) + .send() + .await; + if resp.is_ok() { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + tokio::spawn(async move { + loop { + tokio::time::sleep(Duration::from_millis(10)).await; + let mut messages = serving_source.read_messages().await.unwrap(); + if messages.is_empty() { + // Server has not received any requests yet + continue; + } + assert_eq!(messages.len(), 1); + let msg = messages.remove(0); + serving_source + .ack_messages(vec![format!("{}-0", msg.id)]) + .await + .unwrap(); + break; + } + }); + + let resp = client + .post(format!( + "https://localhost:{}/v1/process/async", + settings.app_listen_port + )) + .json("test-payload") + .send() + .await?; + + assert!(resp.status().is_success()); + Ok(()) + } +} From a45411d31f58ff15015c55e015a7e7cd95e4769c Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Fri, 3 Jan 2025 07:03:31 +0530 Subject: [PATCH 10/15] More unit tests for serving Signed-off-by: Sreekanth --- rust/numaflow-core/src/message.rs | 8 ++-- rust/numaflow-core/src/source/serving.rs | 53 ++++++++++++++++++++++++ rust/serving/src/lib.rs | 4 +- 3 files changed, 58 insertions(+), 7 deletions(-) diff --git a/rust/numaflow-core/src/message.rs b/rust/numaflow-core/src/message.rs index 00f5cca663..fe20613dad 100644 --- a/rust/numaflow-core/src/message.rs +++ b/rust/numaflow-core/src/message.rs @@ -37,7 +37,7 @@ pub(crate) struct Message { } /// Offset of the message which will be used to acknowledge the message. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub(crate) enum Offset { Int(IntOffset), String(StringOffset), @@ -62,7 +62,7 @@ impl Message { } /// IntOffset is integer based offset enum type. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct IntOffset { pub(crate) offset: u64, pub(crate) partition_idx: u16, @@ -84,7 +84,7 @@ impl fmt::Display for IntOffset { } /// StringOffset is string based offset enum type. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub(crate) struct StringOffset { /// offset could be a complex base64 string. pub(crate) offset: Bytes, @@ -120,7 +120,7 @@ pub(crate) enum ReadAck { } /// Message ID which is used to uniquely identify a message. It cheap to clone this. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub(crate) struct MessageID { pub(crate) vertex_name: Bytes, pub(crate) offset: Bytes, diff --git a/rust/numaflow-core/src/source/serving.rs b/rust/numaflow-core/src/source/serving.rs index 02074904eb..a7f00a6c29 100644 --- a/rust/numaflow-core/src/source/serving.rs +++ b/rust/numaflow-core/src/source/serving.rs @@ -75,3 +75,56 @@ impl super::LagReader for ServingSource { Ok(None) } } + +#[cfg(test)] +mod tests { + use crate::message::{Message, MessageID, Offset, StringOffset}; + use std::collections::HashMap; + + use bytes::Bytes; + + type Result = std::result::Result>; + + #[test] + fn test_message_conversion() -> Result<()> { + const MSG_ID: &str = "b149ad7a-5690-4f0a"; + + let mut headers = HashMap::new(); + headers.insert("header-key".to_owned(), "header-value".to_owned()); + + let serving_message = serving::Message { + value: Bytes::from_static(b"test"), + id: MSG_ID.into(), + headers: headers.clone(), + }; + let message: Message = serving_message.try_into()?; + assert_eq!(message.value, Bytes::from_static(b"test")); + assert_eq!( + message.offset, + Some(Offset::String(StringOffset::new(MSG_ID.into(), 0))) + ); + assert_eq!( + message.id, + MessageID { + vertex_name: Bytes::new(), + offset: format!("{MSG_ID}-0").into(), + index: 0 + } + ); + + assert_eq!(message.headers, headers); + + Ok(()) + } + + #[test] + fn test_error_conversion() { + use crate::error::Error; + let error: Error = serving::Error::ParseConfig("Invalid config".to_owned()).into(); + if let Error::Source(val) = error { + assert_eq!(val, "ParseConfig Error - Invalid config".to_owned()); + } else { + panic!("Expected Error::Source() variant"); + } + } +} diff --git a/rust/serving/src/lib.rs b/rust/serving/src/lib.rs index 20fa2c9587..97d4fe3a7f 100644 --- a/rust/serving/src/lib.rs +++ b/rust/serving/src/lib.rs @@ -38,9 +38,7 @@ pub(crate) async fn serve( where T: Clone + Send + Sync + Store + 'static, { - rustls::crypto::aws_lc_rs::default_provider() - .install_default() - .expect("Failed to set crypto provider"); + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); let (cert, key) = generate_certs()?; let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) From e8592bc298cd1895bf952177a0f4e684c14d0c30 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Fri, 3 Jan 2025 09:35:29 +0530 Subject: [PATCH 11/15] Unit tests for Sourcer trait implementation on Serving source Signed-off-by: Sreekanth --- rust/Cargo.lock | 1 + rust/Cargo.toml | 5 +- rust/numaflow-core/Cargo.toml | 5 +- rust/numaflow-core/src/source/serving.rs | 68 +++++++++++++++++++++++- rust/serving/Cargo.toml | 8 ++- 5 files changed, 80 insertions(+), 7 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 6fc31dc61e..e3d90e2f05 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1786,6 +1786,7 @@ dependencies = [ "pulsar", "rand", "rcgen", + "reqwest 0.12.12", "rustls 0.23.19", "semver", "serde", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 91e0cc7361..a7a4df3637 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -58,11 +58,12 @@ numaflow-core = { path = "numaflow-core" } numaflow-models = { path = "numaflow-models" } backoff = { path = "backoff" } numaflow-pb = { path = "numaflow-pb" } -numaflow-pulsar = {path = "extns/numaflow-pulsar"} +numaflow-pulsar = { path = "extns/numaflow-pulsar" } tokio = "1.41.1" bytes = "1.7.1" tracing = "0.1.40" axum = "0.7.5" axum-server = { version = "0.7.1", features = ["tls-rustls"] } serde = { version = "1.0.204", features = ["derive"] } -rustls = { version = "0.23.12", features = ["aws_lc_rs"] } \ No newline at end of file +rustls = { version = "0.23.12", features = ["aws_lc_rs"] } +reqwest = "0.12.12" diff --git a/rust/numaflow-core/Cargo.toml b/rust/numaflow-core/Cargo.toml index c7a33fe275..4a98303a1e 100644 --- a/rust/numaflow-core/Cargo.toml +++ b/rust/numaflow-core/Cargo.toml @@ -50,6 +50,9 @@ async-nats = "0.38.0" [dev-dependencies] tempfile = "3.11.0" numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "9ca9362ad511084501520e5a37d40cdcd0cdc9d9" } -pulsar = { version = "6.3.0", default-features = false, features = ["tokio-rustls-runtime"] } +pulsar = { version = "6.3.0", default-features = false, features = [ + "tokio-rustls-runtime", +] } +reqwest = { workspace = true, features = ["json"] } [build-dependencies] diff --git a/rust/numaflow-core/src/source/serving.rs b/rust/numaflow-core/src/source/serving.rs index a7f00a6c29..8e9794b51d 100644 --- a/rust/numaflow-core/src/source/serving.rs +++ b/rust/numaflow-core/src/source/serving.rs @@ -78,10 +78,14 @@ impl super::LagReader for ServingSource { #[cfg(test)] mod tests { - use crate::message::{Message, MessageID, Offset, StringOffset}; - use std::collections::HashMap; + use crate::{ + message::{Message, MessageID, Offset, StringOffset}, + source::{SourceAcker, SourceReader}, + }; + use std::{collections::HashMap, sync::Arc, time::Duration}; use bytes::Bytes; + use serving::{ServingSource, Settings}; type Result = std::result::Result>; @@ -127,4 +131,64 @@ mod tests { panic!("Expected Error::Source() variant"); } } + + #[tokio::test] + async fn test_serving_source_reader_acker() -> Result<()> { + let settings = Settings { + app_listen_port: 2000, + ..Default::default() + }; + let settings = Arc::new(settings); + let mut serving_source = + ServingSource::new(Arc::clone(&settings), 10, Duration::from_millis(1)).await?; + + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(2)) + .danger_accept_invalid_certs(true) + .build() + .unwrap(); + + // Wait for the server + for _ in 0..10 { + let resp = client + .get(format!( + "https://localhost:{}/livez", + settings.app_listen_port + )) + .send() + .await; + if resp.is_ok() { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + let task_handle = tokio::spawn(async move { + loop { + tokio::time::sleep(Duration::from_millis(10)).await; + let mut messages = serving_source.read().await.unwrap(); + if messages.is_empty() { + // Server has not received any requests yet + continue; + } + assert_eq!(messages.len(), 1); + let msg = messages.remove(0); + serving_source.ack(vec![msg.offset.unwrap()]).await.unwrap(); + break; + } + }); + + let resp = client + .post(format!( + "https://localhost:{}/v1/process/async", + settings.app_listen_port + )) + .json("test-payload") + .send() + .await?; + + assert!(resp.status().is_success()); + assert!(task_handle.await.is_ok()); + Ok(()) + } } diff --git a/rust/serving/Cargo.toml b/rust/serving/Cargo.toml index 5d541c27e0..857d69db77 100644 --- a/rust/serving/Cargo.toml +++ b/rust/serving/Cargo.toml @@ -26,7 +26,11 @@ serde_json = "1.0.120" tower = "0.4.13" tower-http = { version = "0.5.2", features = ["trace", "timeout"] } uuid = { version = "1.10.0", features = ["v4"] } -redis = { version = "0.26.0", features = ["tokio-comp", "aio", "connection-manager"] } +redis = { version = "0.26.0", features = [ + "tokio-comp", + "aio", + "connection-manager", +] } trait-variant = "0.1.2" chrono = { version = "0.4", features = ["serde"] } base64 = "0.22.1" @@ -36,5 +40,5 @@ prometheus-client = "0.22.3" thiserror = "1.0.63" [dev-dependencies] -reqwest = {version= "0.12.12", features = ["json"]} +reqwest = { workspace = true, features = ["json"] } rustls.workspace = true From 54c2ccea5d864a54eb3341482a76b79bd2fdabf2 Mon Sep 17 00:00:00 2001 From: Vigith Maurice Date: Thu, 2 Jan 2025 21:57:37 -0800 Subject: [PATCH 12/15] chore: minor updates Signed-off-by: Vigith Maurice --- rust/serving/src/lib.rs | 9 +++++---- rust/serving/src/pipeline.rs | 30 +++++++++++++++--------------- rust/serving/src/source.rs | 9 ++++++--- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/rust/serving/src/lib.rs b/rust/serving/src/lib.rs index 97d4fe3a7f..b5579fae70 100644 --- a/rust/serving/src/lib.rs +++ b/rust/serving/src/lib.rs @@ -23,13 +23,14 @@ mod metrics; mod pipeline; pub mod source; -pub use source::{Message, MessageWrapper, ServingSource}; +pub use source::{Message, ServingSource}; +use crate::source::MessageWrapper; #[derive(Clone)] pub(crate) struct AppState { - pub message: mpsc::Sender, - pub settings: Arc, - pub callback_state: CallbackState, + pub(crate) message: mpsc::Sender, + pub(crate) settings: Arc, + pub(crate) callback_state: CallbackState, } pub(crate) async fn serve( diff --git a/rust/serving/src/pipeline.rs b/rust/serving/src/pipeline.rs index 2adc169433..cb491d7d88 100644 --- a/rust/serving/src/pipeline.rs +++ b/rust/serving/src/pipeline.rs @@ -10,7 +10,7 @@ use crate::Error::ParseConfig; // OperatorType is an enum that contains the types of operators // that can be used in the conditions for the edge. #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] -pub enum OperatorType { +pub(crate) enum OperatorType { #[serde(rename = "and")] And, #[serde(rename = "or")] @@ -43,36 +43,36 @@ impl From for OperatorType { // Tag is a struct that contains the information about the tags for the edge #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Tag { - pub operator: Option, - pub values: Vec, +pub(crate) struct Tag { + pub(crate) operator: Option, + pub(crate) values: Vec, } // Conditions is a struct that contains the information about the conditions for the edge #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Conditions { - pub tags: Option, +pub(crate) struct Conditions { + pub(crate) tags: Option, } // Edge is a struct that contains the information about the edge in the pipeline. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Edge { - pub from: String, - pub to: String, - pub conditions: Option, +pub(crate) struct Edge { + pub(crate) from: String, + pub(crate) to: String, + pub(crate) conditions: Option, } /// DCG (directed compute graph) of the pipeline with minimal information build using vertices and edges /// from the pipeline spec #[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq)] -pub struct PipelineDCG { - pub vertices: Vec, - pub edges: Vec, +pub(crate) struct PipelineDCG { + pub(crate) vertices: Vec, + pub(crate) edges: Vec, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Vertex { - pub name: String, +pub(crate) struct Vertex { + pub(crate) name: String, } impl FromStr for PipelineDCG { diff --git a/rust/serving/src/source.rs b/rust/serving/src/source.rs index 4b1b2ce3f7..d6f7f9c28f 100644 --- a/rust/serving/src/source.rs +++ b/rust/serving/src/source.rs @@ -12,11 +12,14 @@ use crate::app::tracker::MessageGraph; use crate::Settings; use crate::{Error, Result}; -pub struct MessageWrapper { - pub confirm_save: oneshot::Sender<()>, - pub message: Message, +/// [Message] with a oneshot for notifying when the message has been completed processed. +pub(crate) struct MessageWrapper { + // TODO: this might be more that saving to ISB. + pub(crate) confirm_save: oneshot::Sender<()>, + pub(crate) message: Message, } +/// Serving payload passed on to Numaflow. #[derive(Debug)] pub struct Message { pub value: Bytes, From 8f0389b0433dcb38e861eed3a13835e0714f97b1 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Fri, 3 Jan 2025 14:54:02 +0530 Subject: [PATCH 13/15] Fixes based on code review Signed-off-by: Sreekanth --- rust/Cargo.toml | 4 +- rust/numaflow-core/src/config/components.rs | 2 + rust/numaflow-core/src/lib.rs | 1 - .../src/shared/create_components.rs | 9 ++- rust/numaflow-core/src/source/serving.rs | 20 ++++-- rust/serving/src/app/jetstream_proxy.rs | 6 +- rust/serving/src/error.rs | 3 + rust/serving/src/lib.rs | 4 +- rust/serving/src/metrics.rs | 5 +- rust/serving/src/source.rs | 62 +++++++++++++------ 10 files changed, 83 insertions(+), 33 deletions(-) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index a7a4df3637..75fd036128 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -40,8 +40,8 @@ verbose_file_reads = "warn" # This profile optimizes for runtime performance and small binary size at the expense of longer build times. # Compared to default release profile, this profile reduced binary size from 29MB to 21MB # and increased build time (with only one line change in code) from 12 seconds to 133 seconds (tested on Mac M2 Max). -# [profile.release] -# lto = "fat" +[profile.release] +lto = "fat" # This profile optimizes for short build times at the expense of larger binary size and slower runtime performance. # If you have to rebuild image often, in Dockerfile you may replace `--release` passed to cargo command with `--profile quick-release` diff --git a/rust/numaflow-core/src/config/components.rs b/rust/numaflow-core/src/config/components.rs index 833ad8950b..3dc0bf2a66 100644 --- a/rust/numaflow-core/src/config/components.rs +++ b/rust/numaflow-core/src/config/components.rs @@ -38,6 +38,8 @@ pub(crate) mod source { Generator(GeneratorConfig), UserDefined(UserDefinedConfig), Pulsar(PulsarSourceConfig), + // Serving source starts an Axum HTTP server in the background. + // The settings will be used as application state which gets cloned in each handler on each request. Serving(Arc), } diff --git a/rust/numaflow-core/src/lib.rs b/rust/numaflow-core/src/lib.rs index 79ce4348b4..d65380f8d2 100644 --- a/rust/numaflow-core/src/lib.rs +++ b/rust/numaflow-core/src/lib.rs @@ -55,7 +55,6 @@ mod tracker; mod mapper; pub async fn run() -> Result<()> { - let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); let cln_token = CancellationToken::new(); let shutdown_cln_token = cln_token.clone(); diff --git a/rust/numaflow-core/src/shared/create_components.rs b/rust/numaflow-core/src/shared/create_components.rs index c077a1f44f..b28f4caeee 100644 --- a/rust/numaflow-core/src/shared/create_components.rs +++ b/rust/numaflow-core/src/shared/create_components.rs @@ -12,6 +12,7 @@ use tonic::transport::Channel; use crate::config::components::sink::{SinkConfig, SinkType}; use crate::config::components::source::{SourceConfig, SourceType}; use crate::config::components::transformer::TransformerConfig; +use crate::config::get_vertex_replica; use crate::config::pipeline::map::{MapMode, MapType, MapVtxConfig}; use crate::config::pipeline::{DEFAULT_BATCH_MAP_SOCKET, DEFAULT_STREAM_MAP_SOCKET}; use crate::error::Error; @@ -337,7 +338,13 @@ pub async fn create_source( )) } SourceType::Serving(config) => { - let serving = ServingSource::new(Arc::clone(config), batch_size, read_timeout).await?; + let serving = ServingSource::new( + Arc::clone(config), + batch_size, + read_timeout, + *get_vertex_replica(), + ) + .await?; Ok(( Source::new( batch_size, diff --git a/rust/numaflow-core/src/source/serving.rs b/rust/numaflow-core/src/source/serving.rs index 8e9794b51d..b9fb6c72ed 100644 --- a/rust/numaflow-core/src/source/serving.rs +++ b/rust/numaflow-core/src/source/serving.rs @@ -2,6 +2,7 @@ use std::sync::Arc; pub(crate) use serving::ServingSource; +use crate::config::get_vertex_replica; use crate::message::{MessageID, StringOffset}; use crate::Error; use crate::Result; @@ -12,9 +13,10 @@ impl TryFrom for Message { type Error = Error; fn try_from(message: serving::Message) -> Result { - let offset = Offset::String(StringOffset::new(message.id.clone(), 0)); + let offset = Offset::String(StringOffset::new(message.id.clone(), *get_vertex_replica())); Ok(Message { + // we do not support keys from HTTP client keys: Arc::from(vec![]), tags: None, value: message.value, @@ -50,11 +52,14 @@ impl super::SourceReader for ServingSource { } fn partitions(&self) -> Vec { - vec![] + vec![*get_vertex_replica()] } } impl super::SourceAcker for ServingSource { + /// HTTP response is sent only once we have confirmation that the message has been written to the ISB. + // TODO: Current implementation only works for `/v1/process/async` endpoint. + // For `/v1/process/{sync,sync_serve}` endpoints: https://github.com/numaproj/numaflow/issues/2308 async fn ack(&mut self, offsets: Vec) -> Result<()> { let mut serving_offsets = vec![]; for offset in offsets { @@ -87,6 +92,8 @@ mod tests { use bytes::Bytes; use serving::{ServingSource, Settings}; + use super::get_vertex_replica; + type Result = std::result::Result>; #[test] @@ -139,8 +146,13 @@ mod tests { ..Default::default() }; let settings = Arc::new(settings); - let mut serving_source = - ServingSource::new(Arc::clone(&settings), 10, Duration::from_millis(1)).await?; + let mut serving_source = ServingSource::new( + Arc::clone(&settings), + 10, + Duration::from_millis(1), + *get_vertex_replica(), + ) + .await?; let client = reqwest::Client::builder() .timeout(Duration::from_secs(2)) diff --git a/rust/serving/src/app/jetstream_proxy.rs b/rust/serving/src/app/jetstream_proxy.rs index 41d0c0df3e..6f61a0530f 100644 --- a/rust/serving/src/app/jetstream_proxy.rs +++ b/rust/serving/src/app/jetstream_proxy.rs @@ -92,7 +92,11 @@ async fn sync_publish_serve( }, }; - proxy_state.message.send(message).await.unwrap(); // FIXME: + proxy_state + .message + .send(message) + .await + .expect("Failed to send request payload to Serving channel"); if let Err(e) = rx.await { // Deregister the ID in the callback proxy state if writing to Jetstream fails diff --git a/rust/serving/src/error.rs b/rust/serving/src/error.rs index cfa252daad..8d03c48234 100644 --- a/rust/serving/src/error.rs +++ b/rust/serving/src/error.rs @@ -48,6 +48,9 @@ pub enum Error { #[error("Failed to receive message from channel. Actor task is terminated: {0:?}")] ActorTaskTerminated(oneshot::error::RecvError), + #[error("Serving source error - {0}")] + Source(String), + #[error("Other Error - {0}")] // catch-all variant for now Other(String), diff --git a/rust/serving/src/lib.rs b/rust/serving/src/lib.rs index b5579fae70..001065ddfe 100644 --- a/rust/serving/src/lib.rs +++ b/rust/serving/src/lib.rs @@ -23,8 +23,8 @@ mod metrics; mod pipeline; pub mod source; -pub use source::{Message, ServingSource}; use crate::source::MessageWrapper; +pub use source::{Message, ServingSource}; #[derive(Clone)] pub(crate) struct AppState { @@ -39,7 +39,9 @@ pub(crate) async fn serve( where T: Clone + Send + Sync + Store + 'static, { + // Setup the CryptoProvider (controls core cryptography used by rustls) for the process let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); + let (cert, key) = generate_certs()?; let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) diff --git a/rust/serving/src/metrics.rs b/rust/serving/src/metrics.rs index ff96513774..a605cc9988 100644 --- a/rust/serving/src/metrics.rs +++ b/rust/serving/src/metrics.rs @@ -175,9 +175,8 @@ mod tests { #[tokio::test] async fn test_start_metrics_server() -> Result<()> { - rustls::crypto::aws_lc_rs::default_provider() - .install_default() - .unwrap(); + // Setup the CryptoProvider (controls core cryptography used by rustls) for the process + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); let (cert, key) = generate_certs()?; let tls_config = RustlsConfig::from_pem(cert.pem().into(), key.serialize_pem().into()) diff --git a/rust/serving/src/source.rs b/rust/serving/src/source.rs index d6f7f9c28f..4366338a25 100644 --- a/rust/serving/src/source.rs +++ b/rust/serving/src/source.rs @@ -31,26 +31,36 @@ enum ActorMessage { Read { batch_size: usize, timeout_at: Instant, - reply_to: oneshot::Sender>, + reply_to: oneshot::Sender>>, }, Ack { offsets: Vec, - reply_to: oneshot::Sender<()>, + reply_to: oneshot::Sender>, }, } +/// Background actor that starts Axum server for accepting HTTP requests. struct ServingSourceActor { + /// The HTTP handlers will put the message received from the payload to this channel messages: mpsc::Receiver, + /// Channel for the actor handle to communicate with this actor handler_rx: mpsc::Receiver, + /// Mapping from request's ID header (usually `X-Numaflow-Id` header) to a channel. + /// This sending a message on this channel notifies the HTTP handler function that the message + /// has been successfully processed. tracker: HashMap>, + vertex_replica_id: u16, } impl ServingSourceActor { async fn start( settings: Arc, handler_rx: mpsc::Receiver, + request_channel_buffer_size: usize, + vertex_replica_id: u16, ) -> Result<()> { - let (messages_tx, messages_rx) = mpsc::channel(10000); + // Channel to which HTTP handlers will send request payload + let (messages_tx, messages_rx) = mpsc::channel(request_channel_buffer_size); // Create a redis store to store the callbacks and the custom responses let redis_store = RedisConnection::new(settings.redis.clone()).await?; // Create the message graph from the pipeline spec and the redis store @@ -67,6 +77,7 @@ impl ServingSourceActor { messages: messages_rx, handler_rx, tracker: HashMap::new(), + vertex_replica_id, }; serving_actor.run().await; }); @@ -98,24 +109,32 @@ impl ServingSourceActor { let _ = reply_to.send(messages); } ActorMessage::Ack { offsets, reply_to } => { - self.ack(offsets).await; - let _ = reply_to.send(()); + let status = self.ack(offsets).await; + let _ = reply_to.send(status); } } } - async fn read(&mut self, count: usize, timeout_at: Instant) -> Vec { + async fn read(&mut self, count: usize, timeout_at: Instant) -> Result> { let mut messages = vec![]; loop { + // Stop if the read timeout has reached or if we have collected the requested number of messages if messages.len() >= count || Instant::now() >= timeout_at { break; } let message = match self.messages.try_recv() { Ok(msg) => msg, Err(mpsc::error::TryRecvError::Empty) => break, - Err(e) => { - tracing::error!(?e, "Receiving messages from the serving channel"); // FIXME: - return messages; + Err(mpsc::error::TryRecvError::Disconnected) => { + // If we have collected at-least one message, we return those messages. + // The error will happen on all the subsequent read attempts too. + if messages.is_empty() { + return Err(Error::Other( + "Sending half of the Serving channel has disconnected".into(), + )); + } + tracing::error!("Sending half of the Serving channel has disconnected"); + return Ok(messages); } }; let MessageWrapper { @@ -126,22 +145,24 @@ impl ServingSourceActor { self.tracker.insert(message.id.clone(), confirm_save); messages.push(message); } - messages + Ok(messages) } - async fn ack(&mut self, offsets: Vec) { + async fn ack(&mut self, offsets: Vec) -> Result<()> { + let offset_suffix = format!("-{}", self.vertex_replica_id); for offset in offsets { - let offset = offset - .strip_suffix("-0") - .expect("offset does not end with '-0'"); // FIXME: we hardcode 0 as the partition index when constructing offset + let offset = offset.strip_suffix(&offset_suffix).ok_or_else(|| { + Error::Source(format!("offset does not end with '{}'", &offset_suffix)) + })?; let confirm_save_tx = self .tracker .remove(offset) - .expect("offset was not found in the tracker"); + .ok_or_else(|| Error::Source("offset was not found in the tracker".into()))?; confirm_save_tx .send(()) - .expect("Sending on confirm_save channel"); + .map_err(|e| Error::Source(format!("Sending on confirm_save channel: {e:?}")))?; } + Ok(()) } } @@ -158,9 +179,10 @@ impl ServingSource { settings: Arc, batch_size: usize, timeout: Duration, + vertex_replica_id: u16, ) -> Result { - let (actor_tx, actor_rx) = mpsc::channel(1000); - ServingSourceActor::start(settings, actor_rx).await?; + let (actor_tx, actor_rx) = mpsc::channel(2 * batch_size); + ServingSourceActor::start(settings, actor_rx, 2 * batch_size, vertex_replica_id).await?; Ok(Self { batch_size, timeout, @@ -177,7 +199,7 @@ impl ServingSource { timeout_at: Instant::now() + self.timeout, }; let _ = self.actor_tx.send(actor_msg).await; - let messages = rx.await.map_err(Error::ActorTaskTerminated)?; + let messages = rx.await.map_err(Error::ActorTaskTerminated)??; tracing::debug!( count = messages.len(), requested_count = self.batch_size, @@ -194,7 +216,7 @@ impl ServingSource { reply_to: tx, }; let _ = self.actor_tx.send(actor_msg).await; - rx.await.map_err(Error::ActorTaskTerminated)?; + rx.await.map_err(Error::ActorTaskTerminated)??; Ok(()) } } From d85079a2849885746f6a8348594b1c53b54f1383 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Fri, 3 Jan 2025 15:12:19 +0530 Subject: [PATCH 14/15] Fix test Signed-off-by: Sreekanth --- rust/serving/src/source.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/serving/src/source.rs b/rust/serving/src/source.rs index 4366338a25..418c07ce75 100644 --- a/rust/serving/src/source.rs +++ b/rust/serving/src/source.rs @@ -235,7 +235,7 @@ mod tests { async fn test_serving_source() -> Result<()> { let settings = Arc::new(Settings::default()); let serving_source = - ServingSource::new(Arc::clone(&settings), 10, Duration::from_millis(1)).await?; + ServingSource::new(Arc::clone(&settings), 10, Duration::from_millis(1), 0).await?; let client = reqwest::Client::builder() .timeout(Duration::from_secs(2)) From 1dcad81037745cccbb42c6a24cee7374ce26efdd Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Mon, 6 Jan 2025 11:03:13 +0530 Subject: [PATCH 15/15] Avoid hot loop by using timeout on channel receive Signed-off-by: Sreekanth --- rust/serving/src/source.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/rust/serving/src/source.rs b/rust/serving/src/source.rs index 418c07ce75..d038179672 100644 --- a/rust/serving/src/source.rs +++ b/rust/serving/src/source.rs @@ -122,10 +122,10 @@ impl ServingSourceActor { if messages.len() >= count || Instant::now() >= timeout_at { break; } - let message = match self.messages.try_recv() { - Ok(msg) => msg, - Err(mpsc::error::TryRecvError::Empty) => break, - Err(mpsc::error::TryRecvError::Disconnected) => { + let next_msg = self.messages.recv(); + let message = match tokio::time::timeout_at(timeout_at, next_msg).await { + Ok(Some(msg)) => msg, + Ok(None) => { // If we have collected at-least one message, we return those messages. // The error will happen on all the subsequent read attempts too. if messages.is_empty() { @@ -136,6 +136,7 @@ impl ServingSourceActor { tracing::error!("Sending half of the Serving channel has disconnected"); return Ok(messages); } + Err(_) => return Ok(messages), }; let MessageWrapper { confirm_save,