Skip to content

Commit

Permalink
feat: value specialization (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 authored Mar 28, 2024
1 parent 414befa commit 393f25f
Show file tree
Hide file tree
Showing 22 changed files with 958 additions and 287 deletions.
4 changes: 2 additions & 2 deletions examples/async-gpt2-api/examples/async-gpt2-api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use axum::{
};
use futures::Stream;
use ndarray::{array, concatenate, s, Array1, ArrayViewD, Axis};
use ort::{inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session, Value};
use ort::{inputs, CUDAExecutionProvider, GraphOptimizationLevel, Session};
use rand::Rng;
use tokenizers::Tokenizer;
use tokio::net::TcpListener;
Expand Down Expand Up @@ -67,7 +67,7 @@ fn generate_stream(tokenizer: Arc<Tokenizer>, session: Arc<Session>, tokens: Vec
for _ in 0..gen_tokens {
let array = tokens.view().insert_axis(Axis(0)).insert_axis(Axis(1));
let outputs = session.run_async(inputs![array]?)?.await?;
let generated_tokens: ArrayViewD<f32> = outputs["output1"].extract_tensor()?;
let generated_tokens: ArrayViewD<f32> = outputs["output1"].try_extract_tensor()?;

// Collect and sort logits
let probabilities = &mut generated_tokens
Expand Down
12 changes: 6 additions & 6 deletions examples/custom-ops/examples/custom-ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ impl Kernel for CustomOpOneKernel {
fn compute(&mut self, ctx: &KernelContext) -> ort::Result<()> {
let x = ctx.input(0)?.unwrap();
let y = ctx.input(1)?.unwrap();
let (x_shape, x) = x.extract_raw_tensor::<f32>()?;
let (y_shape, y) = y.extract_raw_tensor::<f32>()?;
let (x_shape, x) = x.try_extract_raw_tensor::<f32>()?;
let (y_shape, y) = y.try_extract_raw_tensor::<f32>()?;

let mut z = ctx.output(0, x_shape)?.unwrap();
let (_, z_ref) = z.extract_raw_tensor_mut::<f32>()?;
let (_, z_ref) = z.try_extract_raw_tensor_mut::<f32>()?;
for i in 0..y_shape.into_iter().reduce(|acc, e| acc * e).unwrap() as usize {
if i % 2 == 0 {
z_ref[i] = x[i];
Expand Down Expand Up @@ -70,9 +70,9 @@ impl Operator for CustomOpTwo {
impl Kernel for CustomOpTwoKernel {
fn compute(&mut self, ctx: &KernelContext) -> ort::Result<()> {
let x = ctx.input(0)?.unwrap();
let (x_shape, x) = x.extract_raw_tensor::<f32>()?;
let (x_shape, x) = x.try_extract_raw_tensor::<f32>()?;
let mut z = ctx.output(0, x_shape.clone())?.unwrap();
let (_, z_ref) = z.extract_raw_tensor_mut::<i32>()?;
let (_, z_ref) = z.try_extract_raw_tensor_mut::<i32>()?;
for i in 0..x_shape.into_iter().reduce(|acc, e| acc * e).unwrap() as usize {
z_ref[i] = (x[i] * i as f32) as i32;
}
Expand All @@ -86,7 +86,7 @@ fn main() -> ort::Result<()> {
.commit_from_file("tests/data/custom_op_test.onnx")?;

let values = session.run(ort::inputs![Array2::<f32>::zeros((3, 5)), Array2::<f32>::ones((3, 5))]?)?;
println!("{:?}", values[0].extract_tensor::<i32>()?);
println!("{:?}", values[0].try_extract_tensor::<i32>()?);

Ok(())
}
2 changes: 1 addition & 1 deletion examples/gpt2/examples/gpt2-no-ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fn main() -> ort::Result<()> {
// The model expects our input to have shape [B, _, S]
let input = (vec![1, 1, tokens.len() as i64], Arc::clone(&tokens));
let outputs = session.run(inputs![input]?)?;
let (dim, mut probabilities) = outputs["output1"].extract_raw_tensor()?;
let (dim, mut probabilities) = outputs["output1"].try_extract_raw_tensor()?;

// The output tensor will have shape [B, _, S + 1, V]
// We want only the probabilities for the last token in this sequence, which will be the token generated by the model
Expand Down
2 changes: 1 addition & 1 deletion examples/gpt2/examples/gpt2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fn main() -> ort::Result<()> {
for _ in 0..GEN_TOKENS {
let array = tokens.view().insert_axis(Axis(0)).insert_axis(Axis(1));
let outputs = session.run(inputs![array]?)?;
let generated_tokens: ArrayViewD<f32> = outputs["output1"].extract_tensor()?;
let generated_tokens: ArrayViewD<f32> = outputs["output1"].try_extract_tensor()?;

// Collect and sort logits
let probabilities = &mut generated_tokens
Expand Down
2 changes: 1 addition & 1 deletion examples/modnet/examples/modnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fn main() -> ort::Result<()> {

let outputs = model.run(inputs!["input" => input.view()]?)?;

let output = outputs["output"].extract_tensor::<f32>()?;
let output = outputs["output"].try_extract_tensor::<f32>()?;

// convert to 8-bit
let output = output.mul(255.0).map(|x| *x as u8);
Expand Down
2 changes: 1 addition & 1 deletion examples/yolov8/examples/yolov8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ fn main() -> ort::Result<()> {

// Run YOLOv8 inference
let outputs: SessionOutputs = model.run(inputs!["images" => input.view()]?)?;
let output = outputs["output0"].extract_tensor::<f32>()?.t().into_owned();
let output = outputs["output0"].try_extract_tensor::<f32>()?.t().into_owned();

let mut boxes = Vec::new();
let output = output.slice(s![.., .., 0]);
Expand Down
2 changes: 1 addition & 1 deletion src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ extern_system_fn! {

#[cfg(test)]
mod tests {
use std::sync::{atomic::Ordering, Arc, OnceLock, RwLock, RwLockWriteGuard};
use std::sync::{OnceLock, RwLock, RwLockWriteGuard};

use test_log::test;

Expand Down
8 changes: 8 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ pub enum Error {
/// Error occurred when creating ONNX tensor with specific data
#[error("Failed to create tensor with data: {0}")]
CreateTensorWithData(ErrorInternal),
/// Error occurred when attempting to create a [`crate::Sequence`].
#[error("Failed to create sequence value: {0}")]
CreateSequence(ErrorInternal),
/// Error occurred when attempting to create a [`crate::Map`].
#[error("Failed to create map value: {0}")]
CreateMap(ErrorInternal),
/// Invalid dimension when creating tensor from raw data
#[error("Invalid dimension at {0}; all dimensions must be >= 1 when creating a tensor from raw data")]
InvalidDimension(usize),
Expand Down Expand Up @@ -230,6 +236,8 @@ pub enum Error {
InvalidMapKeyType { expected: TensorElementType, actual: TensorElementType },
#[error("Tried to extract a map with a value type of {expected:?}, but the map has value type {actual:?}")]
InvalidMapValueType { expected: TensorElementType, actual: TensorElementType },
#[error("Tried to extract a sequence with a different element type than its actual type {actual:?}")]
InvalidSequenceElementType { actual: ValueType },
#[error("Error occurred while attempting to extract data from sequence value: {0}")]
ExtractSequence(ErrorInternal),
#[error("Error occurred while attempting to extract data from map value: {0}")]
Expand Down
6 changes: 3 additions & 3 deletions src/io_binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
ortsys,
session::{output::SessionOutputs, RunOptions},
value::{Value, ValueRefMut},
Error, Result, Session
Error, Result, Session, ValueTypeMarker
};

/// Enables binding of session inputs and/or outputs to pre-allocated memory.
Expand Down Expand Up @@ -41,15 +41,15 @@ impl<'s> IoBinding<'s> {
}

/// Bind a [`Value`] to a session input.
pub fn bind_input<'i: 's, S: AsRef<str>>(&mut self, name: S, ort_value: &'i mut Value) -> Result<ValueRefMut<'i>> {
pub fn bind_input<'i: 's, T: ValueTypeMarker, S: AsRef<str>>(&mut self, name: S, ort_value: &'i mut Value<T>) -> Result<ValueRefMut<'i, T>> {
let name = name.as_ref();
let cname = CString::new(name)?;
ortsys![unsafe BindInput(self.ptr.as_ptr(), cname.as_ptr(), ort_value.ptr()) -> Error::BindInput];
Ok(ort_value.view_mut())
}

/// Bind a session output to a pre-allocated [`Value`].
pub fn bind_output<'o: 's, S: AsRef<str>>(&mut self, name: S, ort_value: &'o mut Value) -> Result<ValueRefMut<'o>> {
pub fn bind_output<'o: 's, T: ValueTypeMarker, S: AsRef<str>>(&mut self, name: S, ort_value: &'o mut Value<T>) -> Result<ValueRefMut<'o, T>> {
let name = name.as_ref();
let cname = CString::new(name)?;
ortsys![unsafe BindOutput(self.ptr.as_ptr(), cname.as_ptr(), ort_value.ptr()) -> Error::BindOutput];
Expand Down
15 changes: 13 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub(crate) mod tensor;
pub(crate) mod value;

#[cfg(feature = "load-dynamic")]
use std::sync::{Arc, Mutex, MutexGuard};
use std::sync::Arc;
use std::{
ffi::{self, CStr},
os::raw::c_char,
Expand Down Expand Up @@ -60,7 +60,12 @@ pub use self::session::{
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
pub use self::tensor::ArrayExtensions;
pub use self::tensor::{IntoTensorElementType, TensorElementType};
pub use self::value::{Value, ValueRef, ValueType};
pub use self::value::{
DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor, DynTensorRef,
DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence, SequenceRef,
SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker, UpcastableTarget, Value, ValueRef,
ValueRefMut, ValueType, ValueTypeMarker
};

#[cfg(not(all(target_arch = "x86", target_os = "windows")))]
macro_rules! extern_system_fn {
Expand Down Expand Up @@ -198,6 +203,12 @@ macro_rules! ortsys {
(unsafe $method:ident($($n:expr),+ $(,)?)) => {
unsafe { $crate::api().as_ref().$method.unwrap()($($n),+) }
};
($method:ident($($n:expr),+ $(,)?).unwrap()) => {
$crate::error::status_to_result($crate::api().as_ref().$method.unwrap()($($n),+)).unwrap()
};
(unsafe $method:ident($($n:expr),+ $(,)?).unwrap()) => {
$crate::error::status_to_result(unsafe { $crate::api().as_ref().$method.unwrap()($($n),+) }).unwrap()
};
($method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {
$crate::api().as_ref().$method.unwrap()($($n),+);
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
Expand Down
30 changes: 17 additions & 13 deletions src/session/input.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use std::{borrow::Cow, collections::HashMap, ops::Deref};

use crate::{Value, ValueRef};
use crate::{
value::{DynValueTypeMarker, ValueTypeMarker},
Value, ValueRef
};

pub enum SessionInputValue<'v> {
View(ValueRef<'v>),
Owned(Value)
View(ValueRef<'v, DynValueTypeMarker>),
Owned(Value<DynValueTypeMarker>)
}

impl<'v> Deref for SessionInputValue<'v> {
type Target = Value;

fn deref(&self) -> &Self::Target {
match self {
SessionInputValue::View(v) => v,
Expand All @@ -17,14 +21,14 @@ impl<'v> Deref for SessionInputValue<'v> {
}
}

impl<'v> From<ValueRef<'v>> for SessionInputValue<'v> {
fn from(value: ValueRef<'v>) -> Self {
SessionInputValue::View(value)
impl<'v, T: ValueTypeMarker + ?Sized> From<ValueRef<'v, T>> for SessionInputValue<'v> {
fn from(value: ValueRef<'v, T>) -> Self {
SessionInputValue::View(value.into_dyn())
}
}
impl<'v> From<Value> for SessionInputValue<'v> {
fn from(value: Value) -> Self {
SessionInputValue::Owned(value)
impl<'v, T: ValueTypeMarker + ?Sized> From<Value<T>> for SessionInputValue<'v> {
fn from(value: Value<T>) -> Self {
SessionInputValue::Owned(value.into_dyn())
}
}

Expand Down Expand Up @@ -113,13 +117,13 @@ impl<'i, 'v, const N: usize> From<[SessionInputValue<'v>; N]> for SessionInputs<
macro_rules! inputs {
($($v:expr),+ $(,)?) => (
(|| -> $crate::Result<_> {
Ok([$(::std::convert::Into::<$crate::SessionInputValue<'_>>::into(::std::convert::TryInto::<$crate::Value>::try_into($v).map_err($crate::Error::from)?)),+])
Ok([$(::std::convert::Into::<$crate::SessionInputValue<'_>>::into(::std::convert::TryInto::<$crate::DynValue>::try_into($v).map_err($crate::Error::from)?)),+])
})()
);
($($n:expr => $v:expr),+ $(,)?) => (
(|| -> $crate::Result<_> {
Ok(vec![$(
::std::convert::TryInto::<$crate::Value>::try_into($v)
::std::convert::TryInto::<$crate::DynValue>::try_into($v)
.map_err($crate::Error::from)
.map(|v| (::std::borrow::Cow::<str>::from($n), $crate::SessionInputValue::from(v)))?,)+])
})()
Expand All @@ -138,7 +142,7 @@ mod tests {
let arc = Arc::new(v.clone().into_boxed_slice());
let shape = vec![v.len() as i64];

let mut inputs: HashMap<&str, Value> = HashMap::new();
let mut inputs: HashMap<&str, DynTensor> = HashMap::new();
inputs.insert("test", (shape, arc).try_into()?);
let _ = SessionInputs::from(inputs);

Expand All @@ -151,7 +155,7 @@ mod tests {
let arc = Arc::new(v.clone().into_boxed_slice());
let shape = vec![v.len() as i64];

let mut inputs: HashMap<String, Value> = HashMap::new();
let mut inputs: HashMap<String, DynTensor> = HashMap::new();
inputs.insert("test".to_string(), (shape, arc).try_into()?);
let _ = SessionInputs::from(inputs);

Expand Down
16 changes: 8 additions & 8 deletions src/session/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
ops::{Deref, DerefMut, Index}
};

use crate::{Allocator, Value};
use crate::{Allocator, DynValue};

/// The outputs returned by a [`crate::Session`] inference call.
///
Expand All @@ -26,15 +26,15 @@ use crate::{Allocator, Value};
/// ```
#[derive(Debug)]
pub struct SessionOutputs<'s> {
map: BTreeMap<&'s str, Value>,
map: BTreeMap<&'s str, DynValue>,
idxs: Vec<&'s str>,
backing_ptr: Option<(&'s Allocator, *mut c_void)>
}

unsafe impl<'s> Send for SessionOutputs<'s> {}

impl<'s> SessionOutputs<'s> {
pub(crate) fn new(output_names: impl Iterator<Item = &'s str> + Clone, output_values: impl IntoIterator<Item = Value>) -> Self {
pub(crate) fn new(output_names: impl Iterator<Item = &'s str> + Clone, output_values: impl IntoIterator<Item = DynValue>) -> Self {
let map = output_names.clone().zip(output_values).collect();
Self {
map,
Expand All @@ -45,7 +45,7 @@ impl<'s> SessionOutputs<'s> {

pub(crate) fn new_backed(
output_names: impl Iterator<Item = &'s str> + Clone,
output_values: impl IntoIterator<Item = Value>,
output_values: impl IntoIterator<Item = DynValue>,
allocator: &'s Allocator,
backing_ptr: *mut c_void
) -> Self {
Expand Down Expand Up @@ -75,7 +75,7 @@ impl<'s> Drop for SessionOutputs<'s> {
}

impl<'s> Deref for SessionOutputs<'s> {
type Target = BTreeMap<&'s str, Value>;
type Target = BTreeMap<&'s str, DynValue>;

fn deref(&self) -> &Self::Target {
&self.map
Expand All @@ -89,21 +89,21 @@ impl<'s> DerefMut for SessionOutputs<'s> {
}

impl<'s> Index<&str> for SessionOutputs<'s> {
type Output = Value;
type Output = DynValue;
fn index(&self, index: &str) -> &Self::Output {
self.map.get(index).expect("no entry found for key")
}
}

impl<'s> Index<String> for SessionOutputs<'s> {
type Output = Value;
type Output = DynValue;
fn index(&self, index: String) -> &Self::Output {
self.map.get(index.as_str()).expect("no entry found for key")
}
}

impl<'s> Index<usize> for SessionOutputs<'s> {
type Output = Value;
type Output = DynValue;
fn index(&self, index: usize) -> &Self::Output {
self.map.get(&self.idxs[index]).expect("no entry found for key")
}
Expand Down
Loading

0 comments on commit 393f25f

Please sign in to comment.