From 7f626ae66d9bd8a20a61fb85c6741f6d71d9e31d Mon Sep 17 00:00:00 2001 From: "Carson M." <carson@pyke.io> Date: Thu, 26 Dec 2024 16:04:11 -0600 Subject: [PATCH] feat: accept creating `TensorRef` from `Arc<[T]>` --- src/value/impl_tensor/create.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/value/impl_tensor/create.rs b/src/value/impl_tensor/create.rs index c0e27171..b2cb7cd6 100644 --- a/src/value/impl_tensor/create.rs +++ b/src/value/impl_tensor/create.rs @@ -620,6 +620,14 @@ impl<T: Clone + 'static, D: ToDimensions> OwnedTensorArrayData<T> for (D, Box<[T } } +impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<[T]>) { + fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> { + let shape = self.0.to_dimensions(Some(self.1.len()))?; + let data = &*self.1; + Ok((shape, data, Some(Box::new(self.1.clone())))) + } +} + impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<Box<[T]>>) { fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> { let shape = self.0.to_dimensions(Some(self.1.len()))?;