diff --git a/examples/async-gpt2-api/examples/async-gpt2-api.rs b/examples/async-gpt2-api/examples/async-gpt2-api.rs index 961ab606..1fe10e0c 100644 --- a/examples/async-gpt2-api/examples/async-gpt2-api.rs +++ b/examples/async-gpt2-api/examples/async-gpt2-api.rs @@ -11,7 +11,7 @@ use axum::{ }; use futures::Stream; use ndarray::{array, concatenate, s, Array1, ArrayViewD, Axis}; -use ort::{inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session, Value}; +use ort::{inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session}; use rand::Rng; use tokenizers::Tokenizer; use tokio::net::TcpListener; @@ -67,7 +67,7 @@ fn generate_stream(tokenizer: Arc, session: Arc, tokens: Vec for _ in 0..gen_tokens { let array = tokens.view().insert_axis(Axis(0)).insert_axis(Axis(1)); let outputs = session.run_async(inputs![array]?)?.await?; - let generated_tokens: ArrayViewD = outputs["output1"].extract_tensor()?; + let generated_tokens: ArrayViewD = outputs["output1"].try_extract_tensor()?; // Collect and sort logits let probabilities = &mut generated_tokens diff --git a/examples/custom-ops/examples/custom-ops.rs b/examples/custom-ops/examples/custom-ops.rs index 431c6058..2d590f0c 100644 --- a/examples/custom-ops/examples/custom-ops.rs +++ b/examples/custom-ops/examples/custom-ops.rs @@ -28,11 +28,11 @@ impl Kernel for CustomOpOneKernel { fn compute(&mut self, ctx: &KernelContext) -> ort::Result<()> { let x = ctx.input(0)?.unwrap(); let y = ctx.input(1)?.unwrap(); - let (x_shape, x) = x.extract_raw_tensor::()?; - let (y_shape, y) = y.extract_raw_tensor::()?; + let (x_shape, x) = x.try_extract_raw_tensor::()?; + let (y_shape, y) = y.try_extract_raw_tensor::()?; let mut z = ctx.output(0, x_shape)?.unwrap(); - let (_, z_ref) = z.extract_raw_tensor_mut::()?; + let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; for i in 0..y_shape.into_iter().reduce(|acc, e| acc * e).unwrap() as usize { if i % 2 == 0 { z_ref[i] = x[i]; @@ -70,9 +70,9 @@ impl Operator for CustomOpTwo { impl Kernel for CustomOpTwoKernel { fn compute(&mut self, ctx: &KernelContext) -> ort::Result<()> { let x = ctx.input(0)?.unwrap(); - let (x_shape, x) = x.extract_raw_tensor::()?; + let (x_shape, x) = x.try_extract_raw_tensor::()?; let mut z = ctx.output(0, x_shape.clone())?.unwrap(); - let (_, z_ref) = z.extract_raw_tensor_mut::()?; + let (_, z_ref) = z.try_extract_raw_tensor_mut::()?; for i in 0..x_shape.into_iter().reduce(|acc, e| acc * e).unwrap() as usize { z_ref[i] = (x[i] * i as f32) as i32; } @@ -86,7 +86,7 @@ fn main() -> ort::Result<()> { .commit_from_file("tests/data/custom_op_test.onnx")?; let values = session.run(ort::inputs![Array2::::zeros((3, 5)), Array2::::ones((3, 5))]?)?; - println!("{:?}", values[0].extract_tensor::()?); + println!("{:?}", values[0].try_extract_tensor::()?); Ok(()) } diff --git a/examples/gpt2/examples/gpt2-no-ndarray.rs b/examples/gpt2/examples/gpt2-no-ndarray.rs index 8f38aa81..08219b46 100644 --- a/examples/gpt2/examples/gpt2-no-ndarray.rs +++ b/examples/gpt2/examples/gpt2-no-ndarray.rs @@ -51,7 +51,7 @@ fn main() -> ort::Result<()> { // The model expects our input to have shape [B, _, S] let input = (vec![1, 1, tokens.len() as i64], Arc::clone(&tokens)); let outputs = session.run(inputs![input]?)?; - let (dim, mut probabilities) = outputs["output1"].extract_raw_tensor()?; + let (dim, mut probabilities) = outputs["output1"].try_extract_raw_tensor()?; // The output tensor will have shape [B, _, S + 1, V] // We want only the probabilities for the last token in this sequence, which will be the token generated by the model diff --git a/examples/gpt2/examples/gpt2.rs b/examples/gpt2/examples/gpt2.rs index 1f3afd31..0689d024 100644 --- a/examples/gpt2/examples/gpt2.rs +++ b/examples/gpt2/examples/gpt2.rs @@ -51,7 +51,7 @@ fn main() -> ort::Result<()> { for _ in 0..GEN_TOKENS { let array = tokens.view().insert_axis(Axis(0)).insert_axis(Axis(1)); let outputs = session.run(inputs![array]?)?; - let generated_tokens: ArrayViewD = outputs["output1"].extract_tensor()?; + let generated_tokens: ArrayViewD = outputs["output1"].try_extract_tensor()?; // Collect and sort logits let probabilities = &mut generated_tokens diff --git a/examples/modnet/examples/modnet.rs b/examples/modnet/examples/modnet.rs index c108c933..7198cfed 100644 --- a/examples/modnet/examples/modnet.rs +++ b/examples/modnet/examples/modnet.rs @@ -33,7 +33,7 @@ fn main() -> ort::Result<()> { let outputs = model.run(inputs!["input" => input.view()]?)?; - let output = outputs["output"].extract_tensor::()?; + let output = outputs["output"].try_extract_tensor::()?; // convert to 8-bit let output = output.mul(255.0).map(|x| *x as u8); diff --git a/examples/yolov8/examples/yolov8.rs b/examples/yolov8/examples/yolov8.rs index 7cd52e6c..f82ebeb3 100644 --- a/examples/yolov8/examples/yolov8.rs +++ b/examples/yolov8/examples/yolov8.rs @@ -63,7 +63,7 @@ fn main() -> ort::Result<()> { // Run YOLOv8 inference let outputs: SessionOutputs = model.run(inputs!["images" => input.view()]?)?; - let output = outputs["output0"].extract_tensor::()?.t().into_owned(); + let output = outputs["output0"].try_extract_tensor::()?.t().into_owned(); let mut boxes = Vec::new(); let output = output.slice(s![.., .., 0]); diff --git a/src/environment.rs b/src/environment.rs index 241e9619..9b57dba3 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -271,7 +271,7 @@ extern_system_fn! { #[cfg(test)] mod tests { - use std::sync::{atomic::Ordering, Arc, OnceLock, RwLock, RwLockWriteGuard}; + use std::sync::{OnceLock, RwLock, RwLockWriteGuard}; use test_log::test; diff --git a/src/error.rs b/src/error.rs index 442b86a1..c729cb35 100644 --- a/src/error.rs +++ b/src/error.rs @@ -103,6 +103,12 @@ pub enum Error { /// Error occurred when creating ONNX tensor with specific data #[error("Failed to create tensor with data: {0}")] CreateTensorWithData(ErrorInternal), + /// Error occurred when attempting to create a [`crate::Sequence`]. + #[error("Failed to create sequence value: {0}")] + CreateSequence(ErrorInternal), + /// Error occurred when attempting to create a [`crate::Map`]. + #[error("Failed to create map value: {0}")] + CreateMap(ErrorInternal), /// Invalid dimension when creating tensor from raw data #[error("Invalid dimension at {0}; all dimensions must be >= 1 when creating a tensor from raw data")] InvalidDimension(usize), @@ -230,6 +236,8 @@ pub enum Error { InvalidMapKeyType { expected: TensorElementType, actual: TensorElementType }, #[error("Tried to extract a map with a value type of {expected:?}, but the map has value type {actual:?}")] InvalidMapValueType { expected: TensorElementType, actual: TensorElementType }, + #[error("Tried to extract a sequence with a different element type than its actual type {actual:?}")] + InvalidSequenceElementType { actual: ValueType }, #[error("Error occurred while attempting to extract data from sequence value: {0}")] ExtractSequence(ErrorInternal), #[error("Error occurred while attempting to extract data from map value: {0}")] diff --git a/src/io_binding.rs b/src/io_binding.rs index 588c6741..a94f2913 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -10,7 +10,7 @@ use crate::{ ortsys, session::{output::SessionOutputs, RunOptions}, value::{Value, ValueRefMut}, - Error, Result, Session + Error, Result, Session, ValueTypeMarker }; /// Enables binding of session inputs and/or outputs to pre-allocated memory. @@ -41,7 +41,7 @@ impl<'s> IoBinding<'s> { } /// Bind a [`Value`] to a session input. - pub fn bind_input<'i: 's, S: AsRef>(&mut self, name: S, ort_value: &'i mut Value) -> Result> { + pub fn bind_input<'i: 's, T: ValueTypeMarker, S: AsRef>(&mut self, name: S, ort_value: &'i mut Value) -> Result> { let name = name.as_ref(); let cname = CString::new(name)?; ortsys![unsafe BindInput(self.ptr.as_ptr(), cname.as_ptr(), ort_value.ptr()) -> Error::BindInput]; @@ -49,7 +49,7 @@ impl<'s> IoBinding<'s> { } /// Bind a session output to a pre-allocated [`Value`]. - pub fn bind_output<'o: 's, S: AsRef>(&mut self, name: S, ort_value: &'o mut Value) -> Result> { + pub fn bind_output<'o: 's, T: ValueTypeMarker, S: AsRef>(&mut self, name: S, ort_value: &'o mut Value) -> Result> { let name = name.as_ref(); let cname = CString::new(name)?; ortsys![unsafe BindOutput(self.ptr.as_ptr(), cname.as_ptr(), ort_value.ptr()) -> Error::BindOutput]; diff --git a/src/lib.rs b/src/lib.rs index 6cb38ae7..667befc5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,7 +25,7 @@ pub(crate) mod tensor; pub(crate) mod value; #[cfg(feature = "load-dynamic")] -use std::sync::{Arc, Mutex, MutexGuard}; +use std::sync::Arc; use std::{ ffi::{self, CStr}, os::raw::c_char, @@ -60,7 +60,12 @@ pub use self::session::{ #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub use self::tensor::ArrayExtensions; pub use self::tensor::{IntoTensorElementType, TensorElementType}; -pub use self::value::{Value, ValueRef, ValueType}; +pub use self::value::{ + DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor, DynTensorRef, + DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence, SequenceRef, + SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker, UpcastableTarget, Value, ValueRef, + ValueRefMut, ValueType, ValueTypeMarker +}; #[cfg(not(all(target_arch = "x86", target_os = "windows")))] macro_rules! extern_system_fn { @@ -198,6 +203,12 @@ macro_rules! ortsys { (unsafe $method:ident($($n:expr),+ $(,)?)) => { unsafe { $crate::api().as_ref().$method.unwrap()($($n),+) } }; + ($method:ident($($n:expr),+ $(,)?).unwrap()) => { + $crate::error::status_to_result($crate::api().as_ref().$method.unwrap()($($n),+)).unwrap() + }; + (unsafe $method:ident($($n:expr),+ $(,)?).unwrap()) => { + $crate::error::status_to_result(unsafe { $crate::api().as_ref().$method.unwrap()($($n),+) }).unwrap() + }; ($method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => { $crate::api().as_ref().$method.unwrap()($($n),+); $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ diff --git a/src/session/input.rs b/src/session/input.rs index 42ab028b..34686f0d 100644 --- a/src/session/input.rs +++ b/src/session/input.rs @@ -1,14 +1,18 @@ use std::{borrow::Cow, collections::HashMap, ops::Deref}; -use crate::{Value, ValueRef}; +use crate::{ + value::{DynValueTypeMarker, ValueTypeMarker}, + Value, ValueRef +}; pub enum SessionInputValue<'v> { - View(ValueRef<'v>), - Owned(Value) + View(ValueRef<'v, DynValueTypeMarker>), + Owned(Value) } impl<'v> Deref for SessionInputValue<'v> { type Target = Value; + fn deref(&self) -> &Self::Target { match self { SessionInputValue::View(v) => v, @@ -17,14 +21,14 @@ impl<'v> Deref for SessionInputValue<'v> { } } -impl<'v> From> for SessionInputValue<'v> { - fn from(value: ValueRef<'v>) -> Self { - SessionInputValue::View(value) +impl<'v, T: ValueTypeMarker + ?Sized> From> for SessionInputValue<'v> { + fn from(value: ValueRef<'v, T>) -> Self { + SessionInputValue::View(value.into_dyn()) } } -impl<'v> From for SessionInputValue<'v> { - fn from(value: Value) -> Self { - SessionInputValue::Owned(value) +impl<'v, T: ValueTypeMarker + ?Sized> From> for SessionInputValue<'v> { + fn from(value: Value) -> Self { + SessionInputValue::Owned(value.into_dyn()) } } @@ -113,13 +117,13 @@ impl<'i, 'v, const N: usize> From<[SessionInputValue<'v>; N]> for SessionInputs< macro_rules! inputs { ($($v:expr),+ $(,)?) => ( (|| -> $crate::Result<_> { - Ok([$(::std::convert::Into::<$crate::SessionInputValue<'_>>::into(::std::convert::TryInto::<$crate::Value>::try_into($v).map_err($crate::Error::from)?)),+]) + Ok([$(::std::convert::Into::<$crate::SessionInputValue<'_>>::into(::std::convert::TryInto::<$crate::DynValue>::try_into($v).map_err($crate::Error::from)?)),+]) })() ); ($($n:expr => $v:expr),+ $(,)?) => ( (|| -> $crate::Result<_> { Ok(vec![$( - ::std::convert::TryInto::<$crate::Value>::try_into($v) + ::std::convert::TryInto::<$crate::DynValue>::try_into($v) .map_err($crate::Error::from) .map(|v| (::std::borrow::Cow::::from($n), $crate::SessionInputValue::from(v)))?,)+]) })() @@ -138,7 +142,7 @@ mod tests { let arc = Arc::new(v.clone().into_boxed_slice()); let shape = vec![v.len() as i64]; - let mut inputs: HashMap<&str, Value> = HashMap::new(); + let mut inputs: HashMap<&str, DynTensor> = HashMap::new(); inputs.insert("test", (shape, arc).try_into()?); let _ = SessionInputs::from(inputs); @@ -151,7 +155,7 @@ mod tests { let arc = Arc::new(v.clone().into_boxed_slice()); let shape = vec![v.len() as i64]; - let mut inputs: HashMap = HashMap::new(); + let mut inputs: HashMap = HashMap::new(); inputs.insert("test".to_string(), (shape, arc).try_into()?); let _ = SessionInputs::from(inputs); diff --git a/src/session/output.rs b/src/session/output.rs index 36812075..c0fed437 100644 --- a/src/session/output.rs +++ b/src/session/output.rs @@ -4,7 +4,7 @@ use std::{ ops::{Deref, DerefMut, Index} }; -use crate::{Allocator, Value}; +use crate::{Allocator, DynValue}; /// The outputs returned by a [`crate::Session`] inference call. /// @@ -26,7 +26,7 @@ use crate::{Allocator, Value}; /// ``` #[derive(Debug)] pub struct SessionOutputs<'s> { - map: BTreeMap<&'s str, Value>, + map: BTreeMap<&'s str, DynValue>, idxs: Vec<&'s str>, backing_ptr: Option<(&'s Allocator, *mut c_void)> } @@ -34,7 +34,7 @@ pub struct SessionOutputs<'s> { unsafe impl<'s> Send for SessionOutputs<'s> {} impl<'s> SessionOutputs<'s> { - pub(crate) fn new(output_names: impl Iterator + Clone, output_values: impl IntoIterator) -> Self { + pub(crate) fn new(output_names: impl Iterator + Clone, output_values: impl IntoIterator) -> Self { let map = output_names.clone().zip(output_values).collect(); Self { map, @@ -45,7 +45,7 @@ impl<'s> SessionOutputs<'s> { pub(crate) fn new_backed( output_names: impl Iterator + Clone, - output_values: impl IntoIterator, + output_values: impl IntoIterator, allocator: &'s Allocator, backing_ptr: *mut c_void ) -> Self { @@ -75,7 +75,7 @@ impl<'s> Drop for SessionOutputs<'s> { } impl<'s> Deref for SessionOutputs<'s> { - type Target = BTreeMap<&'s str, Value>; + type Target = BTreeMap<&'s str, DynValue>; fn deref(&self) -> &Self::Target { &self.map @@ -89,21 +89,21 @@ impl<'s> DerefMut for SessionOutputs<'s> { } impl<'s> Index<&str> for SessionOutputs<'s> { - type Output = Value; + type Output = DynValue; fn index(&self, index: &str) -> &Self::Output { self.map.get(index).expect("no entry found for key") } } impl<'s> Index for SessionOutputs<'s> { - type Output = Value; + type Output = DynValue; fn index(&self, index: String) -> &Self::Output { self.map.get(index.as_str()).expect("no entry found for key") } } impl<'s> Index for SessionOutputs<'s> { - type Output = Value; + type Output = DynValue; fn index(&self, index: usize) -> &Self::Output { self.map.get(&self.idxs[index]).expect("no entry found for key") } diff --git a/src/value/impl_map.rs b/src/value/impl_map.rs index 187e882a..608979ee 100644 --- a/src/value/impl_map.rs +++ b/src/value/impl_map.rs @@ -1,13 +1,41 @@ use std::{ collections::HashMap, + fmt::Debug, hash::Hash, + marker::PhantomData, ptr::{self, NonNull} }; -use crate::{memory::Allocator, ortsys, Error, IntoTensorElementType, Result, Value, ValueType}; +use super::{ValueInner, ValueTypeMarker}; +use crate::{ + memory::Allocator, ortsys, value::impl_tensor::DynTensor, DynValue, Error, IntoTensorElementType, Result, Tensor, Value, ValueRef, ValueRefMut, ValueType +}; + +pub trait MapValueTypeMarker: ValueTypeMarker {} + +#[derive(Debug)] +pub struct DynMapValueType; +impl ValueTypeMarker for DynMapValueType {} +impl MapValueTypeMarker for DynMapValueType {} + +#[derive(Debug)] +pub struct MapValueType(PhantomData<(K, V)>); +impl ValueTypeMarker for MapValueType {} +impl MapValueTypeMarker for MapValueType {} + +pub type DynMap = Value; +pub type Map = Value>; -impl Value { - pub fn extract_map(&self, allocator: &Allocator) -> Result> { +pub type DynMapRef<'v> = ValueRef<'v, DynMapValueType>; +pub type DynMapRefMut<'v> = ValueRefMut<'v, DynMapValueType>; +pub type MapRef<'v, K, V> = ValueRef<'v, MapValueType>; +pub type MapRefMut<'v, K, V> = ValueRefMut<'v, MapValueType>; + +impl Value { + pub fn try_extract_map( + &self, + allocator: &Allocator + ) -> Result> { match self.dtype()? { ValueType::Map { key, value } => { let k_type = K::into_tensor_element_type(); @@ -21,13 +49,13 @@ impl Value { let mut key_tensor_ptr = ptr::null_mut(); ortsys![unsafe GetValue(self.ptr(), 0, allocator.ptr.as_ptr(), &mut key_tensor_ptr) -> Error::ExtractMap; nonNull(key_tensor_ptr)]; - let key_value = unsafe { Value::from_ptr(NonNull::new_unchecked(key_tensor_ptr), None) }; - let (key_tensor_shape, key_tensor) = key_value.extract_raw_tensor::()?; + let key_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(key_tensor_ptr), None) }; + let (key_tensor_shape, key_tensor) = key_value.try_extract_raw_tensor::()?; let mut value_tensor_ptr = ptr::null_mut(); ortsys![unsafe GetValue(self.ptr(), 1, allocator.ptr.as_ptr(), &mut value_tensor_ptr) -> Error::ExtractMap; nonNull(value_tensor_ptr)]; - let value_value = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) }; - let (value_tensor_shape, value_tensor) = value_value.extract_raw_tensor::()?; + let value_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(value_tensor_ptr), None) }; + let (value_tensor_shape, value_tensor) = value_value.try_extract_raw_tensor::()?; assert_eq!(key_tensor_shape.len(), 1); assert_eq!(value_tensor_shape.len(), 1); @@ -43,3 +71,97 @@ impl Value { } } } + +impl Value> { + /// Creates a [`Map`] from an iterable emitting `K` and `V`. + /// + /// ``` + /// # use std::collections::HashMap; + /// # use ort::{Allocator, Map}; + /// # fn main() -> ort::Result<()> { + /// # let allocator = Allocator::default(); + /// let mut map = HashMap::::new(); + /// map.insert(0, 1.0); + /// map.insert(1, 2.0); + /// map.insert(2, 3.0); + /// + /// let value = Map::new(map)?; + /// + /// assert_eq!(*value.extract_map(&allocator).get(&0).unwrap(), 1.0); + /// # Ok(()) + /// # } + /// ``` + pub fn new(data: impl IntoIterator) -> Result { + let (keys, values): (Vec, Vec) = data.into_iter().unzip(); + Self::new_kv(Tensor::from_array((vec![keys.len()], keys))?, Tensor::from_array((vec![values.len()], values))?) + } + + /// Creates a [`Map`] from two tensors of keys & values respectively. + /// + /// ``` + /// # use std::collections::HashMap; + /// # use ort::{Allocator, Map, Tensor}; + /// # fn main() -> ort::Result<()> { + /// # let allocator = Allocator::default(); + /// let keys = Tensor::::from_array(([4], vec![0, 1, 2, 3]))?; + /// let values = Tensor::::from_array(([4], vec![1., 2., 3., 4.]))?; + /// + /// let value = Map::new_kv(keys, values)?; + /// + /// assert_eq!(*value.extract_map(&allocator).get(&0).unwrap(), 1.0); + /// # Ok(()) + /// # } + /// ``` + pub fn new_kv(keys: Tensor, values: Tensor) -> Result { + let mut value_ptr = ptr::null_mut(); + let values: [DynValue; 2] = [keys.into_dyn(), values.into_dyn()]; + let value_ptrs: Vec<*const ort_sys::OrtValue> = values.iter().map(|c| c.ptr().cast_const()).collect(); + ortsys![ + unsafe CreateValue(value_ptrs.as_ptr(), 2, ort_sys::ONNXType::ONNX_TYPE_MAP, &mut value_ptr) + -> Error::CreateMap; + nonNull(value_ptr) + ]; + Ok(Value { + inner: ValueInner::RustOwned { + ptr: unsafe { NonNull::new_unchecked(value_ptr) }, + _array: Box::new(values), + _memory_info: None + }, + _markers: PhantomData + }) + } +} + +impl Value> { + pub fn extract_map(&self, allocator: &Allocator) -> HashMap { + self.try_extract_map(allocator).unwrap() + } + + /// Converts from a strongly-typed [`Map`] to a type-erased [`DynMap`]. + #[inline] + pub fn downcast(self) -> DynMap { + unsafe { std::mem::transmute(self) } + } + + /// Converts from a strongly-typed [`Map`] to a reference to a type-erased [`DynMap`]. + #[inline] + pub fn downcast_ref(&self) -> DynMapRef { + DynMapRef::new(unsafe { + Value::from_ptr_nodrop( + NonNull::new_unchecked(self.ptr()), + if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + ) + }) + } + + /// Converts from a strongly-typed [`Map`] to a mutable reference to a type-erased [`DynMap`]. + #[inline] + pub fn downcast_mut(&mut self) -> DynMapRefMut { + DynMapRefMut::new(unsafe { + Value::from_ptr_nodrop( + NonNull::new_unchecked(self.ptr()), + if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + ) + }) + } +} diff --git a/src/value/impl_sequence.rs b/src/value/impl_sequence.rs index 14751c20..3f59cbc1 100644 --- a/src/value/impl_sequence.rs +++ b/src/value/impl_sequence.rs @@ -1,12 +1,37 @@ use std::{ + fmt::Debug, marker::PhantomData, ptr::{self, NonNull} }; -use crate::{memory::Allocator, ortsys, Error, Result, Value, ValueRef, ValueType}; +use super::{UpcastableTarget, ValueInner, ValueTypeMarker}; +use crate::{memory::Allocator, ortsys, Error, Result, Value, ValueRef, ValueRefMut, ValueType}; -impl Value { - pub fn extract_sequence<'s>(&'s self, allocator: &Allocator) -> Result>> { +pub trait SequenceValueTypeMarker: ValueTypeMarker {} + +#[derive(Debug)] +pub struct DynSequenceValueType; +impl ValueTypeMarker for DynSequenceValueType {} +impl SequenceValueTypeMarker for DynSequenceValueType {} + +#[derive(Debug)] +pub struct SequenceValueType(PhantomData); +impl ValueTypeMarker for SequenceValueType {} +impl SequenceValueTypeMarker for SequenceValueType {} + +pub type DynSequence = Value; +pub type Sequence = Value>; + +pub type DynSequenceRef<'v> = ValueRef<'v, DynSequenceValueType>; +pub type DynSequenceRefMut<'v> = ValueRefMut<'v, DynSequenceValueType>; +pub type SequenceRef<'v, T> = ValueRef<'v, SequenceValueType>; +pub type SequenceRefMut<'v, T> = ValueRefMut<'v, SequenceValueType>; + +impl Value { + pub fn try_extract_sequence<'s, OtherType: ValueTypeMarker + UpcastableTarget + Debug + Sized>( + &'s self, + allocator: &Allocator + ) -> Result>> { match self.dtype()? { ValueType::Sequence(_) => { let mut len: ort_sys::size_t = 0; @@ -17,10 +42,16 @@ impl Value { let mut value_ptr = ptr::null_mut(); ortsys![unsafe GetValue(self.ptr(), i as _, allocator.ptr.as_ptr(), &mut value_ptr) -> Error::ExtractSequence; nonNull(value_ptr)]; - vec.push(ValueRef { + let value = ValueRef { inner: unsafe { Value::from_ptr(NonNull::new_unchecked(value_ptr), None) }, lifetime: PhantomData - }); + }; + let value_type = value.dtype()?; + if !OtherType::can_upcast(&value.dtype()?) { + return Err(Error::InvalidSequenceElementType { actual: value_type }); + } + + vec.push(value); } Ok(vec) } @@ -28,3 +59,76 @@ impl Value { } } } + +impl Value> { + /// Creates a [`Sequence`] from an array of [`Value`]. + /// + /// This `Value` must be either a [`crate::Tensor`] or [`crate::Map`]. + /// + /// ``` + /// # use ort::{Allocator, Sequence, Tensor}; + /// # fn main() -> ort::Result<()> { + /// # let allocator = Allocator::default(); + /// let tensor1 = Tensor::::new(&allocator, [1, 128, 128, 3])?; + /// let tensor2 = Tensor::::new(&allocator, [1, 224, 224, 3])?; + /// let value = Sequence::new([tensor1, tensor2])?; + /// + /// for tensor in value.extract_sequence(&allocator) { + /// println!("{:?}", tensor.shape()?); + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn new(values: impl IntoIterator>) -> Result { + let mut value_ptr = ptr::null_mut(); + let values: Vec> = values.into_iter().collect(); + let value_ptrs: Vec<*const ort_sys::OrtValue> = values.iter().map(|c| c.ptr().cast_const()).collect(); + ortsys![ + unsafe CreateValue(value_ptrs.as_ptr(), values.len() as _, ort_sys::ONNXType::ONNX_TYPE_SEQUENCE, &mut value_ptr) + -> Error::CreateSequence; + nonNull(value_ptr) + ]; + Ok(Value { + inner: ValueInner::RustOwned { + ptr: unsafe { NonNull::new_unchecked(value_ptr) }, + _array: Box::new(values), + _memory_info: None + }, + _markers: PhantomData + }) + } +} + +impl Value> { + pub fn extract_sequence<'s>(&'s self, allocator: &Allocator) -> Vec> { + self.try_extract_sequence(allocator).unwrap() + } + + /// Converts from a strongly-typed [`Sequence`] to a type-erased [`DynSequence`]. + #[inline] + pub fn downcast(self) -> DynSequence { + unsafe { std::mem::transmute(self) } + } + + /// Converts from a strongly-typed [`Sequence`] to a reference to a type-erased [`DynTensor`]. + #[inline] + pub fn downcast_ref(&self) -> DynSequenceRef { + DynSequenceRef::new(unsafe { + Value::from_ptr_nodrop( + NonNull::new_unchecked(self.ptr()), + if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + ) + }) + } + + /// Converts from a strongly-typed [`Sequence`] to a mutable reference to a type-erased [`DynTensor`]. + #[inline] + pub fn downcast_mut(&mut self) -> DynSequenceRefMut { + DynSequenceRefMut::new(unsafe { + Value::from_ptr_nodrop( + NonNull::new_unchecked(self.ptr()), + if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + ) + }) + } +} diff --git a/src/value/impl_tensor/create.rs b/src/value/impl_tensor/create.rs index 46eb1201..b6cd57a0 100644 --- a/src/value/impl_tensor/create.rs +++ b/src/value/impl_tensor/create.rs @@ -2,6 +2,7 @@ use std::{ any::Any, ffi, fmt::Debug, + marker::PhantomData, ptr::{self, NonNull}, sync::Arc }; @@ -9,23 +10,98 @@ use std::{ #[cfg(feature = "ndarray")] use ndarray::{ArcArray, Array, ArrayView, CowArray, Dimension}; +use super::{DynTensor, Tensor}; use crate::{ error::assert_non_null_pointer, memory::{Allocator, MemoryInfo}, ortsys, tensor::{IntoTensorElementType, TensorElementType, Utf8Data}, value::ValueInner, - AllocatorType, Error, MemoryType, Result, Value + AllocatorType, DynValue, Error, MemoryType, Result, Value }; -impl Value { +impl DynTensor { + /// Construct a [`Value`] from an array of strings. + /// + /// Just like numeric tensors, string tensor `Value`s can be created from: + /// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`); + /// - (with feature `ndarray`) a mutable/exclusive reference to an [`ndarray::ArcArray`] (`&mut ArcArray`); + /// - (with feature `ndarray`) an owned [`ndarray::Array`]; + /// - (with feature `ndarray`) a borrowed view of another array, as an [`ndarray::ArrayView`] (`ArrayView<'_, T, + /// D>`); + /// - a tuple of `(dimensions, data)` where: + /// * `dimensions` is one of `Vec`, `[I]` or `&[I]`, where `I` is `i64` or `usize`; + /// * and `data` is one of `Vec`, `Box<[T]>`, `Arc>`, or `&[T]`. + /// + /// ``` + /// # use ort::{Session, Value}; + /// # fn main() -> ort::Result<()> { + /// # let session = Session::builder()?.commit_from_file("tests/data/vectorizer.onnx")?; + /// // You'll need to obtain an `Allocator` from a session in order to create string tensors. + /// let allocator = session.allocator(); + /// + /// // Create a string tensor from a raw data vector + /// let data = vec!["hello", "world"]; + /// let value = Value::from_string_array(allocator, ([data.len()], data.into_boxed_slice()))?; + /// + /// // Create a string tensor from an `ndarray::Array` + /// #[cfg(feature = "ndarray")] + /// let value = Value::from_string_array( + /// allocator, + /// ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap() + /// )?; + /// # Ok(()) + /// # } + /// ``` + /// + /// Note that string data will *always* be copied, no matter what form the data is provided in. + pub fn from_string_array(allocator: &Allocator, input: impl IntoValueTensor) -> Result { + let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); + + let (shape, data) = input.ref_parts()?; + let shape_ptr: *const i64 = shape.as_ptr(); + let shape_len = shape.len(); + + // create tensor without data -- data is filled in later + ortsys![ + unsafe CreateTensorAsOrtValue(allocator.ptr.as_ptr(), shape_ptr, shape_len as _, TensorElementType::String.into(), &mut value_ptr) + -> Error::CreateTensor; + nonNull(value_ptr) + ]; + + // create null-terminated copies of each string, as per `FillStringTensor` docs + let null_terminated_copies: Vec = data + .iter() + .map(|elt| { + let slice = elt.as_utf8_bytes(); + ffi::CString::new(slice) + }) + .collect::, _>>() + .map_err(Error::FfiStringNull)?; + + let string_pointers = null_terminated_copies.iter().map(|cstring| cstring.as_ptr()).collect::>(); + + ortsys![unsafe FillStringTensor(value_ptr, string_pointers.as_ptr(), string_pointers.len() as _) -> Error::FillStringTensor]; + + Ok(Value { + inner: ValueInner::RustOwned { + ptr: unsafe { NonNull::new_unchecked(value_ptr) }, + _array: Box::new(()), + _memory_info: None + }, + _markers: PhantomData + }) + } +} + +impl Tensor { /// Construct a tensor [`Value`] in a given allocator with a given shape and datatype. The data contained in the /// value will be zero-allocated on the allocation device. /// /// This can be used to create a tensor with data on a certain device. For example, to create a tensor with pinned /// (CPU) memory for use with CUDA: /// ```no_run - /// # use ort::{Allocator, Session, Value, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; + /// # use ort::{Allocator, Session, Tensor, MemoryInfo, MemoryType, AllocationDevice, AllocatorType}; /// # fn main() -> ort::Result<()> { /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// let allocator = Allocator::new( @@ -33,11 +109,11 @@ impl Value { /// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)? /// )?; /// - /// let mut img_input = Value::new_tensor::(&allocator, [1, 128, 128, 3])?; + /// let mut img_input = Tensor::::new(&allocator, [1, 128, 128, 3])?; /// # Ok(()) /// # } /// ``` - pub fn new_tensor(allocator: &Allocator, shape: impl ToDimensions) -> Result { + pub fn new(allocator: &Allocator, shape: impl ToDimensions) -> Result> { let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); let shape = shape.to_dimensions(None)?; @@ -60,7 +136,8 @@ impl Value { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, _array: Box::new(()), _memory_info: None - } + }, + _markers: PhantomData }) } @@ -99,7 +176,7 @@ impl Value { /// /// Raw data provided as a `Arc>`, `Box<[T]>`, or `Vec` will never be copied. Raw data is expected to be /// in standard, contigous layout. - pub fn from_array(input: impl IntoValueTensor) -> Result { + pub fn from_array(input: impl IntoValueTensor) -> Result> { let memory_info = MemoryInfo::new_cpu(AllocatorType::Arena, MemoryType::Default)?; let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); @@ -130,78 +207,8 @@ impl Value { ptr: unsafe { NonNull::new_unchecked(value_ptr) }, _array: guard, _memory_info: Some(memory_info) - } - }) - } - - /// Construct a [`Value`] from an array of strings. - /// - /// Just like numeric tensors, string tensor `Value`s can be created from: - /// - (with feature `ndarray`) a shared reference to a [`ndarray::CowArray`] (`&CowArray<'_, T, D>`); - /// - (with feature `ndarray`) a mutable/exclusive reference to an [`ndarray::ArcArray`] (`&mut ArcArray`); - /// - (with feature `ndarray`) an owned [`ndarray::Array`]; - /// - (with feature `ndarray`) a borrowed view of another array, as an [`ndarray::ArrayView`] (`ArrayView<'_, T, - /// D>`); - /// - a tuple of `(dimensions, data)` where: - /// * `dimensions` is one of `Vec`, `[I]` or `&[I]`, where `I` is `i64` or `usize`; - /// * and `data` is one of `Vec`, `Box<[T]>`, `Arc>`, or `&[T]`. - /// - /// ``` - /// # use ort::{Session, Value}; - /// # fn main() -> ort::Result<()> { - /// # let session = Session::builder()?.commit_from_file("tests/data/vectorizer.onnx")?; - /// // You'll need to obtain an `Allocator` from a session in order to create string tensors. - /// let allocator = session.allocator(); - /// - /// // Create a string tensor from a raw data vector - /// let data = vec!["hello", "world"]; - /// let value = Value::from_string_array(allocator, ([data.len()], data.into_boxed_slice()))?; - /// - /// // Create a string tensor from an `ndarray::Array` - /// #[cfg(feature = "ndarray")] - /// let value = Value::from_string_array( - /// allocator, - /// ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap() - /// )?; - /// # Ok(()) - /// # } - /// ``` - /// - /// Note that string data will *always* be copied, no matter what form the data is provided in. - pub fn from_string_array(allocator: &Allocator, input: impl IntoValueTensor) -> Result { - let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut(); - - let (shape, data) = input.ref_parts()?; - let shape_ptr: *const i64 = shape.as_ptr(); - let shape_len = shape.len(); - - // create tensor without data -- data is filled in later - ortsys![ - unsafe CreateTensorAsOrtValue(allocator.ptr.as_ptr(), shape_ptr, shape_len as _, TensorElementType::String.into(), &mut value_ptr) - -> Error::CreateTensor; - nonNull(value_ptr) - ]; - - // create null-terminated copies of each string, as per `FillStringTensor` docs - let null_terminated_copies: Vec = data - .iter() - .map(|elt| { - let slice = elt.as_utf8_bytes(); - ffi::CString::new(slice) - }) - .collect::, _>>() - .map_err(Error::FfiStringNull)?; - - let string_pointers = null_terminated_copies.iter().map(|cstring| cstring.as_ptr()).collect::>(); - - ortsys![unsafe FillStringTensor(value_ptr, string_pointers.as_ptr(), string_pointers.len() as _) -> Error::FillStringTensor]; - - Ok(Value { - inner: ValueInner::RustOwned { - ptr: unsafe { NonNull::new_unchecked(value_ptr) }, - _array: Box::new(()), - _memory_info: None - } + }, + _markers: PhantomData }) } } @@ -210,6 +217,7 @@ pub trait IntoValueTensor { type Item; fn ref_parts(&self) -> Result<(Vec, &[Self::Item])>; + #[allow(clippy::type_complexity)] fn into_parts(self) -> Result<(Vec, *mut Self::Item, usize, Box)>; } @@ -432,3 +440,120 @@ impl IntoValueTensor for (D, Arc TryFrom<&'i CowArray<'v, T, D>> for Tensor +where + 'i: 'v +{ + type Error = Error; + fn try_from(arr: &'i CowArray<'v, T, D>) -> Result { + Tensor::from_array(arr) + } +} + +#[cfg(feature = "ndarray")] +#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] +impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for Tensor { + type Error = Error; + fn try_from(arr: ArrayView<'v, T, D>) -> Result { + Tensor::from_array(arr) + } +} + +#[cfg(feature = "ndarray")] +#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] +impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynTensor +where + 'i: 'v +{ + type Error = Error; + fn try_from(arr: &'i CowArray<'v, T, D>) -> Result { + Tensor::from_array(arr).map(|c| c.downcast()) + } +} + +#[cfg(feature = "ndarray")] +#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] +impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynTensor { + type Error = Error; + fn try_from(arr: ArrayView<'v, T, D>) -> Result { + Tensor::from_array(arr).map(|c| c.downcast()) + } +} + +#[cfg(feature = "ndarray")] +#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] +impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for DynValue +where + 'i: 'v +{ + type Error = Error; + fn try_from(arr: &'i CowArray<'v, T, D>) -> Result { + Tensor::from_array(arr).map(|c| c.into_dyn()) + } +} + +#[cfg(feature = "ndarray")] +#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] +impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynValue { + type Error = Error; + fn try_from(arr: ArrayView<'v, T, D>) -> Result { + Tensor::from_array(arr).map(|c| c.into_dyn()) + } +} + +macro_rules! impl_try_from { + (@T,I $($t:ty),+) => { + $( + impl TryFrom<$t> for Tensor { + type Error = Error; + fn try_from(value: $t) -> Result { + Tensor::from_array(value) + } + } + impl TryFrom<$t> for DynTensor { + type Error = Error; + fn try_from(value: $t) -> Result { + Tensor::from_array(value).map(|c| c.downcast()) + } + } + impl TryFrom<$t> for crate::DynValue { + type Error = Error; + fn try_from(value: $t) -> Result { + Tensor::from_array(value).map(|c| c.into_dyn()) + } + } + )+ + }; + (@T,D $($t:ty),+) => { + $( + #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] + impl TryFrom<$t> for Tensor { + type Error = Error; + fn try_from(value: $t) -> Result { + Tensor::from_array(value) + } + } + #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] + impl TryFrom<$t> for DynTensor { + type Error = Error; + fn try_from(value: $t) -> Result { + Tensor::from_array(value).map(|c| c.downcast()) + } + } + #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] + impl TryFrom<$t> for crate::DynValue { + type Error = Error; + fn try_from(value: $t) -> Result { + Tensor::from_array(value).map(|c| c.into_dyn()) + } + } + )+ + }; +} + +#[cfg(feature = "ndarray")] +impl_try_from!(@T,D &mut ArcArray, Array); +impl_try_from!(@T,I (I, Arc>), (I, Vec), (I, Box<[T]>), (I, &[T])); diff --git a/src/value/impl_tensor/extract.rs b/src/value/impl_tensor/extract.rs index 06a9492e..a973a047 100644 --- a/src/value/impl_tensor/extract.rs +++ b/src/value/impl_tensor/extract.rs @@ -1,34 +1,46 @@ -use std::{os::raw::c_char, ptr, string::FromUtf8Error}; +use std::{fmt::Debug, os::raw::c_char, ptr, string::FromUtf8Error}; #[cfg(feature = "ndarray")] use ndarray::IxDyn; +use super::TensorValueTypeMarker; #[cfg(feature = "ndarray")] use crate::tensor::{extract_primitive_array, extract_primitive_array_mut}; use crate::{ ortsys, tensor::{IntoTensorElementType, TensorElementType}, - Error, Result, Value + Error, Result, Tensor, Value }; -impl Value { - /// Attempt to extract the underlying data into a Rust `ndarray`. +impl Value { + /// Attempt to extract the underlying data of type `T` into a read-only [`ndarray::ArrayView`]. + /// + /// See also: + /// - the mutable counterpart of this function, [`Tensor::try_extract_tensor_mut`]. + /// - the infallible counterpart, [`Tensor::extract_tensor`], for typed [`Tensor`]s. + /// - the alternative function for strings, [`Tensor::try_extract_string_tensor`]. /// /// ``` /// # use std::sync::Arc; - /// # use ort::{Session, Value, ValueType, TensorElementType}; + /// # use ort::{Session, Value}; /// # fn main() -> ort::Result<()> { /// let array = ndarray::Array4::::ones((1, 16, 16, 3)); /// let value = Value::from_array(array.view())?; /// - /// let extracted = value.extract_tensor::()?; + /// let extracted = value.try_extract_tensor::()?; /// assert_eq!(array.into_dyn(), extracted); /// # Ok(()) /// # } /// ``` + /// + /// # Errors + /// May return an error if: + /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the + /// infallible [`Tensor::extract_tensor`] instead)* + /// - The provided type `T` does not match the tensor's element type. #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - pub fn extract_tensor(&self) -> Result> { + pub fn try_extract_tensor(&self) -> Result> { let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; @@ -60,16 +72,18 @@ impl Value { res } - /// Attempt to extract the underlying data into a mutable Rust `ndarray`. + /// Attempt to extract the underlying data of type `T` into a mutable read-only [`ndarray::ArrayViewMut`]. + /// + /// See also the infallible counterpart, [`Tensor::extract_tensor_mut`], for typed [`Tensor`]s. /// /// ``` /// # use std::sync::Arc; - /// # use ort::{Session, Value, ValueType, TensorElementType}; + /// # use ort::{Session, Value}; /// # fn main() -> ort::Result<()> { /// let array = ndarray::Array4::::ones((1, 16, 16, 3)); /// let mut value = Value::from_array(array.view())?; /// - /// let mut extracted = value.extract_tensor_mut::()?; + /// let mut extracted = value.try_extract_tensor_mut::()?; /// extracted[[0, 0, 0, 1]] = 0.0; /// /// let mut array = array.into_dyn(); @@ -79,9 +93,15 @@ impl Value { /// # Ok(()) /// # } /// ``` + /// + /// # Errors + /// May return an error if: + /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the + /// infallible [`Tensor::extract_tensor_mut`] instead)* + /// - The provided type `T` does not match the tensor's element type. #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - pub fn extract_tensor_mut(&mut self) -> Result> { + pub fn try_extract_tensor_mut(&mut self) -> Result> { let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; @@ -113,22 +133,33 @@ impl Value { res } - /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a view - /// into its data. + /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and an + /// immutable view into its data. + /// + /// See also: + /// - the mutable counterpart of this function, [`Tensor::try_extract_raw_tensor_mut`]. + /// - the infallible counterpart, [`Tensor::extract_raw_tensor`], for typed [`Tensor`]s. + /// - the alternative function for strings, [`Tensor::try_extract_raw_string_tensor`]. /// /// ``` - /// # use ort::{Session, Value, ValueType, TensorElementType}; + /// # use ort::{Session, Value}; /// # fn main() -> ort::Result<()> { /// let array = vec![1_i64, 2, 3, 4, 5]; /// let value = Value::from_array(([array.len()], array.clone().into_boxed_slice()))?; /// - /// let (extracted_shape, extracted_data) = value.extract_raw_tensor::()?; + /// let (extracted_shape, extracted_data) = value.try_extract_raw_tensor::()?; /// assert_eq!(extracted_data, &array); /// assert_eq!(extracted_shape, [5]); /// # Ok(()) /// # } /// ``` - pub fn extract_raw_tensor(&self) -> Result<(Vec, &[T])> { + /// + /// # Errors + /// May return an error if: + /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the + /// infallible [`Tensor::extract_raw_tensor`] instead)* + /// - The provided type `T` does not match the tensor's element type. + pub fn try_extract_raw_tensor(&self) -> Result<(Vec, &[T])> { let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; @@ -164,22 +195,30 @@ impl Value { res } - /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a view - /// into its data. + /// Attempt to extract the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a + /// mutable view into its data. + /// + /// See also the infallible counterpart, [`Tensor::extract_raw_tensor_mut`], for typed [`Tensor`]s. /// /// ``` - /// # use ort::{Session, Value, ValueType, TensorElementType}; + /// # use ort::{Session, Value}; /// # fn main() -> ort::Result<()> { /// let array = vec![1_i64, 2, 3, 4, 5]; - /// let value = Value::from_array(([array.len()], array.clone().into_boxed_slice()))?; + /// let mut value = Value::from_array(([array.len()], array.clone().into_boxed_slice()))?; /// - /// let (extracted_shape, extracted_data) = value.extract_raw_tensor::()?; + /// let (extracted_shape, extracted_data) = value.try_extract_raw_tensor_mut::()?; /// assert_eq!(extracted_data, &array); /// assert_eq!(extracted_shape, [5]); /// # Ok(()) /// # } /// ``` - pub fn extract_raw_tensor_mut(&mut self) -> Result<(Vec, &mut [T])> { + /// + /// # Errors + /// May return an error if: + /// - This is a [`crate::DynValue`], and the value is not actually a tensor. *(for typed [`Tensor`]s, use the + /// infallible [`Tensor::extract_raw_tensor_mut`] instead)* + /// - The provided type `T` does not match the tensor's element type. + pub fn try_extract_raw_tensor_mut(&mut self) -> Result<(Vec, &mut [T])> { let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; @@ -218,20 +257,20 @@ impl Value { /// Attempt to extract the underlying data into a Rust `ndarray`. /// /// ``` - /// # use std::sync::Arc; - /// # use ort::{Session, Value, ValueType, TensorElementType}; + /// # use ort::{Allocator, Session, DynTensor, TensorElementType}; /// # fn main() -> ort::Result<()> { - /// let array = ndarray::Array4::::ones((1, 16, 16, 3)); - /// let value = Value::from_array(array.view())?; + /// # let allocator = Allocator::default(); + /// let array = ndarray::Array1::from_vec(vec!["hello", "world"]); + /// let tensor = DynTensor::from_string_array(&allocator, array.clone())?; /// - /// let extracted = value.extract_tensor::()?; + /// let extracted = tensor.try_extract_string_tensor()?; /// assert_eq!(array.into_dyn(), extracted); /// # Ok(()) /// # } /// ``` #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - pub fn extract_string_tensor(&self) -> Result> { + pub fn try_extract_string_tensor(&self) -> Result> { let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; @@ -299,19 +338,19 @@ impl Value { /// an owned `Vec` of its data. /// /// ``` - /// # use ort::{Allocator, Session, Value, ValueType, TensorElementType}; + /// # use ort::{Allocator, Session, DynTensor, TensorElementType}; /// # fn main() -> ort::Result<()> { /// # let allocator = Allocator::default(); /// let array = vec!["hello", "world"]; - /// let value = Value::from_string_array(&allocator, ([array.len()], array.clone().into_boxed_slice()))?; + /// let tensor = DynTensor::from_string_array(&allocator, ([array.len()], array.clone().into_boxed_slice()))?; /// - /// let (extracted_shape, extracted_data) = value.extract_raw_string_tensor()?; + /// let (extracted_shape, extracted_data) = tensor.try_extract_raw_string_tensor()?; /// assert_eq!(extracted_data, array); /// assert_eq!(extracted_shape, [2]); /// # Ok(()) /// # } /// ``` - pub fn extract_raw_string_tensor(&self) -> Result<(Vec, Vec)> { + pub fn try_extract_raw_string_tensor(&self) -> Result<(Vec, Vec)> { let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; @@ -375,4 +414,118 @@ impl Value { ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; res } + + /// Returns the shape of the tensor. + /// + /// ``` + /// # use ort::{Allocator, Sequence, Tensor}; + /// # fn main() -> ort::Result<()> { + /// # let allocator = Allocator::default(); + /// let tensor = Tensor::::new(&allocator, [1, 128, 128, 3])?; + /// + /// assert_eq!(tensor.shape()?, &[1, 128, 128, 3]); + /// # Ok(()) + /// # } + /// ``` + pub fn shape(&self) -> Result> { + let mut tensor_info_ptr: *mut ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + ortsys![unsafe GetTensorTypeAndShape(self.ptr(), &mut tensor_info_ptr) -> Error::GetTensorTypeAndShape]; + + let res = { + let mut num_dims = 0; + ortsys![unsafe GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> Error::GetDimensionsCount]; + + let mut node_dims: Vec = vec![0; num_dims as _]; + ortsys![unsafe GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> Error::GetDimensions]; + + Ok(node_dims) + }; + ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)]; + res + } +} + +impl Tensor { + /// Extracts the underlying data into a read-only [`ndarray::ArrayView`]. + /// + /// ``` + /// # use std::sync::Arc; + /// # use ort::{Session, Tensor, TensorElementType}; + /// # fn main() -> ort::Result<()> { + /// let array = ndarray::Array4::::ones((1, 16, 16, 3)); + /// let tensor = Tensor::from_array(array.view())?; + /// + /// let extracted = tensor.extract_tensor(); + /// assert_eq!(array.into_dyn(), extracted); + /// # Ok(()) + /// # } + /// ``` + #[cfg(feature = "ndarray")] + #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] + pub fn extract_tensor(&self) -> ndarray::ArrayViewD<'_, T> { + self.try_extract_tensor().unwrap() + } + + /// Extracts the underlying data into a mutable [`ndarray::ArrayViewMut`]. + /// + /// ``` + /// # use std::sync::Arc; + /// # use ort::{Session, Tensor, TensorElementType}; + /// # fn main() -> ort::Result<()> { + /// let array = ndarray::Array4::::ones((1, 16, 16, 3)); + /// let mut tensor = Tensor::from_array(array.view())?; + /// + /// let mut extracted = tensor.extract_tensor_mut(); + /// extracted[[0, 0, 0, 1]] = 0.0; + /// + /// let mut array = array.into_dyn(); + /// assert_ne!(array, extracted); + /// array[[0, 0, 0, 1]] = 0.0; + /// assert_eq!(array, extracted); + /// # Ok(()) + /// # } + /// ``` + #[cfg(feature = "ndarray")] + #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] + pub fn extract_tensor_mut(&mut self) -> ndarray::ArrayViewMutD<'_, T> { + self.try_extract_tensor_mut().unwrap() + } + + /// Extracts the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and an immutable + /// view into its data. + /// + /// ``` + /// # use ort::{Session, Tensor, TensorElementType}; + /// # fn main() -> ort::Result<()> { + /// let array = vec![1_i64, 2, 3, 4, 5]; + /// let tensor = Tensor::from_array(([array.len()], array.clone().into_boxed_slice()))?; + /// + /// let (extracted_shape, extracted_data) = tensor.extract_raw_tensor(); + /// assert_eq!(extracted_data, &array); + /// assert_eq!(extracted_shape, [5]); + /// # Ok(()) + /// # } + /// ``` + pub fn extract_raw_tensor(&self) -> (Vec, &[T]) { + self.try_extract_raw_tensor().unwrap() + } + + /// Extracts the underlying data into a "raw" view tuple, consisting of the tensor's dimensions and a mutable view + /// into its data. + /// + /// ``` + /// # use ort::{Session, Tensor, TensorElementType}; + /// # fn main() -> ort::Result<()> { + /// let array = vec![1_i64, 2, 3, 4, 5]; + /// let tensor = Tensor::from_array(([array.len()], array.clone().into_boxed_slice()))?; + /// + /// let (extracted_shape, extracted_data) = tensor.extract_raw_tensor(); + /// assert_eq!(extracted_data, &array); + /// assert_eq!(extracted_shape, [5]); + /// # Ok(()) + /// # } + /// ``` + pub fn extract_raw_tensor_mut(&mut self) -> (Vec, &mut [T]) { + self.try_extract_raw_tensor_mut().unwrap() + } } diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index 8e972a23..10d069ae 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -1,7 +1,107 @@ mod create; mod extract; -pub use self::create::ToDimensions; +use std::{ + fmt::Debug, + marker::PhantomData, + ops::{Index, IndexMut}, + ptr::NonNull +}; + +use super::{UpcastableTarget, Value, ValueInner, ValueTypeMarker}; +use crate::{ortsys, DynValue, IntoTensorElementType, ValueRef, ValueRefMut, ValueType}; + +pub trait TensorValueTypeMarker: ValueTypeMarker {} + +#[derive(Debug)] +pub struct DynTensorValueType; +impl ValueTypeMarker for DynTensorValueType {} +impl TensorValueTypeMarker for DynTensorValueType {} + +#[derive(Debug)] +pub struct TensorValueType(PhantomData); +impl ValueTypeMarker for TensorValueType {} +impl TensorValueTypeMarker for TensorValueType {} + +pub type DynTensor = Value; +pub type Tensor = Value>; + +pub type DynTensorRef<'v> = ValueRef<'v, DynTensorValueType>; +pub type DynTensorRefMut<'v> = ValueRefMut<'v, DynTensorValueType>; +pub type TensorRef<'v, T> = ValueRef<'v, TensorValueType>; +pub type TensorRefMut<'v, T> = ValueRefMut<'v, TensorValueType>; + +impl UpcastableTarget for DynTensorValueType { + fn can_upcast(dtype: &ValueType) -> bool { + matches!(dtype, ValueType::Tensor { .. }) + } +} + +impl Tensor { + /// Converts from a strongly-typed [`Tensor`] to a type-erased [`DynTensor`]. + #[inline] + pub fn downcast(self) -> DynTensor { + unsafe { std::mem::transmute(self) } + } + + /// Converts from a strongly-typed [`Tensor`] to a reference to a type-erased [`DynTensor`]. + #[inline] + pub fn downcast_ref(&self) -> DynTensorRef { + DynTensorRef::new(unsafe { + Value::from_ptr_nodrop( + NonNull::new_unchecked(self.ptr()), + if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + ) + }) + } + + /// Converts from a strongly-typed [`Tensor`] to a mutable reference to a type-erased [`DynTensor`]. + #[inline] + pub fn downcast_mut(&mut self) -> DynTensorRefMut { + DynTensorRefMut::new(unsafe { + Value::from_ptr_nodrop( + NonNull::new_unchecked(self.ptr()), + if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + ) + }) + } +} + +impl UpcastableTarget for TensorValueType { + fn can_upcast(dtype: &ValueType) -> bool { + match dtype { + ValueType::Tensor { ty, .. } => *ty == T::into_tensor_element_type(), + _ => false + } + } +} + +impl From>> for DynValue { + fn from(value: Value>) -> Self { + value.into_dyn() + } +} +impl From> for DynValue { + fn from(value: Value) -> Self { + value.into_dyn() + } +} + +impl Index<[i64; N]> for Tensor { + type Output = T; + fn index(&self, index: [i64; N]) -> &Self::Output { + let mut out: *mut ort_sys::c_void = std::ptr::null_mut(); + ortsys![unsafe TensorAt(self.ptr(), index.as_ptr(), N as _, &mut out).unwrap()]; + unsafe { &*out.cast::() } + } +} +impl IndexMut<[i64; N]> for Tensor { + fn index_mut(&mut self, index: [i64; N]) -> &mut Self::Output { + let mut out: *mut ort_sys::c_void = std::ptr::null_mut(); + ortsys![unsafe TensorAt(self.ptr(), index.as_ptr(), N as _, &mut out).unwrap()]; + unsafe { &mut *out.cast::() } + } +} #[cfg(test)] mod tests { @@ -26,7 +126,7 @@ mod tests { } ); - let (shape, data) = value.extract_raw_tensor::()?; + let (shape, data) = value.extract_raw_tensor(); assert_eq!(shape, vec![v.len() as i64]); assert_eq!(data, &v); @@ -43,16 +143,16 @@ mod tests { let value = Value::from_array(&mut arc2)?; drop((arc1, arc2)); - assert_eq!(value.extract_raw_tensor::()?.1, &v); + assert_eq!(value.extract_raw_tensor().1, &v); let cow = CowArray::from(Array1::from_vec(v.clone())); let value = Value::from_array(&cow)?; - assert_eq!(value.extract_raw_tensor::()?.1, &v); + assert_eq!(value.extract_raw_tensor().1, &v); let owned = Array1::from_vec(v.clone()); let value = Value::from_array(owned.view())?; drop(owned); - assert_eq!(value.extract_raw_tensor::()?.1, &v); + assert_eq!(value.extract_raw_tensor().1, &v); Ok(()) } @@ -65,7 +165,7 @@ mod tests { let shape = vec![v.len() as i64]; let value = Value::from_array((shape, Arc::clone(&arc)))?; drop(arc); - assert_eq!(value.extract_raw_tensor::()?.1, &v); + assert_eq!(value.try_extract_raw_tensor::()?.1, &v); Ok(()) } @@ -76,8 +176,8 @@ mod tests { let allocator = Allocator::default(); let v = Array1::from_vec(vec!["hello world".to_string(), "こんにちは世界".to_string()]); - let value = Value::from_string_array(&allocator, v.view())?; - let extracted = value.extract_string_tensor()?; + let value = DynTensor::from_string_array(&allocator, v.view())?; + let extracted = value.try_extract_string_tensor()?; assert_eq!(extracted, v.into_dyn()); Ok(()) @@ -88,8 +188,8 @@ mod tests { let allocator = Allocator::default(); let v = vec!["hello world".to_string(), "こんにちは世界".to_string()]; - let value = Value::from_string_array(&allocator, (vec![v.len() as i64], v.clone().into_boxed_slice()))?; - let (extracted_shape, extracted_view) = value.extract_raw_string_tensor()?; + let value = DynTensor::from_string_array(&allocator, (vec![v.len() as i64], v.clone().into_boxed_slice()))?; + let (extracted_shape, extracted_view) = value.try_extract_raw_string_tensor()?; assert_eq!(extracted_shape, [v.len() as i64]); assert_eq!(extracted_view, v); @@ -106,10 +206,10 @@ mod tests { let value_vec = Value::from_array((shape, v.clone()))?; let value_slice = Value::from_array((shape, &v[..]))?; - assert_eq!(value_arc_box.extract_raw_tensor::()?.1, &v); - assert_eq!(value_box.extract_raw_tensor::()?.1, &v); - assert_eq!(value_vec.extract_raw_tensor::()?.1, &v); - assert_eq!(value_slice.extract_raw_tensor::()?.1, &v); + assert_eq!(value_arc_box.extract_raw_tensor().1, &v); + assert_eq!(value_box.extract_raw_tensor().1, &v); + assert_eq!(value_vec.extract_raw_tensor().1, &v); + assert_eq!(value_slice.extract_raw_tensor().1, &v); Ok(()) } diff --git a/src/value/mod.rs b/src/value/mod.rs index 7c0772f8..60e5e046 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -7,29 +7,24 @@ use std::{ sync::Arc }; -#[cfg(feature = "ndarray")] -use ndarray::{ArcArray, Array, ArrayView, CowArray, Dimension}; - -use crate::{ - error::status_to_result, - memory::MemoryInfo, - ortsys, - session::SharedSessionInner, - tensor::{IntoTensorElementType, TensorElementType}, - Error, Result -}; - mod impl_map; mod impl_sequence; mod impl_tensor; -use self::impl_tensor::ToDimensions; +pub use self::{ + impl_map::{DynMap, DynMapRef, DynMapRefMut, DynMapValueType, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker}, + impl_sequence::{ + DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, Sequence, SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker + }, + impl_tensor::{DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker} +}; +use crate::{error::status_to_result, memory::MemoryInfo, ortsys, session::SharedSessionInner, tensor::TensorElementType, Error, Result}; /// The type of a [`Value`], or a session input/output. /// /// ``` /// # use std::sync::Arc; -/// # use ort::{Session, Value, ValueType, TensorElementType}; +/// # use ort::{Session, Tensor, ValueType, TensorElementType}; /// # fn main() -> ort::Result<()> { /// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; /// // `ValueType`s can be obtained from session inputs/outputs: @@ -44,7 +39,7 @@ use self::impl_tensor::ToDimensions; /// ); /// /// // Or by `Value`s created in Rust or output by a session. -/// let value = Value::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; +/// let value = Tensor::from_array(([5usize], vec![1_i64, 2, 3, 4, 5].into_boxed_slice()))?; /// assert_eq!( /// value.dtype()?, /// ValueType::Tensor { @@ -171,98 +166,147 @@ pub(crate) enum ValueInner { } } -/// A temporary version of [`Value`] with a lifetime specifier. +/// A temporary version of a [`Value`] with a lifetime specifier. #[derive(Debug)] -pub struct ValueRef<'v> { - inner: Value, +pub struct ValueRef<'v, Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> { + inner: Value, lifetime: PhantomData<&'v ()> } -impl<'v> ValueRef<'v> { - pub(crate) fn new(inner: Value) -> Self { +impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> { + pub(crate) fn new(inner: Value) -> Self { ValueRef { inner, lifetime: PhantomData } } + + pub fn into_dyn(self) -> ValueRef<'v, DynValueTypeMarker> { + unsafe { std::mem::transmute(self) } + } } -impl<'v> Deref for ValueRef<'v> { - type Target = Value; +impl<'v, Type: ValueTypeMarker + ?Sized> Deref for ValueRef<'v, Type> { + type Target = Value; fn deref(&self) -> &Self::Target { &self.inner } } -/// A mutable temporary version of [`Value`] with a lifetime specifier. +/// A mutable temporary version of a [`Value`] with a lifetime specifier. #[derive(Debug)] -pub struct ValueRefMut<'v> { - inner: Value, +pub struct ValueRefMut<'v, Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> { + inner: Value, lifetime: PhantomData<&'v ()> } -impl<'v> ValueRefMut<'v> { - pub(crate) fn new(inner: Value) -> Self { +impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> { + pub(crate) fn new(inner: Value) -> Self { ValueRefMut { inner, lifetime: PhantomData } } + + pub fn into_dyn(self) -> ValueRefMut<'v, DynValueTypeMarker> { + unsafe { std::mem::transmute(self) } + } } -impl<'v> Deref for ValueRefMut<'v> { - type Target = Value; +impl<'v, Type: ValueTypeMarker + ?Sized> Deref for ValueRefMut<'v, Type> { + type Target = Value; fn deref(&self) -> &Self::Target { &self.inner } } -impl<'v> DerefMut for ValueRefMut<'v> { +impl<'v, Type: ValueTypeMarker + ?Sized> DerefMut for ValueRefMut<'v, Type> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.inner } } -/// A [`Value`] contains data for inputs/outputs in ONNX Runtime graphs. [`Value`]s can hold a tensor, sequence -/// (array/vector), or map. +/// A [`Value`] contains data for inputs/outputs in ONNX Runtime graphs. [`Value`]s can be a [`Tensor`], [`Sequence`] +/// (aka array/vector), or [`Map`]. /// /// ## Creation -/// `Value`s can be created via methods like [`Value::from_array`], or as the output from running a [`crate::Session`]. +/// Values can be created via methods like [`Tensor::from_array`], or as the output from running a [`crate::Session`]. /// /// ``` -/// # use ort::{Session, Value, ValueType, TensorElementType}; +/// # use ort::{Session, Tensor, ValueType, TensorElementType}; /// # fn main() -> ort::Result<()> { /// # let upsample = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; -/// // Create a value from a raw data vector -/// let value = Value::from_array(([1usize, 1, 1, 3], vec![1.0_f32, 2.0, 3.0].into_boxed_slice()))?; +/// // Create a Tensor value from a raw data vector +/// let value = Tensor::from_array(([1usize, 1, 1, 3], vec![1.0_f32, 2.0, 3.0].into_boxed_slice()))?; /// -/// // Create a value from an `ndarray::Array` +/// // Create a Tensor value from an `ndarray::Array` /// #[cfg(feature = "ndarray")] -/// let value = Value::from_array(ndarray::Array4::::zeros((1, 16, 16, 3)))?; +/// let value = Tensor::from_array(ndarray::Array4::::zeros((1, 16, 16, 3)))?; /// -/// // Get a value from a session's output +/// // Get a DynValue from a session's output /// let value = &upsample.run(ort::inputs![value]?)?[0]; /// # Ok(()) /// # } /// ``` /// -/// See [`Value::from_array`] for more details on what tensor values are accepted. +/// See [`Tensor::from_array`] for more details on what tensor values are accepted. /// /// ## Usage -/// You can access the data in a `Value` by using the relevant `extract` methods: [`Value::extract_tensor`] & -/// [`Value::extract_raw_tensor`], [`Value::extract_sequence`], and [`Value::extract_map`]. +/// You can access the data contained in a `Value` by using the relevant `extract` methods. +/// You can also use [`DynValue::upcast`] to attempt to convert from a [`DynValue`] to a more strongly typed value. +/// +/// For dynamic values, where the type is not known at compile time, see the `try_extract_*` methods: +/// - [`Tensor::try_extract_tensor`], [`Tensor::try_extract_raw_tensor`] +/// - [`Sequence::try_extract_sequence`] +/// - [`Map::try_extract_map`] +/// +/// If the type was created from Rust (via a method like [`Tensor::from_array`] or via upcasting), you can directly +/// extract the data using the infallible extract methods: +/// - [`Tensor::extract_tensor`], [`Tensor::extract_raw_tensor`] #[derive(Debug)] -pub struct Value { - inner: ValueInner +pub struct Value { + inner: ValueInner, + _markers: PhantomData } +/// A dynamic value, which could be a [`Tensor`], [`Sequence`], or [`Map`]. +/// +/// To attempt to convert a dynamic value to a strongly typed value, use [`DynValue::upcast`]. You can also attempt to +/// extract data from dynamic values directly using `try_extract_*` methods; see [`Value`] for more information. +pub type DynValue = Value; + +/// Marker trait used to determine what operations can and cannot be performed on a [`Value`] of a given type. +/// +/// For example, [`Tensor::try_extract_tensor`] can only be used on [`Value`]s with the [`TensorValueTypeMarker`] (which +/// inherits this trait), i.e. [`Tensor`]s, [`DynTensor`]s, and [`DynValue`]s. +pub trait ValueTypeMarker: Debug {} + +/// Represents a type that a [`DynValue`] can be upcast to. +pub trait UpcastableTarget: ValueTypeMarker { + fn can_upcast(dtype: &ValueType) -> bool; +} + +// this implementation is used in case we want to extract `DynValue`s from a [`Sequence`]; see `try_extract_sequence` +impl UpcastableTarget for DynValueTypeMarker { + fn can_upcast(_: &ValueType) -> bool { + true + } +} + +/// The dynamic type marker, used for values which can be of any type. +#[derive(Debug)] +pub struct DynValueTypeMarker; +impl ValueTypeMarker for DynValueTypeMarker {} +impl MapValueTypeMarker for DynValueTypeMarker {} +impl SequenceValueTypeMarker for DynValueTypeMarker {} +impl TensorValueTypeMarker for DynValueTypeMarker {} + unsafe impl Send for Value {} -impl Value { +impl Value { /// Returns the data type of this [`Value`]. pub fn dtype(&self) -> Result { let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = std::ptr::null_mut(); ortsys![unsafe GetTypeInfo(self.ptr(), &mut typeinfo_ptr) -> Error::GetTypeInfo; nonNull(typeinfo_ptr)]; let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN; - let status = ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty)]; - status_to_result(status).map_err(Error::GetOnnxTypeFromTypeInfo)?; + ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr, &mut ty) -> Error::GetOnnxTypeFromTypeInfo]; let io_type = match ty { ort_sys::ONNXType::ONNX_TYPE_TENSOR | ort_sys::ONNXType::ONNX_TYPE_SPARSETENSOR => { let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); @@ -297,18 +341,20 @@ impl Value { /// - `ptr` must be a valid pointer to an [`ort_sys::OrtValue`]. /// - `session` must be `Some` for values returned from a session. #[must_use] - pub unsafe fn from_ptr(ptr: NonNull, session: Option>) -> Value { + pub unsafe fn from_ptr(ptr: NonNull, session: Option>) -> Value { Value { - inner: ValueInner::CppOwned { ptr, drop: true, _session: session } + inner: ValueInner::CppOwned { ptr, drop: true, _session: session }, + _markers: PhantomData } } /// A variant of [`Value::from_ptr`] that does not release the value upon dropping. Used in operator kernel /// contexts. #[must_use] - pub(crate) unsafe fn from_ptr_nodrop(ptr: NonNull, session: Option>) -> Value { + pub(crate) unsafe fn from_ptr_nodrop(ptr: NonNull, session: Option>) -> Value { Value { - inner: ValueInner::CppOwned { ptr, drop: false, _session: session } + inner: ValueInner::CppOwned { ptr, drop: false, _session: session }, + _markers: PhantomData } } @@ -320,7 +366,7 @@ impl Value { } /// Create a view of this value's data. - pub fn view(&self) -> ValueRef<'_> { + pub fn view(&self) -> ValueRef<'_, Type> { ValueRef::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), @@ -329,8 +375,8 @@ impl Value { }) } - /// Create a view of this value's data. - pub fn view_mut(&mut self) -> ValueRefMut<'_> { + /// Create a mutable view of this value's data. + pub fn view_mut(&mut self) -> ValueRefMut<'_, Type> { ValueRefMut::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), @@ -355,9 +401,56 @@ impl Value { ortsys![unsafe IsTensor(self.ptr(), &mut result) -> Error::GetTensorElementType]; Ok(result == 1) } + + /// Converts this value into a type-erased [`DynValue`]. + pub fn into_dyn(self) -> DynValue { + unsafe { std::mem::transmute(self) } + } + + /// Attempts to upcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed variant, + /// like [`Tensor`]. + #[inline] + pub fn upcast(self) -> Result> { + let dt = self.dtype()?; + if OtherType::can_upcast(&dt) { Ok(unsafe { std::mem::transmute(self) }) } else { panic!() } + } + + /// Attempts to upcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed reference + /// variant, like [`TensorRef`]. + #[inline] + pub fn upcast_ref(&self) -> Result> { + let dt = self.dtype()?; + if OtherType::can_upcast(&dt) { + Ok(ValueRef::new(unsafe { + Value::from_ptr_nodrop( + NonNull::new_unchecked(self.ptr()), + if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + ) + })) + } else { + panic!() + } + } + + /// Attempts to upcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed + /// mutable-reference variant, like [`TensorRefMut`]. + #[inline] + pub fn upcast_mut(&mut self) -> Result> { + let dt = self.dtype()?; + if OtherType::can_upcast(&dt) { + Ok(ValueRefMut::new(unsafe { + Value::from_ptr_nodrop( + NonNull::new_unchecked(self.ptr()), + if let ValueInner::CppOwned { _session, .. } = &self.inner { _session.clone() } else { None } + ) + })) + } else { + panic!() + } + } } -impl Drop for Value { +impl Drop for Value { fn drop(&mut self) { let ptr = self.ptr(); tracing::trace!( @@ -373,55 +466,6 @@ impl Drop for Value { } } -#[cfg(feature = "ndarray")] -#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'i, 'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<&'i CowArray<'v, T, D>> for Value -where - 'i: 'v -{ - type Error = Error; - fn try_from(arr: &'i CowArray<'v, T, D>) -> Result { - Value::from_array(arr) - } -} - -#[cfg(feature = "ndarray")] -#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for Value { - type Error = Error; - fn try_from(arr: ArrayView<'v, T, D>) -> Result { - Value::from_array(arr) - } -} - -macro_rules! impl_try_from { - (@T,I $($t:ty),+) => { - $( - impl TryFrom<$t> for Value { - type Error = Error; - fn try_from(value: $t) -> Result { - Value::from_array(value) - } - } - )+ - }; - (@T,D $($t:ty),+) => { - $( - #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] - impl TryFrom<$t> for Value { - type Error = Error; - fn try_from(value: $t) -> Result { - Value::from_array(value) - } - } - )+ - }; -} - -#[cfg(feature = "ndarray")] -impl_try_from!(@T,D &mut ArcArray, Array); -impl_try_from!(@T,I (I, Arc>), (I, Vec), (I, Box<[T]>), (I, &[T])); - pub(crate) unsafe fn extract_data_type_from_tensor_info(info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo) -> Result { let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; ortsys![GetTensorElementType(info_ptr, &mut type_sys) -> Error::GetTensorElementType]; diff --git a/tests/mnist.rs b/tests/mnist.rs index 015dd812..cbad935e 100644 --- a/tests/mnist.rs +++ b/tests/mnist.rs @@ -44,7 +44,7 @@ fn mnist_5() -> ort::Result<()> { let outputs = session.run(inputs![array]?)?; let mut probabilities: Vec<(usize, f32)> = outputs[0] - .extract_tensor()? + .try_extract_tensor()? .softmax(ndarray::Axis(1)) .iter() .copied() diff --git a/tests/squeezenet.rs b/tests/squeezenet.rs index 3fda0202..da7c596a 100644 --- a/tests/squeezenet.rs +++ b/tests/squeezenet.rs @@ -71,7 +71,7 @@ fn squeezenet_mushroom() -> ort::Result<()> { // Downloaded model does not have a softmax as final layer; call softmax on second axis // and iterate on resulting probabilities, creating an index to later access labels. let mut probabilities: Vec<(usize, f32)> = outputs[0] - .extract_tensor()? + .try_extract_tensor()? .softmax(ndarray::Axis(1)) .iter() .copied() diff --git a/tests/upsample.rs b/tests/upsample.rs index 6124f670..268fd792 100644 --- a/tests/upsample.rs +++ b/tests/upsample.rs @@ -69,7 +69,7 @@ fn upsample() -> ort::Result<()> { let outputs = session.run(inputs![&array]?)?; assert_eq!(outputs.len(), 1); - let output: ArrayViewD = outputs[0].extract_tensor()?; + let output: ArrayViewD = outputs[0].try_extract_tensor()?; // The image should have doubled in size assert_eq!(output.shape(), [1, 448, 448, 3]); @@ -106,7 +106,7 @@ fn upsample_with_ort_model() -> ort::Result<()> { let outputs = session.run(inputs![&array]?)?; assert_eq!(outputs.len(), 1); - let output: ArrayViewD = outputs[0].extract_tensor()?; + let output: ArrayViewD = outputs[0].try_extract_tensor()?; // The image should have doubled in size assert_eq!(output.shape(), [1, 448, 448, 3]); diff --git a/tests/vectorizer.rs b/tests/vectorizer.rs index c54b74aa..bef1a572 100644 --- a/tests/vectorizer.rs +++ b/tests/vectorizer.rs @@ -3,7 +3,7 @@ use std::path::Path; use ndarray::{ArrayD, IxDyn}; -use ort::{inputs, GraphOptimizationLevel, Session, Value}; +use ort::{inputs, DynTensor, GraphOptimizationLevel, Session}; use test_log::test; #[test] @@ -22,11 +22,11 @@ fn vectorizer() -> ort::Result<()> { let array = ndarray::CowArray::from(ndarray::Array::from_shape_vec((1,), vec!["document".to_owned()]).unwrap()); // Just one input - let input_tensor_values = inputs![Value::from_string_array(session.allocator(), &array)?]?; + let input_tensor_values = inputs![DynTensor::from_string_array(session.allocator(), &array)?]?; // Perform the inference let outputs = session.run(input_tensor_values)?; - assert_eq!(outputs[0].extract_tensor::()?, ArrayD::from_shape_vec(IxDyn(&[1, 9]), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).unwrap()); + assert_eq!(outputs[0].try_extract_tensor::()?, ArrayD::from_shape_vec(IxDyn(&[1, 9]), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).unwrap()); Ok(()) }