From 5bd2405c13b1409d8af7a3faab72f88ca4adca44 Mon Sep 17 00:00:00 2001 From: chachaleo Date: Wed, 25 Oct 2023 09:36:22 +0700 Subject: [PATCH] modification transpose 2D --- src/operators/tensor/linalg/transpose.cairo | 38 +++++++++++++++++++ tests/src/lib.cairo | 1 + tests/src/operators.cairo | 1 + tests/src/operators/transpose_test.cairo | 41 +++++++++++++++++++++ 4 files changed, 81 insertions(+) create mode 100644 tests/src/operators.cairo create mode 100644 tests/src/operators/transpose_test.cairo diff --git a/src/operators/tensor/linalg/transpose.cairo b/src/operators/tensor/linalg/transpose.cairo index dc972e9a3..74a5a36fc 100644 --- a/src/operators/tensor/linalg/transpose.cairo +++ b/src/operators/tensor/linalg/transpose.cairo @@ -14,6 +14,10 @@ fn transpose, impl TCopy: Copy, impl TDrop: D assert((*self.shape).len() > 1, 'cannot transpose a 1D tensor'); assert(axes.len() == (*self.shape).len(), 'shape and axes length unequal'); + if (*self.shape).len() == 2 { + return transpose2D(@(*self)); + } + let output_shape = permutation_output_shape(*self.shape, axes); let output_data_len = len_from_shape(output_shape); @@ -47,3 +51,37 @@ fn transpose, impl TCopy: Copy, impl TDrop: D return TensorTrait::new(output_shape, output_data.span()); } + + +fn transpose2D, impl TCopy: Copy, impl TDrop: Drop>( + self: @Tensor +) -> Tensor { + assert((*self.shape).len() == 2, 'transpose a 2D tensor'); + + let mut output_data = ArrayTrait::new(); + let mut output_shape = ArrayTrait::new(); + + let n = *self.shape[0]; + let m = *self.shape[1]; + + output_shape.append(m); + output_shape.append(n); + + let mut j: usize = 0; + loop { + if j == m { + break (); + } + let mut i = 0; + loop { + if i == n { + break (); + } + output_data.append(*(*self.data)[i * m + j]); + i += 1; + }; + j += 1; + }; + + return TensorTrait::new(output_shape.span(), output_data.span()); +} diff --git a/tests/src/lib.cairo b/tests/src/lib.cairo index 081ff4807..0d88c1101 100644 --- a/tests/src/lib.cairo +++ b/tests/src/lib.cairo @@ -4,3 +4,4 @@ mod tensor_core; mod nodes; mod helpers; mod ml; +mod operators; diff --git a/tests/src/operators.cairo b/tests/src/operators.cairo new file mode 100644 index 000000000..3c2ffc47b --- /dev/null +++ b/tests/src/operators.cairo @@ -0,0 +1 @@ +mod transpose_test; \ No newline at end of file diff --git a/tests/src/operators/transpose_test.cairo b/tests/src/operators/transpose_test.cairo new file mode 100644 index 000000000..de9b048f6 --- /dev/null +++ b/tests/src/operators/transpose_test.cairo @@ -0,0 +1,41 @@ +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; +use debug::PrintTrait; + + +#[test] +#[available_gas(200000000000)] +fn transpose_test_shape() { + let tensor = TensorTrait::::new( + shape: array![4, 2].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7].span(), + ); + + let result = tensor.transpose(axes: array![1, 0].span()); + assert(result.shape == array![2, 4].span(), 'wrong dim'); +} + +#[test] +#[available_gas(200000000000)] +fn transpose_test_values() { + let tensor = TensorTrait::::new( + shape: array![4, 2].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7].span(), + ); + + let result = tensor.transpose(axes: array![1, 0].span()); + assert(result.data == array![0, 2, 4, 6, 1, 3, 5, 7].span(), 'wrong data'); +} + + +#[test] +#[available_gas(200000000000)] +fn transpose_test_3D() { + let tensor = TensorTrait::::new( + shape: array![2, 2, 2].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7].span(), + ); + + let result = tensor.transpose(axes: array![1, 2, 0].span()); + + assert(result.shape == array![2, 2, 2].span(), 'wrong shape'); + assert(result.data == array![0, 4, 1, 5, 2, 6, 3, 7].span(), 'wrong data'); +} +