Skip to content

Commit

Permalink
test: more Value tests
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Aug 5, 2024
1 parent b30b621 commit b5e9074
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 8 deletions.
33 changes: 26 additions & 7 deletions src/value/impl_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{

use super::{
impl_tensor::{calculate_tensor_size, DynTensor, Tensor},
DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker
DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker
};
use crate::{
error::{Error, Result},
Expand All @@ -31,6 +31,14 @@ impl MapValueTypeMarker for DynMapValueType {
crate::private_impl!();
}

impl DowncastableTarget for DynMapValueType {
fn can_downcast(dtype: &ValueType) -> bool {
matches!(dtype, ValueType::Map { .. })
}

crate::private_impl!();
}

#[derive(Debug)]
pub struct MapValueType<K: IntoTensorElementType + Clone + Hash + Eq, V: IntoTensorElementType + Debug>(PhantomData<(K, V)>);
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> ValueTypeMarker for MapValueType<K, V> {
Expand All @@ -40,6 +48,17 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementT
crate::private_impl!();
}

impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> DowncastableTarget for MapValueType<K, V> {
fn can_downcast(dtype: &ValueType) -> bool {
match dtype {
ValueType::Map { key, value } => *key == K::into_tensor_element_type() && *value == V::into_tensor_element_type(),
_ => false
}
}

crate::private_impl!();
}

pub type DynMap = Value<DynMapValueType>;
pub type Map<K, V> = Value<MapValueType<K, V>>;

Expand Down Expand Up @@ -166,14 +185,14 @@ impl<V: PrimitiveTensorElementType + Debug + Clone + 'static> Value<MapValueType
/// # use std::collections::HashMap;
/// # use ort::Map;
/// # fn main() -> ort::Result<()> {
/// let mut map = HashMap::<i64, f32>::new();
/// map.insert(0, 1.0);
/// map.insert(1, 2.0);
/// map.insert(2, 3.0);
/// let mut map = HashMap::<String, f32>::new();
/// map.insert("one".to_string(), 1.0);
/// map.insert("two".to_string(), 2.0);
/// map.insert("three".to_string(), 3.0);
///
/// let value = Map::<i64, f32>::new(map)?;
/// let value = Map::<String, f32>::new(map)?;
///
/// assert_eq!(*value.extract_map().get(&0).unwrap(), 1.0);
/// assert_eq!(*value.extract_map().get("one").unwrap(), 1.0);
/// # Ok(())
/// # }
/// ```
Expand Down
24 changes: 23 additions & 1 deletion src/value/impl_tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ mod tests {

use ndarray::{ArcArray1, Array1, CowArray};

use crate::{Tensor, TensorElementType, ValueType};
use crate::{Allocator, Tensor, TensorElementType, ValueType};

#[test]
#[cfg(feature = "ndarray")]
Expand Down Expand Up @@ -387,4 +387,26 @@ mod tests {

Ok(())
}

#[test]
fn test_tensor_index() -> crate::Result<()> {
let mut tensor = Tensor::new(&Allocator::default(), [1, 3, 224, 224])?;

assert_eq!(tensor[[0, 2, 42, 42]], 0.0);
tensor[[0, 2, 42, 42]] = 1.0;
assert_eq!(tensor[[0, 2, 42, 42]], 1.0);

for y in 0..224 {
for x in 0..224 {
tensor[[0, 1, y, x]] = -1.0;
}
}
assert_eq!(tensor[[0, 1, 0, 0]], -1.0);
assert_eq!(tensor[[0, 1, 223, 223]], -1.0);

assert_eq!(tensor[[0, 0, 0, 0]], 0.0);
assert_eq!(tensor[[0, 2, 42, 42]], 1.0);

Ok(())
}
}
68 changes: 68 additions & 0 deletions src/value/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -583,3 +583,71 @@ pub(crate) unsafe fn extract_data_type_from_map_info(info_ptr: *const ort_sys::O
value: value_type_sys.into()
})
}

#[cfg(test)]
mod tests {
use super::{DynTensorValueType, Map, Sequence, Tensor, TensorRef, TensorValueType};
use crate::{Allocator, TensorRefMut};

#[test]
fn test_casting_tensor() -> crate::Result<()> {
let tensor: Tensor<i32> = Tensor::from_array((vec![5], vec![1, 2, 3, 4, 5]))?;

let dyn_tensor = tensor.into_dyn();
let mut tensor: Tensor<i32> = dyn_tensor.downcast()?;

{
let dyn_tensor_ref = tensor.view().into_dyn();
let tensor_ref: TensorRef<i32> = dyn_tensor_ref.downcast()?;
assert_eq!(tensor_ref.extract_raw_tensor(), tensor.extract_raw_tensor());
}
{
let dyn_tensor_ref = tensor.view().into_dyn();
let tensor_ref: TensorRef<i32> = dyn_tensor_ref.downcast_ref()?;
assert_eq!(tensor_ref.extract_raw_tensor(), tensor.extract_raw_tensor());
}

// Ensure mutating a TensorRefMut mutates the original tensor.
{
let mut dyn_tensor_ref = tensor.view_mut().into_dyn();
let mut tensor_ref: TensorRefMut<i32> = dyn_tensor_ref.downcast_mut()?;
let (_, data) = tensor_ref.extract_raw_tensor_mut();
data[2] = 42;
}
{
let (_, data) = tensor.extract_raw_tensor_mut();
assert_eq!(data[2], 42);
}

// chain a bunch of up/downcasts
{
let tensor = tensor
.into_dyn()
.downcast::<DynTensorValueType>()?
.into_dyn()
.downcast::<TensorValueType<i32>>()?
.upcast()
.into_dyn();
let tensor = tensor.view();
let tensor = tensor.downcast_ref::<TensorValueType<i32>>()?;
let (_, data) = tensor.extract_raw_tensor();
assert_eq!(data, [1, 2, 42, 4, 5]);
}

Ok(())
}

#[test]
fn test_sequence_map() -> crate::Result<()> {
let map_contents = [("meaning".to_owned(), 42.0), ("pi".to_owned(), std::f32::consts::PI)];
let value = Sequence::new([Map::<String, f32>::new(map_contents)?])?;

for map in value.extract_sequence(&Allocator::default()) {
let map = map.extract_map();
assert_eq!(map["meaning"], 42.0);
assert_eq!(map["pi"], std::f32::consts::PI);
}

Ok(())
}
}

0 comments on commit b5e9074

Please sign in to comment.