From 7f626ae66d9bd8a20a61fb85c6741f6d71d9e31d Mon Sep 17 00:00:00 2001
From: "Carson M." <carson@pyke.io>
Date: Thu, 26 Dec 2024 16:04:11 -0600
Subject: [PATCH] feat: accept creating `TensorRef` from `Arc<[T]>`

---
 src/value/impl_tensor/create.rs | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/src/value/impl_tensor/create.rs b/src/value/impl_tensor/create.rs
index c0e27171..b2cb7cd6 100644
--- a/src/value/impl_tensor/create.rs
+++ b/src/value/impl_tensor/create.rs
@@ -620,6 +620,14 @@ impl<T: Clone + 'static, D: ToDimensions> OwnedTensorArrayData<T> for (D, Box<[T
 	}
 }
 
+impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<[T]>) {
+	fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
+		let shape = self.0.to_dimensions(Some(self.1.len()))?;
+		let data = &*self.1;
+		Ok((shape, data, Some(Box::new(self.1.clone()))))
+	}
+}
+
 impl<T: Clone + 'static, D: ToDimensions> TensorArrayData<T> for (D, Arc<Box<[T]>>) {
 	fn ref_parts(&self) -> Result<(Vec<i64>, &[T], Option<Box<dyn Any>>)> {
 		let shape = self.0.to_dimensions(Some(self.1.len()))?;