diff --git a/crates/sdk/src/default_main.rs b/crates/sdk/src/default_main.rs index 7debf3e..74397e7 100644 --- a/crates/sdk/src/default_main.rs +++ b/crates/sdk/src/default_main.rs @@ -295,7 +295,7 @@ where C::Configuration: Clone, C::State: Clone, { - let router = Router::new() + Router::new() .route("/capabilities", get(get_capabilities::)) .route("/health", get(get_health::)) .route("/metrics", get(get_metrics::)) @@ -319,53 +319,52 @@ where ); }), ) - .with_state(state); + .with_state(state) + .layer(ValidateRequestHeaderLayer::custom(auth_handler( + service_token_secret, + ))) +} +fn auth_handler( + service_token_secret: Option, +) -> impl Fn(&mut Request) -> Result<(), axum::response::Response> + Clone { let expected_auth_header: Option = service_token_secret.and_then(|service_token_secret| { let expected_bearer = format!("Bearer {service_token_secret}"); HeaderValue::from_str(&expected_bearer).ok() }); - router - .layer( - TraceLayer::new_for_http() - .make_span_with(make_span) - .on_response(on_response), + move |request| { + // Validate the request + let auth_header = request.headers().get("Authorization").cloned(); + + // NOTE: The comparison should probably be more permissive to allow for whitespace, etc. + if auth_header == expected_auth_header { + return Ok(()); + } + + let message = "Bearer token does not match.".to_string(); + + tracing::error!( + meta.signal_type = "log", + event.domain = "ndc", + event.name = "Authorization error", + name = "Authorization error", + body = message, + error = true, + ); + Err(( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + message: "Internal error".into(), + details: serde_json::Value::Object(serde_json::Map::from_iter([( + "cause".into(), + serde_json::Value::String(message), + )])), + }), ) - .layer(ValidateRequestHeaderLayer::custom( - move |request: &mut Request| { - // Validate the request - let auth_header = request.headers().get("Authorization").cloned(); - - // NOTE: The comparison should probably be more permissive to allow for whitespace, etc. - if auth_header == expected_auth_header { - return Ok(()); - } - - let message = "Bearer token does not match.".to_string(); - - tracing::error!( - meta.signal_type = "log", - event.domain = "ndc", - event.name = "Authorization error", - name = "Authorization error", - body = message, - error = true, - ); - Err(( - StatusCode::UNAUTHORIZED, - Json(ErrorResponse { - message: "Internal error".into(), - details: serde_json::Value::Object(serde_json::Map::from_iter([( - "cause".into(), - serde_json::Value::String(message), - )])), - }), - ) - .into_response()) - }, - )) + .into_response()) + } } async fn get_metrics(