Skip to content

Commit

Permalink
refactor: size_t is usize
Browse files Browse the repository at this point in the history
bashing my head into a wall 3,000 times
  • Loading branch information
decahedron1 committed Oct 20, 2024
1 parent 6bb3e71 commit 28a127d
Show file tree
Hide file tree
Showing 16 changed files with 228 additions and 235 deletions.
255 changes: 124 additions & 131 deletions ort-sys/src/lib.rs

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/execution_providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ impl ExecutionProviderOptionsFFI {
self.value_ptrs.as_ptr()
}

pub fn len(&self) -> ort_sys::size_t {
self.key_ptrs.len() as _
pub fn len(&self) -> usize {
self.key_ptrs.len()
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/execution_providers/openvino.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
pub struct OpenVINOExecutionProvider {
device_type: Option<String>,
device_id: Option<String>,
num_threads: ort_sys::size_t,
num_threads: usize,
cache_dir: Option<String>,
context: *mut c_void,
enable_opencl_throttling: bool,
Expand Down Expand Up @@ -57,7 +57,7 @@ impl OpenVINOExecutionProvider {
/// explicitly set, default value of 8 is used during build time.
#[must_use]
pub fn with_num_threads(mut self, num_threads: usize) -> Self {
self.num_threads = num_threads as _;
self.num_threads = num_threads;
self
}

Expand Down
8 changes: 4 additions & 4 deletions src/execution_providers/rocm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
pub struct ROCmExecutionProvider {
device_id: i32,
miopen_conv_exhaustive_search: bool,
gpu_mem_limit: ort_sys::size_t,
gpu_mem_limit: usize,
arena_extend_strategy: ArenaExtendStrategy,
do_copy_in_default_stream: bool,
user_compute_stream: Option<*mut c_void>,
Expand All @@ -29,7 +29,7 @@ impl Default for ROCmExecutionProvider {
Self {
device_id: 0,
miopen_conv_exhaustive_search: false,
gpu_mem_limit: ort_sys::size_t::MAX,
gpu_mem_limit: usize::MAX,
arena_extend_strategy: ArenaExtendStrategy::NextPowerOfTwo,
do_copy_in_default_stream: true,
user_compute_stream: None,
Expand Down Expand Up @@ -57,7 +57,7 @@ impl ROCmExecutionProvider {

#[must_use]
pub fn with_mem_limit(mut self, limit: usize) -> Self {
self.gpu_mem_limit = limit as _;
self.gpu_mem_limit = limit;
self
}

Expand Down Expand Up @@ -137,7 +137,7 @@ impl ExecutionProvider for ROCmExecutionProvider {
let rocm_options = ort_sys::OrtROCMProviderOptions {
device_id: self.device_id,
miopen_conv_exhaustive_search: self.miopen_conv_exhaustive_search.into(),
gpu_mem_limit: self.gpu_mem_limit as _,
gpu_mem_limit: self.gpu_mem_limit,
arena_extend_strategy: match self.arena_extend_strategy {
ArenaExtendStrategy::NextPowerOfTwo => 0,
ArenaExtendStrategy::SameAsRequested => 1
Expand Down
4 changes: 2 additions & 2 deletions src/io_binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,12 @@ impl IoBinding {
ortsys![unsafe RunWithBinding(self.session.session_ptr.as_ptr(), run_options_ptr, self.ptr.as_ptr())?];

let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Arc<ValueInner>> = self.output_values.values().map(|c| (c.ptr(), &c.inner)).collect();
let mut count = self.output_names.len() as ort_sys::size_t;
let mut count = self.output_names.len();
if count > 0 {
let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut();
ortsys![unsafe GetBoundOutputValues(self.ptr.as_ptr(), self.session.allocator.ptr.as_ptr(), &mut output_values_ptr, &mut count)?; nonNull(output_values_ptr)];

let output_values = unsafe { std::slice::from_raw_parts(output_values_ptr, count as _).to_vec() }
let output_values = unsafe { std::slice::from_raw_parts(output_values_ptr, count).to_vec() }
.into_iter()
.map(|v| unsafe {
if let Some(inner) = owned_ptrs.get(&v) {
Expand Down
30 changes: 15 additions & 15 deletions src/operator/bound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl<O: Operator> BoundOperator<O> {
pub(crate) fn new(name: CString, execution_provider_type: Option<CString>) -> Self {
Self {
implementation: ort_sys::OrtCustomOp {
version: ort_sys::ORT_API_VERSION as _,
version: ort_sys::ORT_API_VERSION,
GetStartVersion: Some(BoundOperator::<O>::GetStartVersion),
GetEndVersion: Some(BoundOperator::<O>::GetEndVersion),
CreateKernel: None,
Expand Down Expand Up @@ -119,41 +119,41 @@ impl<O: Operator> BoundOperator<O> {
}

extern_system_fn! {
pub(crate) unsafe fn GetInputMemoryType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtMemType {
O::inputs()[index as usize].memory_type.into()
pub(crate) unsafe fn GetInputMemoryType(_: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtMemType {
O::inputs()[index].memory_type.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetInputCharacteristic(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
O::inputs()[index as usize].characteristic.into()
pub(crate) unsafe fn GetInputCharacteristic(_: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
O::inputs()[index].characteristic.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetOutputCharacteristic(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
O::outputs()[index as usize].characteristic.into()
pub(crate) unsafe fn GetOutputCharacteristic(_: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
O::outputs()[index].characteristic.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetInputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t {
O::inputs().len() as _
pub(crate) unsafe fn GetInputTypeCount(_: *const ort_sys::OrtCustomOp) -> usize {
O::inputs().len()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetOutputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t {
O::outputs().len() as _
pub(crate) unsafe fn GetOutputTypeCount(_: *const ort_sys::OrtCustomOp) -> usize {
O::outputs().len()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetInputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType {
O::inputs()[index as usize]
pub(crate) unsafe fn GetInputType(_: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::ONNXTensorElementDataType {
O::inputs()[index]
.r#type
.map(|c| c.into())
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
}
}
extern_system_fn! {
pub(crate) unsafe fn GetOutputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType {
O::outputs()[index as usize]
pub(crate) unsafe fn GetOutputType(_: *const ort_sys::OrtCustomOp, index: usize) -> ort_sys::ONNXTensorElementDataType {
O::outputs()[index]
.r#type
.map(|c| c.into())
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
Expand Down
68 changes: 34 additions & 34 deletions src/operator/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,53 +38,53 @@ impl KernelAttributes {
}

pub fn inputs(&self) -> Result<Vec<Input>> {
let mut num_inputs: ort_sys::size_t = 0;
let mut num_inputs = 0;
ortsys![unsafe KernelInfo_GetInputCount(self.0.as_ptr(), &mut num_inputs)?];

let mut inputs = Vec::with_capacity(num_inputs as _);
for idx in 0..num_inputs as usize {
let mut name_len: ort_sys::size_t = 0;
ortsys![unsafe KernelInfo_GetInputName(self.0.as_ptr(), idx as _, ptr::null_mut(), &mut name_len)?];
let mut name = vec![0u8; name_len as _];
ortsys![unsafe KernelInfo_GetInputName(self.0.as_ptr(), idx as _, name.as_mut_ptr().cast::<c_char>(), &mut name_len)?];
let mut inputs = Vec::with_capacity(num_inputs);
for idx in 0..num_inputs {
let mut name_len = 0;
ortsys![unsafe KernelInfo_GetInputName(self.0.as_ptr(), idx, ptr::null_mut(), &mut name_len)?];
let mut name = vec![0u8; name_len];
ortsys![unsafe KernelInfo_GetInputName(self.0.as_ptr(), idx, name.as_mut_ptr().cast::<c_char>(), &mut name_len)?];
let name = CString::from_vec_with_nul(name)
.map_err(Error::wrap)?
.into_string()
.map_err(Error::wrap)?;
let mut type_info = ptr::null_mut();
ortsys![unsafe KernelInfo_GetInputTypeInfo(self.0.as_ptr(), idx as _, &mut type_info)?; nonNull(type_info)];
ortsys![unsafe KernelInfo_GetInputTypeInfo(self.0.as_ptr(), idx, &mut type_info)?; nonNull(type_info)];
let input_type = ValueType::from_type_info(type_info);
inputs.push(Input { name, input_type })
}
Ok(inputs)
}

pub fn outputs(&self) -> Result<Vec<Output>> {
let mut num_outputs: ort_sys::size_t = 0;
let mut num_outputs = 0;
ortsys![unsafe KernelInfo_GetOutputCount(self.0.as_ptr(), &mut num_outputs)?];

let mut outputs = Vec::with_capacity(num_outputs as _);
for idx in 0..num_outputs as usize {
let mut name_len: ort_sys::size_t = 0;
ortsys![unsafe KernelInfo_GetOutputName(self.0.as_ptr(), idx as _, ptr::null_mut(), &mut name_len)?];
let mut name = vec![0u8; name_len as _];
ortsys![unsafe KernelInfo_GetOutputName(self.0.as_ptr(), idx as _, name.as_mut_ptr().cast::<c_char>(), &mut name_len)?];
let mut outputs = Vec::with_capacity(num_outputs);
for idx in 0..num_outputs {
let mut name_len = 0;
ortsys![unsafe KernelInfo_GetOutputName(self.0.as_ptr(), idx, ptr::null_mut(), &mut name_len)?];
let mut name = vec![0u8; name_len];
ortsys![unsafe KernelInfo_GetOutputName(self.0.as_ptr(), idx, name.as_mut_ptr().cast::<c_char>(), &mut name_len)?];
let name = CString::from_vec_with_nul(name)
.map_err(Error::wrap)?
.into_string()
.map_err(Error::wrap)?;
let mut type_info = ptr::null_mut();
ortsys![unsafe KernelInfo_GetOutputTypeInfo(self.0.as_ptr(), idx as _, &mut type_info)?; nonNull(type_info)];
ortsys![unsafe KernelInfo_GetOutputTypeInfo(self.0.as_ptr(), idx, &mut type_info)?; nonNull(type_info)];
let output_type = ValueType::from_type_info(type_info);
outputs.push(Output { name, output_type })
}
Ok(outputs)
}

pub fn node_name(&self) -> Result<String> {
let mut name_len: ort_sys::size_t = 0;
let mut name_len = 0;
ortsys![unsafe KernelInfo_GetNodeName(self.0.as_ptr(), ptr::null_mut(), &mut name_len)?];
let mut name = vec![0u8; name_len as _];
let mut name = vec![0u8; name_len];
ortsys![unsafe KernelInfo_GetNodeName(self.0.as_ptr(), name.as_mut_ptr().cast::<c_char>(), &mut name_len)?];
CString::from_vec_with_nul(name).map_err(Error::wrap)?.into_string().map_err(Error::wrap)
}
Expand Down Expand Up @@ -127,9 +127,9 @@ impl GetKernelAttribute<'_> for String {
where
Self: Sized
{
let mut size = ort_sys::size_t::default();
let mut size = 0;
status_to_result(ortsys![unsafe KernelInfoGetAttribute_string(info, name, ptr::null_mut(), &mut size)]).ok()?;
let mut out = vec![0u8; size as _];
let mut out = vec![0u8; size];
status_to_result(ortsys![unsafe KernelInfoGetAttribute_string(info, name, out.as_mut_ptr().cast::<c_char>(), &mut size)]).ok()?;
CString::from_vec_with_nul(out).ok().and_then(|c| c.into_string().ok())
}
Expand All @@ -140,9 +140,9 @@ impl GetKernelAttribute<'_> for Vec<f32> {
where
Self: Sized
{
let mut size = ort_sys::size_t::default();
let mut size = 0;
status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_float(info, name, ptr::null_mut(), &mut size)]).ok()?;
let mut out = vec![0f32; size as _];
let mut out = vec![0f32; size];
status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_float(info, name, out.as_mut_ptr(), &mut size)]).ok()?;
Some(out)
}
Expand All @@ -153,9 +153,9 @@ impl GetKernelAttribute<'_> for Vec<i64> {
where
Self: Sized
{
let mut size = ort_sys::size_t::default();
let mut size = 0;
status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_int64(info, name, ptr::null_mut(), &mut size)]).ok()?;
let mut out = vec![0i64; size as _];
let mut out = vec![0i64; size];
status_to_result(ortsys![unsafe KernelInfoGetAttributeArray_int64(info, name, out.as_mut_ptr(), &mut size)]).ok()?;
Some(out)
}
Expand Down Expand Up @@ -218,27 +218,27 @@ impl KernelContext {

pub fn input(&self, idx: usize) -> Result<Option<ValueRef<'_>>> {
let mut value_ptr: *const ort_sys::OrtValue = ptr::null();
ortsys![unsafe KernelContext_GetInput(self.ptr.as_ptr(), idx as ort_sys::size_t, &mut value_ptr)?];
ortsys![unsafe KernelContext_GetInput(self.ptr.as_ptr(), idx, &mut value_ptr)?];
Ok(NonNull::new(value_ptr.cast_mut()).map(|c| ValueRef::new(unsafe { Value::from_ptr_nodrop(c, None) })))
}

pub fn output(&self, idx: usize, shape: impl IntoIterator<Item = i64>) -> Result<Option<ValueRefMut<'_>>> {
let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();
let shape = shape.into_iter().collect::<Vec<i64>>();
ortsys![unsafe KernelContext_GetOutput(self.ptr.as_ptr(), idx as ort_sys::size_t, shape.as_ptr(), shape.len() as _, &mut value_ptr)?];
ortsys![unsafe KernelContext_GetOutput(self.ptr.as_ptr(), idx, shape.as_ptr(), shape.len(), &mut value_ptr)?];
Ok(NonNull::new(value_ptr).map(|c| ValueRefMut::new(unsafe { Value::from_ptr_nodrop(c, None) })))
}

pub fn num_inputs(&self) -> Result<usize> {
let mut num: ort_sys::size_t = 0;
let mut num = 0;
ortsys![unsafe KernelContext_GetInputCount(self.ptr.as_ptr(), &mut num)?];
Ok(num as _)
Ok(num)
}

pub fn num_outputs(&self) -> Result<usize> {
let mut num: ort_sys::size_t = 0;
let mut num = 0;
ortsys![unsafe KernelContext_GetOutputCount(self.ptr.as_ptr(), &mut num)?];
Ok(num as _)
Ok(num)
}

pub fn allocator(&self, memory_info: &MemoryInfo) -> Result<Allocator> {
Expand All @@ -258,7 +258,7 @@ impl KernelContext {
F: Fn(usize) + Sync + Send
{
let executor = Box::new(f) as Box<dyn Fn(usize) + Sync + Send>;
ortsys![unsafe KernelContext_ParallelFor(self.ptr.as_ptr(), Some(parallel_for_cb), total as _, max_num_batches as _, &executor as *const _ as *mut c_void)?];
ortsys![unsafe KernelContext_ParallelFor(self.ptr.as_ptr(), Some(parallel_for_cb), total, max_num_batches, &executor as *const _ as *mut c_void)?];
Ok(())
}

Expand All @@ -272,7 +272,7 @@ impl KernelContext {
// unsafe KernelContext_GetScratchBuffer(
// self.ptr.as_ptr(),
// memory_info.ptr.as_ptr(),
// (len * std::mem::size_of::<T>()) as ort_sys::size_t,
// len * std::mem::size_of::<T>(),
// &mut buffer
// )?;
// nonNull(buffer)
Expand All @@ -298,7 +298,7 @@ impl KernelContext {
}
}

extern "C" fn parallel_for_cb(user_data: *mut c_void, iterator: ort_sys::size_t) {
extern "C" fn parallel_for_cb(user_data: *mut c_void, iterator: usize) {
let executor = unsafe { &*user_data.cast::<Box<dyn Fn(usize) + Sync + Send>>() };
executor(iterator as _)
executor(iterator)
}
2 changes: 1 addition & 1 deletion src/session/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ pub(crate) struct AsyncInferenceContext<'r, 's> {
}

crate::extern_system_fn! {
pub(crate) fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: ort_sys::size_t, status: *mut OrtStatus) {
pub(crate) fn async_callback(user_data: *mut c_void, _: *mut *mut ort_sys::OrtValue, _: usize, status: *mut OrtStatus) {
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_, '_>>()) };

// Reconvert name ptrs to CString so drop impl is called and memory is freed
Expand Down
2 changes: 1 addition & 1 deletion src/session/builder/impl_commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl SessionBuilder {
let model_data = model_bytes.as_ptr().cast::<std::ffi::c_void>();
let model_data_length = model_bytes.len();
ortsys![
unsafe CreateSessionFromArray(env.env_ptr.as_ptr(), model_data, model_data_length as _, self.session_options_ptr.as_ptr(), &mut session_ptr)?;
unsafe CreateSessionFromArray(env.env_ptr.as_ptr(), model_data, model_data_length, self.session_options_ptr.as_ptr(), &mut session_ptr)?;
nonNull(session_ptr)
];

Expand Down
2 changes: 1 addition & 1 deletion src/session/builder/impl_options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ impl SessionBuilder {
// either accept a &'static [u8] or Vec<u8> via Cow<'_, [u8]>, which still allows users to use include_bytes!.

let file_name = crate::util::path_to_os_char(file_name);
let sizes = [buffer.len() as ort_sys::size_t];
let sizes = [buffer.len()];
ortsys![unsafe AddExternalInitializersFromMemory(self.session_options_ptr.as_ptr(), &file_name.as_ptr(), &buffer.as_ptr().cast::<c_char>().cast_mut(), sizes.as_ptr(), 1)?];
self.external_initializer_buffers.push(buffer);
Ok(self)
Expand Down
Loading

0 comments on commit 28a127d

Please sign in to comment.