Skip to content

Commit

Permalink
feat: expand
Browse files Browse the repository at this point in the history
  • Loading branch information
chachaleo committed Mar 20, 2024
1 parent fb6f4a0 commit d19704e
Show file tree
Hide file tree
Showing 25 changed files with 381 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
* [tensor.blackman_window](framework/operators/tensor/tensor.blackman_window.md)
* [tensor.random_uniform_like](framework/operators/tensor/tensor.random_uniform_like.md)
* [tensor.label_encoder](framework/operators/tensor/tensor.label_encoder.md)
* [tensor.expand](framework/operators/tensor/tensor.expand.md)
* [Neural Network](framework/operators/neural-network/README.md)
* [nn.relu](framework/operators/neural-network/nn.relu.md)
* [nn.leaky\_relu](framework/operators/neural-network/nn.leaky\_relu.md)
Expand Down
4 changes: 3 additions & 1 deletion docs/framework/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ You can see below the list of current supported ONNX Operators:
| [HannWindow](operators/tensor/tensor.tensor.hann_window.md) | :white\_check\_mark: |
| [HammingWindow](operators/tensor/tensor.tensor.hamming_window.md) | :white\_check\_mark: |
| [BlackmanWindow](operators/tensor/tensor.tensor.blackman_window.md) | :white\_check\_mark: |
| [RandomUniformLike](operators/tensor/tensor.tensor.random_uniform_like.md) | :white\_check\_mark: |
| [RandomUniformLike](operators/tensor/tensor.tensor.random_uniform_like.md) | :white\_check\_mark: |
| [LabelEncoder](operators/tensor/tensor.label_encoder.md) | :white\_check\_mark: |
| [Expand](operators/tensor/tensor.expand.md) | :white\_check\_mark: |


Current Operators support: **118/156 (75%)**
54 changes: 54 additions & 0 deletions docs/framework/operators/tensor/tensor.expand.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
## tensor.expand

```rust
fn expand(self: @Tensor<T>, Tensor: Span<usize>,) -> Tensor<T>;
```

Broadcast the input tensor following the given shape and the broadcast rule. The broadcast rule is similar to numpy.array(input) * numpy.ones(shape): Dimensions are right alignment; Two corresponding dimensions must have the same value, or one of them is equal to 1.

## Args

* `self`(`@Tensor<T>`) - The input tensor.
* `shape`(`Tensor<usize>`) - A 1-D tensor indicates the shape you want to expand to, following the broadcast rule

## Panics

* If the shape doesn't follow the broadcast rule.

## Returns

A new `Tensor<T>` result of the expansion.

## Examples

```rust
use orion::operators::tensor::{FP16x16Tensor};
use orion::operators::tensor::{TensorTrait, Tensor};
use core::array::{ArrayTrait, SpanTrait};
use orion::operators::tensor::{U32Tensor};
use orion::numbers::FP16x16;

fn test_expand() -> Tensor<FP16x16> {
let mut shape = ArrayTrait::<usize>::new();
shape.append(1);
shape.append(3);
shape.append(1);

let mut data = ArrayTrait::new();
data.append(FP16x16 { mag: 65536, sign: false });
data.append(FP16x16 { mag: 131072, sign: false });
data.append(FP16x16 { mag: 196608, sign: false });
let mut X = TensorTrait::new(shape.span(), data.span());

return X.expand(TensorTrait::new(array![3].span(),array![2, 1, 6].span()));

}

>>> [[[1. 1. 1. 1. 1. 1.]
[2. 2. 2. 2. 2. 2.]
[3. 3. 3. 3. 3. 3.]]

[[1. 1. 1. 1. 1. 1.]
[2. 2. 2. 2. 2. 2.]
[3. 3. 3. 3. 3. 3.]]]
```
32 changes: 32 additions & 0 deletions nodegen/node/expand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np
from nodegen.node import RunAll
from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait

class Expand(RunAll):
@staticmethod
def expand_with_broadcast() -> None:
shape = [1, 3, 1]
x = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)

new_shape = [2, 1, 6]
y = x * np.ones(new_shape, dtype=np.float32)

x = Tensor(Dtype.FP16x16, x.shape, to_fp(x.flatten(), FixedImpl.FP16x16))
y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16))

name = "expand_with_broadcast"
make_test([x], y, "input_0.expand(TensorTrait::new(array![3].span(),array![2, 1, 6].span()))", name)

@staticmethod
def expand_without_broadcast() -> None:
shape = [3, 1]
new_shape = [3, 4]

x = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape)
y = x * np.ones(new_shape, dtype=np.float32)

x = Tensor(Dtype.FP16x16, x.shape, to_fp(x.flatten(), FixedImpl.FP16x16))
y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16))

