Skip to content

Commit

Permalink
refactor init and exit logic
Browse files Browse the repository at this point in the history
  • Loading branch information
OlofBlomqvist committed Sep 6, 2024
1 parent 26561c6 commit 450f98a
Show file tree
Hide file tree
Showing 13 changed files with 540 additions and 635 deletions.
2 changes: 1 addition & 1 deletion odd-box.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ip = "127.0.0.1"
tls_port = 4343
auto_start = false
root_dir = "~"
log_level = "warn"
log_level = "info"
port_range_start = 4200
default_log_format = "standard"
env_vars = [
Expand Down
88 changes: 43 additions & 45 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,59 +72,57 @@ async fn set_cors(request: axum::extract::Request, next: axum::middleware::Next,
}


pub async fn run(globally_shared_state: Arc<crate::global_state::GlobalState>,port:Option<u16>,tracing_broadcaster:tokio::sync::broadcast::Sender::<String>) {
pub async fn run(globally_shared_state: Arc<crate::global_state::GlobalState>,port:u16,tracing_broadcaster:tokio::sync::broadcast::Sender::<String>) {

if let Some(p) = port {
let websocket_state = WebSocketGlobalState {

let websocket_state = WebSocketGlobalState {

broadcast_channel: tokio::sync::broadcast::channel(10).0,
global_state: globally_shared_state.clone()
};

let socket_address: SocketAddr = format!("127.0.0.1:{p}").parse().unwrap();
let listener = tokio::net::TcpListener::bind(socket_address).await.unwrap();
broadcast_channel: tokio::sync::broadcast::channel(10).0,
global_state: globally_shared_state.clone()
};

let socket_address: SocketAddr = format!("127.0.0.1:{port}").parse().unwrap();
let listener = tokio::net::TcpListener::bind(socket_address).await.unwrap();


let cors_env_var = std::env::vars().find(|(key,_)| key=="ODDBOX_CORS_ALLOWED_ORIGIN").map(|x|x.1.to_lowercase());
let cors_env_var_cloned_for_ws = cors_env_var.clone();
let cors_env_var = std::env::vars().find(|(key,_)| key=="ODDBOX_CORS_ALLOWED_ORIGIN").map(|x|x.1.to_lowercase());
let cors_env_var_cloned_for_ws = cors_env_var.clone();

let mut router = Router::new()
let mut router = Router::new()

// API DOCS
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
.merge(Redoc::with_url("/redoc", ApiDoc::openapi()))
.merge(RapiDoc::new("/api-docs/openapi.json").path("/rapidoc"))
// API ROUTES
.merge(crate::api::controllers::routes(globally_shared_state.clone()).await)
// API DOCS
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
.merge(Redoc::with_url("/redoc", ApiDoc::openapi()))
.merge(RapiDoc::new("/api-docs/openapi.json").path("/rapidoc"))

// API ROUTES
.merge(crate::api::controllers::routes(globally_shared_state.clone()).await)

.route("/", axum::routing::get(root))
.route("/script.js", axum::routing::get(script))
.route("/", axum::routing::get(root))
.route("/script.js", axum::routing::get(script))

// WEBSOCKET ROUTE FOR LOGS
.route("/ws/live_logs", axum::routing::get( move|ws,user_agent,origin,addr,state|
ws_log_messages_handler(ws,user_agent,origin,addr,state, cors_env_var_cloned_for_ws)).with_state(websocket_state.clone()));


// in some cases one might want to allow CORS from a specific origin. this is not currently allowed to do from the config file
// so we use an environment variable to set this. might change in the future if it becomes a common use case
if let Some(cors_var) = cors_env_var {
router = router.layer(
CorsLayer::new()
.allow_methods(Any)
.allow_headers(Any)
.expose_headers(Any))
.layer(axum::middleware::from_fn(move |request: axum::extract::Request, next: axum::middleware::Next|set_cors(request,next,cors_var.clone())));
};

tokio::spawn(broadcast_manager(websocket_state,tracing_broadcaster));

axum::serve(listener, router.into_make_service_with_connect_info::<SocketAddr>())
.await
.unwrap()
}

// WEBSOCKET ROUTE FOR LOGS
.route("/ws/live_logs", axum::routing::get( move|ws,user_agent,origin,addr,state|
ws_log_messages_handler(ws,user_agent,origin,addr,state, cors_env_var_cloned_for_ws)).with_state(websocket_state.clone()));


// in some cases one might want to allow CORS from a specific origin. this is not currently allowed to do from the config file
// so we use an environment variable to set this. might change in the future if it becomes a common use case
if let Some(cors_var) = cors_env_var {
router = router.layer(
CorsLayer::new()
.allow_methods(Any)
.allow_headers(Any)
.expose_headers(Any))
.layer(axum::middleware::from_fn(move |request: axum::extract::Request, next: axum::middleware::Next|set_cors(request,next,cors_var.clone())));
};

tokio::spawn(broadcast_manager(websocket_state,tracing_broadcaster));

axum::serve(listener, router.into_make_service_with_connect_info::<SocketAddr>())
.await
.unwrap()

}

// Define the handler function for the root path
Expand Down
136 changes: 136 additions & 0 deletions src/certs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use dashmap::DashMap;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio_rustls::rustls::server::{ClientHello, ResolvesServerCert};

#[derive(Debug)]
pub struct DynamicCertResolver {
cache: DashMap<String, std::sync::Arc<tokio_rustls::rustls::sign::CertifiedKey>>,
}

impl DynamicCertResolver {
pub fn new() -> Self {
DynamicCertResolver {
cache: DashMap::new(),
}
}
}

impl ResolvesServerCert for DynamicCertResolver {
fn resolve(&self, client_hello: ClientHello) -> Option<std::sync::Arc<tokio_rustls::rustls::sign::CertifiedKey>> {

let server_name = client_hello.server_name()?;

if let Some(certified_key) = self.cache.get(server_name) {
tracing::trace!("Returning a cached certificate for {:?}",server_name);
return Some(certified_key.clone());
}


let odd_cache_base = ".odd_box_cache";

let base_path = std::path::Path::new(odd_cache_base);
let host_name_cert_path = base_path.join(server_name);

if let Err(e) = std::fs::create_dir_all(&host_name_cert_path) {
tracing::error!("Could not create directory: {:?}", e);
return None;
}

let cert_path = format!("{}/{}/cert.pem",odd_cache_base,server_name);
let key_path = format!("{}/{}/key.pem",odd_cache_base,server_name);

if let Err(e) = generate_cert_if_not_exist(server_name, &cert_path, &key_path) {
tracing::error!("Could not generate cert: {:?}", e);
return None
}


if let Ok(cert_chain) = my_certs(&cert_path) {

if cert_chain.is_empty() {
tracing::warn!("EMPTY CERT CHAIN FOR {}",server_name);
return None
}
if let Ok(private_key) = my_rsa_private_keys(&key_path) {
if let Ok(rsa_signing_key) = tokio_rustls::rustls::crypto::aws_lc_rs::sign::any_supported_type(&private_key) {
let result = std::sync::Arc::new(tokio_rustls::rustls::sign::CertifiedKey::new(
cert_chain,
rsa_signing_key
));
self.cache.insert(server_name.into(), result.clone());
Some(result)

} else {
tracing::error!("rustls::crypto::ring::sign::any_supported_type - failed to read cert: {cert_path}");
None
}
} else {
tracing::error!("my_rsa_private_keys - failed to read cert: {cert_path}");
None
}
} else {
tracing::error!("generate_cert_if_not_exist - failed to read cert: {cert_path}");
None
}
}
}

use std::io::BufReader;
use std::fs::File;


fn generate_cert_if_not_exist(hostname: &str, cert_path: &str,key_path: &str) -> Result<(),String> {

let crt_exists = std::fs::metadata(cert_path).is_ok();
let key_exists = std::fs::metadata(key_path).is_ok();

if crt_exists && key_exists {
tracing::debug!("Using existing certificate for {}",hostname);
return Ok(())
}

if crt_exists != key_exists {
return Err(String::from("Missing key or crt for this hostname. Remove both if you want to generate a new set, or add the missing one."))
}

tracing::debug!("Generating new certificate for site '{}'",hostname);


match rcgen::generate_simple_self_signed(
vec![hostname.to_owned()]
) {
Ok(cert) => {
tracing::trace!("Generating new self-signed certificate for host '{}'!",hostname);
let _ = std::fs::write(&cert_path, cert.cert.pem());
let _ = std::fs::write(&key_path, &cert.key_pair.serialize_pem());
Ok(())
},
Err(e) => Err(e.to_string())
}
}


fn my_certs(path: &str) -> Result<Vec<CertificateDer<'static>>, std::io::Error> {
let cert_file = File::open(path)?;
let mut reader = BufReader::new(cert_file);
let certs = rustls_pemfile::certs(&mut reader);
Ok(certs.filter_map(|cert|match cert {
Ok(x) => Some(x),
Err(_) => None,
}).collect())
}

fn my_rsa_private_keys(path: &str) -> Result<PrivateKeyDer, String> {

let file = File::open(&path).map_err(|e|format!("{e:?}"))?;
let mut reader = BufReader::new(file);
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut reader)
.collect::<Result<Vec<tokio_rustls::rustls::pki_types::PrivatePkcs8KeyDer>,_>>().map_err(|e|format!("{e:?}"))?;

match keys.len() {
0 => Err(format!("No PKCS8-encoded private key found in {path}").into()),
1 => Ok(PrivateKeyDer::Pkcs8(keys.remove(0))),
_ => Err(format!("More than one PKCS8-encoded private key found in {path}").into()),
}

}
8 changes: 4 additions & 4 deletions src/configuration/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use serde::Serialize;
use serde::Deserialize;
use utoipa::ToSchema;
use crate::global_state::GlobalState;
use crate::ProcId;
use crate::types::proc_info::ProcId;

