From d19704e5e375e7693cea924ca465e5b3af8c1743 Mon Sep 17 00:00:00 2001 From: chachaleo Date: Wed, 20 Mar 2024 23:24:41 +0100 Subject: [PATCH] feat: expand --- docs/SUMMARY.md | 1 + docs/framework/compatibility.md | 4 +- .../operators/tensor/tensor.expand.md | 54 ++++++++++++++++++ nodegen/node/expand.py | 32 +++++++++++ src/operators/tensor/core.cairo | 57 +++++++++++++++++++ .../tensor/implementations/tensor_bool.cairo | 4 ++ .../implementations/tensor_complex64.cairo | 4 ++ .../implementations/tensor_fp16x16.cairo | 4 ++ .../implementations/tensor_fp16x16wide.cairo | 4 ++ .../implementations/tensor_fp32x32.cairo | 4 ++ .../implementations/tensor_fp64x64.cairo | 4 ++ .../implementations/tensor_fp8x23.cairo | 4 ++ .../implementations/tensor_fp8x23wide.cairo | 4 ++ .../tensor/implementations/tensor_i32.cairo | 4 ++ .../tensor/implementations/tensor_i8.cairo | 4 ++ .../tensor/implementations/tensor_u32.cairo | 4 ++ src/operators/tensor/manipulation.cairo | 1 + .../tensor/manipulation/expand.cairo | 36 ++++++++++++ tests/nodes.cairo | 2 + tests/nodes/expand_with_broadcast.cairo | 22 +++++++ .../nodes/expand_with_broadcast/input_0.cairo | 17 ++++++ .../expand_with_broadcast/output_0.cairo | 50 ++++++++++++++++ tests/nodes/expand_without_broadcast.cairo | 21 +++++++ .../expand_without_broadcast/input_0.cairo | 16 ++++++ .../expand_without_broadcast/output_0.cairo | 25 ++++++++ 25 files changed, 381 insertions(+), 1 deletion(-) create mode 100644 docs/framework/operators/tensor/tensor.expand.md create mode 100644 nodegen/node/expand.py create mode 100644 src/operators/tensor/manipulation/expand.cairo create mode 100644 tests/nodes/expand_with_broadcast.cairo create mode 100644 tests/nodes/expand_with_broadcast/input_0.cairo create mode 100644 tests/nodes/expand_with_broadcast/output_0.cairo create mode 100644 tests/nodes/expand_without_broadcast.cairo create mode 100644 tests/nodes/expand_without_broadcast/input_0.cairo create mode 100644 tests/nodes/expand_without_broadcast/output_0.cairo diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 477601b37..d01023677 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -159,6 +159,7 @@ * [tensor.blackman_window](framework/operators/tensor/tensor.blackman_window.md) * [tensor.random_uniform_like](framework/operators/tensor/tensor.random_uniform_like.md) * [tensor.label_encoder](framework/operators/tensor/tensor.label_encoder.md) + * [tensor.expand](framework/operators/tensor/tensor.expand.md) * [Neural Network](framework/operators/neural-network/README.md) * [nn.relu](framework/operators/neural-network/nn.relu.md) * [nn.leaky\_relu](framework/operators/neural-network/nn.leaky\_relu.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index f3f84ac3f..3e2a4c849 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -124,7 +124,9 @@ You can see below the list of current supported ONNX Operators: | [HannWindow](operators/tensor/tensor.tensor.hann_window.md) | :white\_check\_mark: | | [HammingWindow](operators/tensor/tensor.tensor.hamming_window.md) | :white\_check\_mark: | | [BlackmanWindow](operators/tensor/tensor.tensor.blackman_window.md) | :white\_check\_mark: | -| [RandomUniformLike](operators/tensor/tensor.tensor.random_uniform_like.md) | :white\_check\_mark: | +| [RandomUniformLike](operators/tensor/tensor.tensor.random_uniform_like.md) | :white\_check\_mark: | | [LabelEncoder](operators/tensor/tensor.label_encoder.md) | :white\_check\_mark: | +| [Expand](operators/tensor/tensor.expand.md) | :white\_check\_mark: | + Current Operators support: **118/156 (75%)** diff --git a/docs/framework/operators/tensor/tensor.expand.md b/docs/framework/operators/tensor/tensor.expand.md new file mode 100644 index 000000000..f0a326727 --- /dev/null +++ b/docs/framework/operators/tensor/tensor.expand.md @@ -0,0 +1,54 @@ +## tensor.expand + +```rust + fn expand(self: @Tensor, Tensor: Span,) -> Tensor; +``` + +Broadcast the input tensor following the given shape and the broadcast rule. The broadcast rule is similar to numpy.array(input) * numpy.ones(shape): Dimensions are right alignment; Two corresponding dimensions must have the same value, or one of them is equal to 1. + +## Args + +* `self`(`@Tensor`) - The input tensor. +* `shape`(`Tensor`) - A 1-D tensor indicates the shape you want to expand to, following the broadcast rule + +## Panics + +* If the shape doesn't follow the broadcast rule. + +## Returns + +A new `Tensor` result of the expansion. + +## Examples + +```rust +use orion::operators::tensor::{FP16x16Tensor}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{U32Tensor}; +use orion::numbers::FP16x16; + +fn test_expand() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(3); + shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + let mut X = TensorTrait::new(shape.span(), data.span()); + + return X.expand(TensorTrait::new(array![3].span(),array![2, 1, 6].span())); + +} + +>>> [[[1. 1. 1. 1. 1. 1.] + [2. 2. 2. 2. 2. 2.] + [3. 3. 3. 3. 3. 3.]] + + [[1. 1. 1. 1. 1. 1.] + [2. 2. 2. 2. 2. 2.] + [3. 3. 3. 3. 3. 3.]]] +``` diff --git a/nodegen/node/expand.py b/nodegen/node/expand.py new file mode 100644 index 000000000..4e60d9b7f --- /dev/null +++ b/nodegen/node/expand.py @@ -0,0 +1,32 @@ +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait + +class Expand(RunAll): + @staticmethod + def expand_with_broadcast() -> None: + shape = [1, 3, 1] + x = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) + + new_shape = [2, 1, 6] + y = x * np.ones(new_shape, dtype=np.float32) + + x = Tensor(Dtype.FP16x16, x.shape, to_fp(x.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "expand_with_broadcast" + make_test([x], y, "input_0.expand(TensorTrait::new(array![3].span(),array![2, 1, 6].span()))", name) + + @staticmethod + def expand_without_broadcast() -> None: + shape = [3, 1] + new_shape = [3, 4] + + x = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) + y = x * np.ones(new_shape, dtype=np.float32) + + x = Tensor(Dtype.FP16x16, x.shape, to_fp(x.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "expand_without_broadcast" + make_test([x], y, "input_0.expand(TensorTrait::new(array![2].span(),array![3, 4].span()))", name) diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 0d21a4de3..e06c11624 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -131,6 +131,7 @@ impl TensorSerde, impl TDrop: Drop> of Serde8Bit conversion of FP32 Input data. /// scatter_nd - The output of the operation is produced by creating a copy of the input data, and then updating its value to values specified by updates at specific index positions specified by indices. Its output shape is the same as the shape of data /// label_encoder - Maps each element in the input tensor to another value. +/// expand - Broadcast the input tensor following the given shape and the broadcast rule. trait TensorTrait { /// # tensor.new /// @@ -5850,6 +5851,62 @@ trait TensorTrait { values: Option>, values_tensor: Option> ) -> Tensor; + /// ## tensor.expand + /// + /// ```rust + /// fn expand(self: @Tensor, Tensor: Span,) -> Tensor; + /// ``` + /// + /// Broadcast the input tensor following the given shape and the broadcast rule. The broadcast rule is similar to numpy.array(input) * numpy.ones(shape): Dimensions are right alignment; Two corresponding dimensions must have the same value, or one of them is equal to 1. + /// + /// ## Args + /// + /// * `self`(`@Tensor`) - The input tensor. + /// * `shape`(`Tensor`) - A 1-D tensor indicates the shape you want to expand to, following the broadcast rule + /// + /// ## Panics + /// + /// * If the shape doesn't follow the broadcast rule. + /// + /// ## Returns + /// + /// A new `Tensor` result of the expansion. + /// + /// ## Examples + /// + /// ```rust + /// use orion::operators::tensor::{FP16x16Tensor}; + /// use orion::operators::tensor::{TensorTrait, Tensor}; + /// use core::array::{ArrayTrait, SpanTrait}; + /// use orion::operators::tensor::{U32Tensor}; + /// use orion::numbers::FP16x16; + /// + /// fn test_expand() -> Tensor { + /// let mut shape = ArrayTrait::::new(); + /// shape.append(1); + /// shape.append(3); + /// shape.append(1); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 65536, sign: false }); + /// data.append(FP16x16 { mag: 131072, sign: false }); + /// data.append(FP16x16 { mag: 196608, sign: false }); + /// let mut X = TensorTrait::new(shape.span(), data.span()); + /// + /// return X.expand(TensorTrait::new(array![3].span(),array![2, 1, 6].span())); + /// + /// } + /// + /// >>> [[[1. 1. 1. 1. 1. 1.] + /// [2. 2. 2. 2. 2. 2.] + /// [3. 3. 3. 3. 3. 3.]] + /// + /// [[1. 1. 1. 1. 1. 1.] + /// [2. 2. 2. 2. 2. 2.] + /// [3. 3. 3. 3. 3. 3.]]] + /// ``` + /// + fn expand(self: @Tensor, shape: Tensor,) -> Tensor; } /// Cf: TensorTrait::new docstring diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index 612a397cc..25bd8f7cc 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -547,6 +547,10 @@ impl BoolTensor of TensorTrait { ) -> Tensor { panic(array!['not supported!']) } + + fn expand(self: @Tensor, shape: Tensor,) -> Tensor { + panic(array!['not supported!']) + } } /// Implements partial equal for two `Tensor` using the `PartialEq` trait. diff --git a/src/operators/tensor/implementations/tensor_complex64.cairo b/src/operators/tensor/implementations/tensor_complex64.cairo index c9c31ae23..75b4b7744 100644 --- a/src/operators/tensor/implementations/tensor_complex64.cairo +++ b/src/operators/tensor/implementations/tensor_complex64.cairo @@ -585,6 +585,10 @@ impl Complex64Tensor of TensorTrait { ) -> Tensor { panic(array!['not supported!']) } + + fn expand(self: @Tensor, shape: Tensor,) -> Tensor { + manipulation::expand::expand(self, shape) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index a37ed0442..a536335d0 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -639,6 +639,10 @@ impl FP16x16Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn expand(self: @Tensor, shape: Tensor,) -> Tensor { + manipulation::expand::expand(self, shape) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index 2003b28ff..06e428050 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -599,6 +599,10 @@ impl FP16x16WTensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn expand(self: @Tensor, shape: Tensor,) -> Tensor { + manipulation::expand::expand(self, shape) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 4870226a1..bd178a863 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -635,6 +635,10 @@ impl FP32x32Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn expand(self: @Tensor, shape: Tensor,) -> Tensor { + manipulation::expand::expand(self, shape) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index 3a7214d18..ca6747b35 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -635,6 +635,10 @@ impl FP64x64Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn expand(self: @Tensor, shape: Tensor,) -> Tensor { + manipulation::expand::expand(self, shape) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index b4a26d749..90b5a697b 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -633,6 +633,10 @@ impl FP8x23Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn expand(self: @Tensor, shape: Tensor,) -> Tensor { + manipulation::expand::expand(self, shape) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index 06a297b69..7a22bdd36 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -576,6 +576,10 @@ impl FP8x23WTensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn expand(self: @Tensor, shape: Tensor,) -> Tensor { + manipulation::expand::expand(self, shape) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index 296876516..84aefb14b 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -599,6 +599,10 @@ impl I32Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn expand(self: @Tensor, shape: Tensor,) -> Tensor { + manipulation::expand::expand(self, shape) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index 42d807c68..3b57d87ef 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -602,6 +602,10 @@ impl I8Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn expand(self: @Tensor, shape: Tensor,) -> Tensor { + manipulation::expand::expand(self, shape) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index efb681a86..94c0a45bc 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -546,6 +546,10 @@ impl U32Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn expand(self: @Tensor, shape: Tensor,) -> Tensor { + manipulation::expand::expand(self, shape) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/manipulation.cairo b/src/operators/tensor/manipulation.cairo index 057e5afad..71ba89f82 100644 --- a/src/operators/tensor/manipulation.cairo +++ b/src/operators/tensor/manipulation.cairo @@ -3,3 +3,4 @@ mod split; mod split_to_sequence; mod reverse_sequence; mod optional; +mod expand; diff --git a/src/operators/tensor/manipulation/expand.cairo b/src/operators/tensor/manipulation/expand.cairo new file mode 100644 index 000000000..95c7ed4b4 --- /dev/null +++ b/src/operators/tensor/manipulation/expand.cairo @@ -0,0 +1,36 @@ +use core::array::ArrayTrait; +use orion::numbers::NumberTrait; +use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; +use orion::operators::tensor::helpers::check_compatibility; +//use orion::operators::nn::helpers::prod; + +/// Cf: TensorTrait::expand docstring +fn expand, +NumberTrait, +Copy, +Drop, +Mul>,>( + X: @Tensor, shape: Tensor, +) -> Tensor { + check_compatibility((*X).shape, shape.data); + + let mut ones = ArrayTrait::new(); + let dim = prod(shape.data); + + let mut i = 0; + while i != dim { + ones.append(NumberTrait::one()); + i += 1; + }; + + return *X * TensorTrait::new(shape.data, ones.span()); +} + +///from orion::operators::nn::helpers::prod; -> delete when nn refactor merged +fn prod, +Copy, +NumberTrait, +TensorTrait, +MulEq,>( + mut a: Span +) -> T { + let mut prod = NumberTrait::one(); + loop { + match a.pop_front() { + Option::Some(v) => { prod *= *v; }, + Option::None => { break prod; } + }; + } +} diff --git a/tests/nodes.cairo b/tests/nodes.cairo index 29bebb762..3de4f2f09 100644 --- a/tests/nodes.cairo +++ b/tests/nodes.cairo @@ -1047,3 +1047,5 @@ mod label_encoder_fp8x23_default; mod label_encoder_i8_default; mod label_encoder_i32_default; mod label_encoder_u32_default; +mod expand_with_broadcast; +mod expand_without_broadcast; diff --git a/tests/nodes/expand_with_broadcast.cairo b/tests/nodes/expand_with_broadcast.cairo new file mode 100644 index 000000000..0a59bc983 --- /dev/null +++ b/tests/nodes/expand_with_broadcast.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::tensor::{TensorTrait, Tensor}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{U32Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_expand_with_broadcast() { + let input_0 = input_0::input_0(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.expand(TensorTrait::new(array![3].span(), array![2, 1, 6].span())); + + assert_eq(y_0, z_0); +} + diff --git a/tests/nodes/expand_with_broadcast/input_0.cairo b/tests/nodes/expand_with_broadcast/input_0.cairo new file mode 100644 index 000000000..ce5ad7c94 --- /dev/null +++ b/tests/nodes/expand_with_broadcast/input_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(3); + shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/expand_with_broadcast/output_0.cairo b/tests/nodes/expand_with_broadcast/output_0.cairo new file mode 100644 index 000000000..ace33e262 --- /dev/null +++ b/tests/nodes/expand_with_broadcast/output_0.cairo @@ -0,0 +1,50 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/expand_without_broadcast.cairo b/tests/nodes/expand_without_broadcast.cairo new file mode 100644 index 000000000..66cf9fd1e --- /dev/null +++ b/tests/nodes/expand_without_broadcast.cairo @@ -0,0 +1,21 @@ +mod input_0; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::tensor::{TensorTrait, Tensor}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{U32Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_expand_without_broadcast() { + let input_0 = input_0::input_0(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.expand(TensorTrait::new(array![2].span(), array![3, 4].span())); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/expand_without_broadcast/input_0.cairo b/tests/nodes/expand_without_broadcast/input_0.cairo new file mode 100644 index 000000000..51a0ec433 --- /dev/null +++ b/tests/nodes/expand_without_broadcast/input_0.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/expand_without_broadcast/output_0.cairo b/tests/nodes/expand_without_broadcast/output_0.cairo new file mode 100644 index 000000000..0b385118e --- /dev/null +++ b/tests/nodes/expand_without_broadcast/output_0.cairo @@ -0,0 +1,25 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + TensorTrait::new(shape.span(), data.span()) +}