name = "expand_without_broadcast"
make_test([x], y, "input_0.expand(TensorTrait::new(array![2].span(),array![3, 4].span()))", name)
57 changes: 57 additions & 0 deletions src/operators/tensor/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ impl TensorSerde<T, impl TSerde: Serde<T>, impl TDrop: Drop<T>> of Serde<Tensor<
/// dynamic_quantize_linear - Computes the Scale, Zero Point and FP32->8Bit conversion of FP32 Input data.
/// scatter_nd - The output of the operation is produced by creating a copy of the input data, and then updating its value to values specified by updates at specific index positions specified by indices. Its output shape is the same as the shape of data
/// label_encoder - Maps each element in the input tensor to another value.
/// expand - Broadcast the input tensor following the given shape and the broadcast rule.
trait TensorTrait<T> {
/// # tensor.new
///
Expand Down Expand Up @@ -5850,6 +5851,62 @@ trait TensorTrait<T> {
values: Option<Span<T>>,
values_tensor: Option<Tensor<T>>
) -> Tensor<T>;
/// ## tensor.expand
///
/// ```rust
/// fn expand(self: @Tensor<T>, Tensor: Span<usize>,) -> Tensor<T>;
/// ```
///
/// Broadcast the input tensor following the given shape and the broadcast rule. The broadcast rule is similar to numpy.array(input) * numpy.ones(shape): Dimensions are right alignment; Two corresponding dimensions must have the same value, or one of them is equal to 1.
///
/// ## Args
///
/// * `self`(`@Tensor<T>`) - The input tensor.
/// * `shape`(`Tensor<usize>`) - A 1-D tensor indicates the shape you want to expand to, following the broadcast rule
///
/// ## Panics
///
/// * If the shape doesn't follow the broadcast rule.
///
/// ## Returns
///
/// A new `Tensor<T>` result of the expansion.
///
/// ## Examples
///
/// ```rust
/// use orion::operators::tensor::{FP16x16Tensor};
/// use orion::operators::tensor::{TensorTrait, Tensor};
/// use core::array::{ArrayTrait, SpanTrait};
/// use orion::operators::tensor::{U32Tensor};
/// use orion::numbers::FP16x16;
///
/// fn test_expand() -> Tensor<FP16x16> {
/// let mut shape = ArrayTrait::<usize>::new();
/// shape.append(1);
/// shape.append(3);
/// shape.append(1);
///
/// let mut data = ArrayTrait::new();
/// data.append(FP16x16 { mag: 65536, sign: false });
/// data.append(FP16x16 { mag: 131072, sign: false });
/// data.append(FP16x16 { mag: 196608, sign: false });
/// let mut X = TensorTrait::new(shape.span(), data.span());
///
/// return X.expand(TensorTrait::new(array![3].span(),array![2, 1, 6].span()));
///
/// }
///
/// >>> [[[1. 1. 1. 1. 1. 1.]
/// [2. 2. 2. 2. 2. 2.]
/// [3. 3. 3. 3. 3. 3.]]
///
/// [[1. 1. 1. 1. 1. 1.]
/// [2. 2. 2. 2. 2. 2.]
/// [3. 3. 3. 3. 3. 3.]]]
/// ```
///
fn expand(self: @Tensor<T>, shape: Tensor<usize>,) -> Tensor<T>;
}

/// Cf: TensorTrait::new docstring
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_bool.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,10 @@ impl BoolTensor of TensorTrait<bool> {
) -> Tensor<bool> {
panic(array!['not supported!'])
}

fn expand(self: @Tensor<bool>, shape: Tensor<usize>,) -> Tensor<bool> {
panic(array!['not supported!'])
}
}

/// Implements partial equal for two `Tensor<bool>` using the `PartialEq` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_complex64.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,10 @@ impl Complex64Tensor of TensorTrait<complex64> {
) -> Tensor<complex64> {
panic(array!['not supported!'])
}

fn expand(self: @Tensor<complex64>, shape: Tensor<usize>,) -> Tensor<complex64> {
manipulation::expand::expand(self, shape)
}
}

/// Implements addition for `Tensor<complex64>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_fp16x16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,10 @@ impl FP16x16Tensor of TensorTrait<FP16x16> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn expand(self: @Tensor<FP16x16>, shape: Tensor<usize>,) -> Tensor<FP16x16> {
manipulation::expand::expand(self, shape)
}
}

/// Implements addition for `Tensor<FP16x16>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_fp16x16wide.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,10 @@ impl FP16x16WTensor of TensorTrait<FP16x16W> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn expand(self: @Tensor<FP16x16W>, shape: Tensor<usize>,) -> Tensor<FP16x16W> {
manipulation::expand::expand(self, shape)
}
}

/// Implements addition for `Tensor<FP16x16W>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_fp32x32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ impl FP32x32Tensor of TensorTrait<FP32x32> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn expand(self: @Tensor<FP32x32>, shape: Tensor<usize>,) -> Tensor<FP32x32> {
manipulation::expand::expand(self, shape)
}
}

/// Implements addition for `Tensor<FP32x32>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_fp64x64.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ impl FP64x64Tensor of TensorTrait<FP64x64> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn expand(self: @Tensor<FP64x64>, shape: Tensor<usize>,) -> Tensor<FP64x64> {
manipulation::expand::expand(self, shape)
}
}

