Skip to content

Commit

Permalink
refactor: take thread counts as usize
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Feb 19, 2024
1 parent cc40e6b commit 2cc17c0
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,8 @@ impl SessionBuilder {
///
/// For configuring the number of threads used when the session execution mode is set to `Parallel`, see
/// [`SessionBuilder::with_inter_threads()`].
pub fn with_intra_threads(self, num_threads: i16) -> Result<Self> {
// We use a u16 in the builder to cover the 16-bits positive values of a i32.
let num_threads = num_threads as i32;
ortsys![unsafe SetIntraOpNumThreads(self.session_options_ptr, num_threads) -> Error::CreateSessionOptions];
pub fn with_intra_threads(self, num_threads: usize) -> Result<Self> {
ortsys![unsafe SetIntraOpNumThreads(self.session_options_ptr, num_threads as _) -> Error::CreateSessionOptions];
Ok(self)
}

Expand All @@ -191,10 +189,8 @@ impl SessionBuilder {
///
/// For configuring the number of threads used to parallelize the execution within nodes, see
/// [`SessionBuilder::with_intra_threads()`].
pub fn with_inter_threads(self, num_threads: i16) -> Result<Self> {
// We use a u16 in the builder to cover the 16-bits positive values of a i32.
let num_threads = num_threads as i32;
ortsys![unsafe SetInterOpNumThreads(self.session_options_ptr, num_threads) -> Error::CreateSessionOptions];
pub fn with_inter_threads(self, num_threads: usize) -> Result<Self> {
ortsys![unsafe SetInterOpNumThreads(self.session_options_ptr, num_threads as _) -> Error::CreateSessionOptions];
Ok(self)
}

Expand Down

0 comments on commit 2cc17c0

Please sign in to comment.