From 74a9743321c7adc8f8f145b075c0b6ef31af7733 Mon Sep 17 00:00:00 2001 From: YISH Date: Sat, 18 Jan 2025 18:45:16 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20OpenAPI/Swagger-ui?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit docs url: https://localhost:8453/api/docs --- Cargo.lock | 120 ++++++++++++++++++++++++++++++++++ Cargo.toml | 16 ++++- src/api/address.rs | 7 +- src/api/audit.rs | 6 +- src/api/cache.rs | 25 +++---- src/api/forward.rs | 6 +- src/api/listener.rs | 6 +- src/api/log.rs | 7 +- src/api/mod.rs | 41 ++++++++++-- src/api/nameserver.rs | 6 +- src/api/openapi.rs | 62 ++++++++++++++++++ src/api/serve_dns.rs | 147 +++++++++++++++++++++++++++--------------- src/api/settings.rs | 9 ++- src/dns_conf.rs | 7 +- src/server/https.rs | 73 ++++++++++++--------- swagger-ui.html | 44 +++++++++++++ 16 files changed, 464 insertions(+), 118 deletions(-) create mode 100644 src/api/openapi.rs create mode 100644 swagger-ui.html diff --git a/Cargo.lock b/Cargo.lock index 669df59b..7ba1aa40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1428,6 +1428,7 @@ checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" dependencies = [ "equivalent", "hashbrown 0.15.2", + "serde", ] [[package]] @@ -1617,6 +1618,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1793,6 +1804,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "percent-encoding" version = "2.3.1" @@ -2235,6 +2252,40 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "rust-embed" +version = "8.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa66af4a4fdd5e7ebc276f115e895611a34739a9c1c01028383d612d550953c0" +dependencies = [ + "rust-embed-impl", + "rust-embed-utils", + "walkdir", +] + +[[package]] +name = "rust-embed-impl" +version = "8.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6125dbc8867951125eec87294137f4e9c2c96566e61bf72c45095a7c77761478" +dependencies = [ + "proc-macro2", + "quote", + "rust-embed-utils", + "syn 2.0.96", + "walkdir", +] + +[[package]] +name = "rust-embed-utils" +version = "8.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e5347777e9aacb56039b0e1f28785929a8a3b709e87482e7442c72e7c12529d" +dependencies = [ + "sha2", + "walkdir", +] + [[package]] name = "rust_decimal" version = "1.36.0" @@ -2680,6 +2731,9 @@ dependencies = [ "tracing", "tracing-subscriber", "url", + "utoipa", + "utoipa-axum", + "utoipa-swagger-ui", "uzers", "webpki-roots", "which", @@ -3104,6 +3158,12 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-ident" version = "1.0.14" @@ -3170,6 +3230,56 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "utoipa" +version = "5.3.1" +source = "git+https://github.com/mokeyish/utoipa.git?rev=smartdns.1#7752ae94513a335d97d984cb7274963775643a32" +dependencies = [ + "indexmap", + "serde", + "serde_json", + "utoipa-gen", +] + +[[package]] +name = "utoipa-axum" +version = "0.2.0" +source = "git+https://github.com/mokeyish/utoipa.git?rev=smartdns.1#7752ae94513a335d97d984cb7274963775643a32" +dependencies = [ + "axum", + "paste", + "tower-layer", + "tower-service", + "utoipa", +] + +[[package]] +name = "utoipa-gen" +version = "5.3.1" +source = "git+https://github.com/mokeyish/utoipa.git?rev=smartdns.1#7752ae94513a335d97d984cb7274963775643a32" +dependencies = [ + "proc-macro2", + "quote", + "regex", + "syn 2.0.96", +] + +[[package]] +name = "utoipa-swagger-ui" +version = "9.0.0" +source = "git+https://github.com/mokeyish/utoipa.git?rev=smartdns.1#7752ae94513a335d97d984cb7274963775643a32" +dependencies = [ + "axum", + "base64 0.22.1", + "mime_guess", + "regex", + "rust-embed", + "serde", + "serde_json", + "utoipa", + "zip", +] + [[package]] name = "uuid" version = "1.12.0" @@ -3200,6 +3310,16 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index a762e825..16a068fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,9 +25,11 @@ unexpected_cfgs = { level = "warn", check-cfg = ['cfg(nightly)'] } [features] -default = ["resolve-cli", "dns-over-tls", "dns-over-https", "dns-over-quic", "dns-over-h3", "dnssec", "service", "nft", "nom-recipes-all", "self-update" ] +default = ["common", "self-update" ] -homebrew = ["resolve-cli", "dns-over-tls", "dns-over-https", "dns-over-quic", "dns-over-h3", "dnssec", "service", "nft", "nom-recipes-all" ] +homebrew = [ "common" ] + +common = ["resolve-cli", "dns-over-tls", "dns-over-https", "dns-over-quic", "dns-over-h3", "dnssec", "service", "nft", "nom-recipes-all", "swagger-ui-cdn" ] nom-recipes-all =["nom-recipes-ip"] @@ -75,6 +77,13 @@ nft = ["dep:which", "dep:either"] dnssec = [ "hickory-proto/dnssec-ring", "rustls/ring"] + +swagger-ui-cdn = [] + +swagger-ui-embed = [ + "dep:utoipa-swagger-ui" +] + experimental = ["experimental-trie", "experimental-phf"] experimental-trie = [] @@ -114,6 +123,9 @@ axum = { version = "0.8.1" } hyper = { version = "1.1.0", default-features = false } hyper-util = { version = "0.1.3", features = ["http2"]} tower = { version = "0.5.2", default-features = false } +utoipa = { git = "https://github.com/mokeyish/utoipa.git", rev = "smartdns.1", package = "utoipa", features = ["axum_extras"] } +utoipa-axum = { git = "https://github.com/mokeyish/utoipa.git", rev = "smartdns.1", package = "utoipa-axum", features = []} +utoipa-swagger-ui = { git = "https://github.com/mokeyish/utoipa.git", rev = "smartdns.1", package = "utoipa-swagger-ui", optional = true, default-features = false, features = ["axum"] } # serde serde = { version = "1.0", features = ["derive"]} diff --git a/src/api/address.rs b/src/api/address.rs index 12c2381c..8f2b1840 100644 --- a/src/api/address.rs +++ b/src/api/address.rs @@ -1,13 +1,14 @@ use std::sync::Arc; -use axum::{extract::State, response::IntoResponse, routing::get, Json, Router}; - +use super::openapi::{http::get, routes, IntoRouter}; use super::{IntoDataListPayload, ServeState, StatefulRouter}; +use axum::{extract::State, response::IntoResponse, Json}; pub fn routes() -> StatefulRouter { - Router::new().route("/addresses", get(addresses)) + routes![addresses].into_router() } +#[get("/addresses")] async fn addresses(State(state): State>) -> impl IntoResponse { Json( state diff --git a/src/api/audit.rs b/src/api/audit.rs index 4cedf527..3fa36755 100644 --- a/src/api/audit.rs +++ b/src/api/audit.rs @@ -1,13 +1,15 @@ use std::sync::Arc; -use axum::{extract::State, response::IntoResponse, routing::get, Json, Router}; +use super::openapi::{http::get, routes, IntoRouter}; +use axum::{extract::State, response::IntoResponse, Json}; use super::{ServeState, StatefulRouter}; pub fn routes() -> StatefulRouter { - Router::new().route("/audits/config", get(audit_config)) + routes![audit_config,].into_router() } +#[get("/audits/config")] async fn audit_config(State(state): State>) -> impl IntoResponse { Json(state.app.cfg().await.audit_config()).into_response() } diff --git a/src/api/cache.rs b/src/api/cache.rs index a6c960e9..8a9053e0 100644 --- a/src/api/cache.rs +++ b/src/api/cache.rs @@ -1,22 +1,22 @@ use std::sync::Arc; -use axum::{ - extract::State, - http::StatusCode, - response::IntoResponse, - routing::{get, post}, - Json, Router, +use super::openapi::{ + http::{get, post}, + routes, IntoRouter, }; - use super::{IntoDataListPayload, ServeState, StatefulRouter}; +use crate::log; +use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; pub fn routes() -> StatefulRouter { - Router::new() - .route("/caches", get(caches)) - .route("/caches/config", get(cache_config)) - .route("/caches/flush", post(flush_cache)) + let route1 = routes![flush_cache, caches].into_router(); + let route2 = routes![cache_config].into_router(); + route1.merge(route2) + + // routes![flush_cache, caches, cache_config].into_router() } +#[get("/caches")] async fn caches(State(state): State>) -> impl IntoResponse { Json( (if let Some(c) = state.app.cache().await { @@ -28,13 +28,16 @@ async fn caches(State(state): State>) -> impl IntoResponse { ) } +#[post("/caches/flush")] async fn flush_cache(State(state): State>) -> StatusCode { if let Some(c) = state.app.cache().await { c.clear().await; } + log::info!("flushed cache"); StatusCode::NO_CONTENT } +#[get("/caches/config")] async fn cache_config(State(state): State>) -> impl IntoResponse { Json(state.app.cfg().await.cache_config()).into_response() } diff --git a/src/api/forward.rs b/src/api/forward.rs index f095ad00..c16e50e0 100644 --- a/src/api/forward.rs +++ b/src/api/forward.rs @@ -1,13 +1,15 @@ use std::sync::Arc; -use axum::{extract::State, response::IntoResponse, routing::get, Json, Router}; +use axum::{extract::State, response::IntoResponse, Json}; +use super::openapi::{http::get, routes, IntoRouter}; use super::{IntoDataListPayload, ServeState, StatefulRouter}; pub fn routes() -> StatefulRouter { - Router::new().route("/forwards", get(forwards)) + routes![forwards].into_router() } +#[get("/forwards")] async fn forwards(State(state): State>) -> impl IntoResponse { Json( state diff --git a/src/api/listener.rs b/src/api/listener.rs index 1c6b925b..24d5682a 100644 --- a/src/api/listener.rs +++ b/src/api/listener.rs @@ -1,13 +1,15 @@ use std::sync::Arc; -use axum::{extract::State, response::IntoResponse, routing::get, Json, Router}; +use axum::{extract::State, response::IntoResponse, Json}; +use super::openapi::{http::get, routes, IntoRouter}; use super::{IntoDataListPayload, ServeState, StatefulRouter}; pub fn routes() -> StatefulRouter { - Router::new().route("/listeners", get(listeners)) + routes![listeners].into_router() } +#[get("/listeners")] async fn listeners(State(state): State>) -> impl IntoResponse { Json( state diff --git a/src/api/log.rs b/src/api/log.rs index eda8a90e..14de475c 100644 --- a/src/api/log.rs +++ b/src/api/log.rs @@ -1,13 +1,14 @@ use std::sync::Arc; -use axum::{extract::State, response::IntoResponse, routing::get, Json, Router}; - +use super::openapi::{http::get, routes, IntoRouter}; use super::{ServeState, StatefulRouter}; +use axum::{extract::State, response::IntoResponse, Json}; pub fn routes() -> StatefulRouter { - Router::new().route("/logs/config", get(log_config)) + routes![log_config].into_router() } +#[get("/logs/config")] async fn log_config(State(state): State>) -> impl IntoResponse { Json(state.app.cfg().await.log_config()).into_response() } diff --git a/src/api/mod.rs b/src/api/mod.rs index 9921c044..7e930129 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,12 +1,13 @@ -use std::sync::Arc; - use axum::{ http::StatusCode, response::{IntoResponse, Response}, routing::get, - Json, Router, + Json, }; +use cfg_if::cfg_if; +use openapi::Router; use serde::{Deserialize, Serialize}; +use std::sync::Arc; mod address; mod audit; @@ -15,6 +16,7 @@ mod forward; mod listener; mod log; mod nameserver; +mod openapi; mod serve_dns; mod settings; @@ -27,10 +29,39 @@ pub struct ServeState { pub dns_handle: DnsHandle, } -pub fn routes() -> StatefulRouter { - Router::new() +pub fn routes() -> axum::Router> { + use utoipa::openapi::InfoBuilder; + let (router, mut openapi) = Router::new() .merge(serve_dns::routes()) .nest("/api", api_routes()) + .split_for_parts(); + openapi.info = InfoBuilder::new() + .title(crate::NAME) + .version(crate::version()) + .build(); + + cfg_if! { + if #[cfg(feature = "swagger-ui-cdn")] + { + router.merge(openapi::swagger_cdn("/api/docs", "/api/openapi.json", openapi, None)) + } + else if #[cfg(feature = "swagger-ui-embed")] + { + use utoipa_swagger_ui::{Config, SwaggerUi}; + router.merge( + SwaggerUi::new("/api/docs") + .config( + Config::default() + .show_extensions(true) + .show_common_extensions(true) + .use_base_layout(), + ) + .url("/api/openapi.json", openapi), + ) + } else { + router + } + } } fn api_routes() -> StatefulRouter { diff --git a/src/api/nameserver.rs b/src/api/nameserver.rs index f00fccb7..8ba9dd9d 100644 --- a/src/api/nameserver.rs +++ b/src/api/nameserver.rs @@ -1,13 +1,15 @@ use std::sync::Arc; -use axum::{extract::State, response::IntoResponse, routing::get, Json, Router}; +use axum::{extract::State, response::IntoResponse, Json}; +use super::openapi::{http::get, routes, IntoRouter}; use super::{IntoDataListPayload, ServeState, StatefulRouter}; pub fn routes() -> StatefulRouter { - Router::new().route("/nameservers", get(nameservers)) + routes![nameservers].into_router() } +#[get("/nameservers")] async fn nameservers(State(state): State>) -> impl IntoResponse { Json( state diff --git a/src/api/openapi.rs b/src/api/openapi.rs new file mode 100644 index 00000000..102f33c4 --- /dev/null +++ b/src/api/openapi.rs @@ -0,0 +1,62 @@ +#![allow(unused_imports)] + +use axum::routing::MethodRouter; +use utoipa::{ + openapi::{Paths, RefOr, Schema}, + OpenApi, +}; + +pub use utoipa::{IntoParams, ToSchema}; +pub use utoipa_axum::{router::OpenApiRouter as Router, routes}; + +pub mod http { + pub use utoipa::{any, delete, get, head, options, patch, post, put}; +} + +pub trait IntoRouter { + fn into_router(self) -> Router; +} + +impl IntoRouter + for (Vec<(String, RefOr)>, Paths, MethodRouter) +{ + fn into_router(self) -> Router { + Router::new().routes(self) + } +} + +#[cfg(feature = "swagger-ui-cdn")] +pub fn swagger_cdn( + doc_url: &str, + openapi_url: &str, + openapi: utoipa::openapi::OpenApi, + cdn: Option<&str>, +) -> axum::Router { + use axum::{ + extract::State, + response::{Html, Json}, + routing::get, + Router, + }; + use std::sync::Arc; + use utoipa::openapi::OpenApi; + + // https://unpkg.com/swagger-ui-dist/index.html + let cdn = cdn.unwrap_or("https://unpkg.com/swagger-ui-dist"); + let html = include_str!("../../swagger-ui.html") + .replace("{cdn}", cdn) + .replace("{openapi}", openapi_url) + .replace("{title}", crate::NAME); + + async fn doc(State(doc): State>) -> Json { + Json(doc.as_ref().clone()) + } + + async fn index(State(html): State>) -> Html { + Html(html.to_string()) + } + + Router::new() + .route(doc_url, get(index).with_state(Arc::new(html))) + .route(openapi_url, get(doc).with_state(Arc::new(openapi))) +} diff --git a/src/api/serve_dns.rs b/src/api/serve_dns.rs index 50a9ab92..1261c4ed 100644 --- a/src/api/serve_dns.rs +++ b/src/api/serve_dns.rs @@ -1,31 +1,59 @@ use std::net::SocketAddr; -use std::{collections::HashMap, sync::Arc}; +use std::sync::Arc; use axum::body::Body; +use axum::extract::Query; use axum::http::{header, HeaderValue, StatusCode}; use axum::response::{IntoResponse, Response}; use axum::{ body::Bytes, - extract::{self, ConnectInfo, FromRequest, Request, State}, - routing::any, - Router, + extract::{ConnectInfo, FromRequest, Request, State}, }; -use serde::Serialize; +use serde::{Deserialize, Serialize}; +use super::openapi::{ + http::{get, post}, + routes, IntoParams, IntoRouter, ToSchema, +}; use super::{ServeState, StatefulRouter}; use crate::{dns::SerialMessage, libdns::Protocol, log}; pub fn routes() -> StatefulRouter { - Router::new().route("/dns-query", any(serve_dns)) + routes![serve_dns_get, serve_dns].into_router() +} + +#[get("/dns-query", params(QueryParam), responses( + (status = 200, description = "DNS response", body = DnsResponse) +))] +async fn serve_dns_get( + State(state): State>, + Query(parameters): Query, + ConnectInfo(addr): ConnectInfo, + req: Request, +) -> Response { + // https://developers.cloudflare.com/1.1.1.1/encryption/dns-over-https/make-api-requests/dns-json/ + match process(&state, req, addr, Some(parameters)).await { + Ok((content_type, bytes)) => { + let mut res = Body::from(bytes).into_response(); + res.headers_mut() + .insert(header::CONTENT_TYPE, HeaderValue::from_static(content_type)); + res + } + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!(r#"{{ "error": "{0}" }}"#, err), + ) + .into_response(), + } } +#[post("/dns-query")] async fn serve_dns( State(state): State>, - extract::Query(parameters): extract::Query>, ConnectInfo(addr): ConnectInfo, req: Request, ) -> Response { - match process(&state, req, addr, parameters).await { + match process(&state, req, addr, None).await { Ok((content_type, bytes)) => { let mut res = Body::from(bytes).into_response(); res.headers_mut() @@ -44,7 +72,7 @@ async fn process( state: &ServeState, req: Request, addr: SocketAddr, - parameters: HashMap, + query_param: Option, ) -> anyhow::Result<(&'static str, Bytes)> { const APPLICATION_DNS_MESSAGE: &str = "application/dns-message"; const APPLICATION_JSON: &str = "application/json"; @@ -63,42 +91,35 @@ async fn process( let accept_dns_message = accept == APPLICATION_DNS_MESSAGE; - let req_msg = if !accept_dns_message && parameters.contains_key("name") - || parameters.contains_key("query") - { - // https://developers.cloudflare.com/1.1.1.1/encryption/dns-over-https/make-api-requests/dns-json/ - use crate::libdns::proto::{ - op::{Edns, Message, Query}, - rr::{Name, RecordType}, - }; - - let name: Name = parameters - .get("name") - .or_else(|| parameters.get("query")) - .ok_or_else(|| anyhow::anyhow!("Query name is required"))? - .parse()?; - - let query_type = match parameters.get("type") { - Some(s) => s.parse::().map(RecordType::from).or(s.parse())?, - None => RecordType::A, - }; - - let dnssec = matches!(parameters.get("do"), Some(s) if s == "true" || s == "1"); - let checking_disabled = matches!(parameters.get("cd"), Some(s) if s == "true" || s == "1"); - - let mut message = Message::new(); - message.add_query(Query::query(name, query_type)); - message.set_checking_disabled(checking_disabled); - if dnssec { - let mut edns = Edns::new(); - edns.set_dnssec_ok(dnssec); - message.set_edns(edns); + let req_msg = match query_param { + Some(query_param) if !accept_dns_message => { + // https://developers.cloudflare.com/1.1.1.1/encryption/dns-over-https/make-api-requests/dns-json/ + use crate::libdns::proto::{ + op::{Edns, Message, Query}, + rr::{Name, RecordType}, + }; + + let name: Name = query_param.name.parse()?; + let query_type: RecordType = query_param.query_type.parse().unwrap_or(RecordType::A); + + let dnssec = query_param.dnssec; + let checking_disabled = query_param.checking_disabled; + + let mut message = Message::new(); + message.add_query(Query::query(name, query_type)); + message.set_checking_disabled(checking_disabled); + if dnssec { + let mut edns = Edns::new(); + edns.set_dnssec_ok(dnssec); + message.set_edns(edns); + } + + SerialMessage::raw(message, addr, Protocol::Https) + } + _ => { + let bytes = Bytes::from_request(req, &state).await?; + SerialMessage::binary(bytes.into(), addr, Protocol::Https) } - - SerialMessage::raw(message, addr, Protocol::Https) - } else { - let bytes = Bytes::from_request(req, &state).await?; - SerialMessage::binary(bytes.into(), addr, Protocol::Https) }; let res_msg = state.dns_handle.send(req_msg).await; @@ -112,16 +133,40 @@ async fn process( }; ( APPLICATION_JSON, - serde_json::to_vec(&JsonMessage::from(message))?, + serde_json::to_vec(&DnsResponse::from(message))?, ) }; Ok((content_type, bytes.into())) } -#[derive(Serialize)] +#[derive(Deserialize, IntoParams)] +struct QueryParam { + /// Query name + name: String, + + /// Query type (either a numeric value or text ↗). + #[serde(default = "QueryParam::default_query_type", rename = "type")] + query_type: String, + + /// DO bit - whether the client wants DNSSEC data (either empty or one of 0, false, 1, or true). + #[serde(default, rename = "do")] + dnssec: bool, + + /// CD bit - disable validation (either empty or one of 0, false, 1, or true). + #[serde(default, rename = "cd")] + checking_disabled: bool, +} + +impl QueryParam { + fn default_query_type() -> String { + "A".to_string() + } +} + +#[derive(Serialize, ToSchema)] #[allow(non_snake_case)] -struct JsonMessage { +struct DnsResponse { /// The Response Code of the DNS Query status: u16, @@ -157,13 +202,13 @@ struct JsonMessage { Answer: Vec, } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] struct Question { name: String, r#type: u16, } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] #[allow(non_snake_case)] struct Answer { name: String, @@ -172,9 +217,9 @@ struct Answer { data: String, } -impl From for JsonMessage { +impl From for DnsResponse { fn from(message: crate::libdns::proto::op::Message) -> Self { - JsonMessage { + DnsResponse { status: message.response_code().into(), TC: message.truncated(), RD: message.recursion_desired(), diff --git a/src/api/settings.rs b/src/api/settings.rs index 058d7b7c..e75a81dd 100644 --- a/src/api/settings.rs +++ b/src/api/settings.rs @@ -1,13 +1,16 @@ use std::sync::Arc; -use axum::{extract::State, response::IntoResponse, routing::get, Json, Router}; - +use super::openapi::{http::get, routes, IntoRouter}; use super::{ServeState, StatefulRouter}; +use axum::{extract::State, response::IntoResponse, Json}; pub fn routes() -> StatefulRouter { - Router::new().route("/server-name", get(server_name)) + routes![server_name].into_router() } +#[get("/server-name", responses( + (status = 200, description = "Server Name", content_type="application/json", body = String ) +))] async fn server_name(State(state): State>) -> impl IntoResponse { Json(state.app.cfg().await.server_name()) } diff --git a/src/dns_conf.rs b/src/dns_conf.rs index ba24bb51..010f33f5 100644 --- a/src/dns_conf.rs +++ b/src/dns_conf.rs @@ -1,5 +1,6 @@ use cfg_if::cfg_if; use ipnet::IpNet; +use std::borrow::Cow; use std::collections::{HashMap, HashSet}; use std::ffi::OsStr; use std::fs::File; @@ -45,8 +46,10 @@ pub struct RuntimeConfig { impl RuntimeConfig { pub fn load>(path: Option

) -> Arc { if let Some(ref conf) = path { - let path = conf.as_ref(); - + let mut path = Cow::Borrowed(conf.as_ref()); + if path.is_dir() { + path = Cow::Owned(path.join(format!("{}.conf", crate::NAME.to_lowercase()))); + } RuntimeConfig::load_from_file(path) } else { #[cfg(feature = "service")] diff --git a/src/server/https.rs b/src/server/https.rs index c6758702..1d2345c2 100644 --- a/src/server/https.rs +++ b/src/server/https.rs @@ -1,11 +1,18 @@ -use axum::extract::Request; +use axum::{ + extract::{connect_info::IntoMakeServiceWithConnectInfo, Request}, + Router, +}; use hyper::body::Incoming; use hyper_util::{ rt::{TokioExecutor, TokioIo}, server, }; use std::{convert::Infallible, io, net::SocketAddr, sync::Arc}; -use tokio::{net, task::JoinSet}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net, + task::JoinSet, +}; use tokio_util::sync::CancellationToken; use tower::{Service as _, ServiceExt}; @@ -24,21 +31,20 @@ pub fn serve( app: Arc, listener: net::TcpListener, dns_handle: DnsHandle, - certificate_and_key: (Vec, PrivateKey), + (cert, key): (Vec, PrivateKey), ) -> io::Result { let token = CancellationToken::new(); let cancellation_token = token.clone(); - let tls_config = tls_server_config(b"h2", certificate_and_key.0, certificate_and_key.1) - .map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!("error creating TLS acceptor: {e}"), - ) - })?; - log::debug!("registered HTTPS: {:?}", listener); + let tls_config = tls_server_config(b"h2", cert, key).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("error creating TLS acceptor: {e}"), + ) + })?; + let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config)); let state = Arc::new(ServeState { app, dns_handle }); @@ -83,32 +89,15 @@ pub fn serve( // perform the TLS let tls_stream = tls_acceptor.accept(tcp_stream).await; - let socket = match tls_stream { - Ok(tls_stream) => TokioIo::new(tls_stream), + Ok(tls_stream) => tls_stream, Err(e) => { log::debug!("https handshake src: {} error: {}", src_addr, e); return; } }; - log::debug!("accepted HTTPS request from: {}", src_addr); - - let tower_service = unwrap_infallible(make_service.call(src_addr).await); - - let hyper_service = - hyper::service::service_fn(move |request: Request| { - tower_service.clone().oneshot(request) - }); - - if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) - .http2() - .enable_connect_protocol() - .serve_connection_with_upgrades(socket, hyper_service) - .await - { - eprintln!("failed to serve connection: {err:#}"); - } + serve_connection(&mut make_service, socket, src_addr).await; }); reap_tasks(&mut inner_join_set); @@ -118,6 +107,30 @@ pub fn serve( Ok(token) } +async fn serve_connection( + make_service: &mut IntoMakeServiceWithConnectInfo, + io: I, + src_addr: SocketAddr, +) where + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + let socket = TokioIo::new(io); + let tower_service = unwrap_infallible(make_service.call(src_addr).await); + + let hyper_service = hyper::service::service_fn(move |request: Request| { + tower_service.clone().oneshot(request) + }); + + if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) + .http2() + .enable_connect_protocol() + .serve_connection_with_upgrades(socket, hyper_service) + .await + { + eprintln!("failed to serve connection: {err:#}"); + } +} + fn unwrap_infallible(result: Result) -> T { match result { Ok(value) => value, diff --git a/swagger-ui.html b/swagger-ui.html new file mode 100644 index 00000000..f895be1f --- /dev/null +++ b/swagger-ui.html @@ -0,0 +1,44 @@ + + + + + + + {title} | API Docs + + + + + + + +

+ + + + + +