/// Implements addition for `Tensor<FP64x64>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_fp8x23.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,10 @@ impl FP8x23Tensor of TensorTrait<FP8x23> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn expand(self: @Tensor<FP8x23>, shape: Tensor<usize>,) -> Tensor<FP8x23> {
manipulation::expand::expand(self, shape)
}
}

/// Implements addition for `Tensor<FP8x23>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_fp8x23wide.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,10 @@ impl FP8x23WTensor of TensorTrait<FP8x23W> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn expand(self: @Tensor<FP8x23W>, shape: Tensor<usize>,) -> Tensor<FP8x23W> {
manipulation::expand::expand(self, shape)
}
}

/// Implements addition for `Tensor<FP8x23W>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_i32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,10 @@ impl I32Tensor of TensorTrait<i32> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn expand(self: @Tensor<i32>, shape: Tensor<usize>,) -> Tensor<i32> {
manipulation::expand::expand(self, shape)
}
}

/// Implements addition for `Tensor<i32>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_i8.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,10 @@ impl I8Tensor of TensorTrait<i8> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn expand(self: @Tensor<i8>, shape: Tensor<usize>,) -> Tensor<i8> {
manipulation::expand::expand(self, shape)
}
}

/// Implements addition for `Tensor<i8>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_u32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,10 @@ impl U32Tensor of TensorTrait<u32> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn expand(self: @Tensor<u32>, shape: Tensor<usize>,) -> Tensor<u32> {
manipulation::expand::expand(self, shape)
}
}

/// Implements addition for `Tensor<u32>` using the `Add` trait.
Expand Down
1 change: 1 addition & 0 deletions src/operators/tensor/manipulation.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ mod split;
mod split_to_sequence;
mod reverse_sequence;
mod optional;
mod expand;
36 changes: 36 additions & 0 deletions src/operators/tensor/manipulation/expand.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use core::array::ArrayTrait;
use orion::numbers::NumberTrait;
use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor};
use orion::operators::tensor::helpers::check_compatibility;
//use orion::operators::nn::helpers::prod;

/// Cf: TensorTrait::expand docstring
fn expand<T, MAG, +TensorTrait<T>, +NumberTrait<T, MAG>, +Copy<T>, +Drop<T>, +Mul<Tensor<T>>,>(
X: @Tensor<T>, shape: Tensor<usize>,
) -> Tensor<T> {
check_compatibility((*X).shape, shape.data);

let mut ones = ArrayTrait::new();
let dim = prod(shape.data);

let mut i = 0;
while i != dim {
ones.append(NumberTrait::one());
i += 1;
};

return *X * TensorTrait::new(shape.data, ones.span());
}

///from orion::operators::nn::helpers::prod; -> delete when nn refactor merged
fn prod<T, MAG, +Drop<T>, +Copy<T>, +NumberTrait<T, MAG>, +TensorTrait<T>, +MulEq<T>,>(
mut a: Span<T>
) -> T {
let mut prod = NumberTrait::one();
loop {
match a.pop_front() {
Option::Some(v) => { prod *= *v; },
Option::None => { break prod; }
};
}
}
2 changes: 2 additions & 0 deletions tests/nodes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1047,3 +1047,5 @@ mod label_encoder_fp8x23_default;
mod label_encoder_i8_default;
mod label_encoder_i32_default;
mod label_encoder_u32_default;
mod expand_with_broadcast;
mod expand_without_broadcast;
22 changes: 22 additions & 0 deletions tests/nodes/expand_with_broadcast.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
mod input_0;
mod output_0;


use orion::utils::{assert_eq, assert_seq_eq};
use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd};
use orion::operators::tensor::FP16x16TensorPartialEq;
use orion::operators::tensor::{TensorTrait, Tensor};
use core::array::{ArrayTrait, SpanTrait};
use orion::operators::tensor::{U32Tensor};

#[test]
#[available_gas(2000000000)]
fn test_expand_with_broadcast() {
let input_0 = input_0::input_0();
let z_0 = output_0::output_0();

let y_0 = input_0.expand(TensorTrait::new(array![3].span(), array![2, 1, 6].span()));

assert_eq(y_0, z_0);
}

17 changes: 17 additions & 0 deletions tests/nodes/expand_with_broadcast/input_0.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use core::array::{ArrayTrait, SpanTrait};
use orion::operators::tensor::{TensorTrait, Tensor};
use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd};
use orion::numbers::{FixedTrait, FP16x16};

fn input_0() -> Tensor<FP16x16> {
let mut shape = ArrayTrait::<usize>::new();
shape.append(1);
shape.append(3);
shape.append(1);

let mut data = ArrayTrait::new();
data.append(FP16x16 { mag: 65536, sign: false });
data.append(FP16x16 { mag: 131072, sign: false });
data.append(FP16x16 { mag: 196608, sign: false });
TensorTrait::new(shape.span(), data.span())
}
Loading

0 comments on commit d19704e

Please sign in to comment.