diff --git a/Cargo.toml b/Cargo.toml index 2073cd37..5469eba8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -89,7 +89,7 @@ libloading = { version = "0.8", optional = true } ureq = { version = "2.1", optional = true, default-features = false, features = [ "tls" ] } sha2 = { version = "0.10", optional = true } -tracing = { version = "0.1", default-features = false, features = [ "std" ] } +tracing = { version = "0.1", optional = true, default-features = false, features = [ "std" ] } half = { version = "2.1", optional = true } [dev-dependencies] diff --git a/src/environment.rs b/src/environment.rs index 7c13affc..e91e5721 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -13,15 +13,12 @@ use std::{ any::Any, - ffi::{self, CStr, CString}, + ffi::CString, os::raw::c_void, ptr::{self, NonNull}, sync::{Arc, RwLock} }; -use ort_sys::c_char; -use tracing::{Level, debug}; - #[cfg(feature = "load-dynamic")] use crate::G_ORT_DYLIB_PATH; use crate::{AsPointer, error::Result, execution_providers::ExecutionProviderDispatch, ortsys}; @@ -66,7 +63,7 @@ impl AsPointer for Environment { impl Drop for Environment { fn drop(&mut self) { - debug!(ptr = ?self.ptr(), "Releasing environment"); + crate::debug!(ptr = ?self.ptr(), "Releasing environment"); ortsys![unsafe ReleaseEnv(self.ptr_mut())]; } } @@ -81,7 +78,7 @@ pub fn get_environment() -> Result> { // drop our read lock so we dont deadlock when `commit` takes a write lock drop(env); - debug!("Environment not yet initialized, creating a new one"); + crate::debug!("Environment not yet initialized, creating a new one"); Ok(EnvironmentBuilder::new().commit()?) } } @@ -191,11 +188,13 @@ pub(crate) unsafe extern "system" fn thread_create( .cast_const() .cast::(), Ok(Err(e)) => { - tracing::error!("Failed to create thread using manager: {e}"); + crate::error!("Failed to create thread using manager: {e}"); + let _ = e; ptr::null() } Err(e) => { - tracing::error!("Thread manager panicked: {e:?}"); + crate::error!("Thread manager panicked: {e:?}"); + let _ = e; ptr::null() } } @@ -204,7 +203,8 @@ pub(crate) unsafe extern "system" fn thread_create( pub(crate) unsafe extern "system" fn thread_join(ort_custom_thread_handle: ort_sys::OrtCustomThreadHandle) { let handle = Box::from_raw(ort_custom_thread_handle.cast_mut().cast::<::Thread>()); if let Err(e) = ::join(*handle) { - tracing::error!("Failed to join thread using manager: {e}"); + crate::error!("Failed to join thread using manager: {e}"); + let _ = e; } } @@ -279,14 +279,13 @@ impl EnvironmentBuilder { pub fn commit(self) -> Result> { let (env_ptr, thread_manager, has_global_threadpool) = if let Some(mut thread_pool_options) = self.global_thread_pool_options { let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut(); - let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger); - let logger_param: *mut std::ffi::c_void = std::ptr::null_mut(); let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!()); + #[cfg(feature = "tracing")] ortsys![ unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools( - logging_function, - logger_param, + Some(crate::logging::custom_logger), + ptr::null_mut(), ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, cname.as_ptr(), thread_pool_options.ptr(), @@ -294,28 +293,47 @@ impl EnvironmentBuilder { )?; nonNull(env_ptr) ]; + #[cfg(not(feature = "tracing"))] + ortsys![ + unsafe CreateEnvWithGlobalThreadPools( + crate::logging::default_log_level(), + cname.as_ptr(), + thread_pool_options.ptr(), + &mut env_ptr + )?; + nonNull(env_ptr) + ]; let thread_manager = thread_pool_options.thread_manager.take(); (env_ptr, thread_manager, true) } else { let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut(); - let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger); - // FIXME: What should go here? - let logger_param: *mut std::ffi::c_void = std::ptr::null_mut(); let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!()); + + #[cfg(feature = "tracing")] ortsys![ unsafe CreateEnvWithCustomLogger( - logging_function, - logger_param, + Some(crate::logging::custom_logger), + ptr::null_mut(), ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, cname.as_ptr(), &mut env_ptr )?; nonNull(env_ptr) ]; + #[cfg(not(feature = "tracing"))] + ortsys![ + unsafe CreateEnv( + crate::logging::default_log_level(), + cname.as_ptr(), + &mut env_ptr + )?; + nonNull(env_ptr) + ]; + (env_ptr, None, false) }; - debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created"); + crate::debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created"); if self.telemetry { ortsys![unsafe EnableTelemetryEvents(env_ptr)?]; @@ -394,30 +412,3 @@ pub fn init_from(path: impl ToString) -> EnvironmentBuilder { let _ = G_ORT_DYLIB_PATH.set(Arc::new(path.to_string())); EnvironmentBuilder::new() } - -/// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate. -pub(crate) extern "system" fn custom_logger( - _params: *mut ffi::c_void, - severity: ort_sys::OrtLoggingLevel, - _: *const c_char, - id: *const c_char, - code_location: *const c_char, - message: *const c_char -) { - assert_ne!(code_location, ptr::null()); - let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or(""); - assert_ne!(message, ptr::null()); - let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or(""); - assert_ne!(id, ptr::null()); - let id = unsafe { CStr::from_ptr(id) }.to_str().unwrap_or(""); - - let span = tracing::span!(Level::TRACE, "ort", id = id, location = code_location); - - match severity { - ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => tracing::event!(parent: &span, Level::TRACE, "{message}"), - ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => tracing::event!(parent: &span, Level::INFO, "{message}"), - ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => tracing::event!(parent: &span, Level::WARN, "{message}"), - ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => tracing::event!(parent: &span, Level::ERROR, "{message}"), - ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => tracing::event!(parent: &span, Level::ERROR, "(FATAL): {message}") - } -} diff --git a/src/execution_providers/mod.rs b/src/execution_providers/mod.rs index 5bd1883d..086cc5e0 100644 --- a/src/execution_providers/mod.rs +++ b/src/execution_providers/mod.rs @@ -263,20 +263,20 @@ pub(crate) fn apply_execution_providers( .ends_with("was not registered because its corresponding Cargo feature is not enabled.") { if ex.inner.supported_by_platform() { - tracing::warn!("{e}"); + crate::warn!("{e}"); } else { - tracing::debug!("{e} (note: additionally, `{}` is not supported on this platform)", ex.inner.as_str()); + crate::debug!("{e} (note: additionally, `{}` is not supported on this platform)", ex.inner.as_str()); } } else { - tracing::error!("An error occurred when attempting to register `{}`: {e}", ex.inner.as_str()); + crate::error!("An error occurred when attempting to register `{}`: {e}", ex.inner.as_str()); } } else { - tracing::info!("Successfully registered `{}`", ex.inner.as_str()); + crate::info!("Successfully registered `{}`", ex.inner.as_str()); fallback_to_cpu = false; } } if fallback_to_cpu { - tracing::warn!("No execution providers registered successfully. Falling back to CPU."); + crate::warn!("No execution providers registered successfully. Falling back to CPU."); } Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 18a51a2c..75e6008a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ pub mod environment; pub mod error; pub mod execution_providers; pub mod io_binding; +pub(crate) mod logging; pub mod memory; pub mod metadata; pub mod operator; @@ -39,6 +40,7 @@ pub use ort_sys as sys; #[cfg(feature = "load-dynamic")] pub use self::environment::init_from; +pub(crate) use self::logging::{debug, error, info, trace, warning as warn}; pub use self::{ environment::init, error::{Error, ErrorCode, Result} @@ -138,7 +140,7 @@ pub fn api() -> &'static ort_sys::OrtApi { let version_string = ((*base).GetVersionString)(); let version_string = CStr::from_ptr(version_string).to_string_lossy(); - tracing::info!("Loaded ONNX Runtime dylib with version '{version_string}'"); + crate::info!("Loaded ONNX Runtime dylib with version '{version_string}'"); let lib_minor_version = version_string.split('.').nth(1).map_or(0, |x| x.parse::().unwrap_or(0)); match lib_minor_version.cmp(&MINOR_VERSION) { @@ -147,7 +149,7 @@ pub fn api() -> &'static ort_sys::OrtApi { env!("CARGO_PKG_VERSION"), dylib_path() ), - std::cmp::Ordering::Greater => tracing::warn!( + std::cmp::Ordering::Greater => crate::warn!( "ort {} may have compatibility issues with the ONNX Runtime binary found at `{}`; expected GetVersionString to return '1.{MINOR_VERSION}.x', but got '{version_string}'", env!("CARGO_PKG_VERSION"), dylib_path() diff --git a/src/logging.rs b/src/logging.rs new file mode 100644 index 00000000..cfd42a0d --- /dev/null +++ b/src/logging.rs @@ -0,0 +1,81 @@ +#[cfg(feature = "tracing")] +use std::{ + ffi::{self, CStr}, + ptr +}; + +macro_rules! trace { + ($($arg:tt)+) => { + #[cfg(feature = "tracing")] + tracing::trace!($($arg)+); + } +} +macro_rules! debug { + ($($arg:tt)+) => { + #[cfg(feature = "tracing")] + tracing::debug!($($arg)+); + } +} +macro_rules! info { + ($($arg:tt)+) => { + #[cfg(feature = "tracing")] + tracing::info!($($arg)+); + } +} +macro_rules! warning { + ($($arg:tt)+) => { + #[cfg(feature = "tracing")] + tracing::warn!($($arg)+); + } +} +macro_rules! error { + ($($arg:tt)+) => { + #[cfg(feature = "tracing")] + tracing::error!($($arg)+); + } +} +pub(crate) use debug; +pub(crate) use error; +pub(crate) use info; +pub(crate) use trace; +pub(crate) use warning; + +#[cfg(not(feature = "tracing"))] +pub fn default_log_level() -> ort_sys::OrtLoggingLevel { + match std::env::var("ORT_LOG").as_deref() { + Ok("fatal") => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL, + Ok("error") => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR, + Ok("warning") => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + Ok("info") => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + Ok("verbose") => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + _ => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR + } +} + +/// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate. +#[cfg(feature = "tracing")] +pub(crate) extern "system" fn custom_logger( + _params: *mut ffi::c_void, + severity: ort_sys::OrtLoggingLevel, + _: *const ffi::c_char, + id: *const ffi::c_char, + code_location: *const ffi::c_char, + message: *const ffi::c_char +) { + assert_ne!(code_location, ptr::null()); + let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or(""); + assert_ne!(message, ptr::null()); + let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or(""); + assert_ne!(id, ptr::null()); + let id = unsafe { CStr::from_ptr(id) }.to_str().unwrap_or(""); + + let span = tracing::span!(tracing::Level::TRACE, "ort", id = id, location = code_location); + + match severity { + ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => tracing::event!(parent: &span, tracing::Level::TRACE, "{message}"), + ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => tracing::event!(parent: &span, tracing::Level::INFO, "{message}"), + ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => tracing::event!(parent: &span, tracing::Level::WARN, "{message}"), + ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => tracing::event!(parent: &span, tracing::Level::ERROR, "{message}"), + ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => tracing::event!(parent: &span, tracing::Level::ERROR, "(FATAL): {message}") + } +} diff --git a/src/session/builder/impl_commit.rs b/src/session/builder/impl_commit.rs index bc6e6a67..692f7ae2 100644 --- a/src/session/builder/impl_commit.rs +++ b/src/session/builder/impl_commit.rs @@ -32,10 +32,10 @@ impl SessionBuilder { }); let model_filepath = download_dir.join(&model_filename); let downloaded_path = if model_filepath.exists() { - tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download"); + crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download"); model_filepath } else { - tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model"); + crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model"); let resp = ureq::get(url).call().map_err(|e| Error::new(format!("Error downloading to file: {e}")))?; @@ -43,7 +43,7 @@ impl SessionBuilder { .header("Content-Length") .and_then(|s| s.parse::().ok()) .expect("Missing Content-Length header"); - tracing::info!(len, "Downloading {} bytes", len); + crate::info!(len, "Downloading {} bytes", len); let mut reader = resp.into_reader(); let temp_filepath = download_dir.join(format!("tmp_{}.{model_filename}", ort_sys::internal::random_identifier())); diff --git a/src/session/mod.rs b/src/session/mod.rs index b65b317d..732c6b27 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -74,7 +74,7 @@ impl AsPointer for SharedSessionInner { impl Drop for SharedSessionInner { fn drop(&mut self) { - tracing::debug!(ptr = ?self.session_ptr.as_ptr(), "dropping SharedSessionInner"); + crate::debug!(ptr = ?self.session_ptr.as_ptr(), "dropping SharedSessionInner"); ortsys![unsafe ReleaseSession(self.session_ptr.as_ptr())]; } } diff --git a/src/training/mod.rs b/src/training/mod.rs index 42c19013..e5a087bb 100644 --- a/src/training/mod.rs +++ b/src/training/mod.rs @@ -119,7 +119,7 @@ impl AsPointer for Checkpoint { impl Drop for Checkpoint { fn drop(&mut self) { - tracing::trace!("dropping checkpoint"); + crate::trace!("dropping checkpoint"); trainsys![unsafe ReleaseCheckpointState(self.ptr.as_ptr())]; } } diff --git a/src/training/trainer.rs b/src/training/trainer.rs index 86ceb42b..b8113c40 100644 --- a/src/training/trainer.rs +++ b/src/training/trainer.rs @@ -230,7 +230,7 @@ impl AsPointer for Trainer { impl Drop for Trainer { fn drop(&mut self) { - tracing::trace!("dropping trainer"); + crate::trace!("dropping trainer"); trainsys![unsafe ReleaseTrainingSession(self.ptr.as_ptr())]; } } diff --git a/src/value/mod.rs b/src/value/mod.rs index 5769cb0e..a8658752 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -70,7 +70,7 @@ impl AsPointer for ValueInner { impl Drop for ValueInner { fn drop(&mut self) { let ptr = self.ptr_mut(); - tracing::trace!("dropping value at {ptr:p}"); + crate::trace!("dropping value at {ptr:p}"); if self.drop { ortsys![unsafe ReleaseValue(ptr)]; }