Skip to content

Commit

Permalink
fix: add missing EP config keys
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Jul 12, 2024
1 parent 04da381 commit 9d25514
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 28 deletions.
23 changes: 21 additions & 2 deletions src/execution_providers/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ pub struct CUDAExecutionProvider {
cudnn_conv_use_max_workspace: Option<bool>,
cudnn_conv1d_pad_to_nc1d: Option<bool>,
enable_cuda_graph: Option<bool>,
enable_skip_layer_norm_strict_mode: Option<bool>
enable_skip_layer_norm_strict_mode: Option<bool>,
use_tf32: Option<bool>,
prefer_nhwc: Option<bool>
}

impl CUDAExecutionProvider {
Expand Down Expand Up @@ -156,6 +158,21 @@ impl CUDAExecutionProvider {
self
}

/// TF32 is a math mode available on NVIDIA GPUs since Ampere. It allows certain float32 matrix multiplications and
/// convolutions to run much faster on tensor cores with TensorFloat-32 reduced precision: float32 inputs are
/// rounded with 10 bits of mantissa and results are accumulated with float32 precision.
#[must_use]
pub fn with_tf32(mut self, enable: bool) -> Self {
self.use_tf32 = Some(enable);
self
}

#[must_use]
pub fn with_prefer_nhwc(mut self) -> Self {
self.prefer_nhwc = Some(true);
self
}

#[must_use]
pub fn build(self) -> ExecutionProviderDispatch {
self.into()
Expand Down Expand Up @@ -199,7 +216,9 @@ impl ExecutionProvider for CUDAExecutionProvider {
cudnn_conv_use_max_workspace = self.cudnn_conv_use_max_workspace.map(<bool as Into<i32>>::into),
cudnn_conv1d_pad_to_nc1d = self.cudnn_conv1d_pad_to_nc1d.map(<bool as Into<i32>>::into),
enable_cuda_graph = self.enable_cuda_graph.map(<bool as Into<i32>>::into),
enable_skip_layer_norm_strict_mode = self.enable_skip_layer_norm_strict_mode.map(<bool as Into<i32>>::into)
enable_skip_layer_norm_strict_mode = self.enable_skip_layer_norm_strict_mode.map(<bool as Into<i32>>::into),
use_tf32 = self.use_tf32.map(<bool as Into<i32>>::into),
prefer_nhwc = self.prefer_nhwc.map(<bool as Into<i32>>::into)
};
if let Err(e) =
crate::error::status_to_result(crate::ortsys![unsafe UpdateCUDAProviderOptions(cuda_options, key_ptrs.as_ptr(), value_ptrs.as_ptr(), len as _)])
Expand Down
114 changes: 93 additions & 21 deletions src/execution_providers/qnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
session::SessionBuilder
};

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QNNExecutionProviderPerformanceMode {
Default,
Burst,
Expand Down Expand Up @@ -34,7 +34,7 @@ impl QNNExecutionProviderPerformanceMode {
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QNNExecutionProviderProfilingLevel {
Off,
Basic,
Expand All @@ -51,14 +51,41 @@ impl QNNExecutionProviderProfilingLevel {
}
}

#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum QNNExecutionProviderContextPriority {
Low,
#[default]
Normal,
NormalHigh,
High
}

impl QNNExecutionProviderContextPriority {
pub fn as_str(&self) -> &'static str {
match self {
QNNExecutionProviderContextPriority::Low => "low",
QNNExecutionProviderContextPriority::Normal => "normal",
QNNExecutionProviderContextPriority::NormalHigh => "normal_high",
QNNExecutionProviderContextPriority::High => "normal_high"
}
}
}

#[derive(Debug, Default, Clone)]
pub struct QNNExecutionProvider {
backend_path: Option<String>,
qnn_context_cache_enable: Option<bool>,
qnn_context_cache_path: Option<String>,
profiling_level: Option<QNNExecutionProviderProfilingLevel>,
profiling_file_path: Option<String>,
rpc_control_latency: Option<u32>,
htp_performance_mode: Option<QNNExecutionProviderPerformanceMode>
vtcm_mb: Option<usize>,
htp_performance_mode: Option<QNNExecutionProviderPerformanceMode>,
qnn_saver_path: Option<String>,
qnn_context_priority: Option<QNNExecutionProviderContextPriority>,
htp_graph_finalization_optimization_mode: Option<u8>,
soc_model: Option<String>,
htp_arch: Option<u32>,
device_id: Option<i32>,
enable_htp_fp16_precision: Option<bool>
}

impl QNNExecutionProvider {
Expand All @@ -70,32 +97,28 @@ impl QNNExecutionProvider {
self
}

/// Configure whether to enable QNN graph creation from a cached QNN context file. If enabled, the QNN EP
/// will load from the cached QNN context binary if it exists, or create one if it does not exist.
#[must_use]
pub fn with_enable_context_cache(mut self, enable: bool) -> Self {
self.qnn_context_cache_enable = Some(enable);
pub fn with_profiling(mut self, level: QNNExecutionProviderProfilingLevel) -> Self {
self.profiling_level = Some(level);
self
}

/// Explicitly provide the QNN context cache file (see [`QNNExecutionProvider::with_enable_context_cache`]).
/// Defaults to `model_file.onnx.bin` if not provided.
#[must_use]
pub fn with_context_cache_path(mut self, path: impl ToString) -> Self {
self.qnn_context_cache_path = Some(path.to_string());
pub fn with_profiling_path(mut self, path: impl ToString) -> Self {
self.profiling_file_path = Some(path.to_string());
self
}

/// Allows client to set up RPC control latency in microseconds.
#[must_use]
pub fn with_profiling(mut self, level: QNNExecutionProviderProfilingLevel) -> Self {
self.profiling_level = Some(level);
pub fn with_rpc_control_latency(mut self, latency: u32) -> Self {
self.rpc_control_latency = Some(latency);
self
}

/// Allows client to set up RPC control latency in microseconds.
#[must_use]
pub fn with_rpc_control_latency(mut self, latency: u32) -> Self {
self.rpc_control_latency = Some(latency);
pub fn with_vtcm_mb(mut self, mb: usize) -> Self {
self.vtcm_mb = Some(mb);
self
}

Expand All @@ -105,6 +128,48 @@ impl QNNExecutionProvider {
self
}

#[must_use]
pub fn with_saver_path(mut self, path: impl ToString) -> Self {
self.qnn_saver_path = Some(path.to_string());
self
}

#[must_use]
pub fn with_context_priority(mut self, priority: QNNExecutionProviderContextPriority) -> Self {
self.qnn_context_priority = Some(priority);
self
}

#[must_use]
pub fn with_htp_graph_finalization_optimization_mode(mut self, mode: u8) -> Self {
self.htp_graph_finalization_optimization_mode = Some(mode);
self
}

#[must_use]
pub fn with_soc_model(mut self, model: impl ToString) -> Self {
self.soc_model = Some(model.to_string());
self
}

#[must_use]
pub fn with_htp_arch(mut self, arch: u32) -> Self {
self.htp_arch = Some(arch);
self
}

#[must_use]
pub fn with_device_id(mut self, device: i32) -> Self {
self.device_id = Some(device);
self
}

#[must_use]
pub fn with_htp_fp16_precision(mut self, enable: bool) -> Self {
self.enable_htp_fp16_precision = Some(enable);
self
}

#[must_use]
pub fn build(self) -> ExecutionProviderDispatch {
self.into()
Expand Down Expand Up @@ -133,10 +198,17 @@ impl ExecutionProvider for QNNExecutionProvider {
let (key_ptrs, value_ptrs, len, _keys, _values) = super::map_keys! {
backend_path = self.backend_path.clone(),
profiling_level = self.profiling_level.as_ref().map(QNNExecutionProviderProfilingLevel::as_str),
qnn_context_cache_enable = self.qnn_context_cache_enable.map(<bool as Into<i32>>::into),
qnn_context_cache_path = self.qnn_context_cache_path.clone(),
profiling_file_path = self.profiling_file_path.clone(),
rpc_control_latency = self.rpc_control_latency,
vtcm_mb = self.vtcm_mb,
htp_performance_mode = self.htp_performance_mode.as_ref().map(QNNExecutionProviderPerformanceMode::as_str),
rpc_control_latency = self.rpc_control_latency
qnn_saver_path = self.qnn_saver_path.clone(),
qnn_context_priorty = self.qnn_context_priority.as_ref().map(QNNExecutionProviderContextPriority::as_str),
htp_graph_finalization_optimization_mode = self.htp_graph_finalization_optimization_mode,
soc_model = self.soc_model.clone(),
htp_arch = self.htp_arch,
device_id = self.device_id,
enable_htp_fp16_precision = self.enable_htp_fp16_precision.map(<bool as Into<i32>>::into)
};
let ep_name = std::ffi::CString::new("QNN").unwrap_or_else(|_| unreachable!());
return crate::error::status_to_result(crate::ortsys![unsafe SessionOptionsAppendExecutionProvider(
Expand Down
Loading

0 comments on commit 9d25514

Please sign in to comment.