use super::EnvVar;
use super::LogFormat;
Expand All @@ -17,7 +17,7 @@ use super::LogLevel;
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, Hash)]
pub struct InProcessSiteConfig {

#[serde(skip, default = "crate::ProcId::new")]
#[serde(skip, default = "crate::types::proc_info::ProcId::new")]
proc_id : ProcId,

/// This is set automatically each time we start a process so that we know which ports are in use
Expand Down Expand Up @@ -436,7 +436,7 @@ impl crate::configuration::OddBoxConfiguration<OddBoxV2Config> for OddBoxV2Confi
port_range_start: 4200,
hosted_process: Some(vec![
InProcessSiteConfig {
proc_id: crate::ProcId::new(),
proc_id: ProcId::new(),
active_port: None,
forward_subdomains: None,
disable_tcp_tunnel_mode: Some(false),
Expand Down Expand Up @@ -561,7 +561,7 @@ impl TryFrom<super::v1::OddBoxV1Config> for super::v2::OddBoxV2Config{
hosted_process: Some(old_config.hosted_process.unwrap_or_default().into_iter().map(|x|{
super::v2::InProcessSiteConfig {
exclude_from_start_all: None,
proc_id: crate::ProcId::new(),
proc_id: ProcId::new(),
active_port: None,
forward_subdomains: x.forward_subdomains,
disable_tcp_tunnel_mode: x.disable_tcp_tunnel_mode,
Expand Down
Loading

0 comments on commit 450f98a

Please sign in to comment.