From 2973dd1f30506939a9a9c7b5243137ec7efbdadc Mon Sep 17 00:00:00 2001 From: chachaleo Date: Tue, 2 Apr 2024 22:53:07 +0200 Subject: [PATCH] feat: momentum --- docs/SUMMARY.md | 1 + docs/framework/compatibility.md | 2 + .../machine-learning/tree-ensemble/README.md | 3 +- .../tree-ensemble/tree_ensemble.predict.md | 36 +- .../operators/tensor/tensor.momentum.md | 79 ++ nodegen/node/momentum.py | 159 +++ src/operators/matrix.cairo | 319 ++--- src/operators/ml/svm/core.cairo | 9 +- src/operators/ml/svm/svm_classifier.cairo | 228 ++-- src/operators/ml/svm/svm_regressor.cairo | 63 +- src/operators/nn/functional/col2im.cairo | 2 +- src/operators/nn/functional/conv.cairo | 1057 +++++++++-------- src/operators/nn/functional/grid_sample.cairo | 281 ++--- src/operators/tensor.cairo | 1 + src/operators/tensor/core.cairo | 89 +- src/operators/tensor/helpers.cairo | 82 +- .../tensor/implementations/tensor_bool.cairo | 16 +- .../implementations/tensor_complex64.cairo | 21 +- .../implementations/tensor_fp16x16.cairo | 38 +- .../implementations/tensor_fp16x16wide.cairo | 39 +- .../implementations/tensor_fp32x32.cairo | 38 +- .../implementations/tensor_fp64x64.cairo | 38 +- .../implementations/tensor_fp8x23.cairo | 16 +- .../implementations/tensor_fp8x23wide.cairo | 38 +- .../tensor/implementations/tensor_i32.cairo | 34 +- .../tensor/implementations/tensor_i8.cairo | 34 +- .../tensor/implementations/tensor_u32.cairo | 34 +- src/operators/tensor/linalg/transpose.cairo | 11 +- src/operators/tensor/manipulation/split.cairo | 135 ++- .../manipulation/split_to_sequence.cairo | 2 +- src/operators/tensor/math/cumsum.cairo | 146 +-- src/operators/tensor/math/gather_nd.cairo | 9 +- .../tensor/math/layer_normalization.cairo | 3 +- src/operators/tensor/math/less_equal.cairo | 7 +- src/operators/tensor/math/max.cairo | 45 +- src/operators/tensor/math/min.cairo | 45 +- src/operators/tensor/math/range.cairo | 11 +- src/operators/tensor/math/reduce_l1.cairo | 7 +- src/operators/tensor/math/resize.cairo | 296 ++--- src/operators/tensor/preview_training.cairo | 1 + .../tensor/preview_training/momentum.cairo | 145 +++ .../tensor/quantization/qlinear_matmul.cairo | 28 +- tests/lib.cairo | 1 - tests/nodes.cairo | 2 + tests/nodes/gather_elements_axis1.cairo | 2 +- tests/nodes/gather_elements_axis2.cairo | 2 +- tests/nodes/gather_elements_default.cairo | 2 +- .../gather_elements_negative_indices.cairo | 2 +- tests/nodes/gather_fp16x16_3d_axis1.cairo | 2 +- tests/nodes/gather_fp16x16_3d_axis2.cairo | 2 +- tests/nodes/gather_fp16x16_3d_default.cairo | 2 +- tests/nodes/gather_negative_axis.cairo | 2 +- tests/nodes/gather_negative_indices.cairo | 2 +- tests/nodes/momentum_nesterov.cairo | 33 + tests/nodes/momentum_nesterov/input_0.cairo | 18 + tests/nodes/momentum_nesterov/input_1.cairo | 17 + tests/nodes/momentum_nesterov/output_0.cairo | 28 + tests/nodes/momentum_standard.cairo | 33 + tests/nodes/momentum_standard/input_0.cairo | 18 + tests/nodes/momentum_standard/input_1.cairo | 17 + tests/nodes/momentum_standard/output_0.cairo | 28 + tests/nodes/reshape_reduced_dims.cairo | 2 +- tests/nodes/reshape_reordered_all_dims.cairo | 2 +- tests/nodes/reshape_reordered_last_dims.cairo | 2 +- 64 files changed, 2431 insertions(+), 1436 deletions(-) create mode 100644 docs/framework/operators/tensor/tensor.momentum.md create mode 100644 nodegen/node/momentum.py create mode 100644 src/operators/tensor/preview_training.cairo create mode 100644 src/operators/tensor/preview_training/momentum.cairo create mode 100644 tests/nodes/momentum_nesterov.cairo create mode 100644 tests/nodes/momentum_nesterov/input_0.cairo create mode 100644 tests/nodes/momentum_nesterov/input_1.cairo create mode 100644 tests/nodes/momentum_nesterov/output_0.cairo create mode 100644 tests/nodes/momentum_standard.cairo create mode 100644 tests/nodes/momentum_standard/input_0.cairo create mode 100644 tests/nodes/momentum_standard/input_1.cairo create mode 100644 tests/nodes/momentum_standard/output_0.cairo diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 477601b37..a11d31a7e 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -159,6 +159,7 @@ * [tensor.blackman_window](framework/operators/tensor/tensor.blackman_window.md) * [tensor.random_uniform_like](framework/operators/tensor/tensor.random_uniform_like.md) * [tensor.label_encoder](framework/operators/tensor/tensor.label_encoder.md) + * [tensor.momentum](framework/operators/tensor/tensor.momentum.md) * [Neural Network](framework/operators/neural-network/README.md) * [nn.relu](framework/operators/neural-network/nn.relu.md) * [nn.leaky\_relu](framework/operators/neural-network/nn.leaky\_relu.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index f3f84ac3f..4fd9979c7 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -126,5 +126,7 @@ You can see below the list of current supported ONNX Operators: | [BlackmanWindow](operators/tensor/tensor.tensor.blackman_window.md) | :white\_check\_mark: | | [RandomUniformLike](operators/tensor/tensor.tensor.random_uniform_like.md) | :white\_check\_mark: | | [LabelEncoder](operators/tensor/tensor.label_encoder.md) | :white\_check\_mark: | +| [Momentum](operators/tensor/tensor.momentum.md) | :white\_check\_mark: | + Current Operators support: **118/156 (75%)** diff --git a/docs/framework/operators/machine-learning/tree-ensemble/README.md b/docs/framework/operators/machine-learning/tree-ensemble/README.md index 26fcfb205..dc3c06bea 100644 --- a/docs/framework/operators/machine-learning/tree-ensemble/README.md +++ b/docs/framework/operators/machine-learning/tree-ensemble/README.md @@ -19,4 +19,5 @@ Orion supports currently only fixed point data types for `TreeEnsembleTrait`. | function | description | | --- | --- | -| [`tree_ensemble.predict`](tree_ensemble.predict.md) | Returns the regressed values for each input in a batch. | \ No newline at end of file +| [`tree_ensemble.predict`](tree_ensemble.predict.md) | Returns the regressed values for each input in a batch. | + diff --git a/docs/framework/operators/machine-learning/tree-ensemble/tree_ensemble.predict.md b/docs/framework/operators/machine-learning/tree-ensemble/tree_ensemble.predict.md index a7f97e96d..d1ab33641 100644 --- a/docs/framework/operators/machine-learning/tree-ensemble/tree_ensemble.predict.md +++ b/docs/framework/operators/machine-learning/tree-ensemble/tree_ensemble.predict.md @@ -2,23 +2,23 @@ ```rust fn predict(X: @Tensor, - nodes_splits: Tensor, - nodes_featureids: Span, - nodes_modes: Span, - nodes_truenodeids: Span, - nodes_falsenodeids: Span, - nodes_trueleafs: Span, - nodes_falseleafs: Span, - leaf_targetids: Span, - leaf_weights: Tensor, - tree_roots: Span, - post_transform: POST_TRANSFORM, - aggregate_function: AGGREGATE_FUNCTION, - nodes_hitrates: Option>, - nodes_missing_value_tracks_true: Option>, - membership_values: Option>, - n_targets: usize - ) -> MutMatrix::; + nodes_splits: Tensor, + nodes_featureids: Span, + nodes_modes: Span, + nodes_truenodeids: Span, + nodes_falsenodeids: Span, + nodes_trueleafs: Span, + nodes_falseleafs: Span, + leaf_targetids: Span, + leaf_weights: Tensor, + tree_roots: Span, + post_transform: POST_TRANSFORM, + aggregate_function: AGGREGATE_FUNCTION, + nodes_hitrates: Option>, + nodes_missing_value_tracks_true: Option>, + membership_values: Option>, + n_targets: usize + ) -> MutMatrix::; ``` Tree Ensemble operator. Returns the regressed values for each input in a batch. Inputs have dimensions [N, F] where N is the input batch size and F is the number of input features. Outputs have dimensions [N, num_targets] where N is the batch size and num_targets is the number of targets, which is a configurable attribute. @@ -50,7 +50,7 @@ Tree Ensemble operator. Returns the regressed values for each input in a batch. ## Type Constraints -`TreeEnsembleClassifier` and `X` must be fixed points +`T` must be fixed point ## Examples diff --git a/docs/framework/operators/tensor/tensor.momentum.md b/docs/framework/operators/tensor/tensor.momentum.md new file mode 100644 index 000000000..11030e03b --- /dev/null +++ b/docs/framework/operators/tensor/tensor.momentum.md @@ -0,0 +1,79 @@ +# tensor.momentum + +```rust +fn momentum( + r: T, t: T, inputs: @Tensor, alpha: T, beta: T, mode: MODE, norm_coefficient: T, +) -> (Tensor, Tensor); +``` + +Compute one iteration of stochastic gradient update with momentum. + +## Args + +* `r`(`T`) - The learning rate. +* `i`(`T`) - Update count of "X". +* `inputs`(`@Tensor`) - It sequentially contains the current values of optimized tensors, then their gradient tensors, and finally their momentum tensors. For example, if two tensors "X_1" and "X_2" are optimized, The expected input list would be ["X_1", "X_2", gradient of "X_1", gradient of "X_2", momentum of "X_1", momentum of "X_2"]. +* `alpha`(`T`) - The decay factor of momentum. +* `beta`(`T`) - The coefficient of gradient in computing new momentum. +* `mode`(`MODE`) - Its value should be either "nesterov" or "standard". The value "nesterov" leads to the use of Nesterov's momentum while "standard" invokes stochastic gradient method using standard momentum +* `norm_coefficient`(`T`) - Coefficient of 0.5 * norm_coefficient * ||X||^2. +## Returns + +Two `Tensor` containing the new values of optimized tensors and then the new values of their momentum tensors. + +## Type Constraints + +* `T` in (`Tensor`, `Tensor`, `Tensor`, `tensor,`) + +## Examples + +```rust +use array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{FP16x16Tensor}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::preview_training::momentum::MODE; + +fn example_momentum() -> (Tensor, Tensor){ + let mut shape = ArrayTrait::::new(); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 183500, sign: false }); + data.append(FP16x16 { mag: 61603, sign: true }); + data.append(FP16x16 { mag: 163840, sign: true }); + data.append(FP16x16 { mag: 111411, sign: false }); + data.append(FP16x16 { mag: 235929, sign: false }); + let mut X = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65, sign: false }); + data.append(FP16x16 { mag: 62259, sign: false }); + data.append(FP16x16 { mag: 6553, sign: false }); + let param = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 74211, sign: false }); + data.append(FP16x16 { mag: 177453, sign: false }); + let expected_output = TensorTrait::new(shape.span(), data.span()); + + + return TensorTrait::momentum( + FP16x16 { mag: 6553, sign: false }, + FP16x16 { mag: 0, sign: false }, + @X, + *param.data.at(1), + *param.data.at(2), + MODE::STANDARD, + *param.data.at(0), + ); +} +>>> ([1.13238 2.70772],[0.67620003 0.9227998 ]) + +``` diff --git a/nodegen/node/momentum.py b/nodegen/node/momentum.py new file mode 100644 index 000000000..e135b4b02 --- /dev/null +++ b/nodegen/node/momentum.py @@ -0,0 +1,159 @@ +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait +from typing import List + +import numpy as np + +#from onnx.reference.ops.op_resize import _get_all_coords + +def _run1( r, t, x, g, v, mode="standard", norm_coefficient=None, alpha=None, beta=None): # type: ignore + if mode == "standard": + x_new, v_new = _apply_momentum(r, t, x, g, v, norm_coefficient, alpha, beta) + else: + x_new, v_new = _apply_nesterov(r, t, x, g, v, norm_coefficient, alpha, beta) + return x_new, v_new + +def _apply_momentum(r, t, x, g, v, norm_coefficient, alpha, beta): # type: ignore + # Add gradient of regularization term. + g_regularized = norm_coefficient * x + g + # Coefficient of gradient should be 1 at the first iteration. + beta_adjusted = beta if t > 0 else 1 + # Update momentum. + v_new = alpha * v + beta_adjusted * g_regularized + # Apply SG with momentum update rule. + x_new = x - r * v_new + return x_new, v_new + + +def _apply_nesterov(r, t, x, g, v, norm_coefficient, alpha, beta): # type: ignore + # Add gradient of regularization term. + g_regularized = norm_coefficient * x + g + # Coefficient of gradient should be 1 at the first iteration. + beta_adjusted = beta if t > 0 else 1 + # Update momentum. + v_new = alpha * v + beta_adjusted * g_regularized + # Apply Nesterov with momentum update rule. + x_new = x - r * (g_regularized + alpha * v_new) + return x_new, v_new + +def momentum(*data, alpha=None, beta=None, mode=None, norm_coefficient=None): # type: ignore + if len(data) == 5: + r, t, x, g, v = data + return _run1( # type: ignore + r, + t, + x, + g, + v, + norm_coefficient=norm_coefficient, + alpha=alpha, + beta=beta, + mode=mode, + ) + n = (len(data) - 2) // 3 + xs = [] + vs = [] + for i in range(0, n): + a, b = _run1( # type: ignore + *data[:2], # r and t + data[2 + i], + data[2 + n + i], + data[2 + n * 2 + i], + norm_coefficient=norm_coefficient, + alpha=alpha, + beta=beta, + mode=mode, + ) + xs.append(a) + vs.append(b) + return tuple(xs + vs) + + + +class Momentum(RunAll): + @staticmethod + def export_momentum() -> None: + # Define operator attributes. + norm_coefficient = 0.001 + alpha = 0.95 + beta = 0.1 + + # Define operator inputs. + r = np.array(0.1, dtype=np.float32) # scalar + t = np.array(0, dtype=np.int64) # scalar + x = np.array([1.2, 2.8], dtype=np.float32) + g = np.array([-0.94, -2.5], dtype=np.float32) + v = np.array([1.7, 3.6], dtype=np.float32) + + # Compute expected outputs of Momentum. + x_new, v_new = _apply_momentum(r, t, x, g, v, norm_coefficient, alpha, beta) + + x = np.array([1.2, 2.8, -0.94, -2.5, 1.7, 3.6]) + param = np.array([r, t, alpha, beta, norm_coefficient]) + + x_new = np.array(x_new) + v_new = np.array(v_new) + + x = Tensor(Dtype.FP16x16, x.shape, to_fp(x.flatten(), FixedImpl.FP16x16)) + param = Tensor(Dtype.FP16x16, param.shape, to_fp(param.flatten(), FixedImpl.FP16x16)) + x_new = Tensor(Dtype.FP16x16, x_new.shape, to_fp(x_new.flatten(), FixedImpl.FP16x16)) + v_new = Tensor(Dtype.FP16x16, v_new.shape, to_fp(v_new.flatten(), FixedImpl.FP16x16)) + + name = "momentum_standard" + func_sig = "TensorTrait::momentum(" + func_sig += "*input_1.data.at(0)," + func_sig += "*input_1.data.at(1)," + func_sig += "@input_0," + func_sig += "*input_1.data.at(2)," + func_sig += "*input_1.data.at(3)," + func_sig += "MODE::STANDARD," + func_sig += "*input_1.data.at(4))" + make_test( + [x, param], [x_new, v_new], func_sig, name) + + @staticmethod + def export_nesterov_momentum() -> None: + # Define operator attributes. + norm_coefficient = 0.01 + alpha = 0.95 + beta = 1.0 + + # Define operator inputs. + r = np.array(0.1, dtype=np.float32) # scalar + t = np.array(0, dtype=np.int64) # scalar + x = np.array([1.2, 2.8], dtype=np.float32) + g = np.array([-0.94, -2.5], dtype=np.float32) + v = np.array([1.7, 3.6], dtype=np.float32) + + # Compute expected outputs of Momentum. + x_new, v_new = _apply_nesterov(r, t, x, g, v, norm_coefficient, alpha, beta) + + + x = np.array([1.2, 2.8, -0.94, -2.5, 1.7, 3.6]) + param = np.array([r, t, alpha, beta, norm_coefficient]) + + x_new = np.array(x_new) + v_new = np.array(v_new) + + x = Tensor(Dtype.FP16x16, x.shape, to_fp(x.flatten(), FixedImpl.FP16x16)) + param = Tensor(Dtype.FP16x16, param.shape, to_fp(param.flatten(), FixedImpl.FP16x16)) + x_new = Tensor(Dtype.FP16x16, x_new.shape, to_fp(x_new.flatten(), FixedImpl.FP16x16)) + v_new = Tensor(Dtype.FP16x16, v_new.shape, to_fp(v_new.flatten(), FixedImpl.FP16x16)) + + name = "momentum_nesterov" + func_sig = "TensorTrait::momentum(" + func_sig += "*input_1.data.at(0)," + func_sig += "*input_1.data.at(1)," + func_sig += "@input_0," + func_sig += "*input_1.data.at(2)," + func_sig += "*input_1.data.at(3)," + func_sig += "MODE::STANDARD," + func_sig += "*input_1.data.at(4))" + make_test( + [x, param], [x_new, v_new], func_sig, name) + + + + + diff --git a/src/operators/matrix.cairo b/src/operators/matrix.cairo index 5e7564d11..efdee2e3a 100644 --- a/src/operators/matrix.cairo +++ b/src/operators/matrix.cairo @@ -90,66 +90,70 @@ impl MutMatrixImpl< if axis == 0 { let mut col: usize = 0; - while col != self.cols { - let mut max_value = self.get(0, col); - let mut max_value = match max_value { - Option::Some => { max_value.unwrap() }, - Option::None => { NumberTrait::min_value() } - }; - let mut max_index = 0; - - let mut row: usize = 1; - while row != self.rows { - let mut value = self.get(row, col); - let mut value = match value { - Option::Some => { value.unwrap() }, + while col != self + .cols { + let mut max_value = self.get(0, col); + let mut max_value = match max_value { + Option::Some => { max_value.unwrap() }, Option::None => { NumberTrait::min_value() } }; - - if value > max_value { - max_value = value; - max_index = row; - } - - row += 1; + let mut max_index = 0; + + let mut row: usize = 1; + while row != self + .rows { + let mut value = self.get(row, col); + let mut value = match value { + Option::Some => { value.unwrap() }, + Option::None => { NumberTrait::min_value() } + }; + + if value > max_value { + max_value = value; + max_index = row; + } + + row += 1; + }; + + result.append(max_index); + col += 1; }; - result.append(max_index); - col += 1; - }; - return result.span(); } let mut row: usize = 0; - while row != self.rows { - let mut max_value = self.get(row, 0); - let mut max_value = match max_value { - Option::Some => { max_value.unwrap() }, - Option::None => { NumberTrait::min_value() } - }; - let mut max_index = 0; - - let mut col: usize = 1; - while col != self.cols { - let mut value = self.get(row, col); - let mut value = match value { - Option::Some => { value.unwrap() }, + while row != self + .rows { + let mut max_value = self.get(row, 0); + let mut max_value = match max_value { + Option::Some => { max_value.unwrap() }, Option::None => { NumberTrait::min_value() } }; + let mut max_index = 0; - if value > max_value { - max_value = value; - max_index = col; - } + let mut col: usize = 1; + while col != self + .cols { + let mut value = self.get(row, col); + let mut value = match value { + Option::Some => { value.unwrap() }, + Option::None => { NumberTrait::min_value() } + }; + + if value > max_value { + max_value = value; + max_index = col; + } + + col += 1; + }; - col += 1; + result.append(max_index); + row += 1; }; - result.append(max_index); - row += 1; - }; - result.span() } @@ -161,50 +165,56 @@ impl MutMatrixImpl< if axis == 0 { let mut col: usize = 0; - while col != self.cols { - let mut sum_exp = NumberTrait::zero(); - let mut row: usize = 0; - while row != self.rows { - let value = self.get(row, col).unwrap().into(); - sum_exp += value.exp(); - - row += 1; - }; - - row = 0; - while row != self.rows { - let value = self.get(row, col).unwrap().into(); - let softmax_value = (value.exp() / sum_exp).into(); - result.set(row, col, softmax_value); + while col != self + .cols { + let mut sum_exp = NumberTrait::zero(); + let mut row: usize = 0; + while row != self + .rows { + let value = self.get(row, col).unwrap().into(); + sum_exp += value.exp(); + + row += 1; + }; + + row = 0; + while row != self + .rows { + let value = self.get(row, col).unwrap().into(); + let softmax_value = (value.exp() / sum_exp).into(); + result.set(row, col, softmax_value); + + row += 1; + }; - row += 1; + col += 1; }; - - col += 1; - }; } else { let mut row: usize = 0; - while row != self.rows { - let mut sum_exp = NumberTrait::zero(); - let mut col: usize = 0; - while col != self.cols { - let value = self.get(row, col).unwrap().into(); - sum_exp += value.exp(); - - col += 1; - }; + while row != self + .rows { + let mut sum_exp = NumberTrait::zero(); + let mut col: usize = 0; + while col != self + .cols { + let value = self.get(row, col).unwrap().into(); + sum_exp += value.exp(); + + col += 1; + }; + + col = 0; + while col != self + .cols { + let value = self.get(row, col).unwrap().into(); + let softmax_value = (value.exp() / sum_exp).into(); + result.set(row, col, softmax_value); + + col += 1; + }; - col = 0; - while col != self.cols { - let value = self.get(row, col).unwrap().into(); - let softmax_value = (value.exp() / sum_exp).into(); - result.set(row, col, softmax_value); - - col += 1; + row += 1; }; - - row += 1; - }; } result @@ -220,65 +230,71 @@ impl MutMatrixImpl< if axis == 0 { let mut col: usize = 0; - while col != self.cols { - let mut sum_exp = NumberTrait::zero(); - let mut row: usize = 0; - while row != self.rows { - let value = self.get(row, col).unwrap().into(); - - if value != NumberTrait::zero() { - sum_exp += value.exp(); - } + while col != self + .cols { + let mut sum_exp = NumberTrait::zero(); + let mut row: usize = 0; + while row != self + .rows { + let value = self.get(row, col).unwrap().into(); + + if value != NumberTrait::zero() { + sum_exp += value.exp(); + } + + row += 1; + }; + + row = 0; + while row != self + .rows { + let value = self.get(row, col).unwrap().into(); + + if value != NumberTrait::zero() { + let softmax_value = (value.exp() / sum_exp).into(); + result.set(row, col, softmax_value); + } else { + result.set(row, col, NumberTrait::zero()); + } + + row += 1; + }; - row += 1; - }; - - row = 0; - while row != self.rows { - let value = self.get(row, col).unwrap().into(); - - if value != NumberTrait::zero() { - let softmax_value = (value.exp() / sum_exp).into(); - result.set(row, col, softmax_value); - } else { - result.set(row, col, NumberTrait::zero()); - } - - row += 1; + col += 1; }; - - col += 1; - }; } else { let mut row: usize = 0; - while row != self.rows { - let mut sum_exp = NumberTrait::zero(); - let mut col: usize = 0; - while col != self.cols { - let value = self.get(row, col).unwrap().into(); - if value != NumberTrait::zero() { - sum_exp += value.exp(); - } + while row != self + .rows { + let mut sum_exp = NumberTrait::zero(); + let mut col: usize = 0; + while col != self + .cols { + let value = self.get(row, col).unwrap().into(); + if value != NumberTrait::zero() { + sum_exp += value.exp(); + } + + col += 1; + }; + + col = 0; + while col != self + .cols { + let value = self.get(row, col).unwrap().into(); + + if value != NumberTrait::zero() { + let softmax_value = (value.exp() / sum_exp).into(); + result.set(row, col, softmax_value); + } else { + result.set(row, col, NumberTrait::zero()); + } + + col += 1; + }; - col += 1; - }; - - col = 0; - while col != self.cols { - let value = self.get(row, col).unwrap().into(); - - if value != NumberTrait::zero() { - let softmax_value = (value.exp() / sum_exp).into(); - result.set(row, col, softmax_value); - } else { - result.set(row, col, NumberTrait::zero()); - } - - col += 1; + row += 1; }; - - row += 1; - }; } result @@ -289,23 +305,26 @@ impl MutMatrixImpl< let mut result = MutMatrixImpl::new(self.rows, self.cols); let mut row: usize = 0; - while row != self.rows { - let mut col: usize = 0; - while col != self.cols { - let value = self.get(row, col); + while row != self + .rows { + let mut col: usize = 0; + while col != self + .cols { + let value = self.get(row, col); - if value.is_some() { - let value = NumberTrait::one() - / (NumberTrait::one() + (value.unwrap() * NumberTrait::neg_one()).exp()); + if value.is_some() { + let value = NumberTrait::one() + / (NumberTrait::one() + + (value.unwrap() * NumberTrait::neg_one()).exp()); - result.set(row, col, value); - } + result.set(row, col, value); + } - col += 1; - }; + col += 1; + }; - row += 1; - }; + row += 1; + }; result } diff --git a/src/operators/ml/svm/core.cairo b/src/operators/ml/svm/core.cairo index 365cb0c1b..64c853077 100644 --- a/src/operators/ml/svm/core.cairo +++ b/src/operators/ml/svm/core.cairo @@ -81,10 +81,11 @@ fn squared_diff< ) -> T { let mut i = 0; let mut sum = NumberTrait::zero(); - while i != pA.len() { - sum = sum + (*pA.at(i) - *pB.at(i)).pow(NumberTrait::one() + NumberTrait::one()); - i += 1; - }; + while i != pA + .len() { + sum = sum + (*pA.at(i) - *pB.at(i)).pow(NumberTrait::one() + NumberTrait::one()); + i += 1; + }; sum } diff --git a/src/operators/ml/svm/svm_classifier.cairo b/src/operators/ml/svm/svm_classifier.cairo index 4df3d63f6..8d1f8b90b 100644 --- a/src/operators/ml/svm/svm_classifier.cairo +++ b/src/operators/ml/svm/svm_classifier.cairo @@ -266,11 +266,12 @@ impl SVMClassifierImpl< let (vectors_per_class_, starting_vector_) = match self.vectors_per_class { Option::Some(vectors_per_class) => { let mut i = 0; - while i != vectors_per_class.len() { - starting_vector_.append(vector_count_); - vector_count_ += *vectors_per_class.at(i); - i += 1; - }; + while i != vectors_per_class + .len() { + starting_vector_.append(vector_count_); + vector_count_ += *vectors_per_class.at(i); + i += 1; + }; (vectors_per_class, starting_vector_.span()) }, @@ -309,17 +310,19 @@ impl SVMClassifierImpl< MODE::SVM_LINEAR => { let mut res: Array = array![]; let mut n = 0; - while n != *X.shape.at(0) { - let mut x_n = get_row(@X, n); - let scores = run_linear(ref self, x_n, coefs, class_count_, kernel_type_); - let mut i = 0; - while i != scores.len() { - res.append(*scores.at(i)); - i += 1; - }; + while n != *X + .shape + .at(0) { + let mut x_n = get_row(@X, n); + let scores = run_linear(ref self, x_n, coefs, class_count_, kernel_type_); + let mut i = 0; + while i != scores.len() { + res.append(*scores.at(i)); + i += 1; + }; - n += 1; - }; + n += 1; + }; ( TensorTrait::new(array![*X.shape.at(0), class_count_].span(), res.span()), @@ -330,33 +333,35 @@ impl SVMClassifierImpl< let mut res: Array = array![]; let mut votes: Array = array![]; let mut n = 0; - while n != *X.shape.at(0) { - let mut x_n = get_row(@X, n); - let (scores, mut vote) = run_svm( - ref self, - x_n, - sv, - vector_count_, - kernel_type_, - class_count_, - starting_vector_, - coefs, - vectors_per_class_ - ); - let mut i = 0; - while i != scores.len() { - res.append(*scores.at(i)); - i += 1; - }; + while n != *X + .shape + .at(0) { + let mut x_n = get_row(@X, n); + let (scores, mut vote) = run_svm( + ref self, + x_n, + sv, + vector_count_, + kernel_type_, + class_count_, + starting_vector_, + coefs, + vectors_per_class_ + ); + let mut i = 0; + while i != scores.len() { + res.append(*scores.at(i)); + i += 1; + }; - let mut i = 0; - while i != vote.len() { - votes.append(vote.at(i)); - i += 1; - }; + let mut i = 0; + while i != vote.len() { + votes.append(vote.at(i)); + i += 1; + }; - n += 1; - }; + n += 1; + }; ( TensorTrait::new( @@ -377,18 +382,20 @@ impl SVMClassifierImpl< let (scores, has_proba) = if self.prob_a.len() > 0 { let mut scores: Array = array![]; let mut n = 0; - while n != *res.shape.at(0) { - let res_n = get_row(@res, n); - let mut s = probablities(ref self, res_n, class_count_); - - let mut i = 0; - while i != s.len() { - scores.append(s.at(i)); - i += 1; + while n != *res + .shape + .at(0) { + let res_n = get_row(@res, n); + let mut s = probablities(ref self, res_n, class_count_); + + let mut i = 0; + while i != s.len() { + scores.append(s.at(i)); + i += 1; + }; + + n += 1; }; - - n += 1; - }; ( TensorTrait::new( array![*res.shape.at(0), scores.len() / *res.shape.at(0)].span(), @@ -409,50 +416,56 @@ impl SVMClassifierImpl< let mut final_scores: Array = array![]; let mut n = 0; - while n != *scores.shape.at(0) { - let mut scores_n = get_row(@scores, n); - match votes { - Option::Some(votes) => { - let mut votes_n = get_row(@votes, n); - let (label, new_scores) = compute_final_scores( - ref self, - votes_n, - scores_n, - weights_are_all_positive_, - has_proba, - self.classlabels - ); - - let mut i = 0; - while i != new_scores.data.len() { - final_scores.append(*new_scores.data.at(i)); - i += 1; - }; + while n != *scores + .shape + .at(0) { + let mut scores_n = get_row(@scores, n); + match votes { + Option::Some(votes) => { + let mut votes_n = get_row(@votes, n); + let (label, new_scores) = compute_final_scores( + ref self, + votes_n, + scores_n, + weights_are_all_positive_, + has_proba, + self.classlabels + ); - labels.append(label); - }, - Option::None => { - let (label, new_scores) = compute_final_scores( - ref self, - array![].span(), - scores_n, - weights_are_all_positive_, - has_proba, - self.classlabels - ); - - let mut i = 0; - while i != new_scores.data.len() { - final_scores.append(*new_scores.data.at(i)); - i += 1; - }; - - labels.append(label); - }, - } + let mut i = 0; + while i != new_scores + .data + .len() { + final_scores.append(*new_scores.data.at(i)); + i += 1; + }; + + labels.append(label); + }, + Option::None => { + let (label, new_scores) = compute_final_scores( + ref self, + array![].span(), + scores_n, + weights_are_all_positive_, + has_proba, + self.classlabels + ); - n += 1; - }; + let mut i = 0; + while i != new_scores + .data + .len() { + final_scores.append(*new_scores.data.at(i)); + i += 1; + }; + + labels.append(label); + }, + } + + n += 1; + }; let labels = labels.span(); @@ -460,10 +473,11 @@ impl SVMClassifierImpl< if self.classlabels.len() > 0 { let mut class_labels: Array = array![]; let mut i = 0; - while i != labels.len() { - class_labels.append(*self.classlabels.at(*labels.at(i))); - i += 1; - }; + while i != labels + .len() { + class_labels.append(*self.classlabels.at(*labels.at(i))); + i += 1; + }; return ( class_labels.span(), @@ -1070,11 +1084,12 @@ fn dot_start_end< let mut sum = NumberTrait::zero(); let mut index_a = a_start; let mut index_b = b_start; - while index_a != a_end && index_b != b_end { - sum = sum + *pA.at(index_a) * *pB.at(index_b); - index_a += 1; - index_b += 1; - }; + while index_a != a_end + && index_b != b_end { + sum = sum + *pA.at(index_a) * *pB.at(index_b); + index_a += 1; + index_b += 1; + }; sum } @@ -1110,10 +1125,11 @@ fn squared_diff< ) -> T { let mut i = 0; let mut sum = NumberTrait::zero(); - while i != pA.len() { - sum = sum + (*pA.at(i) - *pB.at(i)).pow(NumberTrait::one() + NumberTrait::one()); - i += 1; - }; + while i != pA + .len() { + sum = sum + (*pA.at(i) - *pB.at(i)).pow(NumberTrait::one() + NumberTrait::one()); + i += 1; + }; sum } diff --git a/src/operators/ml/svm/svm_regressor.cairo b/src/operators/ml/svm/svm_regressor.cairo index 1d5858a2f..286729ff4 100644 --- a/src/operators/ml/svm/svm_regressor.cairo +++ b/src/operators/ml/svm/svm_regressor.cairo @@ -189,40 +189,43 @@ impl SVMRegressorImpl< let mut z: Array = array![]; let mut n = 0; - while n != *X.shape.at(0) { - let mut s = NumberTrait::zero(); - match mode_ { - MODE::SVM_LINEAR => { - let mut x_n = get_row(@X, n); - s = kernel_dot(self.kernel_params, x_n, self.coefficients, kernel_type_); - s += *self.rho.at(0); - }, - MODE::SVM_SVC => { - let mut x_n = get_row(@X, n); - let mut j = 0; - while j != self.n_supports { - let mut sv_j = get_row(@sv, j); - let d = kernel_dot(self.kernel_params, x_n, sv_j, kernel_type_); - s += *self.coefficients.at(j) * d; - j += 1; - }; + while n != *X + .shape + .at(0) { + let mut s = NumberTrait::zero(); + match mode_ { + MODE::SVM_LINEAR => { + let mut x_n = get_row(@X, n); + s = kernel_dot(self.kernel_params, x_n, self.coefficients, kernel_type_); + s += *self.rho.at(0); + }, + MODE::SVM_SVC => { + let mut x_n = get_row(@X, n); + let mut j = 0; + while j != self + .n_supports { + let mut sv_j = get_row(@sv, j); + let d = kernel_dot(self.kernel_params, x_n, sv_j, kernel_type_); + s += *self.coefficients.at(j) * d; + j += 1; + }; - s += *self.rho.at(0); - }, - } - if self.one_class == 1 { - let elem = if s > NumberTrait::zero() { - NumberTrait::one() + s += *self.rho.at(0); + }, + } + if self.one_class == 1 { + let elem = if s > NumberTrait::zero() { + NumberTrait::one() + } else { + -NumberTrait::one() + }; + z.append(elem); } else { - -NumberTrait::one() + z.append(s); }; - z.append(elem); - } else { - z.append(s); - }; - n += 1; - }; + n += 1; + }; // Post Transform let mut score = TensorTrait::new(array![*X.shape.at(0)].span(), z.span()); diff --git a/src/operators/nn/functional/col2im.cairo b/src/operators/nn/functional/col2im.cairo index b08d9f650..465f65cfb 100644 --- a/src/operators/nn/functional/col2im.cairo +++ b/src/operators/nn/functional/col2im.cairo @@ -299,4 +299,4 @@ fn prod, +Copy, +NumberTrait, +TensorTrait, +Mul< }; prod -} \ No newline at end of file +} diff --git a/src/operators/nn/functional/conv.cairo b/src/operators/nn/functional/conv.cairo index ac72c336d..2000b0845 100644 --- a/src/operators/nn/functional/conv.cairo +++ b/src/operators/nn/functional/conv.cairo @@ -193,22 +193,23 @@ fn conv< let mut p = 0; let mut i = 0; - while i != res_b.len() { - let cv = *res_cv.at(i); - - let mut n = 0; - while n != cv.data.len() { - final.append(*cv.data.at(n)); - n += 1; - }; + while i != res_b + .len() { + let cv = *res_cv.at(i); + + let mut n = 0; + while n != cv.data.len() { + final.append(*cv.data.at(n)); + n += 1; + }; - p += *cv.shape.at(1); - if p >= td { - p = 0; - } + p += *cv.shape.at(1); + if p >= td { + p = 0; + } - i += 1; - }; + i += 1; + }; let final = final.span(); @@ -217,24 +218,32 @@ fn conv< let mut final_b: Array = array![]; let final_stride = stride(final_shape); let mut i = 0; - while i != *final_shape.at(0) { - let mut j = 0; - while j != B.len() { - let mut k = 0; - while k != *final_stride.at(1) { - final_b - .append( - *final.at(i * *final_stride.at(0) + j * *final_stride.at(1) + k) - + *B.at(j) - ); - k += 1; - }; + while i != *final_shape + .at(0) { + let mut j = 0; + while j != B + .len() { + let mut k = 0; + while k != *final_stride + .at(1) { + final_b + .append( + *final + .at( + i * *final_stride.at(0) + + j * *final_stride.at(1) + + k + ) + + *B.at(j) + ); + k += 1; + }; - j += 1; - }; + j += 1; + }; - i += 1; - }; + i += 1; + }; final_b.span() }, @@ -253,13 +262,14 @@ fn conv< new_shape.append_span(SpanTrait::slice((*W).shape, 0, (*W).shape.len() - nd)); let mut i = 0; - while i != dilations.len() { - let d = *dilations.at(i); - let di = (*W).shape.len() - nd + i; - new_shape.append(*(*W).shape.at(di) + (*(*W).shape.at(di) - 1) * (d - 1)); - new_kernel_shape.append(*kernel_shape.at(i) + (*kernel_shape.at(i) - 1) * (d - 1)); - i += 1; - }; + while i != dilations + .len() { + let d = *dilations.at(i); + let di = (*W).shape.len() - nd + i; + new_shape.append(*(*W).shape.at(di) + (*(*W).shape.at(di) - 1) * (d - 1)); + new_kernel_shape.append(*kernel_shape.at(i) + (*kernel_shape.at(i) - 1) * (d - 1)); + i += 1; + }; let new_shape = new_shape.span(); let new_w_strides = stride(new_shape); @@ -273,12 +283,13 @@ fn conv< indices.append(arange(0, *new_shape.at(1), 1)); let mut i = 0; - while i != dilations.len() { - let d = *dilations.at(i); - let di = (*W).shape.len() - nd + i; - indices.append(arange(0, *new_shape.at(di), d)); - i += 1; - }; + while i != dilations + .len() { + let d = *dilations.at(i); + let di = (*W).shape.len() - nd + i; + indices.append(arange(0, *new_shape.at(di), d)); + i += 1; + }; let set_of_all_indices = cartesian(indices.span()); @@ -286,29 +297,32 @@ fn conv< let mut i = 0; let mut prev = 0; - while i != (*W).data.len() { - let nd_index = *set_of_all_indices.at(i); - let mut flatten_index = 0; - let mut j = 0; - while j != nd_index.len() { - flatten_index += *nd_index.at(j) * *new_w_strides.at(j); - j += 1; - }; + while i != (*W) + .data + .len() { + let nd_index = *set_of_all_indices.at(i); + let mut flatten_index = 0; + let mut j = 0; + while j != nd_index + .len() { + flatten_index += *nd_index.at(j) * *new_w_strides.at(j); + j += 1; + }; - if flatten_index > prev + 1 { - let mut j = prev + 1; - while j != flatten_index { - new_w_arr.append(NumberTrait::zero()); - }; + if flatten_index > prev + 1 { + let mut j = prev + 1; + while j != flatten_index { + new_w_arr.append(NumberTrait::zero()); + }; - j += 1; - } + j += 1; + } - new_w_arr.append(*(*W).data.at(i)); - new_w.set(flatten_index, *(*W).data.at(i)); - prev = flatten_index; - i += 1; - }; + new_w_arr.append(*(*W).data.at(i)); + new_w.set(flatten_index, *(*W).data.at(i)); + prev = flatten_index; + i += 1; + }; } let pads = match auto_pad { @@ -425,42 +439,51 @@ fn conv< let w = SpanTrait::slice((*W).data, nw * sC * kh + c * kh, kh); let mut io = bh; - while io < eh.into() { - let hr = (io - bh) / sth.into(); - if hr < h_out.into() { - let i = io + (kh % 2).into(); - - let ih1 = I32Number::max(0, i + oh).into(); - let ih2 = I32Number::min(i + oh + kh.into(), sH.into()).into(); - let img = SpanTrait::slice((*X).data, n * sN + c * sC + ih1, ih2 - ih1); - - let s = if w.len() != img.len() { - let jh1 = I32Number::max(0, -i - oh).into(); - let jh2 = I32Number::min(sH.into() - (i + oh), kh.into()).into(); - - let w_ = SpanTrait::slice(w, jh1, jh2 - jh1); - assert(w_.len() == img.len(), 'unexpected w and img len'); - dot(img, w_) - } else { - dot(img, w) - }; + while io < eh + .into() { + let hr = (io - bh) / sth.into(); + if hr < h_out.into() { + let i = io + (kh % 2).into(); + + let ih1 = I32Number::max(0, i + oh).into(); + let ih2 = I32Number::min(i + oh + kh.into(), sH.into()).into(); + let img = SpanTrait::slice( + (*X).data, n * sN + c * sC + ih1, ih2 - ih1 + ); - let hr = if hr < 0 { - *res_strides.at(1) - hr.into() - } else { - hr.into() - }; + let s = if w.len() != img.len() { + let jh1 = I32Number::max(0, -i - oh).into(); + let jh2 = I32Number::min(sH.into() - (i + oh), kh.into()) + .into(); - res - .set( - n * *res_strides.at(0) + nw * *res_strides.at(1) + hr, - res.at(n * *res_strides.at(0) + nw * *res_strides.at(1) + hr) - + s - ); - } + let w_ = SpanTrait::slice(w, jh1, jh2 - jh1); + assert(w_.len() == img.len(), 'unexpected w and img len'); + dot(img, w_) + } else { + dot(img, w) + }; - io += sth.into(); - }; + let hr = if hr < 0 { + *res_strides.at(1) - hr.into() + } else { + hr.into() + }; + + res + .set( + n * *res_strides.at(0) + nw * *res_strides.at(1) + hr, + res + .at( + n * *res_strides.at(0) + + nw * *res_strides.at(1) + + hr + ) + + s + ); + } + + io += sth.into(); + }; c += 1; }; @@ -558,102 +581,114 @@ fn conv< ); let mut io = bh; - while io < eh.into() { - let hr = (io - bh) / sth.into(); - if hr < h_out.into() { - let i = io + (kh % 2).into(); - let ih1 = I32Number::max(0, i + oh).into(); - let ih2 = I32Number::min(i + oh + kh.into(), sH.into()).into(); - - let mut jo = bw; - while jo < ew.into() { - let wr = (jo - bw) / stw.into(); - if wr < w_out.into() { - let j = jo + (kw % 2).into(); - let iw1 = I32Number::max(0, j + ow).into(); - let iw2 = I32Number::min(j + ow + kw.into(), sW.into()).into(); - - let mut img: Array = array![]; - let mut ihi = ih1; - while ihi != ih2 { - img - .append_span( - SpanTrait::slice( - (*X).data, - n * (sC * sH * sW) - + c * (sH * sW) - + ihi * sW - + iw1, - iw2 - iw1 - ) - ); - ihi += 1; - }; + while io < eh + .into() { + let hr = (io - bh) / sth.into(); + if hr < h_out.into() { + let i = io + (kh % 2).into(); + let ih1 = I32Number::max(0, i + oh).into(); + let ih2 = I32Number::min(i + oh + kh.into(), sH.into()).into(); + + let mut jo = bw; + while jo < ew + .into() { + let wr = (jo - bw) / stw.into(); + if wr < w_out.into() { + let j = jo + (kw % 2).into(); + let iw1 = I32Number::max(0, j + ow).into(); + let iw2 = I32Number::min(j + ow + kw.into(), sW.into()) + .into(); - let img = img.span(); + let mut img: Array = array![]; + let mut ihi = ih1; + while ihi != ih2 { + img + .append_span( + SpanTrait::slice( + (*X).data, + n * (sC * sH * sW) + + c * (sH * sW) + + ihi * sW + + iw1, + iw2 - iw1 + ) + ); + ihi += 1; + }; - let s = if w.len() != img.len() { - let jh1 = I32Number::max(0, -i - oh).into(); - let jh2 = I32Number::min(sH.into() - (i + oh), kh.into()) - .into(); + let img = img.span(); - let jw1 = I32Number::max(0, -j - ow).into(); - let jw2 = I32Number::min(sW.into() - (j + ow), kw.into()) - .into(); + let s = if w.len() != img.len() { + let jh1 = I32Number::max(0, -i - oh).into(); + let jh2 = I32Number::min( + sH.into() - (i + oh), kh.into() + ) + .into(); - let mut w_: Array = array![]; - let mut jhj = jh1; - while jhj != jh2 { - w_ - .append_span( - SpanTrait::slice(w, jhj * kw + jw1, jw2 - jw1) - ); - jhj += 1; - }; + let jw1 = I32Number::max(0, -j - ow).into(); + let jw2 = I32Number::min( + sW.into() - (j + ow), kw.into() + ) + .into(); - let w_ = w_.span(); + let mut w_: Array = array![]; + let mut jhj = jh1; + while jhj != jh2 { + w_ + .append_span( + SpanTrait::slice( + w, jhj * kw + jw1, jw2 - jw1 + ) + ); + jhj += 1; + }; - assert(w_.len() == img.len(), 'unexpected w and img len'); - dot(img, w_) - } else { - dot(img, w) - }; + let w_ = w_.span(); - let hr = if hr < 0 { - h_out - hr.into() - } else { - hr.into() - }; + assert( + w_.len() == img.len(), + 'unexpected w and img len' + ); + dot(img, w_) + } else { + dot(img, w) + }; - let wr = if wr < 0 { - w_out - wr.into() - } else { - wr.into() - }; + let hr = if hr < 0 { + h_out - hr.into() + } else { + hr.into() + }; + + let wr = if wr < 0 { + w_out - wr.into() + } else { + wr.into() + }; - res - .set( - n * *res_strides.at(0) - + nw * *res_strides.at(1) - + hr * *res_strides.at(2) - + wr, res - .at( + .set( n * *res_strides.at(0) + nw * *res_strides.at(1) + hr * *res_strides.at(2) - + wr - ) - + s - ); - } + + wr, + res + .at( + n * *res_strides.at(0) + + nw * *res_strides.at(1) + + hr * *res_strides.at(2) + + wr + ) + + s + ); + } - jo += stw.into(); - }; - } + jo += stw.into(); + }; + } - io += sth.into(); - }; + io += sth.into(); + }; c += 1; }; @@ -767,151 +802,165 @@ fn conv< ); let mut io = bh; - while io < eh.into() { - let hr = (io - bh) / sth.into(); - if hr < h_out.into() { - let i = io + (kh % 2).into(); - let ih1 = I32Number::max(0, i + oh).into(); - let ih2 = I32Number::min(i + oh + kh.into(), sH.into()).into(); - - let mut jo = bw; - while jo < ew.into() { - let wr = (jo - bw) / stw.into(); - if wr < w_out.into() { - let j = jo + (kw % 2).into(); - let iw1 = I32Number::max(0, j + ow).into(); - let iw2 = I32Number::min(j + ow + kw.into(), sW.into()).into(); - - let mut zo = bz; - while zo < ez.into() { - let zr = (zo - bz) / stz.into(); - if zr < z_out.into() { - let z = zo + (kz % 2).into(); - let iz1 = I32Number::max(0, z + oz).into(); - let iz2 = I32Number::min(z + oz + kz.into(), sW.into()) + while io < eh + .into() { + let hr = (io - bh) / sth.into(); + if hr < h_out.into() { + let i = io + (kh % 2).into(); + let ih1 = I32Number::max(0, i + oh).into(); + let ih2 = I32Number::min(i + oh + kh.into(), sH.into()).into(); + + let mut jo = bw; + while jo < ew + .into() { + let wr = (jo - bw) / stw.into(); + if wr < w_out.into() { + let j = jo + (kw % 2).into(); + let iw1 = I32Number::max(0, j + ow).into(); + let iw2 = I32Number::min(j + ow + kw.into(), sW.into()) .into(); - let mut img: Array = array![]; - let mut ihi = ih1; - while ihi != ih2 { - let mut iwi = iw1; - while iwi != iw2 { - img - .append_span( - SpanTrait::slice( - (*X).data, - n * (sC * sH * sW * sZ) - + c * (sH * sW * sZ) - + ihi * (sW * sZ) - + iwi * sZ - + iz1, - iz2 - iz1 + let mut zo = bz; + while zo < ez + .into() { + let zr = (zo - bz) / stz.into(); + if zr < z_out.into() { + let z = zo + (kz % 2).into(); + let iz1 = I32Number::max(0, z + oz).into(); + let iz2 = I32Number::min( + z + oz + kz.into(), sW.into() + ) + .into(); + + let mut img: Array = array![]; + let mut ihi = ih1; + while ihi != ih2 { + let mut iwi = iw1; + while iwi != iw2 { + img + .append_span( + SpanTrait::slice( + (*X).data, + n * (sC * sH * sW * sZ) + + c * (sH * sW * sZ) + + ihi * (sW * sZ) + + iwi * sZ + + iz1, + iz2 - iz1 + ) + ); + iwi += 1; + }; + + ihi += 1; + }; + + let img = img.span(); + + let s = if w.len() != img.len() { + let jh1 = I32Number::max(0, -i - oh) + .into(); + let jh2 = I32Number::min( + sH.into() - (i + oh), kh.into() ) - ); - iwi += 1; - }; + .into(); - ihi += 1; - }; - - let img = img.span(); - - let s = if w.len() != img.len() { - let jh1 = I32Number::max(0, -i - oh).into(); - let jh2 = I32Number::min( - sH.into() - (i + oh), kh.into() - ) - .into(); - - let jw1 = I32Number::max(0, -j - ow).into(); - let jw2 = I32Number::min( - sW.into() - (j + ow), kw.into() - ) - .into(); - - let jz1 = I32Number::max(0, -z - oz).into(); - let jz2 = I32Number::min( - sZ.into() - (z + oz), kz.into() - ) - .into(); + let jw1 = I32Number::max(0, -j - ow) + .into(); + let jw2 = I32Number::min( + sW.into() - (j + ow), kw.into() + ) + .into(); - let mut w_: Array = array![]; - let mut jhj = jh1; - while jhj != jh2 { - let mut jwj = jw1; - while jwj != jw2 { - w_ - .append_span( - SpanTrait::slice( - w, - jhj * kw * kz + jwj * kz + jz1, - jz2 - jz1 - ) + let jz1 = I32Number::max(0, -z - oz) + .into(); + let jz2 = I32Number::min( + sZ.into() - (z + oz), kz.into() + ) + .into(); + + let mut w_: Array = array![]; + let mut jhj = jh1; + while jhj != jh2 { + let mut jwj = jw1; + while jwj != jw2 { + w_ + .append_span( + SpanTrait::slice( + w, + jhj * kw * kz + + jwj * kz + + jz1, + jz2 - jz1 + ) + ); + jwj += 1; + }; + + jhj += 1; + }; + + let w_ = w_.span(); + + assert( + w_.len() == img.len(), + 'unexpected w and img len' ); - jwj += 1; - }; + dot(img, w_) + } else { + dot(img, w) + }; + + let hr = if hr < 0 { + h_out - hr.into() + } else { + hr.into() + }; + + let wr = if wr < 0 { + w_out - wr.into() + } else { + wr.into() + }; + + let zr = if zr < 0 { + z_out - zr.into() + } else { + zr.into() + }; + + res + .set( + n * *res_strides.at(0) + + nw * *res_strides.at(1) + + hr * *res_strides.at(2) + + wr * *res_strides.at(3) + + zr, + res + .at( + n * *res_strides.at(0) + + nw + * *res_strides.at(1) + + hr + * *res_strides.at(2) + + wr + * *res_strides.at(3) + + zr + ) + + s + ); + } - jhj += 1; + zo += stz.into(); }; - - let w_ = w_.span(); - - assert( - w_.len() == img.len(), - 'unexpected w and img len' - ); - dot(img, w_) - } else { - dot(img, w) - }; - - let hr = if hr < 0 { - h_out - hr.into() - } else { - hr.into() - }; - - let wr = if wr < 0 { - w_out - wr.into() - } else { - wr.into() - }; - - let zr = if zr < 0 { - z_out - zr.into() - } else { - zr.into() - }; - - res - .set( - n * *res_strides.at(0) - + nw * *res_strides.at(1) - + hr * *res_strides.at(2) - + wr * *res_strides.at(3) - + zr, - res - .at( - n * *res_strides.at(0) - + nw * *res_strides.at(1) - + hr * *res_strides.at(2) - + wr * *res_strides.at(3) - + zr - ) - + s - ); } - zo += stz.into(); + jo += stw.into(); }; - } - - jo += stw.into(); - }; - } + } - io += sth.into(); - }; + io += sth.into(); + }; c += 1; }; @@ -990,10 +1039,11 @@ fn conv< while j != sM { let b_j = *B.at(j); let mut k = 0; - while k != *res_strides.at(1) { - res.set(i * *res_strides.at(0) + j * *res_strides.at(1) + k, b_j); - k += 1; - }; + while k != *res_strides + .at(1) { + res.set(i * *res_strides.at(0) + j * *res_strides.at(1) + k, b_j); + k += 1; + }; j += 1; }; @@ -1014,185 +1064,211 @@ fn conv< (*W).data, nw * *w_stride.at(0) + c * *w_stride.at(1), *w_stride.at(1) ); let mut i = 0; - while i != *range_len.at(0) * *range_stride.at(0) { - let mut io_index: Array = array![]; - let mut r_index: Array = array![]; - let mut flatten_index = i; - - let mut nx = 0; - while nx != nd { - let (n_index, rem) = DivRem::div_rem( - flatten_index, (*range_stride.at(nx)).try_into().unwrap() - ); - - flatten_index = rem; - io_index - .append(n_index.into() * (*strides.at(nx)).into() + *b_index.at(nx)); - r_index.append(n_index.into()); - nx += 1; - }; + while i != *range_len.at(0) + * *range_stride + .at(0) { + let mut io_index: Array = array![]; + let mut r_index: Array = array![]; + let mut flatten_index = i; - if r_index_check(r_index.span(), shape_out) { - let mut indices: Array = array![]; - let mut i1_index: Array = array![]; - let mut i2_index: Array = array![]; - let mut idiff_index: Array = array![]; - - let mut nx = 0; - while nx != nd { - indices.append(*io_index.at(nx) + (*kernel_shape.at(nx) % 2).into()); - i1_index - .append( - I32Number::max(0, *indices.at(nx) + *o_index.at(nx)).into() - ); - i2_index - .append( - I32Number::min( - (*(*X).shape.at(nx + 2)).into(), - *indices.at(nx) - + *o_index.at(nx) - + (*kernel_shape.at(nx)).into() - ) - .into() + let mut nx = 0; + while nx != nd { + let (n_index, rem) = DivRem::div_rem( + flatten_index, (*range_stride.at(nx)).try_into().unwrap() ); - if nx != nd - 1 { - idiff_index.append(*i2_index.at(nx) - *i1_index.at(nx)); - } - nx += 1; - }; - - let i1_index = i1_index.span(); - let mut img: Array = array![]; - - let img = if nx == 1 { - let img = SpanTrait::slice( - (*X).data, - n * sN + c * sC + *i1_index.at(nd - 1), - *i2_index.at(nd - 1) - *i1_index.at(nd - 1) - ); - img - } else { - let i_stride = stride(idiff_index.span()); + flatten_index = rem; + io_index + .append( + n_index.into() * (*strides.at(nx)).into() + *b_index.at(nx) + ); + r_index.append(n_index.into()); + nx += 1; + }; - let mut ii = 0; - while ii != *i_stride.at(0) * *idiff_index.at(0) { - let mut flatten_index = ii; - let mut start = n * *x_stride.at(0) + c * *x_stride.at(1); + if r_index_check(r_index.span(), shape_out) { + let mut indices: Array = array![]; + let mut i1_index: Array = array![]; + let mut i2_index: Array = array![]; + let mut idiff_index: Array = array![]; let mut nx = 0; - while nx != nd - 1 { - let (ii_index, rem) = DivRem::div_rem( - flatten_index, (*i_stride.at(nx)).try_into().unwrap() - ); - flatten_index = rem; + while nx != nd { + indices + .append( + *io_index.at(nx) + (*kernel_shape.at(nx) % 2).into() + ); + i1_index + .append( + I32Number::max(0, *indices.at(nx) + *o_index.at(nx)) + .into() + ); + i2_index + .append( + I32Number::min( + (*(*X).shape.at(nx + 2)).into(), + *indices.at(nx) + + *o_index.at(nx) + + (*kernel_shape.at(nx)).into() + ) + .into() + ); - start += (*i1_index.at(nx) + ii_index) * *x_stride.at(2 + nx); + if nx != nd - 1 { + idiff_index.append(*i2_index.at(nx) - *i1_index.at(nx)); + } nx += 1; }; - img - .append_span( - SpanTrait::slice( - (*X).data, - start + *i1_index.at(nd - 1), - *i2_index.at(nd - 1) - *i1_index.at(nd - 1) - ) - ); - ii += 1; - }; + let i1_index = i1_index.span(); + let mut img: Array = array![]; - img.span() - }; - - let s = if w.len() != img.len() { - let mut j1_index: Array = array![]; - let mut j2_index: Array = array![]; - let mut jdiff_index: Array = array![]; - - let mut nx = 0; - while nx != nd { - j1_index - .append( - I32Number::max(0, -*indices.at(nx) - *o_index.at(nx)).into() - ); - j2_index - .append( - I32Number::min( - (*(*X).shape.at(nx + 2)).into() - - *indices.at(nx) - - *o_index.at(nx), - (*kernel_shape.at(nx)).into() - ) - .into() + let img = if nx == 1 { + let img = SpanTrait::slice( + (*X).data, + n * sN + c * sC + *i1_index.at(nd - 1), + *i2_index.at(nd - 1) - *i1_index.at(nd - 1) ); - if nx != nd - 1 { - jdiff_index.append(*j2_index.at(nx) - *j1_index.at(nx)); - } - nx += 1; - }; + img + } else { + let i_stride = stride(idiff_index.span()); + + let mut ii = 0; + while ii != *i_stride.at(0) + * *idiff_index + .at(0) { + let mut flatten_index = ii; + let mut start = n * *x_stride.at(0) + + c * *x_stride.at(1); + + let mut nx = 0; + while nx != nd + - 1 { + let (ii_index, rem) = DivRem::div_rem( + flatten_index, + (*i_stride.at(nx)).try_into().unwrap() + ); + flatten_index = rem; - let j1_index = j1_index.span(); + start += (*i1_index.at(nx) + ii_index) + * *x_stride.at(2 + nx); + nx += 1; + }; - let mut w_: Array = array![]; + img + .append_span( + SpanTrait::slice( + (*X).data, + start + *i1_index.at(nd - 1), + *i2_index.at(nd - 1) + - *i1_index.at(nd - 1) + ) + ); + ii += 1; + }; - let w_ = if nx == 1 { - let w_ = SpanTrait::slice( - w, - *j1_index.at(nd - 1), - *j2_index.at(nd - 1) - *j1_index.at(nd - 1) - ); - w_ - } else { - let j_stride = stride(jdiff_index.span()); + img.span() + }; - let mut jj = 0; - while jj != *j_stride.at(0) * *jdiff_index.at(0) { - let mut flatten_index = jj; - let mut start = 0; + let s = if w.len() != img.len() { + let mut j1_index: Array = array![]; + let mut j2_index: Array = array![]; + let mut jdiff_index: Array = array![]; let mut nx = 0; - while nx != nd - 1 { - let (jj_index, rem) = DivRem::div_rem( - flatten_index, (*j_stride.at(nx)).try_into().unwrap() - ); - flatten_index = rem; - start += (*j1_index.at(nx) + jj_index) - * *kernel_shape.at(nx); + while nx != nd { + j1_index + .append( + I32Number::max( + 0, -*indices.at(nx) - *o_index.at(nx) + ) + .into() + ); + j2_index + .append( + I32Number::min( + (*(*X).shape.at(nx + 2)).into() + - *indices.at(nx) + - *o_index.at(nx), + (*kernel_shape.at(nx)).into() + ) + .into() + ); + if nx != nd - 1 { + jdiff_index.append(*j2_index.at(nx) - *j1_index.at(nx)); + } nx += 1; }; - w_ - .append_span( - SpanTrait::slice( - w, - start + *j1_index.at(nd - 1), - *j2_index.at(nd - 1) - *j1_index.at(nd - 1) - ) + + let j1_index = j1_index.span(); + + let mut w_: Array = array![]; + + let w_ = if nx == 1 { + let w_ = SpanTrait::slice( + w, + *j1_index.at(nd - 1), + *j2_index.at(nd - 1) - *j1_index.at(nd - 1) ); - jj += 1; - }; + w_ + } else { + let j_stride = stride(jdiff_index.span()); + + let mut jj = 0; + while jj != *j_stride.at(0) + * *jdiff_index + .at(0) { + let mut flatten_index = jj; + let mut start = 0; + + let mut nx = 0; + while nx != nd + - 1 { + let (jj_index, rem) = DivRem::div_rem( + flatten_index, + (*j_stride.at(nx)) + .try_into() + .unwrap() + ); + flatten_index = rem; + start += (*j1_index.at(nx) + jj_index) + * *kernel_shape.at(nx); + nx += 1; + }; + w_ + .append_span( + SpanTrait::slice( + w, + start + *j1_index.at(nd - 1), + *j2_index.at(nd - 1) + - *j1_index.at(nd - 1) + ) + ); + jj += 1; + }; - w_.span() - }; + w_.span() + }; - dot(img, w_) - } else { - dot(img, w) - }; + dot(img, w_) + } else { + dot(img, w) + }; - let mut res_index = n * *res_strides.at(0) + nw * *res_strides.at(1); + let mut res_index = n * *res_strides.at(0) + + nw * *res_strides.at(1); - let mut nx = 0; - while nx != nd { - res_index += (*r_index.at(nx)).into() * *res_strides.at(2 + nx); - nx += 1; - }; + let mut nx = 0; + while nx != nd { + res_index += (*r_index.at(nx)).into() * *res_strides.at(2 + nx); + nx += 1; + }; - res.set(res_index, res.at(res_index) + s); - }; + res.set(res_index, res.at(res_index) + s); + }; - i += 1 - }; + i += 1 + }; c += 1; }; @@ -1306,14 +1382,15 @@ fn cartesian(mut arrays: Span>,) -> Span> { let mut m = n; let mut i = 0; - while i != arrays.len() { - m = m / (*(arrays.at(i))).len(); - let mut out = repeat(*(arrays.at(i)), m); - out = repeat_2(out, size_arrays, i); + while i != arrays + .len() { + m = m / (*(arrays.at(i))).len(); + let mut out = repeat(*(arrays.at(i)), m); + out = repeat_2(out, size_arrays, i); - output_arrays.append(out); - i += 1; - }; + output_arrays.append(out); + i += 1; + }; let output_arrays = output_arrays.span(); @@ -1339,15 +1416,16 @@ fn repeat_2(mut array: Array, size_array: Span, index: usize) -> A let mut i = 0; while i != index { let mut j = 1; - while j != *size_array.at(index - 1 - i) { - let mut k = 0; - while k != size { - array.append(*array.at(k)); - k += 1; - }; + while j != *size_array + .at(index - 1 - i) { + let mut k = 0; + while k != size { + array.append(*array.at(k)); + k += 1; + }; - j += 1; - }; + j += 1; + }; size = size * *size_array.at(index - 1 - i); i += 1; @@ -1359,15 +1437,16 @@ fn repeat_2(mut array: Array, size_array: Span, index: usize) -> A fn repeat(array: Span, m: usize,) -> Array { let mut out: Array = array![]; let mut j = 0; - while j != array.len() { - let mut k = 0; - while k != m { - out.append(*array.at(j)); - k += 1; - }; + while j != array + .len() { + let mut k = 0; + while k != m { + out.append(*array.at(j)); + k += 1; + }; - j += 1; - }; + j += 1; + }; out } diff --git a/src/operators/nn/functional/grid_sample.cairo b/src/operators/nn/functional/grid_sample.cairo index aed560e37..909065bfa 100644 --- a/src/operators/nn/functional/grid_sample.cairo +++ b/src/operators/nn/functional/grid_sample.cairo @@ -99,68 +99,70 @@ fn grid_sample< let all_coords = get_all_coords(SpanTrait::slice(grid_dims, 1, grid_dims.len() - 2)); let mut ix = 0; - while ix != all_coords.len() { - let ox = *all_coords.at(ix); - let nx = get_sub(grid_data, grid_data_stride, ox); - let nx = reverse(nx); - let x = gs_denormalize_coordinates(nx, dims, align_corner); - - let x = match mode { - MODE::NEAREST => { rint(x) }, - MODE::LINEAR => { x }, - MODE::CUBIC => { x }, - }; + while ix != all_coords + .len() { + let ox = *all_coords.at(ix); + let nx = get_sub(grid_data, grid_data_stride, ox); + let nx = reverse(nx); + let x = gs_denormalize_coordinates(nx, dims, align_corner); + + let x = match mode { + MODE::NEAREST => { rint(x) }, + MODE::LINEAR => { x }, + MODE::CUBIC => { x }, + }; - let mut new_x: Array = array![]; - let mut i = 0; - while i != x.len() { - let v = *x.at(i); - let mut x_min = *border.at(i); - let mut x_max = *border.at(i + num_dims); - let new_v = if v < x_min || v > x_max { - let v = match padding_mode { - PADDING_MODE::ZEROS => { v }, - PADDING_MODE::BORDER => { - clamp( - v, - NumberTrait::zero(), - NumberTrait::new_unscaled((*dims.at(i)).into(), false) - - NumberTrait::one() - ) - }, - PADDING_MODE::REFLECTION => { gs_reflect(v, x_min, x_max) }, + let mut new_x: Array = array![]; + let mut i = 0; + while i != x + .len() { + let v = *x.at(i); + let mut x_min = *border.at(i); + let mut x_max = *border.at(i + num_dims); + let new_v = if v < x_min || v > x_max { + let v = match padding_mode { + PADDING_MODE::ZEROS => { v }, + PADDING_MODE::BORDER => { + clamp( + v, + NumberTrait::zero(), + NumberTrait::new_unscaled((*dims.at(i)).into(), false) + - NumberTrait::one() + ) + }, + PADDING_MODE::REFLECTION => { gs_reflect(v, x_min, x_max) }, + }; + v + } else { + v + }; + + new_x.append(new_v); + i += 1; }; - v - } else { - v - }; - new_x.append(new_v); - i += 1; - }; + let x = new_x.span(); + + let y = match mode { + MODE::NEAREST => { + pixel_at_ndarray(X_data, dims, X_data_stride, x, border, padding_mode) + }, + MODE::LINEAR => { + gs_linear_interpolation_nd_with_x( + X_data, dims, X_data_stride, x, border, padding_mode + ) + }, + MODE::CUBIC => { + gs_cubic_interpolation_nd_with_x( + X_data, dims, X_data_stride, x, border, padding_mode + ) + }, + }; - let x = new_x.span(); - - let y = match mode { - MODE::NEAREST => { - pixel_at_ndarray(X_data, dims, X_data_stride, x, border, padding_mode) - }, - MODE::LINEAR => { - gs_linear_interpolation_nd_with_x( - X_data, dims, X_data_stride, x, border, padding_mode - ) - }, - MODE::CUBIC => { - gs_cubic_interpolation_nd_with_x( - X_data, dims, X_data_stride, x, border, padding_mode - ) - }, + Y.append(y); + ix += 1; }; - Y.append(y); - ix += 1; - }; - c += 1; }; @@ -288,26 +290,27 @@ fn gs_cubic_interpolation_nd_with_x< let mut res1d: Array = array![]; let mut i = 0; - while i != *data_dims.at(0) { - let sub_data = SpanTrait::slice(data, i * *data_stride.at(0), *data_stride.at(0)); - let sub_x = SpanTrait::slice(x, 1, x.len() - 1); - - let data_dims_sub = SpanTrait::slice(data_dims, 1, data_dims.len() - 1); - let data_stride_sub = SpanTrait::slice(data_stride, 1, data_stride.len() - 1); - - let border1 = SpanTrait::slice(border, 1, num_dims - 1); - let border2 = SpanTrait::slice(border, num_dims + 1, num_dims - 1); - let mut border = ArrayTrait::new(); - border.append_span(border1); - border.append_span(border2); - - let r = gs_cubic_interpolation_nd_with_x( - sub_data, data_dims_sub, data_stride_sub, sub_x, border.span(), padding_mode - ); + while i != *data_dims + .at(0) { + let sub_data = SpanTrait::slice(data, i * *data_stride.at(0), *data_stride.at(0)); + let sub_x = SpanTrait::slice(x, 1, x.len() - 1); + + let data_dims_sub = SpanTrait::slice(data_dims, 1, data_dims.len() - 1); + let data_stride_sub = SpanTrait::slice(data_stride, 1, data_stride.len() - 1); + + let border1 = SpanTrait::slice(border, 1, num_dims - 1); + let border2 = SpanTrait::slice(border, num_dims + 1, num_dims - 1); + let mut border = ArrayTrait::new(); + border.append_span(border1); + border.append_span(border2); + + let r = gs_cubic_interpolation_nd_with_x( + sub_data, data_dims_sub, data_stride_sub, sub_x, border.span(), padding_mode + ); - res1d.append(r); - i += 1; - }; + res1d.append(r); + i += 1; + }; gs_cubic_interpolation_1d_with_x( res1d.span(), *x.at(0), array![*border.at(0), *border.at(num_dims)].span(), padding_mode @@ -408,26 +411,27 @@ fn gs_linear_interpolation_nd_with_x< let mut res1d: Array = array![]; let mut i = 0; - while i != *data_dims.at(0) { - let sub_data = SpanTrait::slice(data, i * *data_stride.at(0), *data_stride.at(0)); - let sub_x = SpanTrait::slice(x, 1, x.len() - 1); - - let data_dims_sub = SpanTrait::slice(data_dims, 1, data_dims.len() - 1); - let data_stride_sub = SpanTrait::slice(data_stride, 1, data_stride.len() - 1); - - let border1 = SpanTrait::slice(border, 1, num_dims - 1); - let border2 = SpanTrait::slice(border, num_dims + 1, num_dims - 1); - let mut border = ArrayTrait::new(); - border.append_span(border1); - border.append_span(border2); - - let r = gs_linear_interpolation_nd_with_x( - sub_data, data_dims_sub, data_stride_sub, sub_x, border.span(), padding_mode - ); + while i != *data_dims + .at(0) { + let sub_data = SpanTrait::slice(data, i * *data_stride.at(0), *data_stride.at(0)); + let sub_x = SpanTrait::slice(x, 1, x.len() - 1); + + let data_dims_sub = SpanTrait::slice(data_dims, 1, data_dims.len() - 1); + let data_stride_sub = SpanTrait::slice(data_stride, 1, data_stride.len() - 1); + + let border1 = SpanTrait::slice(border, 1, num_dims - 1); + let border2 = SpanTrait::slice(border, num_dims + 1, num_dims - 1); + let mut border = ArrayTrait::new(); + border.append_span(border1); + border.append_span(border2); + + let r = gs_linear_interpolation_nd_with_x( + sub_data, data_dims_sub, data_stride_sub, sub_x, border.span(), padding_mode + ); - res1d.append(r); - i += 1; - }; + res1d.append(r); + i += 1; + }; gs_linear_interpolation_1d_with_x( res1d.span(), *x.at(0), array![*border.at(0), *border.at(num_dims)].span(), padding_mode @@ -586,20 +590,21 @@ fn rint< let two: T = NumberTrait::one() + NumberTrait::one(); let mut i = 0; - while i != data.len() { - let x = *data.at(i); - let mut round = NumberTrait::round(x); - - let diff = round - x; - if diff == NumberTrait::half() { - if round % two != NumberTrait::zero() { - round -= NumberTrait::one() + while i != data + .len() { + let x = *data.at(i); + let mut round = NumberTrait::round(x); + + let diff = round - x; + if diff == NumberTrait::half() { + if round % two != NumberTrait::zero() { + round -= NumberTrait::one() + } } - } - rint.append(round); - i += 1; - }; + rint.append(round); + i += 1; + }; rint.span() } @@ -779,12 +784,13 @@ fn gs_denormalize_coordinates< let mut x: Array = array![]; let mut i = 0; - while i != n.len() { - let v = *n.at(i); - let dim = *dims.at(i); - x.append(gs_denormalize(v, dim, align_corner)); - i += 1; - }; + while i != n + .len() { + let v = *n.at(i); + let dim = *dims.at(i); + x.append(gs_denormalize(v, dim, align_corner)); + i += 1; + }; x.span() } @@ -851,14 +857,15 @@ fn cartesian(mut arrays: Span>,) -> Span> { let mut m = n; let mut i = 0; - while i != arrays.len() { - m = m / (*(arrays.at(i))).len(); - let mut out = repeat(*(arrays.at(i)), m); - out = repeat_2(out, size_arrays, i); - - output_arrays.append(out); - i += 1; - }; + while i != arrays + .len() { + m = m / (*(arrays.at(i))).len(); + let mut out = repeat(*(arrays.at(i)), m); + out = repeat_2(out, size_arrays, i); + + output_arrays.append(out); + i += 1; + }; let output_arrays = output_arrays.span(); @@ -884,15 +891,16 @@ fn repeat_2(mut array: Array, size_array: Span, index: usize) -> A let mut i = 0; while i != index { let mut j = 1; - while j != *size_array.at(index - 1 - i) { - let mut k = 0; - while k != size { - array.append(*array.at(k)); - k += 1; - }; + while j != *size_array + .at(index - 1 - i) { + let mut k = 0; + while k != size { + array.append(*array.at(k)); + k += 1; + }; - j += 1; - }; + j += 1; + }; size = size * *size_array.at(index - 1 - i); i += 1; @@ -904,15 +912,16 @@ fn repeat_2(mut array: Array, size_array: Span, index: usize) -> A fn repeat(array: Span, m: usize,) -> Array { let mut out: Array = array![]; let mut j = 0; - while j != array.len() { - let mut k = 0; - while k != m { - out.append(*array.at(j)); - k += 1; - }; + while j != array + .len() { + let mut k = 0; + while k != m { + out.append(*array.at(j)); + k += 1; + }; - j += 1; - }; + j += 1; + }; out } diff --git a/src/operators/tensor.cairo b/src/operators/tensor.cairo index adf2076b5..2a5bd2de1 100644 --- a/src/operators/tensor.cairo +++ b/src/operators/tensor.cairo @@ -6,6 +6,7 @@ mod quantization; mod implementations; mod manipulation; mod ml; +mod preview_training; use orion::operators::tensor::core::{Tensor, TensorSerde, TensorTrait}; diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 02f9cc6e4..bef5088f8 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -4,8 +4,7 @@ use core::serde::Serde; use core::option::OptionTrait; use alexandria_data_structures::array_ext::{SpanTraitExt}; -//::resize::{MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE}; - +use orion::operators::tensor::preview_training::momentum::MODE; use orion::operators::tensor::helpers::{len_from_shape, check_shape}; use orion::numbers::{NumberTrait, I32IntoU32, U32IntoI32}; @@ -683,7 +682,7 @@ trait TensorTrait { axes: Option>, keepdims: Option, noop_with_empty_axes: Option - ) -> Tensor; + ) -> Tensor; /// # tensor.argmax /// /// ```rust @@ -5858,6 +5857,90 @@ trait TensorTrait { values: Option>, values_tensor: Option> ) -> Tensor; + + /// # tensor.momentum + /// + /// ```rust + /// fn momentum( + /// r: T, t: T, inputs: @Tensor, alpha: T, beta: T, mode: MODE, norm_coefficient: T, + /// ) -> (Tensor, Tensor); + /// ``` + /// + /// Compute one iteration of stochastic gradient update with momentum. + /// + /// ## Args + /// + /// * `r`(`T`) - The learning rate. + /// * `i`(`T`) - Update count of "X". + /// * `inputs`(`@Tensor`) - It sequentially contains the current values of optimized tensors, then their gradient tensors, and finally their momentum tensors. For example, if two tensors "X_1" and "X_2" are optimized, The expected input list would be ["X_1", "X_2", gradient of "X_1", gradient of "X_2", momentum of "X_1", momentum of "X_2"]. + /// * `alpha`(`T`) - The decay factor of momentum. + /// * `beta`(`T`) - The coefficient of gradient in computing new momentum. + /// * `mode`(`MODE`) - Its value should be either "nesterov" or "standard". The value "nesterov" leads to the use of Nesterov's momentum while "standard" invokes stochastic gradient method using standard momentum + /// * `norm_coefficient`(`T`) - Coefficient of 0.5 * norm_coefficient * ||X||^2. + /// ## Returns + /// + /// Two `Tensor` containing the new values of optimized tensors and then the new values of their momentum tensors. + /// + /// ## Type Constraints + /// + /// * `T` in (`Tensor`, `Tensor`, `Tensor`, `tensor,`) + /// + /// ## Examples + /// + /// ```rust + /// use array::{ArrayTrait, SpanTrait}; + /// use orion::operators::tensor::{FP16x16Tensor}; + /// use orion::operators::tensor::{TensorTrait, Tensor}; + /// use orion::operators::tensor::preview_training::momentum::MODE; + /// + /// fn example_momentum() -> (Tensor, Tensor){ + /// let mut shape = ArrayTrait::::new(); + /// shape.append(6); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 78643, sign: false }); + /// data.append(FP16x16 { mag: 183500, sign: false }); + /// data.append(FP16x16 { mag: 61603, sign: true }); + /// data.append(FP16x16 { mag: 163840, sign: true }); + /// data.append(FP16x16 { mag: 111411, sign: false }); + /// data.append(FP16x16 { mag: 235929, sign: false }); + /// let mut X = TensorTrait::new(shape.span(), data.span()); + /// + /// let mut shape = ArrayTrait::::new(); + /// shape.append(3); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 65, sign: false }); + /// data.append(FP16x16 { mag: 62259, sign: false }); + /// data.append(FP16x16 { mag: 6553, sign: false }); + /// let param = TensorTrait::new(shape.span(), data.span()); + /// + /// let mut shape = ArrayTrait::::new(); + /// shape.append(2); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 74211, sign: false }); + /// data.append(FP16x16 { mag: 177453, sign: false }); + /// let expected_output = TensorTrait::new(shape.span(), data.span()); + /// + /// + /// return TensorTrait::momentum( + /// FP16x16 { mag: 6553, sign: false }, + /// FP16x16 { mag: 0, sign: false }, + /// @X, + /// *param.data.at(1), + /// *param.data.at(2), + /// MODE::STANDARD, + /// *param.data.at(0), + /// ); + /// } + /// >>> ([1.13238 2.70772],[0.67620003 0.9227998 ]) + /// + /// ``` + /// + fn momentum( + r: T, t: T, inputs: @Tensor, alpha: T, beta: T, mode: MODE, norm_coefficient: T, + ) -> (Tensor, Tensor); } /// Cf: TensorTrait::new docstring diff --git a/src/operators/tensor/helpers.cairo b/src/operators/tensor/helpers.cairo index 550ff45c5..8482e5cdf 100644 --- a/src/operators/tensor/helpers.cairo +++ b/src/operators/tensor/helpers.cairo @@ -52,32 +52,33 @@ fn check_compatibility(mut shape_1: Span, mut shape_2: Span) { let mut iter_2 = shape_2.len(); // Iterate while there are dimensions left in either shape - while iter_1 > 0 || iter_2 > 0 { - // Get the current dimension for each shape, defaulting to 1 if we've run out of dimensions - let dim_1 = if iter_1 > 0 { - *shape_1[iter_1 - 1] - } else { - 1 - }; - let dim_2 = if iter_2 > 0 { - *shape_2[iter_2 - 1] - } else { - 1 - }; + while iter_1 > 0 + || iter_2 > 0 { + // Get the current dimension for each shape, defaulting to 1 if we've run out of dimensions + let dim_1 = if iter_1 > 0 { + *shape_1[iter_1 - 1] + } else { + 1 + }; + let dim_2 = if iter_2 > 0 { + *shape_2[iter_2 - 1] + } else { + 1 + }; - // Check the broadcasting rule for the current dimension - if dim_1 != dim_2 && dim_1 != 1 && dim_2 != 1 { - panic(array!['tensors shape must match']); - } + // Check the broadcasting rule for the current dimension + if dim_1 != dim_2 && dim_1 != 1 && dim_2 != 1 { + panic(array!['tensors shape must match']); + } - // Move to the next dimension - if iter_1 > 0 { - iter_1 -= 1; - } - if iter_2 > 0 { - iter_2 -= 1; + // Move to the next dimension + if iter_1 > 0 { + iter_1 -= 1; + } + if iter_2 > 0 { + iter_2 -= 1; + } } - } } /// Computes the index in the broadcasted tensor corresponding to the given indices and shape. @@ -250,17 +251,18 @@ fn combine_indices(mut output_indices: Span, axis_index: usize, axis: usi let mut result: Array = array![]; let mut n: usize = 0; - while n != output_indices.len() + 1 { - if n == axis { - result.append(axis_index); - } else if n > axis { - result.append(*output_indices[n - 1_usize]); - } else { - result.append(*output_indices[n]); - } + while n != output_indices.len() + + 1 { + if n == axis { + result.append(axis_index); + } else if n > axis { + result.append(*output_indices[n - 1_usize]); + } else { + result.append(*output_indices[n]); + } - n += 1; - }; + n += 1; + }; result.span() } @@ -313,13 +315,15 @@ fn broadcast_shape(mut shape1: Span, mut shape2: Span) -> Span = array![]; - while !shape1.is_empty() || !shape2.is_empty() { - let dim1 = *shape1.pop_back().unwrap_or(@1); - let dim2 = *shape2.pop_back().unwrap_or(@1); + while !shape1.is_empty() + || !shape2 + .is_empty() { + let dim1 = *shape1.pop_back().unwrap_or(@1); + let dim2 = *shape2.pop_back().unwrap_or(@1); - let broadcasted_dim = u32_max(dim1, dim2); - result.append(broadcasted_dim); - }; + let broadcasted_dim = u32_max(dim1, dim2); + result.append(broadcasted_dim); + }; result.reverse().span() } diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index e8ca7e2d8..2d5ceabb3 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -3,7 +3,9 @@ use orion::operators::tensor::core::{ constant_of_shape, new_tensor, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, }; -use orion::operators::tensor::{math, linalg, quantization, core as core_ops, ml, manipulation}; +use orion::operators::tensor::{ + math, linalg, quantization, core as core_ops, ml, manipulation, preview_training +}; use orion::numbers::{NumberTrait}; use orion::operators::tensor::implementations::tensor_u32::U32Tensor; @@ -552,6 +554,18 @@ impl BoolTensor of TensorTrait { ) -> Tensor { panic(array!['not supported!']) } + + fn momentum( + r: bool, + t: bool, + inputs: @Tensor, + alpha: bool, + beta: bool, + mode: preview_training::momentum::MODE, + norm_coefficient: bool, + ) -> (Tensor, Tensor) { + panic(array!['not supported!']) + } } /// Implements partial equal for two `Tensor` using the `PartialEq` trait. diff --git a/src/operators/tensor/implementations/tensor_complex64.cairo b/src/operators/tensor/implementations/tensor_complex64.cairo index 8acb0891e..183b1e561 100644 --- a/src/operators/tensor/implementations/tensor_complex64.cairo +++ b/src/operators/tensor/implementations/tensor_complex64.cairo @@ -3,7 +3,9 @@ use orion::operators::tensor::core::{ new_tensor, constant_of_shape, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, }; -use orion::operators::tensor::{math, linalg, quantization, core as core_tensor, ml, manipulation}; +use orion::operators::tensor::{ + math, linalg, quantization, core as core_tensor, ml, manipulation, preview_training +}; use orion::numbers::{NumberTrait, FP64x64, FP64x64Impl}; use orion::numbers::fixed_point::implementations::fp64x64::core::ONE; use orion::operators::tensor::implementations::{ @@ -89,10 +91,7 @@ impl Complex64Tensor of TensorTrait { } fn argmax( - self: @Tensor, - axis: i32, - keepdims: Option, - select_last_index: Option + self: @Tensor, axis: i32, keepdims: Option, select_last_index: Option ) -> Tensor { panic(array!['not supported!']) } @@ -592,6 +591,18 @@ impl Complex64Tensor of TensorTrait { ) -> Tensor { panic(array!['not supported!']) } + + fn momentum( + r: complex64, + t: complex64, + inputs: @Tensor, + alpha: complex64, + beta: complex64, + mode: preview_training::momentum::MODE, + norm_coefficient: complex64, + ) -> (Tensor, Tensor) { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index 27f853df5..f8685d188 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -4,7 +4,9 @@ use orion::operators::tensor::core::{ new_tensor, constant_of_shape, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, }; -use orion::operators::tensor::{math, linalg, quantization, core as core_tensor, ml, manipulation}; +use orion::operators::tensor::{ + math, linalg, quantization, core as core_tensor, ml, manipulation, preview_training +}; use orion::numbers::{NumberTrait, FP16x16, I8IntoFP16x16}; use orion::operators::tensor::implementations::{ tensor_i8::I8Tensor, tensor_u32::U32Tensor, tensor_bool::BoolTensor @@ -71,7 +73,9 @@ impl FP16x16Tensor of TensorTrait { unravel_index(index, *self.shape) } - fn reshape(self: @Tensor, target_shape: Span, allowzero: bool) -> Tensor { + fn reshape( + self: @Tensor, target_shape: Span, allowzero: bool + ) -> Tensor { reshape(self, target_shape, allowzero) } @@ -353,9 +357,7 @@ impl FP16x16Tensor of TensorTrait { core_tensor::slice::(self, starts, ends, axes, steps) } - fn gather( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { + fn gather(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { math::gather::gather(self, indices, axis) } @@ -644,6 +646,18 @@ impl FP16x16Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn momentum( + r: FP16x16, + t: FP16x16, + inputs: @Tensor, + alpha: FP16x16, + beta: FP16x16, + mode: preview_training::momentum::MODE, + norm_coefficient: FP16x16, + ) -> (Tensor, Tensor) { + preview_training::momentum::momentum(r, t, inputs, alpha, beta, mode, norm_coefficient) + } } /// Implements addition for `Tensor` using the `Add` trait. @@ -765,17 +779,19 @@ fn relative_eq(lhs: @FP16x16, rhs: @FP16x16) -> bool { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.data.len() != 0 && is_eq { - is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); - }; + while lhs.data.len() != 0 + && is_eq { + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; is_eq } diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index 61485bae6..f0b737f9d 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -4,7 +4,9 @@ use orion::operators::tensor::core::{ new_tensor, constant_of_shape, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, }; -use orion::operators::tensor::{math, linalg, quantization, core as core_tensor, ml, manipulation}; +use orion::operators::tensor::{ + math, linalg, quantization, core as core_tensor, ml, manipulation, preview_training +}; use orion::numbers::{NumberTrait, FP16x16W}; use orion::operators::tensor::implementations::{ tensor_i8::I8Tensor, tensor_u32::U32Tensor, tensor_bool::BoolTensor @@ -75,7 +77,9 @@ impl FP16x16WTensor of TensorTrait { unravel_index(index, *self.shape) } - fn reshape(self: @Tensor, target_shape: Span, allowzero: bool) -> Tensor { + fn reshape( + self: @Tensor, target_shape: Span, allowzero: bool + ) -> Tensor { reshape(self, target_shape, allowzero) } @@ -93,10 +97,7 @@ impl FP16x16WTensor of TensorTrait { } fn argmax( - self: @Tensor, - axis: i32, - keepdims: Option, - select_last_index: Option + self: @Tensor, axis: i32, keepdims: Option, select_last_index: Option ) -> Tensor { math::argmax::argmax(self, axis, keepdims, select_last_index) } @@ -604,6 +605,18 @@ impl FP16x16WTensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn momentum( + r: FP16x16W, + t: FP16x16W, + inputs: @Tensor, + alpha: FP16x16W, + beta: FP16x16W, + mode: preview_training::momentum::MODE, + norm_coefficient: FP16x16W, + ) -> (Tensor, Tensor) { + preview_training::momentum::momentum(r, t, inputs, alpha, beta, mode, norm_coefficient) + } } /// Implements addition for `Tensor` using the `Add` trait. @@ -724,17 +737,19 @@ fn relative_eq(lhs: @FP16x16W, rhs: @FP16x16W) -> bool { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.data.len() != 0 && is_eq { - is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); - }; + while lhs.data.len() != 0 + && is_eq { + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; is_eq } diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 6ea3c7d94..bfc06d321 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -4,7 +4,9 @@ use orion::operators::tensor::core::{ new_tensor, constant_of_shape, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, }; -use orion::operators::tensor::{math, linalg, quantization, core as core_tensor, ml, manipulation}; +use orion::operators::tensor::{ + math, linalg, quantization, core as core_tensor, ml, manipulation, preview_training +}; use orion::numbers::{NumberTrait, FP32x32, FP32x32Impl, I8IntoFP32x32}; use orion::numbers::fixed_point::implementations::fp32x32::core::ONE; use orion::operators::tensor::implementations::{ @@ -68,7 +70,9 @@ impl FP32x32Tensor of TensorTrait { unravel_index(index, *self.shape) } - fn reshape(self: @Tensor, target_shape: Span, allowzero: bool) -> Tensor { + fn reshape( + self: @Tensor, target_shape: Span, allowzero: bool + ) -> Tensor { reshape(self, target_shape, allowzero) } @@ -350,9 +354,7 @@ impl FP32x32Tensor of TensorTrait { core_tensor::slice::(self, starts, ends, axes, steps) } - fn gather( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { + fn gather(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { math::gather::gather(self, indices, axis) } @@ -640,6 +642,18 @@ impl FP32x32Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn momentum( + r: FP32x32, + t: FP32x32, + inputs: @Tensor, + alpha: FP32x32, + beta: FP32x32, + mode: preview_training::momentum::MODE, + norm_coefficient: FP32x32, + ) -> (Tensor, Tensor) { + preview_training::momentum::momentum(r, t, inputs, alpha, beta, mode, norm_coefficient) + } } /// Implements addition for `Tensor` using the `Add` trait. @@ -771,17 +785,19 @@ fn relative_eq(lhs: @FP32x32, rhs: @FP32x32) -> bool { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.data.len() != 0 && is_eq { - is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); - }; + while lhs.data.len() != 0 + && is_eq { + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; is_eq } diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index af955fff1..4b72cc622 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -4,7 +4,9 @@ use orion::operators::tensor::core::{ new_tensor, constant_of_shape, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, }; -use orion::operators::tensor::{math, linalg, quantization, core as core_tensor, ml, manipulation}; +use orion::operators::tensor::{ + math, linalg, quantization, core as core_tensor, ml, manipulation, preview_training +}; use orion::numbers::{NumberTrait, FP64x64, FP64x64Impl, I8IntoFP64x64}; use orion::numbers::fixed_point::implementations::fp64x64::core::ONE; use orion::operators::tensor::implementations::{ @@ -68,7 +70,9 @@ impl FP64x64Tensor of TensorTrait { unravel_index(index, *self.shape) } - fn reshape(self: @Tensor, target_shape: Span, allowzero: bool) -> Tensor { + fn reshape( + self: @Tensor, target_shape: Span, allowzero: bool + ) -> Tensor { reshape(self, target_shape, allowzero) } @@ -350,9 +354,7 @@ impl FP64x64Tensor of TensorTrait { core_tensor::slice::(self, starts, ends, axes, steps) } - fn gather( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { + fn gather(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { math::gather::gather(self, indices, axis) } @@ -640,6 +642,18 @@ impl FP64x64Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn momentum( + r: FP64x64, + t: FP64x64, + inputs: @Tensor, + alpha: FP64x64, + beta: FP64x64, + mode: preview_training::momentum::MODE, + norm_coefficient: FP64x64, + ) -> (Tensor, Tensor) { + preview_training::momentum::momentum(r, t, inputs, alpha, beta, mode, norm_coefficient) + } } /// Implements addition for `Tensor` using the `Add` trait. @@ -771,17 +785,19 @@ fn relative_eq(lhs: @FP64x64, rhs: @FP64x64) -> bool { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.shape.len() != 0 && is_eq { - is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; is_eq } diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 19681e641..4dba20c76 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -4,7 +4,9 @@ use orion::operators::tensor::core::{ new_tensor, constant_of_shape, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, }; -use orion::operators::tensor::{math, linalg, quantization, core as core_ops, ml, manipulation}; +use orion::operators::tensor::{ + math, linalg, quantization, core as core_ops, ml, manipulation, preview_training +}; use orion::numbers::{NumberTrait, FP8x23, I8IntoFP8x23}; use orion::operators::tensor::implementations::{ tensor_i8::I8Tensor, tensor_u32::U32Tensor, tensor_bool::BoolTensor @@ -636,6 +638,18 @@ impl FP8x23Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn momentum( + r: FP8x23, + t: FP8x23, + inputs: @Tensor, + alpha: FP8x23, + beta: FP8x23, + mode: preview_training::momentum::MODE, + norm_coefficient: FP8x23, + ) -> (Tensor, Tensor) { + preview_training::momentum::momentum(r, t, inputs, alpha, beta, mode, norm_coefficient) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index ef65871d4..00134b433 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -4,7 +4,9 @@ use orion::operators::tensor::core::{ new_tensor, constant_of_shape, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, }; -use orion::operators::tensor::{math, linalg, quantization, core as core_tensor, ml, manipulation}; +use orion::operators::tensor::{ + math, linalg, quantization, core as core_tensor, ml, manipulation, preview_training +}; use orion::numbers::{NumberTrait, FP8x23W}; use orion::operators::tensor::implementations::{ tensor_i8::I8Tensor, tensor_u32::U32Tensor, tensor_bool::BoolTensor @@ -71,7 +73,9 @@ impl FP8x23WTensor of TensorTrait { unravel_index(index, *self.shape) } - fn reshape(self: @Tensor, target_shape: Span, allowzero: bool) -> Tensor { + fn reshape( + self: @Tensor, target_shape: Span, allowzero: bool + ) -> Tensor { reshape(self, target_shape, allowzero) } @@ -304,9 +308,7 @@ impl FP8x23WTensor of TensorTrait { core_tensor::slice::(self, starts, ends, axes, steps) } - fn gather( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { + fn gather(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { math::gather::gather(self, indices, axis) } @@ -581,6 +583,18 @@ impl FP8x23WTensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn momentum( + r: FP8x23W, + t: FP8x23W, + inputs: @Tensor, + alpha: FP8x23W, + beta: FP8x23W, + mode: preview_training::momentum::MODE, + norm_coefficient: FP8x23W, + ) -> (Tensor, Tensor) { + preview_training::momentum::momentum(r, t, inputs, alpha, beta, mode, norm_coefficient) + } } /// Implements addition for `Tensor` using the `Add` trait. @@ -725,17 +739,19 @@ fn relative_eq(lhs: @FP8x23W, rhs: @FP8x23W) -> bool { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.data.len() != 0 && is_eq { - is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); - }; + while lhs.data.len() != 0 + && is_eq { + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; is_eq } diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index 924a6b1fd..705d360c4 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -5,7 +5,9 @@ use orion::operators::tensor::core::{ new_tensor, constant_of_shape, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, }; -use orion::operators::tensor::{math, linalg, quantization, core as core_tensor, ml, manipulation}; +use orion::operators::tensor::{ + math, linalg, quantization, core as core_tensor, ml, manipulation, preview_training +}; use orion::numbers::{NumberTrait}; use orion::operators::tensor::implementations::{ tensor_u32::U32Tensor, tensor_i8::I8Tensor, tensor_bool::BoolTensor @@ -435,9 +437,7 @@ impl I32Tensor of TensorTrait { panic(array!['not supported!']) } - fn gather_elements( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { + fn gather_elements(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { math::gather_elements::gather_elements(self, indices, axis) } @@ -604,6 +604,18 @@ impl I32Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn momentum( + r: i32, + t: i32, + inputs: @Tensor, + alpha: i32, + beta: i32, + mode: preview_training::momentum::MODE, + norm_coefficient: i32, + ) -> (Tensor, Tensor) { + preview_training::momentum::momentum(r, t, inputs, alpha, beta, mode, norm_coefficient) + } } /// Implements addition for `Tensor` using the `Add` trait. @@ -716,17 +728,19 @@ impl I32TensorPartialOrd of PartialOrd> { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.data.len() != 0 && is_eq { - is_eq = lhs.data.pop_front().unwrap() == rhs.data.pop_front().unwrap(); - }; + while lhs.data.len() != 0 + && is_eq { + is_eq = lhs.data.pop_front().unwrap() == rhs.data.pop_front().unwrap(); + }; is_eq } diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index f523c47b2..914271ea8 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -5,7 +5,9 @@ use orion::operators::tensor::core::{ new_tensor, constant_of_shape, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, }; -use orion::operators::tensor::{math, linalg, quantization, core as core_tensor, ml, manipulation}; +use orion::operators::tensor::{ + math, linalg, quantization, core as core_tensor, ml, manipulation, preview_training +}; use orion::numbers::{NumberTrait}; use orion::operators::tensor::implementations::{tensor_u32::U32Tensor, tensor_bool::BoolTensor}; @@ -438,9 +440,7 @@ impl I8Tensor of TensorTrait { panic(array!['not supported!']) } - fn gather_elements( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { + fn gather_elements(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { math::gather_elements::gather_elements(self, indices, axis) } @@ -607,6 +607,18 @@ impl I8Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn momentum( + r: i8, + t: i8, + inputs: @Tensor, + alpha: i8, + beta: i8, + mode: preview_training::momentum::MODE, + norm_coefficient: i8, + ) -> (Tensor, Tensor) { + preview_training::momentum::momentum(r, t, inputs, alpha, beta, mode, norm_coefficient) + } } /// Implements addition for `Tensor` using the `Add` trait. @@ -707,17 +719,19 @@ impl I8TensorPartialOrd of PartialOrd> { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.data.len() == 0 && !is_eq { - is_eq = lhs.data.pop_front().unwrap() == rhs.data.pop_front().unwrap(); - }; + while lhs.data.len() == 0 + && !is_eq { + is_eq = lhs.data.pop_front().unwrap() == rhs.data.pop_front().unwrap(); + }; is_eq } diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 7aa2ade26..853f8fee0 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -4,7 +4,9 @@ use orion::operators::tensor::core::{ new_tensor, constant_of_shape, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, }; -use orion::operators::tensor::{math, linalg, quantization, core as core_tensor, ml, manipulation}; +use orion::operators::tensor::{ + math, linalg, quantization, core as core_tensor, ml, manipulation, preview_training +}; use orion::numbers::{NumberTrait}; use orion::operators::tensor::implementations::{tensor_i8::I8Tensor, tensor_bool::BoolTensor}; @@ -382,9 +384,7 @@ impl U32Tensor of TensorTrait { panic(array!['not supported!']) } - fn gather_elements( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { + fn gather_elements(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { math::gather_elements::gather_elements(self, indices, axis) } @@ -551,6 +551,18 @@ impl U32Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn momentum( + r: u32, + t: u32, + inputs: @Tensor, + alpha: u32, + beta: u32, + mode: preview_training::momentum::MODE, + norm_coefficient: u32, + ) -> (Tensor, Tensor) { + preview_training::momentum::momentum(r, t, inputs, alpha, beta, mode, norm_coefficient) + } } /// Implements addition for `Tensor` using the `Add` trait. @@ -661,17 +673,19 @@ impl U32TensorPartialOrd of PartialOrd> { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.data.len() != 0 && is_eq { - is_eq = lhs.data.pop_front().unwrap() == rhs.data.pop_front().unwrap(); - }; + while lhs.data.len() != 0 + && is_eq { + is_eq = lhs.data.pop_front().unwrap() == rhs.data.pop_front().unwrap(); + }; is_eq } diff --git a/src/operators/tensor/linalg/transpose.cairo b/src/operators/tensor/linalg/transpose.cairo index 97ad240b4..b1f4381c4 100644 --- a/src/operators/tensor/linalg/transpose.cairo +++ b/src/operators/tensor/linalg/transpose.cairo @@ -29,12 +29,13 @@ fn transpose, impl TCopy: Copy, impl TDrop: D let mut input_indices: Array = array![]; let mut output_axis: usize = 0; - while output_axis != axes.len() { - let input_axis = find_axis(axes, output_axis); - input_indices.append(*output_indices[input_axis]); + while output_axis != axes + .len() { + let input_axis = find_axis(axes, output_axis); + input_indices.append(*output_indices[input_axis]); - output_axis += 1; - }; + output_axis += 1; + }; let input_index = ravel_index(*self.shape, input_indices.span()); output_data.append(*(*self.data)[input_index]); diff --git a/src/operators/tensor/manipulation/split.cairo b/src/operators/tensor/manipulation/split.cairo index a8036f219..fb8621614 100644 --- a/src/operators/tensor/manipulation/split.cairo +++ b/src/operators/tensor/manipulation/split.cairo @@ -69,42 +69,45 @@ fn split_num_outputs, +Drop, +TensorTrait,>( let mut sli: MutMatrix = MutMatrixImpl::new((*t).shape.len(), 2); let mut pos: usize = 0; let mut i = 0; - while i != (*t).shape.len() { - let s: usize = *(*t).shape.at(i); - sli.set(i, 0, 0); - sli.set(i, 1, s); - i += 1; - }; + while i != (*t) + .shape + .len() { + let s: usize = *(*t).shape.at(i); + sli.set(i, 0, 0); + sli.set(i, 1, s); + i += 1; + }; let mut i: usize = 0; - while i != split.len() { - let spl = *split.at(i); - sli.set(axis, 0, pos); - pos += spl; - sli.set(axis, 1, pos); + while i != split + .len() { + let spl = *split.at(i); + sli.set(axis, 0, pos); + pos += spl; + sli.set(axis, 1, pos); - let end_ele_0 = match sli.get(axis, 0) { - Option::Some(res) => res, - Option::None => { - assert(false, 'Get end_ele_0 is failed'); - 0 - }, - }; - let end_ele_1 = match sli.get(axis, 1) { - Option::Some(res) => res, - Option::None => { - assert(false, 'Get end_ele_0 is failed'); - 0 - }, + let end_ele_0 = match sli.get(axis, 0) { + Option::Some(res) => res, + Option::None => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, + }; + let end_ele_1 = match sli.get(axis, 1) { + Option::Some(res) => res, + Option::None => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, + }; + let starts: Span = array![sli.get(0, 0).unwrap(), end_ele_0].span(); + let ends: Span = array![sli.get(0, 1).unwrap(), end_ele_1].span(); + let axes: Option> = Option::None(()); + let steps: Option> = Option::None(()); + let sub_t: Tensor = t.slice(starts, ends, axes, steps); + splited_t.append(sub_t); + i += 1; }; - let starts: Span = array![sli.get(0, 0).unwrap(), end_ele_0].span(); - let ends: Span = array![sli.get(0, 1).unwrap(), end_ele_1].span(); - let axes: Option> = Option::None(()); - let steps: Option> = Option::None(()); - let sub_t: Tensor = t.slice(starts, ends, axes, steps); - splited_t.append(sub_t); - i += 1; - }; splited_t } @@ -118,42 +121,46 @@ fn split_has_split, +Drop, +TensorTrait,>( let mut sli: MutMatrix = MutMatrixImpl::new((*t).shape.len(), 2); let mut pos: usize = 0; let mut i = 0; - while i != (*t).shape.len() { - let s: usize = *(*t).shape.at(i); - sli.set(i, 0, 0); - sli.set(i, 1, s); - i += 1; - }; + while i != (*t) + .shape + .len() { + let s: usize = *(*t).shape.at(i); + sli.set(i, 0, 0); + sli.set(i, 1, s); + i += 1; + }; let mut i: usize = 0; - while i != split.data.len() { - let spl: usize = split.at(indices: array![i].span()); - sli.set(axis, 0, pos); - pos += spl; - sli.set(axis, 1, pos); + while i != split + .data + .len() { + let spl: usize = split.at(indices: array![i].span()); + sli.set(axis, 0, pos); + pos += spl; + sli.set(axis, 1, pos); - let end_ele_0 = match sli.get(axis, 0) { - Option::Some(res) => res, - Option::None => { - assert(false, 'Get end_ele_0 is failed'); - 0 - }, - }; - let end_ele_1 = match sli.get(axis, 1) { - Option::Some(res) => res, - Option::None => { - assert(false, 'Get end_ele_0 is failed'); - 0 - }, + let end_ele_0 = match sli.get(axis, 0) { + Option::Some(res) => res, + Option::None => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, + }; + let end_ele_1 = match sli.get(axis, 1) { + Option::Some(res) => res, + Option::None => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, + }; + let starts: Span = array![sli.get(0, 0).unwrap(), end_ele_0].span(); + let ends: Span = array![sli.get(0, 1).unwrap(), end_ele_1].span(); + let axes: Option> = Option::None(()); + let steps: Option> = Option::None(()); + let sub_t: Tensor = t.slice(starts, ends, axes, steps); + splited_t.append(sub_t); + i += 1; }; - let starts: Span = array![sli.get(0, 0).unwrap(), end_ele_0].span(); - let ends: Span = array![sli.get(0, 1).unwrap(), end_ele_1].span(); - let axes: Option> = Option::None(()); - let steps: Option> = Option::None(()); - let sub_t: Tensor = t.slice(starts, ends, axes, steps); - splited_t.append(sub_t); - i += 1; - }; splited_t } diff --git a/src/operators/tensor/manipulation/split_to_sequence.cairo b/src/operators/tensor/manipulation/split_to_sequence.cairo index 2e8e4704c..51c3662e8 100644 --- a/src/operators/tensor/manipulation/split_to_sequence.cairo +++ b/src/operators/tensor/manipulation/split_to_sequence.cairo @@ -202,4 +202,4 @@ fn split_has_split, +Drop, +TensorTrait,>( }; splited_t -} \ No newline at end of file +} diff --git a/src/operators/tensor/math/cumsum.cairo b/src/operators/tensor/math/cumsum.cairo index 6fef885d2..247494617 100644 --- a/src/operators/tensor/math/cumsum.cairo +++ b/src/operators/tensor/math/cumsum.cairo @@ -51,37 +51,38 @@ fn cumsum_forward< let mut index: usize = 0; - while index != data.len() { - let current_indices = unravel_index(index, *self.shape); - let axis_value = *current_indices[axis]; - - if axis_value == 0 { - if exclusive { - output_data.append(zero); + while index != data + .len() { + let current_indices = unravel_index(index, *self.shape); + let axis_value = *current_indices[axis]; + + if axis_value == 0 { + if exclusive { + output_data.append(zero); + } else { + output_data.append(*(data)[index]); + } } else { - output_data.append(*(data)[index]); + let previous_axis_element_indices = replace_index( + current_indices, axis, axis_value - 1 + ); + let previous_axis_element_index = ravel_index( + *self.shape, previous_axis_element_indices + ); + + if exclusive { + output_data + .append( + *output_data[previous_axis_element_index] + + *(data)[previous_axis_element_index] + ); + } else { + output_data.append(*output_data[previous_axis_element_index] + *(data)[index]); + }; } - } else { - let previous_axis_element_indices = replace_index( - current_indices, axis, axis_value - 1 - ); - let previous_axis_element_index = ravel_index( - *self.shape, previous_axis_element_indices - ); - - if exclusive { - output_data - .append( - *output_data[previous_axis_element_index] - + *(data)[previous_axis_element_index] - ); - } else { - output_data.append(*output_data[previous_axis_element_index] + *(data)[index]); - }; - } - index += 1; - }; + index += 1; + }; TensorTrait::::new(*self.shape, output_data.span()) } @@ -106,54 +107,59 @@ fn cumsum_reverse< let data = *self.data; let mut output_data = array![]; let mut index: usize = 0; - while index != data.len() { - let current_indices = unravel_index(index, *self.shape); - let mut axis_value = *current_indices[axis]; - - if axis_value == 0 { - // If the axis value is 0, we need to sum all the elements - // in the axis. - let mut sum = *(data)[index]; - if exclusive { - sum = zero; - } + while index != data + .len() { + let current_indices = unravel_index(index, *self.shape); + let mut axis_value = *current_indices[axis]; + + if axis_value == 0 { + // If the axis value is 0, we need to sum all the elements + // in the axis. + let mut sum = *(data)[index]; + if exclusive { + sum = zero; + } - let end_index = *(*self.shape)[axis] - 1; + let end_index = *(*self.shape)[axis] - 1; - loop { - axis_value += 1; - if axis_value > end_index { - break (); - } + loop { + axis_value += 1; + if axis_value > end_index { + break (); + } - let next_axis_element_indices = replace_index(current_indices, axis, axis_value); - let next_axis_element_index = ravel_index(*self.shape, next_axis_element_indices); - sum += *data[next_axis_element_index]; - }; - - output_data.append(sum); - } else { - // If the axis value is not 0, we only need to do a subtraction - let previous_axis_element_indices = replace_index( - current_indices, axis, axis_value - 1 - ); - let previous_axis_element_index = ravel_index( - *self.shape, previous_axis_element_indices - ); - - if exclusive { - output_data.append(*output_data[previous_axis_element_index] - *(data)[index]); - } else { - output_data - .append( - *output_data[previous_axis_element_index] - - *(data)[previous_axis_element_index] + let next_axis_element_indices = replace_index( + current_indices, axis, axis_value ); + let next_axis_element_index = ravel_index( + *self.shape, next_axis_element_indices + ); + sum += *data[next_axis_element_index]; + }; + + output_data.append(sum); + } else { + // If the axis value is not 0, we only need to do a subtraction + let previous_axis_element_indices = replace_index( + current_indices, axis, axis_value - 1 + ); + let previous_axis_element_index = ravel_index( + *self.shape, previous_axis_element_indices + ); + + if exclusive { + output_data.append(*output_data[previous_axis_element_index] - *(data)[index]); + } else { + output_data + .append( + *output_data[previous_axis_element_index] + - *(data)[previous_axis_element_index] + ); + } } - } - index += 1; - }; + index += 1; + }; TensorTrait::::new(*self.shape, output_data.span()) } diff --git a/src/operators/tensor/math/gather_nd.cairo b/src/operators/tensor/math/gather_nd.cairo index e5f340487..99564de6f 100644 --- a/src/operators/tensor/math/gather_nd.cairo +++ b/src/operators/tensor/math/gather_nd.cairo @@ -127,10 +127,11 @@ fn gather_nd, impl TCopy: Copy, impl TDr if (index == *indices_shape_last - 1) { let mut data_ind: usize = result; - while data_ind != result + incrementer { - index_data.append(data_ind + incr); - data_ind += 1; - }; + while data_ind != result + + incrementer { + index_data.append(data_ind + incr); + data_ind += 1; + }; result = 0; }; diff --git a/src/operators/tensor/math/layer_normalization.cairo b/src/operators/tensor/math/layer_normalization.cairo index b6aa33ec0..6473ce15a 100644 --- a/src/operators/tensor/math/layer_normalization.cairo +++ b/src/operators/tensor/math/layer_normalization.cairo @@ -99,7 +99,8 @@ fn layer_normalization< let x_diff = x_mat - x_mean; let x_squared_diff = x_diff * x_diff; - let variance = x_squared_diff.reduce_sum(Option::Some(array![1].span()), Option::Some(true), Option::Some(false)) + let variance = x_squared_diff + .reduce_sum(Option::Some(array![1].span()), Option::Some(true), Option::Some(false)) / TensorTrait::new(shape_one.span(), col_number_tensor.span()); let variance_eps = variance + TensorTrait::new(shape_one.span(), epsilon_tensor.span()); diff --git a/src/operators/tensor/math/less_equal.cairo b/src/operators/tensor/math/less_equal.cairo index dea786878..dd54a0a41 100644 --- a/src/operators/tensor/math/less_equal.cairo +++ b/src/operators/tensor/math/less_equal.cairo @@ -4,12 +4,7 @@ use orion::operators::tensor::helpers::{ }; /// Cf: TensorTrait::less_equal docstring -fn less_equal< - T, - impl TPartialOrd: PartialOrd, - impl TCopy: Copy, - impl TDrop: Drop ->( +fn less_equal, impl TCopy: Copy, impl TDrop: Drop>( y: @Tensor, z: @Tensor ) -> Tensor { let broadcasted_shape = broadcast_shape(*y.shape, *z.shape); diff --git a/src/operators/tensor/math/max.cairo b/src/operators/tensor/math/max.cairo index 3ce6d4919..c6a55576a 100644 --- a/src/operators/tensor/math/max.cairo +++ b/src/operators/tensor/math/max.cairo @@ -28,35 +28,36 @@ fn max< let mut tensor_counter: usize = 1; - while tensor_counter != tensors.len() { - let mut new_max_data: Array = array![]; + while tensor_counter != tensors + .len() { + let mut new_max_data: Array = array![]; - let mut current_tensor = *tensors.at(tensor_counter); + let mut current_tensor = *tensors.at(tensor_counter); - let mut broadcasted_shape = broadcast_shape(max_shape, current_tensor.shape); + let mut broadcasted_shape = broadcast_shape(max_shape, current_tensor.shape); - let num_elements = len_from_shape(broadcasted_shape); - let mut n: usize = 0; - while n != num_elements { - let mut indices_broadcasted = unravel_index(n, broadcasted_shape); + let num_elements = len_from_shape(broadcasted_shape); + let mut n: usize = 0; + while n != num_elements { + let mut indices_broadcasted = unravel_index(n, broadcasted_shape); - let mut indices_self = broadcast_index_mapping(max_shape, indices_broadcasted); - let mut indices_other = broadcast_index_mapping( - current_tensor.shape, indices_broadcasted - ); + let mut indices_self = broadcast_index_mapping(max_shape, indices_broadcasted); + let mut indices_other = broadcast_index_mapping( + current_tensor.shape, indices_broadcasted + ); - let mut max_value = NumberTrait::max( - *(max_data)[indices_self], *(current_tensor.data)[indices_other] - ); - new_max_data.append(max_value); + let mut max_value = NumberTrait::max( + *(max_data)[indices_self], *(current_tensor.data)[indices_other] + ); + new_max_data.append(max_value); - n += 1; - }; + n += 1; + }; - max_shape = broadcasted_shape; - max_data = new_max_data.span(); - tensor_counter += 1; - }; + max_shape = broadcasted_shape; + max_data = new_max_data.span(); + tensor_counter += 1; + }; TensorTrait::::new(max_shape, max_data) } diff --git a/src/operators/tensor/math/min.cairo b/src/operators/tensor/math/min.cairo index 2e7acadab..9f0de2dfd 100644 --- a/src/operators/tensor/math/min.cairo +++ b/src/operators/tensor/math/min.cairo @@ -28,35 +28,36 @@ fn min< let mut tensor_counter: usize = 1; - while tensor_counter != tensors.len() { - let mut new_min_data: Array = array![]; + while tensor_counter != tensors + .len() { + let mut new_min_data: Array = array![]; - let mut current_tensor = *tensors.at(tensor_counter); + let mut current_tensor = *tensors.at(tensor_counter); - let mut broadcasted_shape = broadcast_shape(min_shape, current_tensor.shape); + let mut broadcasted_shape = broadcast_shape(min_shape, current_tensor.shape); - let num_elements = len_from_shape(broadcasted_shape); - let mut n: usize = 0; - while n != num_elements { - let mut indices_broadcasted = unravel_index(n, broadcasted_shape); + let num_elements = len_from_shape(broadcasted_shape); + let mut n: usize = 0; + while n != num_elements { + let mut indices_broadcasted = unravel_index(n, broadcasted_shape); - let mut indices_self = broadcast_index_mapping(min_shape, indices_broadcasted); - let mut indices_other = broadcast_index_mapping( - current_tensor.shape, indices_broadcasted - ); + let mut indices_self = broadcast_index_mapping(min_shape, indices_broadcasted); + let mut indices_other = broadcast_index_mapping( + current_tensor.shape, indices_broadcasted + ); - let mut min_value = NumberTrait::min( - *(min_data)[indices_self], *(current_tensor.data)[indices_other] - ); - new_min_data.append(min_value); + let mut min_value = NumberTrait::min( + *(min_data)[indices_self], *(current_tensor.data)[indices_other] + ); + new_min_data.append(min_value); - n += 1; - }; + n += 1; + }; - min_shape = broadcasted_shape; - min_data = new_min_data.span(); - tensor_counter += 1; - }; + min_shape = broadcasted_shape; + min_data = new_min_data.span(); + tensor_counter += 1; + }; TensorTrait::::new(min_shape, min_data) } diff --git a/src/operators/tensor/math/range.cairo b/src/operators/tensor/math/range.cairo index 1edc0f628..b1950f2b9 100644 --- a/src/operators/tensor/math/range.cairo +++ b/src/operators/tensor/math/range.cairo @@ -18,11 +18,12 @@ fn range< ) -> Tensor { let mut result: Array = array![]; let zero: T = NumberTrait::zero(); - while !(step >= zero && start >= end) && !(step <= zero && start <= end) { - let v = start; - result.append(v); - start += step; - }; + while !(step >= zero && start >= end) + && !(step <= zero && start <= end) { + let v = start; + result.append(v); + start += step; + }; let shape = array![result.len()]; diff --git a/src/operators/tensor/math/reduce_l1.cairo b/src/operators/tensor/math/reduce_l1.cairo index 29b83b69d..422af1a6d 100644 --- a/src/operators/tensor/math/reduce_l1.cairo +++ b/src/operators/tensor/math/reduce_l1.cairo @@ -16,5 +16,10 @@ fn reduce_l1< ) -> Tensor { let data_abs = self.abs(); - data_abs.reduce_sum(Option::Some(array![axis.try_into().unwrap()].span()), Option::Some(keepdims), Option::Some(false)) + data_abs + .reduce_sum( + Option::Some(array![axis.try_into().unwrap()].span()), + Option::Some(keepdims), + Option::Some(false) + ) } diff --git a/src/operators/tensor/math/resize.cairo b/src/operators/tensor/math/resize.cairo index ab0ef86f7..5b93e5497 100644 --- a/src/operators/tensor/math/resize.cairo +++ b/src/operators/tensor/math/resize.cairo @@ -283,13 +283,14 @@ fn interpolate_nd< KEEP_ASPECT_RATIO_POLICY::NOT_LARGER => { let mut scale = *scale_factors.at(*axes.at(0)); let mut i = 1; - while i != axes.len() { - if scale > *scale_factors.at(*axes.at(i)) { - scale = *scale_factors.at(*axes.at(i)); - } + while i != axes + .len() { + if scale > *scale_factors.at(*axes.at(i)) { + scale = *scale_factors.at(*axes.at(i)); + } - i += 1; - }; + i += 1; + }; let mut scale_factors: Array = array![]; let mut d = 0; @@ -341,13 +342,14 @@ fn interpolate_nd< KEEP_ASPECT_RATIO_POLICY::NOT_SMALLER => { let mut scale = *scale_factors.at(*axes.at(0)); let mut i = 1; - while i != axes.len() { - if scale < *scale_factors.at(*axes.at(i)) { - scale = *scale_factors.at(*axes.at(i)); - } + while i != axes + .len() { + if scale < *scale_factors.at(*axes.at(i)) { + scale = *scale_factors.at(*axes.at(i)); + } - i += 1; - }; + i += 1; + }; let mut scale_factors: Array = array![]; let mut d = 0; @@ -409,12 +411,13 @@ fn interpolate_nd< }; let mut i = 0; - while i != scale_factors.len() { - let item = *scale_factors.at(i) - * NumberTrait::new_unscaled((*(*(data).shape).at(i)).into(), false); - output_size.append(item.try_into().unwrap()); - i += 1; - }; + while i != scale_factors + .len() { + let item = *scale_factors.at(i) + * NumberTrait::new_unscaled((*(*(data).shape).at(i)).into(), false); + output_size.append(item.try_into().unwrap()); + i += 1; + }; (output_size.span(), scale_factors) }, @@ -422,17 +425,18 @@ fn interpolate_nd< let mut ret: Array> = array![]; let mut i = 0; - while i != output_size.len() { - let mut temp = ArrayTrait::::new(); - let mut j = 0; - while j != *output_size.at(i) { - temp.append(j); - j += 1; - }; + while i != output_size + .len() { + let mut temp = ArrayTrait::::new(); + let mut j = 0; + while j != *output_size.at(i) { + temp.append(j); + j += 1; + }; - ret.append(temp.span()); - i += 1; - }; + ret.append(temp.span()); + i += 1; + }; let mut ret = cartesian(ret.span()); let mut ret_data = array![]; @@ -442,10 +446,11 @@ fn interpolate_nd< Option::Some(X) => { let mut x: Array = array![]; let mut i = 0; - while i != X.len() { - x.append(NumberTrait::new_unscaled((*X.at(i)).into(), false)); - i += 1; - }; + while i != X + .len() { + x.append(NumberTrait::new_unscaled((*X.at(i)).into(), false)); + i += 1; + }; let mut x = x.span(); let item = interpolate_nd_with_x( @@ -499,14 +504,15 @@ fn cartesian(mut arrays: Span>,) -> Array> { let mut m = n; let mut i = 0; - while i != arrays.len() { - m = m / (*(arrays.at(i))).len(); - let mut out = repeat(*(arrays.at(i)), m); - out = repeat_2(out, size_arrays, i); + while i != arrays + .len() { + m = m / (*(arrays.at(i))).len(); + let mut out = repeat(*(arrays.at(i)), m); + out = repeat_2(out, size_arrays, i); - output_arrays.append(out); - i += 1; - }; + output_arrays.append(out); + i += 1; + }; let output_arrays = output_arrays.span(); @@ -532,15 +538,16 @@ fn repeat_2(mut array: Array, size_array: Span, index: usize) -> A let mut i = 0; while i != index { let mut j = 1; - while j != *size_array.at(index - 1 - i) { - let mut k = 0; - while k != size { - array.append(*array.at(k)); - k += 1; - }; + while j != *size_array + .at(index - 1 - i) { + let mut k = 0; + while k != size { + array.append(*array.at(k)); + k += 1; + }; - j += 1; - }; + j += 1; + }; size = size * *size_array.at(index - 1 - i); i += 1; @@ -552,15 +559,16 @@ fn repeat_2(mut array: Array, size_array: Span, index: usize) -> A fn repeat(array: Span, m: usize,) -> Array { let mut out = array![]; let mut j = 0; - while j != array.len() { - let mut k = 0; - while k != m { - out.append(*array.at(j)); - k += 1; - }; + while j != array + .len() { + let mut k = 0; + while k != m { + out.append(*array.at(j)); + k += 1; + }; - j += 1; - }; + j += 1; + }; out } @@ -648,35 +656,37 @@ fn interpolate_nd_with_x< }; let mut i = 0; - while i != *(*data).shape.at(0) { - let data = get_row_n(data, i); - - let mut r = interpolate_nd_with_x( - @data, - n - 1, - scale_factor, - output_size, - x, - antialias, - mode, - nearest_mode, - reduced_roi, - extrapolation_value, - coordinate_transformation_mode, - exclude_outside, - cubic_coeff_a - ); + while i != *(*data) + .shape + .at(0) { + let data = get_row_n(data, i); + + let mut r = interpolate_nd_with_x( + @data, + n - 1, + scale_factor, + output_size, + x, + antialias, + mode, + nearest_mode, + reduced_roi, + extrapolation_value, + coordinate_transformation_mode, + exclude_outside, + cubic_coeff_a + ); + + loop { + match r.data.pop_front() { + Option::Some(item) => { res1d.append(*item); }, + Option::None => { break; } + } + }; - loop { - match r.data.pop_front() { - Option::Some(item) => { res1d.append(*item); }, - Option::None => { break; } - } + i += 1; }; - i += 1; - }; - let mut shape = array![]; shape.append(res1d.len()); @@ -727,14 +737,16 @@ fn get_row_n, +Copy, +Drop,>( let mut stride_output = 1; let mut i = 0; - while i != (*data).shape.len() { - if i != 0 { - output_shape.append(*(*data).shape.at(i)); - stride_output = stride_output * *(*data).shape.at(i); - } + while i != (*data) + .shape + .len() { + if i != 0 { + output_shape.append(*(*data).shape.at(i)); + stride_output = stride_output * *(*data).shape.at(i); + } - i += 1; - }; + i += 1; + }; let mut i = 0; while i != stride_output { @@ -897,17 +909,19 @@ fn interpolate_1d_with_x< let mut coeffs_exclude_outside: Array = array![]; let mut sum = NumberTrait::zero(); let mut i = 0; - while i != idxes.data.len() { - if *idxes.data.at(i) { - coeffs_exclude_outside.append(NumberTrait::zero()); - sum += NumberTrait::zero(); - } else { - coeffs_exclude_outside.append(*coeffs.data.at(i)); - sum += *coeffs.data.at(i); - } + while i != idxes + .data + .len() { + if *idxes.data.at(i) { + coeffs_exclude_outside.append(NumberTrait::zero()); + sum += NumberTrait::zero(); + } else { + coeffs_exclude_outside.append(*coeffs.data.at(i)); + sum += *coeffs.data.at(i); + } - i += 1; - }; + i += 1; + }; let mut coeff_div: Array = array![]; let mut i = 0; @@ -974,21 +988,23 @@ fn get_neighbor< let mut idxes_centered = array![]; let mut ret = array![]; let mut i = 0; - while i != idxes.data.len() { - ret.append(*padded.at(*idxes.data.at(i))); - - if *idxes.data.at(i) >= pad_width { - if (*idxes.data.at(i) - pad_width) >= (*data).data.len() { - idxes_centered.append(true); + while i != idxes + .data + .len() { + ret.append(*padded.at(*idxes.data.at(i))); + + if *idxes.data.at(i) >= pad_width { + if (*idxes.data.at(i) - pad_width) >= (*data).data.len() { + idxes_centered.append(true); + } else { + idxes_centered.append(false); + } } else { - idxes_centered.append(false); + idxes_centered.append(true); } - } else { - idxes_centered.append(true); - } - i += 1; - }; + i += 1; + }; let mut shape = array![]; shape.append(idxes.data.len()); @@ -1049,22 +1065,23 @@ fn get_neighbor_idxes< } let mut i = 0; - while i != n / 2 { - if i_low - i < 0 { - idxes.append(i_high + i); - i_high += 1; - } else { - idxes.append(i_low - i); - } - if i_high + i >= limit { - i_low -= 1; - idxes.append(i_low - i); - } else { - idxes.append(i_high + i); - } + while i != n + / 2 { + if i_low - i < 0 { + idxes.append(i_high + i); + i_high += 1; + } else { + idxes.append(i_low - i); + } + if i_high + i >= limit { + i_low -= 1; + idxes.append(i_low - i); + } else { + idxes.append(i_high + i); + } - i += 1; - } + i += 1; + } } else { core::panic_with_felt252('MUST BE EVEN'); } @@ -1129,21 +1146,22 @@ fn linear_coeffs_antialias< // arange and clip + compute sum let mut i = start; - while i != start + footprint { - let value = NumberTrait::one() - NumberTrait::abs((i - ratio) * scale); - - if value < NumberTrait::zero() { - coeffs.append(NumberTrait::zero()); - } else if value > NumberTrait::one() { - coeffs.append(NumberTrait::one()); - sum += NumberTrait::one(); - } else { - coeffs.append(value); - sum += value; - } + while i != start + + footprint { + let value = NumberTrait::one() - NumberTrait::abs((i - ratio) * scale); + + if value < NumberTrait::zero() { + coeffs.append(NumberTrait::zero()); + } else if value > NumberTrait::one() { + coeffs.append(NumberTrait::one()); + sum += NumberTrait::one(); + } else { + coeffs.append(value); + sum += value; + } - i += NumberTrait::one(); - }; + i += NumberTrait::one(); + }; let n = coeffs.len(); diff --git a/src/operators/tensor/preview_training.cairo b/src/operators/tensor/preview_training.cairo new file mode 100644 index 000000000..45773ddfd --- /dev/null +++ b/src/operators/tensor/preview_training.cairo @@ -0,0 +1 @@ +mod momentum; diff --git a/src/operators/tensor/preview_training/momentum.cairo b/src/operators/tensor/preview_training/momentum.cairo new file mode 100644 index 000000000..913a6faee --- /dev/null +++ b/src/operators/tensor/preview_training/momentum.cairo @@ -0,0 +1,145 @@ +use core::array::ArrayTrait; +use orion::numbers::NumberTrait; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[derive(Copy, Drop)] +enum MODE { + STANDARD, + NESTEROV +} + +/// Cf: TensorTrait::momentum docstring +fn momentum< + T, + MAG, + +TensorTrait, + +NumberTrait, + +Copy, + +Drop, + +Add, + +Mul, + +Sub, + +PartialOrd, +>( + r: T, t: T, inputs: @Tensor, alpha: T, beta: T, mode: MODE, norm_coefficient: T, +) -> (Tensor, Tensor) { + if (*inputs).data.len() == 3 { + let (x, v) = run_momentum( + r, + t, + *(*inputs).data.at(0), + *(*inputs).data.at(1), + *(*inputs).data.at(2), + alpha, + beta, + mode, + norm_coefficient + ); + return ( + TensorTrait::new(array![1].span(), array![x].span()), + TensorTrait::new(array![1].span(), array![v].span()) + ); + } + + let n = (*inputs).data.len() / 3; + let mut xs = ArrayTrait::new(); + let mut vs = ArrayTrait::new(); + + let mut i = 0; + while i != n { + let (x, v) = run_momentum( + r, + t, + *(*inputs).data.at(i), + *(*inputs).data.at(n + i), + *(*inputs).data.at(n * 2 + i), + alpha, + beta, + mode, + norm_coefficient + ); + xs.append(x); + vs.append(v); + i += 1; + }; + return ( + TensorTrait::new(array![xs.len()].span(), xs.span()), + TensorTrait::new(array![vs.len()].span(), vs.span()) + ); +} + + +fn run_momentum< + T, + MAG, + +TensorTrait, + +NumberTrait, + +Copy, + +Drop, + +Add, + +Mul, + +Sub, + +PartialOrd, +>( + r: T, t: T, x: T, g: T, v: T, alpha: T, beta: T, mode: MODE, norm_coefficient: T, +) -> (T, T) { + match mode { + MODE::STANDARD => { apply_momentum(r, t, x, g, v, alpha, beta, mode, norm_coefficient) }, + MODE::NESTEROV => { apply_nesterov(r, t, x, g, v, alpha, beta, mode, norm_coefficient) } + } +} + +fn apply_momentum< + T, + MAG, + +TensorTrait, + +NumberTrait, + +Copy, + +Drop, + +Add, + +Mul, + +Sub, + +PartialOrd, +>( + r: T, t: T, x: T, g: T, v: T, alpha: T, beta: T, mode: MODE, norm_coefficient: T, +) -> (T, T) { + let g_regularized = norm_coefficient * x + g; + + let beta_adjusted = if t > NumberTrait::zero() { + beta + } else { + NumberTrait::one() + }; + let v_new = alpha * v + beta_adjusted * g_regularized; + let x_new = x - r * v_new; + + return (x_new, v_new); +} + +fn apply_nesterov< + T, + MAG, + +TensorTrait, + +NumberTrait, + +Copy, + +Drop, + +Add, + +Mul, + +Sub, + +PartialOrd, +>( + r: T, t: T, x: T, g: T, v: T, alpha: T, beta: T, mode: MODE, norm_coefficient: T, +) -> (T, T) { + let g_regularized = norm_coefficient * x + g; + + let beta_adjusted = if t > NumberTrait::zero() { + beta + } else { + NumberTrait::one() + }; + let v_new = alpha * v + beta_adjusted * g_regularized; + let x_new = x - r * (g_regularized + alpha * v_new); + + return (x_new, v_new); +} + diff --git a/src/operators/tensor/quantization/qlinear_matmul.cairo b/src/operators/tensor/quantization/qlinear_matmul.cairo index 325f4fd30..f869ae6d0 100644 --- a/src/operators/tensor/quantization/qlinear_matmul.cairo +++ b/src/operators/tensor/quantization/qlinear_matmul.cairo @@ -78,14 +78,15 @@ fn qlinear_matmul< b_shape_reduced.append(n); let mut i = 0; - while i != stride(a_shape) / (m * k) { - result_updates( - @subtensor(@dequantized_a, i * (m * k), a_shape_reduced.span()), - @subtensor(@dequantized_b, i * (k * n), b_shape_reduced.span()), - ref x_data - ); - i += 1; - }; + while i != stride(a_shape) + / (m * k) { + result_updates( + @subtensor(@dequantized_a, i * (m * k), a_shape_reduced.span()), + @subtensor(@dequantized_b, i * (k * n), b_shape_reduced.span()), + ref x_data + ); + i += 1; + }; x_shape(ref x_shape, a_shape, m, n); let x = TensorTrait::new(x_shape.span(), x_data.span()); @@ -94,12 +95,13 @@ fn qlinear_matmul< } fn x_shape(ref x_data: Array, mut shape: Span, m: usize, n: usize) { - while shape.len() != 2 { - match shape.pop_front() { - Option::Some(elem) => { x_data.append(*elem); }, - Option::None => { break; } + while shape + .len() != 2 { + match shape.pop_front() { + Option::Some(elem) => { x_data.append(*elem); }, + Option::None => { break; } + }; }; - }; x_data.append(m); x_data.append(n); diff --git a/tests/lib.cairo b/tests/lib.cairo index c408347ef..f5cecb77d 100644 --- a/tests/lib.cairo +++ b/tests/lib.cairo @@ -5,4 +5,3 @@ mod nodes; mod ml; mod operators; - diff --git a/tests/nodes.cairo b/tests/nodes.cairo index 244d8b0c9..b2ecca90e 100644 --- a/tests/nodes.cairo +++ b/tests/nodes.cairo @@ -985,3 +985,5 @@ mod argmax_negative_axis_keepdims; mod argmax_negative_axis_keepdims_select_last_index; mod argmax_no_keepdims; mod argmax_no_keepdims_select_last_index; +mod momentum_standard; +mod momentum_nesterov; diff --git a/tests/nodes/gather_elements_axis1.cairo b/tests/nodes/gather_elements_axis1.cairo index 82b08e271..638f2d60b 100644 --- a/tests/nodes/gather_elements_axis1.cairo +++ b/tests/nodes/gather_elements_axis1.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_axis1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_elements(indices:input_1, axis:Option::Some(1)); + let y_0 = input_0.gather_elements(indices: input_1, axis: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_elements_axis2.cairo b/tests/nodes/gather_elements_axis2.cairo index 0e0b7caea..dc7b0c373 100644 --- a/tests/nodes/gather_elements_axis2.cairo +++ b/tests/nodes/gather_elements_axis2.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_axis2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_elements(indices:input_1, axis:Option::Some(2)); + let y_0 = input_0.gather_elements(indices: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_elements_default.cairo b/tests/nodes/gather_elements_default.cairo index 9d1a099c1..98d66dd58 100644 --- a/tests/nodes/gather_elements_default.cairo +++ b/tests/nodes/gather_elements_default.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_elements(indices:input_1, axis:Option::Some(0)); + let y_0 = input_0.gather_elements(indices: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_elements_negative_indices.cairo b/tests/nodes/gather_elements_negative_indices.cairo index 0aff55566..8980eff88 100644 --- a/tests/nodes/gather_elements_negative_indices.cairo +++ b/tests/nodes/gather_elements_negative_indices.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_negative_indices() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_elements(indices:input_1, axis:Option::Some(0)); + let y_0 = input_0.gather_elements(indices: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_fp16x16_3d_axis1.cairo b/tests/nodes/gather_fp16x16_3d_axis1.cairo index d10ab5245..643aacd78 100644 --- a/tests/nodes/gather_fp16x16_3d_axis1.cairo +++ b/tests/nodes/gather_fp16x16_3d_axis1.cairo @@ -18,7 +18,7 @@ fn test_gather_fp16x16_3d_axis1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather(indices:input_1, axis:Option::Some(1)); + let y_0 = input_0.gather(indices: input_1, axis: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_fp16x16_3d_axis2.cairo b/tests/nodes/gather_fp16x16_3d_axis2.cairo index 40ef5691d..e256664b6 100644 --- a/tests/nodes/gather_fp16x16_3d_axis2.cairo +++ b/tests/nodes/gather_fp16x16_3d_axis2.cairo @@ -18,7 +18,7 @@ fn test_gather_fp16x16_3d_axis2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather(indices:input_1, axis:Option::Some(2)); + let y_0 = input_0.gather(indices: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_fp16x16_3d_default.cairo b/tests/nodes/gather_fp16x16_3d_default.cairo index 2003b0838..affd608a4 100644 --- a/tests/nodes/gather_fp16x16_3d_default.cairo +++ b/tests/nodes/gather_fp16x16_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_fp16x16_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather(indices:input_1, axis:Option::Some(0)); + let y_0 = input_0.gather(indices: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_negative_axis.cairo b/tests/nodes/gather_negative_axis.cairo index 27c511614..b5b4854e5 100644 --- a/tests/nodes/gather_negative_axis.cairo +++ b/tests/nodes/gather_negative_axis.cairo @@ -18,7 +18,7 @@ fn test_gather_negative_axis() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather(indices:input_1, axis:Option::Some(-1)); + let y_0 = input_0.gather(indices: input_1, axis: Option::Some(-1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_negative_indices.cairo b/tests/nodes/gather_negative_indices.cairo index 559a276ea..03dfbaa74 100644 --- a/tests/nodes/gather_negative_indices.cairo +++ b/tests/nodes/gather_negative_indices.cairo @@ -18,7 +18,7 @@ fn test_gather_negative_indices() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather(indices:input_1, axis:Option::Some(0)); + let y_0 = input_0.gather(indices: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/momentum_nesterov.cairo b/tests/nodes/momentum_nesterov.cairo new file mode 100644 index 000000000..9aef810d9 --- /dev/null +++ b/tests/nodes/momentum_nesterov.cairo @@ -0,0 +1,33 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::preview_training::momentum::MODE; + +#[test] +#[available_gas(2000000000)] +fn test_momentum_nesterov() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let (y0, y1) = TensorTrait::momentum( + *input_1.data.at(0), + *input_1.data.at(1), + @input_0, + *input_1.data.at(2), + *input_1.data.at(3), + MODE::NESTEROV, + *input_1.data.at(4) + ); + + assert_eq(y0, *z.at(0)); + assert_eq(y1, *z.at(1)); +} + diff --git a/tests/nodes/momentum_nesterov/input_0.cairo b/tests/nodes/momentum_nesterov/input_0.cairo new file mode 100644 index 000000000..2c6d4f43e --- /dev/null +++ b/tests/nodes/momentum_nesterov/input_0.cairo @@ -0,0 +1,18 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 183500, sign: false }); + data.append(FP16x16 { mag: 61603, sign: true }); + data.append(FP16x16 { mag: 163840, sign: true }); + data.append(FP16x16 { mag: 111411, sign: false }); + data.append(FP16x16 { mag: 235929, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/momentum_nesterov/input_1.cairo b/tests/nodes/momentum_nesterov/input_1.cairo new file mode 100644 index 000000000..cffe4e3b2 --- /dev/null +++ b/tests/nodes/momentum_nesterov/input_1.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 6553, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 62259, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 655, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/momentum_nesterov/output_0.cairo b/tests/nodes/momentum_nesterov/output_0.cairo new file mode 100644 index 000000000..569ab2597 --- /dev/null +++ b/tests/nodes/momentum_nesterov/output_0.cairo @@ -0,0 +1,28 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 80447, sign: false }); + data.append(FP16x16 { mag: 193799, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 45023, sign: false }); + data.append(FP16x16 { mag: 62128, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/momentum_standard.cairo b/tests/nodes/momentum_standard.cairo new file mode 100644 index 000000000..70d5ab948 --- /dev/null +++ b/tests/nodes/momentum_standard.cairo @@ -0,0 +1,33 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::preview_training::momentum::MODE; + +#[test] +#[available_gas(2000000000)] +fn test_momentum_standard() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z = output_0::output_0(); + + let (y0, y1) = TensorTrait::momentum( + *input_1.data.at(0), + *input_1.data.at(1), + @input_0, + *input_1.data.at(2), + *input_1.data.at(3), + MODE::STANDARD, + *input_1.data.at(4) + ); + + assert_eq(y0, *z.at(0)); + assert_eq(y1, *z.at(1)); +} + diff --git a/tests/nodes/momentum_standard/input_0.cairo b/tests/nodes/momentum_standard/input_0.cairo new file mode 100644 index 000000000..2c6d4f43e --- /dev/null +++ b/tests/nodes/momentum_standard/input_0.cairo @@ -0,0 +1,18 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 183500, sign: false }); + data.append(FP16x16 { mag: 61603, sign: true }); + data.append(FP16x16 { mag: 163840, sign: true }); + data.append(FP16x16 { mag: 111411, sign: false }); + data.append(FP16x16 { mag: 235929, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/momentum_standard/input_1.cairo b/tests/nodes/momentum_standard/input_1.cairo new file mode 100644 index 000000000..1e2564367 --- /dev/null +++ b/tests/nodes/momentum_standard/input_1.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 6553, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 62259, sign: false }); + data.append(FP16x16 { mag: 6553, sign: false }); + data.append(FP16x16 { mag: 65, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/momentum_standard/output_0.cairo b/tests/nodes/momentum_standard/output_0.cairo new file mode 100644 index 000000000..8d2987fba --- /dev/null +++ b/tests/nodes/momentum_standard/output_0.cairo @@ -0,0 +1,28 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 74211, sign: false }); + data.append(FP16x16 { mag: 177453, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 44315, sign: false }); + data.append(FP16x16 { mag: 60476, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/reshape_reduced_dims.cairo b/tests/nodes/reshape_reduced_dims.cairo index 7952505d1..4d42db34e 100644 --- a/tests/nodes/reshape_reduced_dims.cairo +++ b/tests/nodes/reshape_reduced_dims.cairo @@ -14,7 +14,7 @@ fn test_reshape_reduced_dims() { let input_0 = input_0::input_0(); let z_0 = output_0::output_0(); - let y_0 = input_0.reshape(array![2,12].span(), false); + let y_0 = input_0.reshape(array![2, 12].span(), false); assert_eq(y_0, z_0); } diff --git a/tests/nodes/reshape_reordered_all_dims.cairo b/tests/nodes/reshape_reordered_all_dims.cairo index 237c867c2..b9d1f456e 100644 --- a/tests/nodes/reshape_reordered_all_dims.cairo +++ b/tests/nodes/reshape_reordered_all_dims.cairo @@ -14,7 +14,7 @@ fn test_reshape_reordered_all_dims() { let input_0 = input_0::input_0(); let z_0 = output_0::output_0(); - let y_0 = input_0.reshape(array![4,2,3].span(), false); + let y_0 = input_0.reshape(array![4, 2, 3].span(), false); assert_eq(y_0, z_0); } diff --git a/tests/nodes/reshape_reordered_last_dims.cairo b/tests/nodes/reshape_reordered_last_dims.cairo index 5c5f4fd7e..79a82982a 100644 --- a/tests/nodes/reshape_reordered_last_dims.cairo +++ b/tests/nodes/reshape_reordered_last_dims.cairo @@ -14,7 +14,7 @@ fn test_reshape_reordered_last_dims() { let input_0 = input_0::input_0(); let z_0 = output_0::output_0(); - let y_0 = input_0.reshape(array![2,4,3].span(), false); + let y_0 = input_0.reshape(array![2, 4, 3].span(), false); assert_eq(y_0, z_0); }