diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index fa13bee87..10ca0b173 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -4,11 +4,21 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] - 2023-12-01 + +## Added +- Reduce LogSum Operator + +## [Unreleased] - 2023-12-05 + +## Added +- Erf Operator + ## [Unreleased] - 2023-11-27 ## Added - Reduce Prod Operator -- + ## [Unreleased] - 2023-11-20 ## Added diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index cd345421c..4e6a77a15 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -123,6 +123,8 @@ * [tensor.is\_nan](framework/operators/tensor/tensor.is\_nan.md) * [tensor.is_inf](framework/operators/tensor/tensor.is\_inf.md) * [tensor.not](framework/operators/tensor/tensor.not.md) + * [tensor.erf](framework/operators/tensor/tensor.erf.md) + * [tensor.reduce_log_sum](framework/operators/tensor/tensor.reduce_log_sum.md) * [Neural Network](framework/operators/neural-network/README.md) * [nn.relu](framework/operators/neural-network/nn.relu.md) * [nn.leaky\_relu](framework/operators/neural-network/nn.leaky\_relu.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index ac25bb642..2419946b6 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -101,6 +101,8 @@ You can see below the list of current supported ONNX Operators: | [IsNaN](operators/tensor/tensor.is\_nan.md) | :white\_check\_mark: | | [IsInf](operators/tensor/tensor.is\_inf.md) | :white\_check\_mark: | | [Not](operators/tensor/tensor.not.md) | :white\_check\_mark: | +| [ReduceLogSum](operators/tensor/tensor.reduce\_log\_sum.md) | :white\_check\_mark: | +| [Erf](operators/tensor/tensor.erf.md) | :white\_check\_mark: | -Current Operators support: **95/156 (60%)** +Current Operators support: **96/156 (62%)** diff --git a/docs/framework/numbers/fixed-point/README.md b/docs/framework/numbers/fixed-point/README.md index f30122676..7da277d55 100644 --- a/docs/framework/numbers/fixed-point/README.md +++ b/docs/framework/numbers/fixed-point/README.md @@ -69,6 +69,7 @@ use orion::numbers::fixed_point::core::FixedTrait; | [`fp.sinh`](fp.sinh.md) | Returns the value of the hyperbolic sine of the fixed point number. | | [`fp.tanh`](fp.tanh.md) | Returns the value of the hyperbolic tangent of the fixed point number. | | [`fp.sign`](fp.sign.md) | Returns the element-wise indication of the sign of the input fixed point number. | +| [`fp.erf`](fp.erf.md) | The error function of the input fixed point number computed element-wise.| ### Arithmetic & Comparison operators diff --git a/docs/framework/numbers/fixed-point/fp.erf.md b/docs/framework/numbers/fixed-point/fp.erf.md new file mode 100644 index 000000000..82b5d826d --- /dev/null +++ b/docs/framework/numbers/fixed-point/fp.erf.md @@ -0,0 +1,30 @@ +# fp.erf + +```rust +fn erf(self: T) -> T; +``` + +Returns the error function of the input fixed point number computed element-wise. + +## Args + +* `self`(`T`) - The input fixed point + +## Returns + +The error function of the input fixed point number computed element-wise. + +## Examples + +```rust +use orion::numbers::{FP16x16, FP16x16Impl, FixedTrait}; + +fn erf_fp_example() -> FP16x16 { + // We instantiate fixed point here. + let fp = FixedTrait::new(65536, false); + + // We can call `erf` function as follows. + fp.erf() +} +>>> {mag: 55227, sign: false} // = -1 +``` diff --git a/docs/framework/operators/neural-network/README.md b/docs/framework/operators/neural-network/README.md index cd1c92f8d..242882109 100644 --- a/docs/framework/operators/neural-network/README.md +++ b/docs/framework/operators/neural-network/README.md @@ -23,15 +23,15 @@ Orion supports currently these `NN` types. | function | description | | --- | --- | -| [`nn.relu`](nn.relu.md) | Applies the rectified linear unit function element-wise. | -| [`nn.leaky_relu`](nn.leaky\_relu.md) | Applies the leaky rectified linear unit (Leaky ReLU) activation function element-wise. | -| [`nn.sigmoid`](nn.sigmoid.md) | Applies the Sigmoid function to an n-dimensional input tensor. | -| [`nn.softmax`](nn.softmax.md) | Computes softmax activations. | -| [`nn.logsoftmax`](nn.logsoftmax.md) | Applies the natural log to Softmax function to an n-dimensional input Tensor. | -| [`nn.softsign`](nn.softsign.md) | Applies the Softsign function element-wise. | -| [`nn.softplus`](nn.softplus.md) | Applies the Softplus function element-wise. | -| [`nn.linear`](nn.linear.md) | Performs a linear transformation of the input tensor using the provided weights and bias. | -| [`nn.hard_sigmoid`](nn.hard\_sigmoid.md) | Applies the Hard Sigmoid function to an n-dimensional input tensor. | -| [`nn.thresholded_relu`](nn.thresholded\_relu.md) | Performs the thresholded relu activation function element-wise. | -| [`nn.gemm`](nn.gemm.md) | Performs General Matrix multiplication. | +| [`nn.relu`](nn.relu.md) | Applies the rectified linear unit function element-wise. | +| [`nn.leaky_relu`](nn.leaky\_relu.md) | Applies the leaky rectified linear unit (Leaky ReLU) activation function element-wise. | +| [`nn.sigmoid`](nn.sigmoid.md) | Applies the Sigmoid function to an n-dimensional input tensor. | +| [`nn.softmax`](nn.softmax.md) | Computes softmax activations. | +| [`nn.logsoftmax`](nn.logsoftmax.md) | Applies the natural log to Softmax function to an n-dimensional input Tensor. | +| [`nn.softsign`](nn.softsign.md) | Applies the Softsign function element-wise. | +| [`nn.softplus`](nn.softplus.md) | Applies the Softplus function element-wise. | +| [`nn.linear`](nn.linear.md) | Performs a linear transformation of the input tensor using the provided weights and bias. | +| [`nn.hard_sigmoid`](nn.hard\_sigmoid.md) | Applies the Hard Sigmoid function to an n-dimensional input tensor. | +| [`nn.thresholded_relu`](nn.thresholded\_relu.md) | Performs the thresholded relu activation function element-wise. | +| [`nn.gemm`](nn.gemm.md) | Performs General Matrix multiplication. | diff --git a/docs/framework/operators/tensor/README.md b/docs/framework/operators/tensor/README.md index 1127918e5..575546094 100644 --- a/docs/framework/operators/tensor/README.md +++ b/docs/framework/operators/tensor/README.md @@ -122,6 +122,8 @@ use orion::operators::tensor::TensorTrait; | [`tensor.is_nan`](tensor.is\_nan.md) | Returns which elements of the input are NaN. | | [`tensor.is_inf`](tensor.is\_inf.md) | Maps infinity to true and other values to false. | | [`tensor.not`](tensor.not.md) | Computes the logical negation of all elements in the input tensor. | +| [`tensor.reduce_log_sum`](tensor.reduce\_log\_sum.md) | Computes the log sum of the input tensor's elements along the provided axes. | +| [`tensor.erf`](tensor.erf.md) | Computes the error function of the given input tensor element-wise. | ## Arithmetic Operations diff --git a/docs/framework/operators/tensor/tensor.erf.md b/docs/framework/operators/tensor/tensor.erf.md new file mode 100644 index 000000000..19ce86a94 --- /dev/null +++ b/docs/framework/operators/tensor/tensor.erf.md @@ -0,0 +1,48 @@ +## tensor.erf + +```rust + fn erf(self: @Tensor) -> Tensor; +``` + +Computes the mean of the input tensor's elements along the provided axes. + +## Args + +* `self`(`@Tensor`) - The input tensor. + +## Returns + +A new `Tensor` of the same shape as the input tensor with +the the error function of the input tensor computed element-wise. + +## Type Constraints + +Constrain input and output types to fixed point tensors. + +## Examples + +```rust +use core::array::{ArrayTrait, SpanTrait}; + +use orion::operators::tensor::{TensorTrait, Tensor, FP16x16Tensor}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn erf_example() -> Tensor { + // The erf inputs is [1.0, 0.134, 0.520, 2.0, 3.5, 5.164] + let tensor = TensorTrait::::new( + shape: array![6].span(), + data: array![ + FixedTrait::new_unscaled(65536, false), + FixedTrait::new_unscaled(8832, false), + FixedTrait::new_unscaled(34079, false), + FixedTrait::new_unscaled(131072, false), + FixedTrait::new_unscaled(229376, false), + FixedTrait::new_unscaled(338428, false), + ] + .span(), + ); + + return tensor.erf(); +} +>>> [55227,9560,35252,65229,65536,65536] +``` diff --git a/docs/framework/operators/tensor/tensor.reduce_log_sum.md b/docs/framework/operators/tensor/tensor.reduce_log_sum.md new file mode 100644 index 000000000..f15bce4a9 --- /dev/null +++ b/docs/framework/operators/tensor/tensor.reduce_log_sum.md @@ -0,0 +1,45 @@ +## tensor.reduce_log_sum + +```rust + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor; +``` + +Computes the log sum of the input tensor's elements along the provided axes. +## Args + +* `self`(`@Tensor`) - The input tensor. +* `axis`(`usize`) - The dimension to reduce. +* `keepdims`(`bool`) - If true, retains reduced dimensions with length 1. + +## Panics + +* Panics if axis is not in the range of the input tensor's dimensions. + +## Returns + +A new `Tensor` instance with the specified axis reduced by summing its elements. + +fn reduce_log_sum() -> Tensor { + + let mut sizes = ArrayTrait::new(); + sizes.append(2); + sizes.append(2); + sizes.append(2); + + let mut data = ArrayTrait::new(); + data.append(FixedTrait::new_unscaled(1, false)); + data.append(FixedTrait::new_unscaled(2, false)); + data.append(FixedTrait::new_unscaled(3, false)); + data.append(FixedTrait::new_unscaled(4, false)); + data.append(FixedTrait::new_unscaled(5, false)); + data.append(FixedTrait::new_unscaled(6, false)); + data.append(FixedTrait::new_unscaled(7, false)); + data.append(FixedTrait::new_unscaled(8, false)); + + let tensor = TensorTrait::::new(sizes.span(), data.span()); + + We can call `reduce_log_sum` function as follows. + return tensor.reduce_log_sum(axis: 2, keepdims: false); +} +>>> [[0x11938, 0x1f203], [0x265d9, 0x2b540]] +``` diff --git a/nodegen/node/erf.py b/nodegen/node/erf.py new file mode 100644 index 000000000..d14dab073 --- /dev/null +++ b/nodegen/node/erf.py @@ -0,0 +1,34 @@ +import numpy as np +from math import erf +from nodegen.node import RunAll +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl + + +class Erf(RunAll): + + @staticmethod + def erf_fp8x23(): + x = np.asarray([0.12, -1.66, 3.4, 4.8, 2.7]).astype(np.float64).reshape(1,5) + y = np.asarray([erf(value) for value in x[0]]).astype(np.float64).reshape(1,5) + + x = Tensor(Dtype.FP8x23, x.shape, to_fp( + x.flatten(), FixedImpl.FP8x23)) + y = Tensor(Dtype.FP8x23, y.shape, to_fp( + y.flatten(), FixedImpl.FP8x23)) + + name = "erf_fp8x23" + make_test([x], y, "input_0.erf()", name) + + + @staticmethod + def erf_fp16x16(): + x = np.asarray([0.12, -1.66, 3.4, 4.8, 2.7]).astype(np.float64).reshape(1,5) + y = np.asarray([erf(value) for value in x[0]]).astype(np.float64).reshape(1,5) + + x = Tensor(Dtype.FP16x16, x.shape, to_fp( + x.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "erf_fp16x16" + make_test([x], y, "input_0.erf()", name) diff --git a/nodegen/node/reduce_log_sum.py b/nodegen/node/reduce_log_sum.py new file mode 100644 index 000000000..259081f5a --- /dev/null +++ b/nodegen/node/reduce_log_sum.py @@ -0,0 +1,117 @@ +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl +import numpy as np + + +class Reduce_log_sum(RunAll): + @staticmethod + def reduce_log_sum_fp8x23(): + def reduce_log_sum_export_do_not_keepdims(): + shape = [3, 2, 2] + axes = np.array([2], dtype=np.int64) + keepdims = False + x = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) + y = np.log(np.sum(x, axis=tuple(axes), keepdims=False)) + + x = Tensor(Dtype.FP8x23, x.shape, to_fp( + x.flatten(), FixedImpl.FP8x23)) + y = Tensor(Dtype.FP8x23, y.shape, to_fp( + y.flatten(), FixedImpl.FP8x23)) + + name = "reduce_log_sum_fp8x23_export_do_not_keepdims" + make_test( + [x], y, "input_0.reduce_log_sum(2, false)", name) + + def reduce_log_sum_export_keepdims(): + shape = [3, 2, 2] + axes = np.array([2], dtype=np.int64) + keepdims = True + x = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape).astype(np.int64) + y = np.log(np.sum(x, axis=tuple(axes), keepdims=True)) + + x = Tensor(Dtype.FP8x23, x.shape, to_fp( + x.flatten(), FixedImpl.FP8x23)) + y = Tensor(Dtype.FP8x23, y.shape, to_fp( + y.flatten(), FixedImpl.FP8x23)) + + name = "reduce_log_sum_fp8x23_export_keepdims" + make_test( + [x], y, "input_0.reduce_log_sum(2, true)", name) + + def reduce_log_sum_axis_0(): + shape = [3, 3, 3] + axes = np.array([0], dtype=np.int64) + keepdims = True + x = np.reshape(np.arange(1, np.prod(shape) + 1), shape) + y = np.log(np.sum(x, axis=tuple(axes), keepdims=True)) + + x = Tensor(Dtype.FP8x23, x.shape, to_fp( + x.flatten(), FixedImpl.FP8x23)) + y = Tensor(Dtype.FP8x23, y.shape, to_fp( + y.flatten(), FixedImpl.FP8x23)) + + name = "reduce_log_sum_fp8x23_export_negative_axes_keepdims" + make_test( + [x], y, "input_0.reduce_log_sum(0, true)", name) + + + reduce_log_sum_export_do_not_keepdims() + reduce_log_sum_export_keepdims() + reduce_log_sum_axis_0() + + @staticmethod + def reduce_log_sum_fp16x16(): + def reduce_log_sum_export_do_not_keepdims(): + shape = [3, 2, 2] + axes = np.array([2], dtype=np.int64) + keepdims = False + x = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape).astype(np.int64) + y = np.log(np.sum(x, axis=tuple(axes), keepdims=False)) + + x = Tensor(Dtype.FP8x23, x.shape, to_fp( + x.flatten(), FixedImpl.FP8x23)) + y = Tensor(Dtype.FP8x23, y.shape, to_fp( + y.flatten(), FixedImpl.FP8x23)) + + name = "reduce_log_sum_fp16x16_export_do_not_keepdims" + make_test( + [x], y, "input_0.reduce_log_sum(2, false)", name) + + def reduce_log_sum_export_keepdims(): + shape = [3, 2, 2] + axes = np.array([2], dtype=np.int64) + keepdims = True + x = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape).astype(np.int64) + y = np.log(np.sum(x, axis=tuple(axes), keepdims=True)) + + + x = Tensor(Dtype.FP8x23, x.shape, to_fp( + x.flatten(), FixedImpl.FP8x23)) + y = Tensor(Dtype.FP8x23, y.shape, to_fp( + y.flatten(), FixedImpl.FP8x23)) + + name = "reduce_log_sum_fp16x16_export_keepdims" + make_test( + [x], y, "input_0.reduce_log_sum(2, true)", name) + + def reduce_log_sum_axis_0(): + shape = [2, 2, 2] + axes = np.array([0], dtype=np.int64) + keepdims = True + x = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape).astype(np.int64) + y = np.log(np.sum(x, axis=tuple(axes), keepdims=True)) + + x = Tensor(Dtype.FP8x23, x.shape, to_fp( + x.flatten(), FixedImpl.FP8x23)) + y = Tensor(Dtype.FP8x23, y.shape, to_fp( + y.flatten(), FixedImpl.FP8x23)) + + name = "reduce_log_sum_fp16x16_export_negative_axes_keepdims" + make_test( + [x], y, "input_0.reduce_log_sum(0, true)", name) + + + reduce_log_sum_export_do_not_keepdims() + reduce_log_sum_export_keepdims() + reduce_log_sum_axis_0() \ No newline at end of file diff --git a/src/numbers/fixed_point/core.cairo b/src/numbers/fixed_point/core.cairo index 769190f93..0ef1f8c6f 100644 --- a/src/numbers/fixed_point/core.cairo +++ b/src/numbers/fixed_point/core.cairo @@ -1098,6 +1098,38 @@ trait FixedTrait { /// ``` /// fn sign(self: T) -> T; + /// # fp.erf + /// + /// ```rust + /// fn erf(self: T) -> T; + /// ``` + /// + /// Returns the error function of the input fixed point number computed element-wise. + /// + /// ## Args + /// + /// * `self`(`T`) - The input fixed point + /// + /// ## Returns + /// + /// The error function of the input fixed point number computed element-wise. + /// + /// ## Examples + /// + /// ```rust + /// use orion::numbers::{FP16x16, FP16x16Impl, FixedTrait}; + /// + /// fn erf_fp_example() -> FP16x16 { + /// // We instantiate fixed point here. + /// let fp = FixedTrait::new(65536, false); + /// + /// // We can call `erf` function as follows. + /// fp.erf() + /// } + /// >>> {mag: 55227, sign: false} // = -1 + /// ``` + /// + fn erf(self: T) -> T; fn ZERO() -> T; fn HALF() -> T; diff --git a/src/numbers/fixed_point/implementations/fp16x16/core.cairo b/src/numbers/fixed_point/implementations/fp16x16/core.cairo index a4a99bb82..a6fc18c8d 100644 --- a/src/numbers/fixed_point/implementations/fp16x16/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16/core.cairo @@ -6,7 +6,9 @@ use core::traits::{TryInto, Into}; use orion::numbers::signed_integer::{i32::i32, i8::i8}; use orion::numbers::fixed_point::core::FixedTrait; -use orion::numbers::fixed_point::implementations::fp16x16::math::{core as core_math, trig, hyp}; +use orion::numbers::fixed_point::implementations::fp16x16::math::{ + core as core_math, trig, hyp, erf +}; use orion::numbers::fixed_point::utils; /// A struct representing a fixed point number. @@ -218,6 +220,10 @@ impl FP16x16Impl of FixedTrait { fn is_neg_inf(self: FP16x16) -> bool { self.is_inf() && self.sign } + + fn erf(self: FP16x16) -> FP16x16 { + return erf::erf(self); + } } diff --git a/src/numbers/fixed_point/implementations/fp16x16/math.cairo b/src/numbers/fixed_point/implementations/fp16x16/math.cairo index 970c65f30..b0cf1d5e7 100644 --- a/src/numbers/fixed_point/implementations/fp16x16/math.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16/math.cairo @@ -3,3 +3,4 @@ mod comp; mod lut; mod trig; mod hyp; +mod erf; diff --git a/src/numbers/fixed_point/implementations/fp16x16/math/erf.cairo b/src/numbers/fixed_point/implementations/fp16x16/math/erf.cairo new file mode 100644 index 000000000..86f87f5ca --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16/math/erf.cairo @@ -0,0 +1,24 @@ +use core::traits::Into; +use orion::numbers::fixed_point::implementations::fp16x16::core::{ONE, FP16x16, FixedTrait}; +use orion::numbers::fixed_point::implementations::fp16x16::math::lut::erf_lut; + +const ERF_COMPUTATIONAL_ACCURACY: u32 = 100; +const ROUND_CHECK_NUMBER: u32 = 10; +// Values > MAX_ERF_NUMBER return 1 +const MAX_ERF_NUMBER: u32 = 229376; +// Values <= ERF_TRUNCATION_NUMBER -> two decimal places, and values > ERF_TRUNCATION_NUMBER -> one decimal place +const ERF_TRUNCATION_NUMBER: u32 = 131072; + +fn erf(x: FP16x16) -> FP16x16 { + // Lookup + // 1. if x.mag < 3.5 { lookup table } + // 2. else{ return 1} + let mut erf_value: u32 = 0; + + if x.mag < MAX_ERF_NUMBER { + erf_value = erf_lut(x.mag); + } else { + erf_value = ONE; + } + FP16x16 { mag: erf_value, sign: x.sign } +} diff --git a/src/numbers/fixed_point/implementations/fp16x16/math/lut.cairo b/src/numbers/fixed_point/implementations/fp16x16/math/lut.cairo index 0586963aa..65c9746c1 100644 --- a/src/numbers/fixed_point/implementations/fp16x16/math/lut.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16/math/lut.cairo @@ -1,3 +1,5 @@ +use orion::numbers::fixed_point::implementations::fp16x16::core::ONE; + // Calculates the most significant bit fn msb(whole: u32) -> (u32, u32) { if whole < 256 { @@ -1233,3 +1235,695 @@ fn atan(a: u32) -> (u32, u32, u32) { return (45416, 39716, 40025); } + +fn erf_lut(x: u32) -> u32 { + // Construct the erf lookup table + if x <= 5898 { + if x <= 0 { + return 0; + } + if x <= 655 { + return 739; + } + if x <= 1310 { + return 1478; + } + if x <= 1966 { + return 2217; + } + if x <= 2621 { + return 2956; + } + if x <= 3276 { + return 3694; + } + if x <= 3932 { + return 4431; + } + if x <= 4587 { + return 5168; + } + if x <= 5242 { + return 5903; + } + if x <= 5898 { + return 6637; + } + } + if x <= 12451 { + if x <= 6553 { + return 7370; + } + if x <= 7208 { + return 8101; + } + if x <= 7864 { + return 8831; + } + if x <= 8519 { + return 9559; + } + if x <= 9175 { + return 10285; + } + if x <= 9830 { + return 11009; + } + if x <= 10485 { + return 11731; + } + if x <= 11141 { + return 12451; + } + if x <= 11796 { + return 13168; + } + if x <= 12451 { + return 13883; + } + } + if x <= 19005 { + if x <= 13107 { + return 14595; + } + if x <= 13762 { + return 15304; + } + if x <= 14417 { + return 16010; + } + if x <= 15073 { + return 16713; + } + if x <= 15728 { + return 17412; + } + if x <= 16384 { + return 18109; + } + if x <= 17039 { + return 18802; + } + if x <= 17694 { + return 19491; + } + if x <= 18350 { + return 20177; + } + if x <= 19005 { + return 20859; + } + } + if x <= 25559 { + if x <= 19660 { + return 21536; + } + if x <= 20316 { + return 22210; + } + if x <= 20971 { + return 22880; + } + if x <= 21626 { + return 23545; + } + if x <= 22282 { + return 24206; + } + if x <= 22937 { + return 24863; + } + if x <= 23592 { + return 25515; + } + if x <= 24248 { + return 26162; + } + if x <= 24903 { + return 26804; + } + if x <= 25559 { + return 27442; + } + } + if x <= 32112 { + if x <= 26214 { + return 28075; + } + if x <= 26869 { + return 28702; + } + if x <= 27525 { + return 29325; + } + if x <= 28180 { + return 29942; + } + if x <= 28835 { + return 30554; + } + if x <= 29491 { + return 31161; + } + if x <= 30146 { + return 31762; + } + if x <= 30801 { + return 32358; + } + if x <= 31457 { + return 32948; + } + if x <= 32112 { + return 33532; + } + } + if x <= 38666 { + if x <= 32768 { + return 34111; + } + if x <= 33423 { + return 34684; + } + if x <= 34078 { + return 35251; + } + if x <= 34734 { + return 35813; + } + if x <= 35389 { + return 36368; + } + if x <= 36044 { + return 36917; + } + if x <= 36700 { + return 37461; + } + if x <= 37355 { + return 37998; + } + if x <= 38010 { + return 38530; + } + if x <= 38666 { + return 39055; + } + } + if x <= 45219 { + if x <= 39321 { + return 39574; + } + if x <= 39976 { + return 40087; + } + if x <= 40632 { + return 40593; + } + if x <= 41287 { + return 41094; + } + if x <= 41943 { + return 41588; + } + if x <= 42598 { + return 42076; + } + if x <= 43253 { + return 42557; + } + if x <= 43909 { + return 43032; + } + if x <= 44564 { + return 43501; + } + if x <= 45219 { + return 43964; + } + } + if x <= 51773 { + if x <= 45875 { + return 44420; + } + if x <= 46530 { + return 44870; + } + if x <= 47185 { + return 45313; + } + if x <= 47841 { + return 45750; + } + if x <= 48496 { + return 46181; + } + if x <= 49152 { + return 46606; + } + if x <= 49807 { + return 47024; + } + if x <= 50462 { + return 47436; + } + if x <= 51118 { + return 47841; + } + if x <= 51773 { + return 48241; + } + } + if x <= 58327 { + if x <= 52428 { + return 48634; + } + if x <= 53084 { + return 49021; + } + if x <= 53739 { + return 49401; + } + if x <= 54394 { + return 49776; + } + if x <= 55050 { + return 50144; + } + if x <= 55705 { + return 50506; + } + if x <= 56360 { + return 50862; + } + if x <= 57016 { + return 51212; + } + if x <= 57671 { + return 51556; + } + if x <= 58327 { + return 51894; + } + } + if x <= 64880 { + if x <= 58982 { + return 52226; + } + if x <= 59637 { + return 52552; + } + if x <= 60293 { + return 52872; + } + if x <= 60948 { + return 53186; + } + if x <= 61603 { + return 53495; + } + if x <= 62259 { + return 53797; + } + if x <= 62914 { + return 54094; + } + if x <= 63569 { + return 54386; + } + if x <= 64225 { + return 54672; + } + if x <= 64880 { + return 54952; + } + } + if x <= 71434 { + if x <= 65536 { + return 55227; + } + if x <= 66191 { + return 55496; + } + if x <= 66846 { + return 55760; + } + if x <= 67502 { + return 56019; + } + if x <= 68157 { + return 56272; + } + if x <= 68812 { + return 56520; + } + if x <= 69468 { + return 56763; + } + if x <= 70123 { + return 57001; + } + if x <= 70778 { + return 57234; + } + if x <= 71434 { + return 57462; + } + } + if x <= 77987 { + if x <= 72089 { + return 57685; + } + if x <= 72744 { + return 57903; + } + if x <= 73400 { + return 58116; + } + if x <= 74055 { + return 58325; + } + if x <= 74711 { + return 58529; + } + if x <= 75366 { + return 58728; + } + if x <= 76021 { + return 58923; + } + if x <= 76677 { + return 59113; + } + if x <= 77332 { + return 59299; + } + if x <= 77987 { + return 59481; + } + } + if x <= 84541 { + if x <= 78643 { + return 59658; + } + if x <= 79298 { + return 59831; + } + if x <= 79953 { + return 60000; + } + if x <= 80609 { + return 60165; + } + if x <= 81264 { + return 60326; + } + if x <= 81920 { + return 60483; + } + if x <= 82575 { + return 60636; + } + if x <= 83230 { + return 60785; + } + if x <= 83886 { + return 60931; + } + if x <= 84541 { + return 61072; + } + } + if x <= 91095 { + if x <= 85196 { + return 61211; + } + if x <= 85852 { + return 61345; + } + if x <= 86507 { + return 61477; + } + if x <= 87162 { + return 61604; + } + if x <= 87818 { + return 61729; + } + if x <= 88473 { + return 61850; + } + if x <= 89128 { + return 61968; + } + if x <= 89784 { + return 62083; + } + if x <= 90439 { + return 62194; + } + if x <= 91095 { + return 62303; + } + } + if x <= 97648 { + if x <= 91750 { + return 62408; + } + if x <= 92405 { + return 62511; + } + if x <= 93061 { + return 62611; + } + if x <= 93716 { + return 62708; + } + if x <= 94371 { + return 62802; + } + if x <= 95027 { + return 62894; + } + if x <= 95682 { + return 62983; + } + if x <= 96337 { + return 63070; + } + if x <= 96993 { + return 63154; + } + if x <= 97648 { + return 63235; + } + } + if x <= 104202 { + if x <= 98304 { + return 63314; + } + if x <= 98959 { + return 63391; + } + if x <= 99614 { + return 63465; + } + if x <= 100270 { + return 63538; + } + if x <= 100925 { + return 63608; + } + if x <= 101580 { + return 63676; + } + if x <= 102236 { + return 63742; + } + if x <= 102891 { + return 63806; + } + if x <= 103546 { + return 63867; + } + if x <= 104202 { + return 63927; + } + } + if x <= 110755 { + if x <= 104857 { + return 63985; + } + if x <= 105512 { + return 64042; + } + if x <= 106168 { + return 64096; + } + if x <= 106823 { + return 64149; + } + if x <= 107479 { + return 64200; + } + if x <= 108134 { + return 64249; + } + if x <= 108789 { + return 64297; + } + if x <= 109445 { + return 64343; + } + if x <= 110100 { + return 64388; + } + if x <= 110755 { + return 64431; + } + } + if x <= 117309 { + if x <= 111411 { + return 64473; + } + if x <= 112066 { + return 64514; + } + if x <= 112721 { + return 64553; + } + if x <= 113377 { + return 64590; + } + if x <= 114032 { + return 64627; + } + if x <= 114688 { + return 64662; + } + if x <= 115343 { + return 64696; + } + if x <= 115998 { + return 64729; + } + if x <= 116654 { + return 64760; + } + if x <= 117309 { + return 64791; + } + } + if x <= 123863 { + if x <= 117964 { + return 64821; + } + if x <= 118620 { + return 64849; + } + if x <= 119275 { + return 64876; + } + if x <= 119930 { + return 64903; + } + if x <= 120586 { + return 64928; + } + if x <= 121241 { + return 64953; + } + if x <= 121896 { + return 64977; + } + if x <= 122552 { + return 64999; + } + if x <= 123207 { + return 65021; + } + if x <= 123863 { + return 65043; + } + } + if x <= 130416 { + if x <= 124518 { + return 65063; + } + if x <= 125173 { + return 65083; + } + if x <= 125829 { + return 65102; + } + if x <= 126484 { + return 65120; + } + if x <= 127139 { + return 65137; + } + if x <= 127795 { + return 65154; + } + if x <= 128450 { + return 65170; + } + if x <= 129105 { + return 65186; + } + if x <= 129761 { + return 65201; + } + if x <= 130416 { + return 65215; + } + } + if x <= 222822 { + if x <= 131072 { + return 65229; + } + if x <= 137625 { + return 65340; + } + if x <= 144179 { + return 65413; + } + if x <= 150732 { + return 65461; + } + if x <= 157286 { + return 65490; + } + if x <= 163840 { + return 65509; + } + if x <= 170393 { + return 65520; + } + if x <= 176947 { + return 65527; + } + if x <= 183500 { + return 65531; + } + if x <= 190054 { + return 65533; + } + if x <= 196608 { + return 65534; + } + if x <= 203161 { + return 65535; + } + if x <= 209715 { + return 65535; + } + if x <= 216268 { + return 65535; + } + if x <= 222822 { + return 65535; + } + } + return ONE; +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo index c6703b7d3..26d6feec0 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo @@ -6,7 +6,9 @@ use core::traits::{TryInto, Into}; use orion::numbers::signed_integer::{i32::i32, i8::i8}; use orion::numbers::{fixed_point::core::FixedTrait, FP16x16}; -use orion::numbers::fixed_point::implementations::fp16x16wide::math::{core as core_math, trig, hyp}; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::{ + core as core_math, trig, hyp, erf +}; use orion::numbers::fixed_point::utils; /// A struct representing a fixed point number. @@ -218,6 +220,10 @@ impl FP16x16WImpl of FixedTrait { fn is_neg_inf(self: FP16x16W) -> bool { self.is_inf() && self.sign } + + fn erf(self: FP16x16W) -> FP16x16W { + return erf::erf(self); + } } diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math.cairo index 970c65f30..b0cf1d5e7 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math.cairo @@ -3,3 +3,4 @@ mod comp; mod lut; mod trig; mod hyp; +mod erf; diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/erf.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/erf.cairo new file mode 100644 index 000000000..49d19bf20 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/erf.cairo @@ -0,0 +1,25 @@ +use core::traits::Into; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ONE, FP16x16W, FixedTrait}; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::lut::erf_lut; + + +const ERF_COMPUTATIONAL_ACCURACY: u64 = 100; +const ROUND_CHECK_NUMBER: u64 = 10; +// Values > MAX_ERF_NUMBER return 1 +const MAX_ERF_NUMBER: u64 = 229376; +// Values <= ERF_TRUNCATION_NUMBER -> two decimal places, and values > ERF_TRUNCATION_NUMBER -> one decimal place +const ERF_TRUNCATION_NUMBER: u64 = 131072; + +fn erf(x: FP16x16W) -> FP16x16W { + // Lookup + // 1. if x.mag < 3.5 { lookup table } + // 2. else{ return 1} + let mut erf_value: u64 = 0; + + if x.mag < MAX_ERF_NUMBER { + erf_value = erf_lut(x.mag); + } else { + erf_value = ONE; + } + FP16x16W { mag: erf_value, sign: x.sign } +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo index e96b0d389..62c58537e 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo @@ -1,3 +1,5 @@ +use orion::numbers::fixed_point::implementations::fp8x23wide::core::ONE; + // Calculates the most significant bit fn msb(whole: u64) -> (u64, u64) { if whole < 256 { @@ -1233,3 +1235,695 @@ fn atan(a: u64) -> (u64, u64, u64) { return (45416, 39716, 40025); } + +fn erf_lut(x: u64) -> u64 { + // Construct the erf lookup table + if x <= 5898 { + if x <= 0 { + return 0; + } + if x <= 655 { + return 739; + } + if x <= 1310 { + return 1478; + } + if x <= 1966 { + return 2217; + } + if x <= 2621 { + return 2956; + } + if x <= 3276 { + return 3694; + } + if x <= 3932 { + return 4431; + } + if x <= 4587 { + return 5168; + } + if x <= 5242 { + return 5903; + } + if x <= 5898 { + return 6637; + } + } + if x <= 12451 { + if x <= 6553 { + return 7370; + } + if x <= 7208 { + return 8101; + } + if x <= 7864 { + return 8831; + } + if x <= 8519 { + return 9559; + } + if x <= 9175 { + return 10285; + } + if x <= 9830 { + return 11009; + } + if x <= 10485 { + return 11731; + } + if x <= 11141 { + return 12451; + } + if x <= 11796 { + return 13168; + } + if x <= 12451 { + return 13883; + } + } + if x <= 19005 { + if x <= 13107 { + return 14595; + } + if x <= 13762 { + return 15304; + } + if x <= 14417 { + return 16010; + } + if x <= 15073 { + return 16713; + } + if x <= 15728 { + return 17412; + } + if x <= 16384 { + return 18109; + } + if x <= 17039 { + return 18802; + } + if x <= 17694 { + return 19491; + } + if x <= 18350 { + return 20177; + } + if x <= 19005 { + return 20859; + } + } + if x <= 25559 { + if x <= 19660 { + return 21536; + } + if x <= 20316 { + return 22210; + } + if x <= 20971 { + return 22880; + } + if x <= 21626 { + return 23545; + } + if x <= 22282 { + return 24206; + } + if x <= 22937 { + return 24863; + } + if x <= 23592 { + return 25515; + } + if x <= 24248 { + return 26162; + } + if x <= 24903 { + return 26804; + } + if x <= 25559 { + return 27442; + } + } + if x <= 32112 { + if x <= 26214 { + return 28075; + } + if x <= 26869 { + return 28702; + } + if x <= 27525 { + return 29325; + } + if x <= 28180 { + return 29942; + } + if x <= 28835 { + return 30554; + } + if x <= 29491 { + return 31161; + } + if x <= 30146 { + return 31762; + } + if x <= 30801 { + return 32358; + } + if x <= 31457 { + return 32948; + } + if x <= 32112 { + return 33532; + } + } + if x <= 38666 { + if x <= 32768 { + return 34111; + } + if x <= 33423 { + return 34684; + } + if x <= 34078 { + return 35251; + } + if x <= 34734 { + return 35813; + } + if x <= 35389 { + return 36368; + } + if x <= 36044 { + return 36917; + } + if x <= 36700 { + return 37461; + } + if x <= 37355 { + return 37998; + } + if x <= 38010 { + return 38530; + } + if x <= 38666 { + return 39055; + } + } + if x <= 45219 { + if x <= 39321 { + return 39574; + } + if x <= 39976 { + return 40087; + } + if x <= 40632 { + return 40593; + } + if x <= 41287 { + return 41094; + } + if x <= 41943 { + return 41588; + } + if x <= 42598 { + return 42076; + } + if x <= 43253 { + return 42557; + } + if x <= 43909 { + return 43032; + } + if x <= 44564 { + return 43501; + } + if x <= 45219 { + return 43964; + } + } + if x <= 51773 { + if x <= 45875 { + return 44420; + } + if x <= 46530 { + return 44870; + } + if x <= 47185 { + return 45313; + } + if x <= 47841 { + return 45750; + } + if x <= 48496 { + return 46181; + } + if x <= 49152 { + return 46606; + } + if x <= 49807 { + return 47024; + } + if x <= 50462 { + return 47436; + } + if x <= 51118 { + return 47841; + } + if x <= 51773 { + return 48241; + } + } + if x <= 58327 { + if x <= 52428 { + return 48634; + } + if x <= 53084 { + return 49021; + } + if x <= 53739 { + return 49401; + } + if x <= 54394 { + return 49776; + } + if x <= 55050 { + return 50144; + } + if x <= 55705 { + return 50506; + } + if x <= 56360 { + return 50862; + } + if x <= 57016 { + return 51212; + } + if x <= 57671 { + return 51556; + } + if x <= 58327 { + return 51894; + } + } + if x <= 64880 { + if x <= 58982 { + return 52226; + } + if x <= 59637 { + return 52552; + } + if x <= 60293 { + return 52872; + } + if x <= 60948 { + return 53186; + } + if x <= 61603 { + return 53495; + } + if x <= 62259 { + return 53797; + } + if x <= 62914 { + return 54094; + } + if x <= 63569 { + return 54386; + } + if x <= 64225 { + return 54672; + } + if x <= 64880 { + return 54952; + } + } + if x <= 71434 { + if x <= 65536 { + return 55227; + } + if x <= 66191 { + return 55496; + } + if x <= 66846 { + return 55760; + } + if x <= 67502 { + return 56019; + } + if x <= 68157 { + return 56272; + } + if x <= 68812 { + return 56520; + } + if x <= 69468 { + return 56763; + } + if x <= 70123 { + return 57001; + } + if x <= 70778 { + return 57234; + } + if x <= 71434 { + return 57462; + } + } + if x <= 77987 { + if x <= 72089 { + return 57685; + } + if x <= 72744 { + return 57903; + } + if x <= 73400 { + return 58116; + } + if x <= 74055 { + return 58325; + } + if x <= 74711 { + return 58529; + } + if x <= 75366 { + return 58728; + } + if x <= 76021 { + return 58923; + } + if x <= 76677 { + return 59113; + } + if x <= 77332 { + return 59299; + } + if x <= 77987 { + return 59481; + } + } + if x <= 84541 { + if x <= 78643 { + return 59658; + } + if x <= 79298 { + return 59831; + } + if x <= 79953 { + return 60000; + } + if x <= 80609 { + return 60165; + } + if x <= 81264 { + return 60326; + } + if x <= 81920 { + return 60483; + } + if x <= 82575 { + return 60636; + } + if x <= 83230 { + return 60785; + } + if x <= 83886 { + return 60931; + } + if x <= 84541 { + return 61072; + } + } + if x <= 91095 { + if x <= 85196 { + return 61211; + } + if x <= 85852 { + return 61345; + } + if x <= 86507 { + return 61477; + } + if x <= 87162 { + return 61604; + } + if x <= 87818 { + return 61729; + } + if x <= 88473 { + return 61850; + } + if x <= 89128 { + return 61968; + } + if x <= 89784 { + return 62083; + } + if x <= 90439 { + return 62194; + } + if x <= 91095 { + return 62303; + } + } + if x <= 97648 { + if x <= 91750 { + return 62408; + } + if x <= 92405 { + return 62511; + } + if x <= 93061 { + return 62611; + } + if x <= 93716 { + return 62708; + } + if x <= 94371 { + return 62802; + } + if x <= 95027 { + return 62894; + } + if x <= 95682 { + return 62983; + } + if x <= 96337 { + return 63070; + } + if x <= 96993 { + return 63154; + } + if x <= 97648 { + return 63235; + } + } + if x <= 104202 { + if x <= 98304 { + return 63314; + } + if x <= 98959 { + return 63391; + } + if x <= 99614 { + return 63465; + } + if x <= 100270 { + return 63538; + } + if x <= 100925 { + return 63608; + } + if x <= 101580 { + return 63676; + } + if x <= 102236 { + return 63742; + } + if x <= 102891 { + return 63806; + } + if x <= 103546 { + return 63867; + } + if x <= 104202 { + return 63927; + } + } + if x <= 110755 { + if x <= 104857 { + return 63985; + } + if x <= 105512 { + return 64042; + } + if x <= 106168 { + return 64096; + } + if x <= 106823 { + return 64149; + } + if x <= 107479 { + return 64200; + } + if x <= 108134 { + return 64249; + } + if x <= 108789 { + return 64297; + } + if x <= 109445 { + return 64343; + } + if x <= 110100 { + return 64388; + } + if x <= 110755 { + return 64431; + } + } + if x <= 117309 { + if x <= 111411 { + return 64473; + } + if x <= 112066 { + return 64514; + } + if x <= 112721 { + return 64553; + } + if x <= 113377 { + return 64590; + } + if x <= 114032 { + return 64627; + } + if x <= 114688 { + return 64662; + } + if x <= 115343 { + return 64696; + } + if x <= 115998 { + return 64729; + } + if x <= 116654 { + return 64760; + } + if x <= 117309 { + return 64791; + } + } + if x <= 123863 { + if x <= 117964 { + return 64821; + } + if x <= 118620 { + return 64849; + } + if x <= 119275 { + return 64876; + } + if x <= 119930 { + return 64903; + } + if x <= 120586 { + return 64928; + } + if x <= 121241 { + return 64953; + } + if x <= 121896 { + return 64977; + } + if x <= 122552 { + return 64999; + } + if x <= 123207 { + return 65021; + } + if x <= 123863 { + return 65043; + } + } + if x <= 130416 { + if x <= 124518 { + return 65063; + } + if x <= 125173 { + return 65083; + } + if x <= 125829 { + return 65102; + } + if x <= 126484 { + return 65120; + } + if x <= 127139 { + return 65137; + } + if x <= 127795 { + return 65154; + } + if x <= 128450 { + return 65170; + } + if x <= 129105 { + return 65186; + } + if x <= 129761 { + return 65201; + } + if x <= 130416 { + return 65215; + } + } + if x <= 222822 { + if x <= 131072 { + return 65229; + } + if x <= 137625 { + return 65340; + } + if x <= 144179 { + return 65413; + } + if x <= 150732 { + return 65461; + } + if x <= 157286 { + return 65490; + } + if x <= 163840 { + return 65509; + } + if x <= 170393 { + return 65520; + } + if x <= 176947 { + return 65527; + } + if x <= 183500 { + return 65531; + } + if x <= 190054 { + return 65533; + } + if x <= 196608 { + return 65534; + } + if x <= 203161 { + return 65535; + } + if x <= 209715 { + return 65535; + } + if x <= 216268 { + return 65535; + } + if x <= 222822 { + return 65535; + } + } + return ONE; +} diff --git a/src/numbers/fixed_point/implementations/fp32x32.cairo b/src/numbers/fixed_point/implementations/fp32x32.cairo index 1c347faf1..3456a2ebd 100644 --- a/src/numbers/fixed_point/implementations/fp32x32.cairo +++ b/src/numbers/fixed_point/implementations/fp32x32.cairo @@ -1,2 +1,4 @@ mod core; mod comp; +mod erf; +mod lut; diff --git a/src/numbers/fixed_point/implementations/fp32x32/core.cairo b/src/numbers/fixed_point/implementations/fp32x32/core.cairo index 869a8c519..64bbadb74 100644 --- a/src/numbers/fixed_point/implementations/fp32x32/core.cairo +++ b/src/numbers/fixed_point/implementations/fp32x32/core.cairo @@ -9,6 +9,7 @@ use cubit::f64::Fixed as FP32x32; use cubit::f64::{ONE, HALF}; use cubit::f64::types::fixed; +use orion::numbers::fixed_point::implementations::fp32x32::erf; use orion::numbers::fixed_point::core::{FixedTrait}; use orion::numbers::fixed_point::utils; use orion::numbers::{i32, i8}; @@ -209,6 +210,10 @@ impl FP32x32Impl of FixedTrait { fn is_neg_inf(self: FP32x32) -> bool { self.is_inf() && self.sign } + + fn erf(self: FP32x32) -> FP32x32 { + return erf::erf(self); + } } diff --git a/src/numbers/fixed_point/implementations/fp32x32/erf.cairo b/src/numbers/fixed_point/implementations/fp32x32/erf.cairo new file mode 100644 index 000000000..63ee48f85 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp32x32/erf.cairo @@ -0,0 +1,26 @@ +use core::traits::Into; +use orion::numbers::{FP32x32, FixedTrait}; +use cubit::f64::ONE; + +use orion::numbers::fixed_point::implementations::fp32x32::lut::erf_lut; + +const ERF_COMPUTATIONAL_ACCURACY: u64 = 100; +const ROUND_CHECK_NUMBER: u64 = 10; +// Values > MAX_ERF_NUMBER return 1 +const MAX_ERF_NUMBER: u64 = 15032385536; +// Values <= ERF_TRUNCATION_NUMBER -> two decimal places, and values > ERF_TRUNCATION_NUMBER -> one decimal place +const ERF_TRUNCATION_NUMBER: u64 = 8589934592; + +fn erf(x: FP32x32) -> FP32x32 { + // Lookup + // 1. if x.mag < 3.5 { lookup table } + // 2. else{ return 1} + let mut erf_value: u64 = 0_u64; + + if x.mag < MAX_ERF_NUMBER { + erf_value = erf_lut(x.mag); + } else { + erf_value = ONE; + } + FP32x32 { mag: erf_value, sign: x.sign } +} diff --git a/src/numbers/fixed_point/implementations/fp32x32/lut.cairo b/src/numbers/fixed_point/implementations/fp32x32/lut.cairo new file mode 100644 index 000000000..03a576452 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp32x32/lut.cairo @@ -0,0 +1,693 @@ +use orion::numbers::fixed_point::implementations::fp32x32::core::ONE; + +fn erf_lut(x: u64) -> u64 { + // Construct the erf lookup table + if x <= 386547056 { + if x <= 0 { + return 0; + } + if x <= 42949672 { + return 48461900; + } + if x <= 85899345 { + return 96914110; + } + if x <= 128849018 { + return 145346943; + } + if x <= 171798691 { + return 193750725; + } + if x <= 214748364 { + return 242115801; + } + if x <= 257698037 { + return 290432536; + } + if x <= 300647710 { + return 338691327; + } + if x <= 343597383 { + return 386882604; + } + if x <= 386547056 { + return 434996838; + } + } + if x <= 816043786 { + if x <= 429496729 { + return 483024546; + } + if x <= 472446402 { + return 530956296; + } + if x <= 515396075 { + return 578782713; + } + if x <= 558345748 { + return 626494487; + } + if x <= 601295421 { + return 674082374; + } + if x <= 644245094 { + return 721537203; + } + if x <= 687194767 { + return 768849883; + } + if x <= 730144440 { + return 816011407; + } + if x <= 773094113 { + return 863012857; + } + if x <= 816043786 { + return 909845408; + } + } + if x <= 1245540515 { + if x <= 858993459 { + return 956500337; + } + if x <= 901943132 { + return 1002969022; + } + if x <= 944892805 { + return 1049242950; + } + if x <= 987842478 { + return 1095313724; + } + if x <= 1030792151 { + return 1141173063; + } + if x <= 1073741824 { + return 1186812808; + } + if x <= 1116691496 { + return 1232224928; + } + if x <= 1159641169 { + return 1277401521; + } + if x <= 1202590842 { + return 1322334823; + } + if x <= 1245540515 { + return 1367017205; + } + } + if x <= 1675037245 { + if x <= 1288490188 { + return 1411441184; + } + if x <= 1331439861 { + return 1455599421; + } + if x <= 1374389534 { + return 1499484729; + } + if x <= 1417339207 { + return 1543090073; + } + if x <= 1460288880 { + return 1586408573; + } + if x <= 1503238553 { + return 1629433512; + } + if x <= 1546188226 { + return 1672158333; + } + if x <= 1589137899 { + return 1714576645; + } + if x <= 1632087572 { + return 1756682226; + } + if x <= 1675037245 { + return 1798469022; + } + } + if x <= 2104533975 { + if x <= 1717986918 { + return 1839931154; + } + if x <= 1760936591 { + return 1881062918; + } + if x <= 1803886264 { + return 1921858787; + } + if x <= 1846835937 { + return 1962313411; + } + if x <= 1889785610 { + return 2002421622; + } + if x <= 1932735283 { + return 2042178436; + } + if x <= 1975684956 { + return 2081579049; + } + if x <= 2018634629 { + return 2120618846; + } + if x <= 2061584302 { + return 2159293393; + } + if x <= 2104533975 { + return 2197598448; + } + } + if x <= 2534030704 { + if x <= 2147483648 { + return 2235529952; + } + if x <= 2190433320 { + return 2273084038; + } + if x <= 2233382993 { + return 2310257026; + } + if x <= 2276332666 { + return 2347045424; + } + if x <= 2319282339 { + return 2383445931; + } + if x <= 2362232012 { + return 2419455435; + } + if x <= 2405181685 { + return 2455071011; + } + if x <= 2448131358 { + return 2490289925; + } + if x <= 2491081031 { + return 2525109629; + } + if x <= 2534030704 { + return 2559527765; + } + } + if x <= 2963527434 { + if x <= 2576980377 { + return 2593542161; + } + if x <= 2619930050 { + return 2627150830; + } + if x <= 2662879723 { + return 2660351971; + } + if x <= 2705829396 { + return 2693143967; + } + if x <= 2748779069 { + return 2725525382; + } + if x <= 2791728742 { + return 2757494964; + } + if x <= 2834678415 { + return 2789051637; + } + if x <= 2877628088 { + return 2820194507; + } + if x <= 2920577761 { + return 2850922852; + } + if x <= 2963527434 { + return 2881236128; + } + } + if x <= 3393024163 { + if x <= 3006477107 { + return 2911133960; + } + if x <= 3049426780 { + return 2940616146; + } + if x <= 3092376453 { + return 2969682651; + } + if x <= 3135326126 { + return 2998333604; + } + if x <= 3178275799 { + return 3026569298; + } + if x <= 3221225472 { + return 3054390188; + } + if x <= 3264175144 { + return 3081796886; + } + if x <= 3307124817 { + return 3108790160; + } + if x <= 3350074490 { + return 3135370928; + } + if x <= 3393024163 { + return 3161540260; + } + } + if x <= 3822520893 { + if x <= 3435973836 { + return 3187299373; + } + if x <= 3478923509 { + return 3212649627; + } + if x <= 3521873182 { + return 3237592522; + } + if x <= 3564822855 { + return 3262129696; + } + if x <= 3607772528 { + return 3286262922; + } + if x <= 3650722201 { + return 3309994103; + } + if x <= 3693671874 { + return 3333325270; + } + if x <= 3736621547 { + return 3356258580; + } + if x <= 3779571220 { + return 3378796308; + } + if x <= 3822520893 { + return 3400940848; + } + } + if x <= 4252017623 { + if x <= 3865470566 { + return 3422694710; + } + if x <= 3908420239 { + return 3444060511; + } + if x <= 3951369912 { + return 3465040979; + } + if x <= 3994319585 { + return 3485638942; + } + if x <= 4037269258 { + return 3505857331; + } + if x <= 4080218931 { + return 3525699170; + } + if x <= 4123168604 { + return 3545167580; + } + if x <= 4166118277 { + return 3564265768; + } + if x <= 4209067950 { + return 3582997028; + } + if x <= 4252017623 { + return 3601364736; + } + } + if x <= 4681514352 { + if x <= 4294967296 { + return 3619372346; + } + if x <= 4337916968 { + return 3637023387; + } + if x <= 4380866641 { + return 3654321460; + } + if x <= 4423816314 { + return 3671270233; + } + if x <= 4466765987 { + return 3687873439; + } + if x <= 4509715660 { + return 3704134870; + } + if x <= 4552665333 { + return 3720058378; + } + if x <= 4595615006 { + return 3735647866; + } + if x <= 4638564679 { + return 3750907289; + } + if x <= 4681514352 { + return 3765840647; + } + } + if x <= 5111011082 { + if x <= 4724464025 { + return 3780451987; + } + if x <= 4767413698 { + return 3794745393; + } + if x <= 4810363371 { + return 3808724986; + } + if x <= 4853313044 { + return 3822394923; + } + if x <= 4896262717 { + return 3835759389; + } + if x <= 4939212390 { + return 3848822598; + } + if x <= 4982162063 { + return 3861588787; + } + if x <= 5025111736 { + return 3874062214; + } + if x <= 5068061409 { + return 3886247156; + } + if x <= 5111011082 { + return 3898147905; + } + } + if x <= 5540507811 { + if x <= 5153960755 { + return 3909768765; + } + if x <= 5196910428 { + return 3921114049; + } + if x <= 5239860101 { + return 3932188077; + } + if x <= 5282809774 { + return 3942995173; + } + if x <= 5325759447 { + return 3953539662; + } + if x <= 5368709120 { + return 3963825868; + } + if x <= 5411658792 { + return 3973858111; + } + if x <= 5454608465 { + return 3983640704; + } + if x <= 5497558138 { + return 3993177952; + } + if x <= 5540507811 { + return 4002474150; + } + } + if x <= 5970004541 { + if x <= 5583457484 { + return 4011533577; + } + if x <= 5626407157 { + return 4020360499; + } + if x <= 5669356830 { + return 4028959162; + } + if x <= 5712306503 { + return 4037333795; + } + if x <= 5755256176 { + return 4045488602; + } + if x <= 5798205849 { + return 4053427767; + } + if x <= 5841155522 { + return 4061155446; + } + if x <= 5884105195 { + return 4068675768; + } + if x <= 5927054868 { + return 4075992834; + } + if x <= 5970004541 { + return 4083110714; + } + } + if x <= 6399501271 { + if x <= 6012954214 { + return 4090033445; + } + if x <= 6055903887 { + return 4096765032; + } + if x <= 6098853560 { + return 4103309442; + } + if x <= 6141803233 { + return 4109670609; + } + if x <= 6184752906 { + return 4115852426; + } + if x <= 6227702579 { + return 4121858749; + } + if x <= 6270652252 { + return 4127693393; + } + if x <= 6313601925 { + return 4133360131; + } + if x <= 6356551598 { + return 4138862695; + } + if x <= 6399501271 { + return 4144204773; + } + } + if x <= 6828998000 { + if x <= 6442450944 { + return 4149390008; + } + if x <= 6485400616 { + return 4154421999; + } + if x <= 6528350289 { + return 4159304298; + } + if x <= 6571299962 { + return 4164040410; + } + if x <= 6614249635 { + return 4168633795; + } + if x <= 6657199308 { + return 4173087863; + } + if x <= 6700148981 { + return 4177405975; + } + if x <= 6743098654 { + return 4181591444; + } + if x <= 6786048327 { + return 4185647533; + } + if x <= 6828998000 { + return 4189577456; + } + } + if x <= 7258494730 { + if x <= 6871947673 { + return 4193384375; + } + if x <= 6914897346 { + return 4197071404; + } + if x <= 6957847019 { + return 4200641603; + } + if x <= 7000796692 { + return 4204097984; + } + if x <= 7043746365 { + return 4207443505; + } + if x <= 7086696038 { + return 4210681075; + } + if x <= 7129645711 { + return 4213813550; + } + if x <= 7172595384 { + return 4216843737; + } + if x <= 7215545057 { + return 4219774388; + } + if x <= 7258494730 { + return 4222608207; + } + } + if x <= 7687991459 { + if x <= 7301444403 { + return 4225347845; + } + if x <= 7344394076 { + return 4227995903; + } + if x <= 7387343749 { + return 4230554929; + } + if x <= 7430293422 { + return 4233027424; + } + if x <= 7473243095 { + return 4235415834; + } + if x <= 7516192768 { + return 4237722559; + } + if x <= 7559142440 { + return 4239949947; + } + if x <= 7602092113 { + return 4242100295; + } + if x <= 7645041786 { + return 4244175854; + } + if x <= 7687991459 { + return 4246178824; + } + } + if x <= 8117488189 { + if x <= 7730941132 { + return 4248111357; + } + if x <= 7773890805 { + return 4249975557; + } + if x <= 7816840478 { + return 4251773482; + } + if x <= 7859790151 { + return 4253507139; + } + if x <= 7902739824 { + return 4255178493; + } + if x <= 7945689497 { + return 4256789460; + } + if x <= 7988639170 { + return 4258341912; + } + if x <= 8031588843 { + return 4259837674; + } + if x <= 8074538516 { + return 4261278529; + } + if x <= 8117488189 { + return 4262666214; + } + } + if x <= 8546984919 { + if x <= 8160437862 { + return 4264002425; + } + if x <= 8203387535 { + return 4265288813; + } + if x <= 8246337208 { + return 4266526989; + } + if x <= 8289286881 { + return 4267718520; + } + if x <= 8332236554 { + return 4268864936; + } + if x <= 8375186227 { + return 4269967724; + } + if x <= 8418135900 { + return 4271028331; + } + if x <= 8461085573 { + return 4272048167; + } + if x <= 8504035246 { + return 4273028604; + } + if x <= 8546984919 { + return 4273970975; + } + } + if x <= 14602888806 { + if x <= 8589934592 { + return 4274876577; + } + if x <= 9019431321 { + return 4282170584; + } + if x <= 9448928051 { + return 4286966432; + } + if x <= 9878424780 { + return 4290057389; + } + if x <= 10307921510 { + return 4292010151; + } + if x <= 10737418240 { + return 4293219450; + } + if x <= 11166914969 { + return 4293953535; + } + if x <= 11596411699 { + return 4294390341; + } + if x <= 12025908428 { + return 4294645116; + } + if x <= 12455405158 { + return 4294790781; + } + if x <= 12884901888 { + return 4294872418; + } + if x <= 13314398617 { + return 4294917265; + } + if x <= 13743895347 { + return 4294941415; + } + if x <= 14173392076 { + return 4294954163; + } + if x <= 14602888806 { + return 4294960759; + } + } + return ONE; +} diff --git a/src/numbers/fixed_point/implementations/fp64x64.cairo b/src/numbers/fixed_point/implementations/fp64x64.cairo index 1c347faf1..3456a2ebd 100644 --- a/src/numbers/fixed_point/implementations/fp64x64.cairo +++ b/src/numbers/fixed_point/implementations/fp64x64.cairo @@ -1,2 +1,4 @@ mod core; mod comp; +mod erf; +mod lut; diff --git a/src/numbers/fixed_point/implementations/fp64x64/core.cairo b/src/numbers/fixed_point/implementations/fp64x64/core.cairo index 6cfba5423..9d15035cc 100644 --- a/src/numbers/fixed_point/implementations/fp64x64/core.cairo +++ b/src/numbers/fixed_point/implementations/fp64x64/core.cairo @@ -9,6 +9,7 @@ use cubit::f128::types::Fixed as FP64x64; use cubit::f128::ONE_u128 as ONE; use cubit::f128::ops::MAX_u128 as MAX; +use orion::numbers::fixed_point::implementations::fp64x64::erf; use orion::numbers::fixed_point::core::{FixedTrait}; use orion::numbers::fixed_point::utils; use orion::numbers::{i32, i8}; @@ -209,6 +210,10 @@ impl FP64x64Impl of FixedTrait { fn is_neg_inf(self: FP64x64) -> bool { self.is_inf() && self.sign } + + fn erf(self: FP64x64) -> FP64x64 { + return erf::erf(self); + } } diff --git a/src/numbers/fixed_point/implementations/fp64x64/erf.cairo b/src/numbers/fixed_point/implementations/fp64x64/erf.cairo new file mode 100644 index 000000000..3f5101b65 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp64x64/erf.cairo @@ -0,0 +1,26 @@ +use core::traits::Into; +use orion::numbers::{FP64x64, FixedTrait}; +use cubit::f128::ONE_u128 as ONE; + +use orion::numbers::fixed_point::implementations::fp64x64::lut::erf_lut; + +const ERF_COMPUTATIONAL_ACCURACY: u128 = 100_u128; +const ROUND_CHECK_NUMBER: u128 = 10_u128; +// Values > MAX_ERF_NUMBER return 1 +const MAX_ERF_NUMBER: u128 = 64563604257983430656_u128; +// Values <= ERF_TRUNCATION_NUMBER -> two decimal places, and values > ERF_TRUNCATION_NUMBER -> one decimal place +const ERF_TRUNCATION_NUMBER: u128 = 36893488147419103232_u128; + +fn erf(x: FP64x64) -> FP64x64 { + // Lookup + // 1. if x.mag < 3.5 { lookup table } + // 2. else{ return 1} + let mut erf_value: u128 = 0_u128; + + if x.mag <= MAX_ERF_NUMBER { + erf_value = erf_lut(x.mag); + } else { + erf_value = ONE; + } + FP64x64 { mag: erf_value, sign: x.sign } +} diff --git a/src/numbers/fixed_point/implementations/fp64x64/lut.cairo b/src/numbers/fixed_point/implementations/fp64x64/lut.cairo new file mode 100644 index 000000000..34042bf26 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp64x64/lut.cairo @@ -0,0 +1,693 @@ +use orion::numbers::fixed_point::implementations::fp64x64::core::ONE; + +fn erf_lut(x: u128) -> u128 { + // Construct the erf lookup table + if x <= 1660206966633859584 { + if x <= 0 { + return 0; + } + if x <= 184467440737095520 { + return 208142279036071072; + } + if x <= 368934881474191040 { + return 416242934472567232; + } + if x <= 553402322211286528 { + return 624260367679495296; + } + if x <= 737869762948382080 { + return 832153029941062528; + } + if x <= 922337203685477632 { + return 1039879447350402944; + } + if x <= 1106804644422573056 { + return 1247398245629553408; + } + if x <= 1291272085159668736 { + return 1454668174849927424; + } + if x <= 1475739525896764160 { + return 1661648134028665088; + } + if x <= 1660206966633859584 { + return 1868297195576427008; + } + } + if x <= 3504881374004814848 { + if x <= 1844674407370955264 { + return 2074574629572391936; + } + if x <= 2029141848108050688 { + return 2280439927842463744; + } + if x <= 2213609288845146112 { + return 2485852827816977408; + } + if x <= 2398076729582241792 { + return 2690773336144481280; + } + if x <= 2582544170319337472 { + return 2895161752038532608; + } + if x <= 2767011611056432640 { + return 3098978690334796800; + } + if x <= 2951479051793528320 { + return 3302185104236156928; + } + if x <= 3135946492530624000 { + return 3504742307723958272; + } + if x <= 3320413933267719168 { + return 3706611997613982720; + } + if x <= 3504881374004814848 { + return 3907756275236240384; + } + } + if x <= 5349555781375769600 { + if x <= 3689348814741910528 { + return 4108137667718166528; + } + if x <= 3873816255479005696 { + return 4307719148851377152; + } + if x <= 4058283696216101376 { + return 4506464159522699776; + } + if x <= 4242751136953197056 { + return 4704336627690769408; + } + if x <= 4427218577690292224 { + return 4901300987890141184; + } + if x <= 4611686018427387904 { + return 5097322200245477376; + } + if x <= 4796153459164483584 { + return 5292365768979031040; + } + if x <= 4980620899901579264 { + return 5486397760395360256; + } + if x <= 5165088340638674944 { + return 5679384820327877632; + } + if x <= 5349555781375769600 { + return 5871294191032579072; + } + } + if x <= 7194230188746725376 { + if x <= 5534023222112865280 { + return 6062093727515032576; + } + if x <= 5718490662849960960 { + return 6251751913277435904; + } + if x <= 5902958103587056640 { + return 6440237875473368064; + } + if x <= 6087425544324152320 { + return 6627521399458594816; + } + if x <= 6271892985061248000 { + return 6813572942727099392; + } + if x <= 6456360425798342656 { + return 6998363648222307328; + } + if x <= 6640827866535438336 { + return 7181865357014296576; + } + if x <= 6825295307272534016 { + return 7364050620334585856; + } + if x <= 7009762748009629696 { + return 7544892710960923648; + } + if x <= 7194230188746725376 { + return 7724365633945352192; + } + } + if x <= 9038904596117680128 { + if x <= 7378697629483821056 { + return 7902444136679609344; + } + if x <= 7563165070220915712 { + return 8079103718292817920; + } + if x <= 7747632510958011392 { + return 8254320638377208832; + } + if x <= 7932099951695107072 { + return 8428071925038478336; + } + if x <= 8116567392432202752 { + return 8600335382268215296; + } + if x <= 8301034833169298432 { + return 8771089596636659712; + } + if x <= 8485502273906394112 { + return 8940313943304876032; + } + if x <= 8669969714643488768 { + return 9107988591356256256; + } + if x <= 8854437155380584448 { + return 9274094508448081920; + } + if x <= 9038904596117680128 { + return 9438613464784658432; + } + } + if x <= 10883579003488634880 { + if x <= 9223372036854775808 { + return 9601528036414361600; + } + if x <= 9407839477591871488 { + return 9762821607853701120; + } + if x <= 9592306918328967168 { + return 9922478374042292224; + } + if x <= 9776774359066062848 { + return 10080483341633368064; + } + if x <= 9961241799803158528 { + return 10236822329625237504; + } + if x <= 10145709240540254208 { + return 10391481969339820032; + } + if x <= 10330176681277349888 { + return 10544449703755059200; + } + if x <= 10514644122014443520 { + return 10695713786198818816; + } + if x <= 10699111562751539200 { + return 10845263278412423168; + } + if x <= 10883579003488634880 { + return 10993088047992748032; + } + } + if x <= 12728253410859589632 { + if x <= 11068046444225730560 { + return 11139178765222393856; + } + if x <= 11252513884962826240 { + return 11283526899298078720; + } + if x <= 11436981325699921920 { + return 11426124713968005120; + } + if x <= 11621448766437017600 { + return 11566965262589513728; + } + if x <= 11805916207174113280 { + return 11706042382618923008; + } + if x <= 11990383647911208960 { + return 11843350689545969664; + } + if x <= 12174851088648304640 { + return 11978885570285762560; + } + if x <= 12359318529385400320 { + return 12112643176041672704; + } + if x <= 12543785970122496000 { + return 12244620414653018112; + } + if x <= 12728253410859589632 { + return 12374814942441867264; + } + } + if x <= 14572927818230546432 { + if x <= 12912720851596685312 { + return 12503225155573657600; + } + if x <= 13097188292333780992 { + return 12629850180946728960; + } + if x <= 13281655733070876672 { + return 12754689866626244608; + } + if x <= 13466123173807972352 { + return 12877744771838261248; + } + if x <= 13650590614545068032 { + return 12999016156540069888; + } + if x <= 13835058055282163712 { + return 13118505970583140352; + } + if x <= 14019525496019259392 { + return 13236216842485327872; + } + if x <= 14203992936756355072 { + return 13352152067829151744; + } + if x <= 14388460377493450752 { + return 13466315597303212032; + } + if x <= 14572927818230546432 { + return 13578712024403965952; + } + } + if x <= 16417602225601501184 { + if x <= 14757395258967642112 { + return 13689346572815177728; + } + if x <= 14941862699704737792 { + return 13798225083482576896; + } + if x <= 15126330140441831424 { + return 13905354001401262080; + } + if x <= 15310797581178927104 { + return 14010740362133477376; + } + if x <= 15495265021916022784 { + return 14114391778074478592; + } + if x <= 15679732462653118464 { + return 14216316424484128768; + } + if x <= 15864199903390214144 { + return 14316523025301962752; + } + if x <= 16048667344127309824 { + return 14415020838763323392; + } + if x <= 16233134784864405504 { + return 14511819642834194432; + } + if x <= 16417602225601501184 { + return 14606929720482222080; + } + } + if x <= 18262276632972455936 { + if x <= 16602069666338596864 { + return 14700361844801351680; + } + if x <= 16786537107075692544 { + return 14792127264007346176; + } + if x <= 16971004547812788224 { + return 14882237686321330176; + } + if x <= 17155471988549883904 { + return 14970705264758325248; + } + if x <= 17339939429286977536 { + return 15057542581837537280; + } + if x <= 17524406870024073216 { + return 15142762634230988800; + } + if x <= 17708874310761168896 { + return 15226378817366812672; + } + if x <= 17893341751498264576 { + return 15308404910003300352; + } + if x <= 18077809192235360256 { + return 15388855058789533696; + } + if x <= 18262276632972455936 { + return 15467743762828154880; + } + } + if x <= 20106951040343412736 { + if x <= 18446744073709551616 { + return 15545085858255493120; + } + if x <= 18631211514446647296 { + return 15620896502854008832; + } + if x <= 18815678955183742976 { + return 15695191160711634944; + } + if x <= 19000146395920838656 { + return 15767985586942304256; + } + if x <= 19184613836657934336 { + return 15839295812481531904; + } + if x <= 19369081277395030016 { + return 15909138128970633216; + } + if x <= 19553548718132125696 { + return 15977529073742716928; + } + if x <= 19738016158869221376 { + return 16044485414923208704; + } + if x <= 19922483599606317056 { + return 16110024136657332224; + } + if x <= 20106951040343412736 { + return 16174162424476436480; + } + } + if x <= 21951625447714365440 { + if x <= 20291418481080508416 { + return 16236917650814795776; + } + if x <= 20475885921817604096 { + return 16298307360687947776; + } + if x <= 20660353362554699776 { + return 16358349257543309312; + } + if x <= 20844820803291791360 { + return 16417061189293291520; + } + if x <= 21029288244028887040 { + return 16474461134540791808; + } + if x <= 21213755684765982720 { + return 16530567189006364672; + } + if x <= 21398223125503078400 { + return 16585397552166088704; + } + if x <= 21582690566240174080 { + return 16638970514108524544; + } + if x <= 21767158006977269760 { + return 16691304442618875904; + } + if x <= 21951625447714365440 { + return 16742417770497863680; + } + } + if x <= 23796299855085322240 { + if x <= 22136092888451461120 { + return 16792328983122491392; + } + if x <= 22320560329188556800 { + return 16841056606255333376; + } + if x <= 22505027769925652480 { + return 16888619194108602368; + } + if x <= 22689495210662748160 { + return 16935035317668771840; + } + if x <= 22873962651399843840 { + return 16980323553287045120; + } + if x <= 23058430092136939520 { + return 17024502471540604928; + } + if x <= 23242897532874035200 { + return 17067590626369081344; + } + if x <= 23427364973611130880 { + return 17109606544490213376; + } + if x <= 23611832414348226560 { + return 17150568715098355712; + } + if x <= 23796299855085322240 { + return 17190495579848931328; + } + } + if x <= 25640974262456274944 { + if x <= 23980767295822417920 { + return 17229405523131617280; + } + if x <= 24165234736559513600 { + return 17267316862634573824; + } + if x <= 24349702177296609280 { + return 17304247840201650176; + } + if x <= 24534169618033704960 { + return 17340216612984107008; + } + if x <= 24718637058770800640 { + return 17375241244887996416; + } + if x <= 24903104499507896320 { + return 17409339698317971456; + } + if x <= 25087571940244992000 { + return 17442529826217906176; + } + if x <= 25272039380982087680 { + return 17474829364408369152; + } + if x <= 25456506821719179264 { + return 17506255924220641280; + } + if x <= 25640974262456274944 { + return 17536826985426591744; + } + } + if x <= 27485648669827231744 { + if x <= 25825441703193370624 { + return 17566559889463431168; + } + if x <= 26009909143930466304 { + return 17595471832952045568; + } + if x <= 26194376584667561984 { + return 17623579861507229696; + } + if x <= 26378844025404657664 { + return 17650900863837954048; + } + if x <= 26563311466141753344 { + return 17677451566135410688; + } + if x <= 26747778906878849024 { + return 17703248526746337280; + } + if x <= 26932246347615944704 { + return 17728308131128877056; + } + if x <= 27116713788353040384 { + return 17752646587087935488; + } + if x <= 27301181229090136064 { + return 17776279920286781440; + } + if x <= 27485648669827231744 { + return 17799223970031376384; + } + } + if x <= 29330323077198188544 { + if x <= 27670116110564327424 { + return 17821494385323737088; + } + if x <= 27854583551301423104 { + return 17843106621180358656; + } + if x <= 28039050992038518784 { + return 17864075935211624448; + } + if x <= 28223518432775614464 { + return 17884417384457840640; + } + if x <= 28407985873512710144 { + return 17904145822477408256; + } + if x <= 28592453314249805824 { + return 17923275896682506240; + } + if x <= 28776920754986901504 { + return 17941822045917437952; + } + if x <= 28961388195723997184 { + return 17959798498274711552; + } + if x <= 29145855636461092864 { + return 17977219269143760896; + } + if x <= 29330323077198188544 { + return 17994098159487121408; + } + } + if x <= 31174997484569141248 { + if x <= 29514790517935284224 { + return 18010448754338713600; + } + if x <= 29699257958672379904 { + return 18026284421518878720; + } + if x <= 29883725399409475584 { + return 18041618310560610304; + } + if x <= 30068192840146567168 { + return 18056463351841458176; + } + if x <= 30252660280883662848 { + return 18070832255915431936; + } + if x <= 30437127721620758528 { + return 18084737513039206400; + } + if x <= 30621595162357854208 { + return 18098191392886906880; + } + if x <= 30806062603094949888 { + return 18111205944447655936; + } + if x <= 30990530043832045568 { + return 18123792996100098048; + } + if x <= 31174997484569141248 { + return 18135964155858038784; + } + } + if x <= 33019671891940098048 { + if x <= 31359464925306236928 { + return 18147730811781371904; + } + if x <= 31543932366043332608 { + return 18159104132546453504; + } + if x <= 31728399806780428288 { + return 18170095068170047488; + } + if x <= 31912867247517523968 { + return 18180714350881038336; + } + if x <= 32097334688254619648 { + return 18190972496134107136; + } + if x <= 32281802128991715328 { + return 18200879803759552512; + } + if x <= 32466269569728811008 { + return 18210446359243550720; + } + if x <= 32650737010465906688 { + return 18219682035133120512; + } + if x <= 32835204451203002368 { + return 18228596492560154624; + } + if x <= 33019671891940098048 { + return 18237199182878894080; + } + } + if x <= 34864346299311050752 { + if x <= 33204139332677193728 { + return 18245499349411323904; + } + if x <= 33388606773414289408 { + return 18253506029294995456; + } + if x <= 33573074214151385088 { + return 18261228055427880960; + } + if x <= 33757541654888480768 { + return 18268674058504921088; + } + if x <= 33942009095625576448 { + return 18275852469141008384; + } + if x <= 34126476536362672128 { + return 18282771520075268096; + } + if x <= 34310943977099767808 { + return 18289439248451522560; + } + if x <= 34495411417836863488 { + return 18295863498169980928; + } + if x <= 34679878858573955072 { + return 18302051922305267712; + } + if x <= 34864346299311050752 { + return 18308011985585967104; + } + } + if x <= 36709020706682007552 { + if x <= 35048813740048146432 { + return 18313750966931048448; + } + if x <= 35233281180785242112 { + return 18319275962038544384; + } + if x <= 35417748621522337792 { + return 18324593886022047744; + } + if x <= 35602216062259433472 { + return 18329711476090615808; + } + if x <= 35786683502996529152 { + return 18334635294267887616; + } + if x <= 35971150943733624832 { + return 18339371730146226176; + } + if x <= 36155618384470720512 { + return 18343927003671875584; + } + if x <= 36340085825207816192 { + return 18348307167957243904; + } + if x <= 36524553265944911872 { + return 18352518112116494336; + } + if x <= 36709020706682007552 { + return 18356565564120772608; + } + } + if x <= 62718929850612473856 { + if x <= 36893488147419103232 { + return 18360455093669533696; + } + if x <= 38738162554790060032 { + return 18391782614824026112; + } + if x <= 40582836962161016832 { + return 18412380624802023424; + } + if x <= 42427511369531965440 { + return 18425656187587059712; + } + if x <= 44272185776902922240 { + return 18434043234066948096; + } + if x <= 46116860184273879040 { + return 18439237133993463808; + } + if x <= 47961534591644835840 { + return 18442390007235248128; + } + if x <= 49806208999015792640 { + return 18444266072035147776; + } + if x <= 51650883406386741248 { + return 18445360324505407488; + } + if x <= 53495557813757698048 { + return 18445985951670278144; + } + if x <= 55340232221128654848 { + return 18446336575964956672; + } + if x <= 57184906628499611648 { + return 18446529193908295680; + } + if x <= 59029581035870568448 { + return 18446632918035736576; + } + if x <= 60874255443241517056 { + return 18446687668919484416; + } + if x <= 62718929850612473856 { + return 18446715997887504384; + } + } + return ONE; +} diff --git a/src/numbers/fixed_point/implementations/fp8x23/core.cairo b/src/numbers/fixed_point/implementations/fp8x23/core.cairo index dcbfefa96..cdee93541 100644 --- a/src/numbers/fixed_point/implementations/fp8x23/core.cairo +++ b/src/numbers/fixed_point/implementations/fp8x23/core.cairo @@ -6,7 +6,7 @@ use core::traits::{TryInto, Into}; use orion::numbers::signed_integer::{i32::i32, i8::i8}; use orion::numbers::fixed_point::core::{FixedTrait}; -use orion::numbers::fixed_point::implementations::fp8x23::math::{core as core_math, trig, hyp}; +use orion::numbers::fixed_point::implementations::fp8x23::math::{core as core_math, trig, hyp, erf}; use orion::numbers::fixed_point::utils; /// A struct representing a fixed point number. @@ -218,6 +218,10 @@ impl FP8x23Impl of FixedTrait { fn is_neg_inf(self: FP8x23) -> bool { self.is_inf() && self.sign } + + fn erf(self: FP8x23) -> FP8x23 { + return erf::erf(self); + } } diff --git a/src/numbers/fixed_point/implementations/fp8x23/math.cairo b/src/numbers/fixed_point/implementations/fp8x23/math.cairo index 970c65f30..b0cf1d5e7 100644 --- a/src/numbers/fixed_point/implementations/fp8x23/math.cairo +++ b/src/numbers/fixed_point/implementations/fp8x23/math.cairo @@ -3,3 +3,4 @@ mod comp; mod lut; mod trig; mod hyp; +mod erf; diff --git a/src/numbers/fixed_point/implementations/fp8x23/math/erf.cairo b/src/numbers/fixed_point/implementations/fp8x23/math/erf.cairo new file mode 100644 index 000000000..8121e170b --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23/math/erf.cairo @@ -0,0 +1,25 @@ +use core::traits::Into; +use orion::numbers::fixed_point::implementations::fp8x23::core::{ONE, FP8x23, FixedTrait}; +use orion::numbers::fixed_point::implementations::fp8x23::math::lut::erf_lut; + +const ERF_COMPUTATIONAL_ACCURACY: u32 = 100; +const MAX_ERF_COMPUTATIONAL_ACCURACY: u32 = 10; +const ROUND_CHECK_NUMBER: u32 = 1; +// Values > MAX_ERF_NUMBER return 1 +const MAX_ERF_NUMBER: u32 = 29360128; +// Values <= ERF_TRUNCATION_NUMBER -> two decimal places, and values > ERF_TRUNCATION_NUMBER -> one decimal place +const ERF_TRUNCATION_NUMBER: u32 = 16777216; + +fn erf(x: FP8x23) -> FP8x23 { + // Lookup + // 1. if x.mag < 3.5 { lookup table } + // 2. else{ return 1} + let mut erf_value: u32 = 0; + + if x.mag < MAX_ERF_NUMBER { + erf_value = erf_lut(x.mag); + } else { + erf_value = ONE; + } + FP8x23 { mag: erf_value, sign: x.sign } +} diff --git a/src/numbers/fixed_point/implementations/fp8x23/math/lut.cairo b/src/numbers/fixed_point/implementations/fp8x23/math/lut.cairo index e18d32a0a..fdb9dfea3 100644 --- a/src/numbers/fixed_point/implementations/fp8x23/math/lut.cairo +++ b/src/numbers/fixed_point/implementations/fp8x23/math/lut.cairo @@ -1,3 +1,5 @@ +use orion::numbers::fixed_point::implementations::fp8x23::core::ONE; + // Calculates the most significant bit fn msb(whole: u32) -> (u32, u32) { if whole < 256 { @@ -1227,3 +1229,695 @@ fn atan(a: u32) -> (u32, u32, u32) { return (5813305, 5083601, 5123141); } + +fn erf_lut(x: u32) -> u32 { + // Construct the erf lookup table + if x <= 754974 { + if x <= 0 { + return 0; + } + if x <= 83886 { + return 94652; + } + if x <= 167772 { + return 189285; + } + if x <= 251658 { + return 283880; + } + if x <= 335544 { + return 378419; + } + if x <= 419430 { + return 472882; + } + if x <= 503316 { + return 567251; + } + if x <= 587202 { + return 661506; + } + if x <= 671088 { + return 755630; + } + if x <= 754974 { + return 849603; + } + } + if x <= 1593835 { + if x <= 838860 { + return 943407; + } + if x <= 922746 { + return 1037024; + } + if x <= 1006632 { + return 1130434; + } + if x <= 1090519 { + return 1223622; + } + if x <= 1174405 { + return 1316567; + } + if x <= 1258291 { + return 1409252; + } + if x <= 1342177 { + return 1501659; + } + if x <= 1426063 { + return 1593772; + } + if x <= 1509949 { + return 1685571; + } + if x <= 1593835 { + return 1777041; + } + } + if x <= 2432696 { + if x <= 1677721 { + return 1868164; + } + if x <= 1761607 { + return 1958923; + } + if x <= 1845493 { + return 2049302; + } + if x <= 1929379 { + return 2139284; + } + if x <= 2013265 { + return 2228853; + } + if x <= 2097152 { + return 2317993; + } + if x <= 2181038 { + return 2406689; + } + if x <= 2264924 { + return 2494924; + } + if x <= 2348810 { + return 2582685; + } + if x <= 2432696 { + return 2669955; + } + } + if x <= 3271557 { + if x <= 2516582 { + return 2756721; + } + if x <= 2600468 { + return 2842967; + } + if x <= 2684354 { + return 2928681; + } + if x <= 2768240 { + return 3013847; + } + if x <= 2852126 { + return 3098454; + } + if x <= 2936012 { + return 3182487; + } + if x <= 3019898 { + return 3265934; + } + if x <= 3103784 { + return 3348782; + } + if x <= 3187671 { + return 3431019; + } + if x <= 3271557 { + return 3512634; + } + } + if x <= 4110417 { + if x <= 3355443 { + return 3593615; + } + if x <= 3439329 { + return 3673951; + } + if x <= 3523215 { + return 3753630; + } + if x <= 3607101 { + return 3832643; + } + if x <= 3690987 { + return 3910979; + } + if x <= 3774873 { + return 3988629; + } + if x <= 3858759 { + return 4065584; + } + if x <= 3942645 { + return 4141833; + } + if x <= 4026531 { + return 4217369; + } + if x <= 4110417 { + return 4292184; + } + } + if x <= 4949278 { + if x <= 4194304 { + return 4366269; + } + if x <= 4278190 { + return 4439617; + } + if x <= 4362076 { + return 4512220; + } + if x <= 4445962 { + return 4584073; + } + if x <= 4529848 { + return 4655167; + } + if x <= 4613734 { + return 4725498; + } + if x <= 4697620 { + return 4795060; + } + if x <= 4781506 { + return 4863847; + } + if x <= 4865392 { + return 4931854; + } + if x <= 4949278 { + return 4999077; + } + } + if x <= 5788139 { + if x <= 5033164 { + return 5065512; + } + if x <= 5117050 { + return 5131153; + } + if x <= 5200936 { + return 5195999; + } + if x <= 5284823 { + return 5260046; + } + if x <= 5368709 { + return 5323291; + } + if x <= 5452595 { + return 5385732; + } + if x <= 5536481 { + return 5447366; + } + if x <= 5620367 { + return 5508192; + } + if x <= 5704253 { + return 5568208; + } + if x <= 5788139 { + return 5627414; + } + } + if x <= 6627000 { + if x <= 5872025 { + return 5685808; + } + if x <= 5955911 { + return 5743390; + } + if x <= 6039797 { + return 5800161; + } + if x <= 6123683 { + return 5856120; + } + if x <= 6207569 { + return 5911268; + } + if x <= 6291456 { + return 5965605; + } + if x <= 6375342 { + return 6019134; + } + if x <= 6459228 { + return 6071855; + } + if x <= 6543114 { + return 6123771; + } + if x <= 6627000 { + return 6174883; + } + } + if x <= 7465861 { + if x <= 6710886 { + return 6225194; + } + if x <= 6794772 { + return 6274706; + } + if x <= 6878658 { + return 6323422; + } + if x <= 6962544 { + return 6371347; + } + if x <= 7046430 { + return 6418482; + } + if x <= 7130316 { + return 6464832; + } + if x <= 7214202 { + return 6510400; + } + if x <= 7298088 { + return 6555192; + } + if x <= 7381975 { + return 6599211; + } + if x <= 7465861 { + return 6642462; + } + } + if x <= 8304721 { + if x <= 7549747 { + return 6684950; + } + if x <= 7633633 { + return 6726680; + } + if x <= 7717519 { + return 6767658; + } + if x <= 7801405 { + return 6807888; + } + if x <= 7885291 { + return 6847377; + } + if x <= 7969177 { + return 6886131; + } + if x <= 8053063 { + return 6924155; + } + if x <= 8136949 { + return 6961456; + } + if x <= 8220835 { + return 6998041; + } + if x <= 8304721 { + return 7033915; + } + } + if x <= 9143582 { + if x <= 8388608 { + return 7069086; + } + if x <= 8472494 { + return 7103561; + } + if x <= 8556380 { + return 7137346; + } + if x <= 8640266 { + return 7170449; + } + if x <= 8724152 { + return 7202877; + } + if x <= 8808038 { + return 7234638; + } + if x <= 8891924 { + return 7265739; + } + if x <= 8975810 { + return 7296187; + } + if x <= 9059696 { + return 7325990; + } + if x <= 9143582 { + return 7355157; + } + } + if x <= 9982443 { + if x <= 9227468 { + return 7383695; + } + if x <= 9311354 { + return 7411612; + } + if x <= 9395240 { + return 7438915; + } + if x <= 9479127 { + return 7465615; + } + if x <= 9563013 { + return 7491717; + } + if x <= 9646899 { + return 7517231; + } + if x <= 9730785 { + return 7542165; + } + if x <= 9814671 { + return 7566527; + } + if x <= 9898557 { + return 7590326; + } + if x <= 9982443 { + return 7613570; + } + } + if x <= 10821304 { + if x <= 10066329 { + return 7636267; + } + if x <= 10150215 { + return 7658425; + } + if x <= 10234101 { + return 7680054; + } + if x <= 10317987 { + return 7701162; + } + if x <= 10401873 { + return 7721757; + } + if x <= 10485760 { + return 7741847; + } + if x <= 10569646 { + return 7761441; + } + if x <= 10653532 { + return 7780548; + } + if x <= 10737418 { + return 7799175; + } + if x <= 10821304 { + return 7817332; + } + } + if x <= 11660165 { + if x <= 10905190 { + return 7835026; + } + if x <= 10989076 { + return 7852266; + } + if x <= 11072962 { + return 7869060; + } + if x <= 11156848 { + return 7885417; + } + if x <= 11240734 { + return 7901344; + } + if x <= 11324620 { + return 7916851; + } + if x <= 11408506 { + return 7931944; + } + if x <= 11492392 { + return 7946632; + } + if x <= 11576279 { + return 7960923; + } + if x <= 11660165 { + return 7974825; + } + } + if x <= 12499025 { + if x <= 11744051 { + return 7988346; + } + if x <= 11827937 { + return 8001494; + } + if x <= 11911823 { + return 8014276; + } + if x <= 11995709 { + return 8026700; + } + if x <= 12079595 { + return 8038774; + } + if x <= 12163481 { + return 8050505; + } + if x <= 12247367 { + return 8061901; + } + if x <= 12331253 { + return 8072969; + } + if x <= 12415139 { + return 8083716; + } + if x <= 12499025 { + return 8094149; + } + } + if x <= 13337886 { + if x <= 12582912 { + return 8104277; + } + if x <= 12666798 { + return 8114105; + } + if x <= 12750684 { + return 8123641; + } + if x <= 12834570 { + return 8132891; + } + if x <= 12918456 { + return 8141862; + } + if x <= 13002342 { + return 8150562; + } + if x <= 13086228 { + return 8158996; + } + if x <= 13170114 { + return 8167170; + } + if x <= 13254000 { + return 8175092; + } + if x <= 13337886 { + return 8182768; + } + } + if x <= 14176747 { + if x <= 13421772 { + return 8190203; + } + if x <= 13505658 { + return 8197405; + } + if x <= 13589544 { + return 8204378; + } + if x <= 13673431 { + return 8211128; + } + if x <= 13757317 { + return 8217663; + } + if x <= 13841203 { + return 8223986; + } + if x <= 13925089 { + return 8230104; + } + if x <= 14008975 { + return 8236022; + } + if x <= 14092861 { + return 8241746; + } + if x <= 14176747 { + return 8247281; + } + } + if x <= 15015608 { + if x <= 14260633 { + return 8252632; + } + if x <= 14344519 { + return 8257804; + } + if x <= 14428405 { + return 8262802; + } + if x <= 14512291 { + return 8267631; + } + if x <= 14596177 { + return 8272296; + } + if x <= 14680064 { + return 8276801; + } + if x <= 14763950 { + return 8281152; + } + if x <= 14847836 { + return 8285352; + } + if x <= 14931722 { + return 8289405; + } + if x <= 15015608 { + return 8293318; + } + } + if x <= 15854469 { + if x <= 15099494 { + return 8297092; + } + if x <= 15183380 { + return 8300733; + } + if x <= 15267266 { + return 8304245; + } + if x <= 15351152 { + return 8307631; + } + if x <= 15435038 { + return 8310895; + } + if x <= 15518924 { + return 8314041; + } + if x <= 15602810 { + return 8317074; + } + if x <= 15686696 { + return 8319995; + } + if x <= 15770583 { + return 8322809; + } + if x <= 15854469 { + return 8325519; + } + } + if x <= 16693329 { + if x <= 15938355 { + return 8328129; + } + if x <= 16022241 { + return 8330642; + } + if x <= 16106127 { + return 8333060; + } + if x <= 16190013 { + return 8335387; + } + if x <= 16273899 { + return 8337626; + } + if x <= 16357785 { + return 8339780; + } + if x <= 16441671 { + return 8341852; + } + if x <= 16525557 { + return 8343844; + } + if x <= 16609443 { + return 8345758; + } + if x <= 16693329 { + return 8347599; + } + } + if x <= 28521267 { + if x <= 16777216 { + return 8349368; + } + if x <= 17616076 { + return 8363614; + } + if x <= 18454937 { + return 8372981; + } + if x <= 19293798 { + return 8379018; + } + if x <= 20132659 { + return 8382832; + } + if x <= 20971520 { + return 8385194; + } + if x <= 21810380 { + return 8386627; + } + if x <= 22649241 { + return 8387481; + } + if x <= 23488102 { + return 8387978; + } + if x <= 24326963 { + return 8388263; + } + if x <= 25165824 { + return 8388422; + } + if x <= 26004684 { + return 8388510; + } + if x <= 26843545 { + return 8388557; + } + if x <= 27682406 { + return 8388582; + } + if x <= 28521267 { + return 8388595; + } + } + return ONE; +} diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo index d33ea4524..1f7ad81a6 100644 --- a/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo +++ b/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo @@ -6,7 +6,9 @@ use core::traits::{TryInto, Into}; use orion::numbers::signed_integer::{i32::i32, i8::i8}; use orion::numbers::{fixed_point::core::{FixedTrait}, FP8x23}; -use orion::numbers::fixed_point::implementations::fp8x23wide::math::{core as core_math, trig, hyp}; +use orion::numbers::fixed_point::implementations::fp8x23wide::math::{ + core as core_math, trig, hyp, erf +}; use orion::numbers::fixed_point::utils; /// A struct representing a fixed point number. @@ -218,6 +220,10 @@ impl FP8x23WImpl of FixedTrait { fn is_neg_inf(self: FP8x23W) -> bool { self.is_inf() && self.sign } + + fn erf(self: FP8x23W) -> FP8x23W { + return erf::erf(self); + } } diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math.cairo index 970c65f30..b0cf1d5e7 100644 --- a/src/numbers/fixed_point/implementations/fp8x23wide/math.cairo +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math.cairo @@ -3,3 +3,4 @@ mod comp; mod lut; mod trig; mod hyp; +mod erf; diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math/erf.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math/erf.cairo new file mode 100644 index 000000000..83f33f9ad --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math/erf.cairo @@ -0,0 +1,25 @@ +use core::traits::Into; +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{ONE, FP8x23W, FixedTrait}; +use orion::numbers::fixed_point::implementations::fp8x23wide::math::lut::erf_lut; + +const ERF_COMPUTATIONAL_ACCURACY: u64 = 100; +const MAX_ERF_COMPUTATIONAL_ACCURACY: u64 = 10; +const ROUND_CHECK_NUMBER: u64 = 1; +// Values > MAX_ERF_NUMBER return 1 +const MAX_ERF_NUMBER: u64 = 29360128; +// Values <= ERF_TRUNCATION_NUMBER -> two decimal places, and values > ERF_TRUNCATION_NUMBER -> one decimal place +const ERF_TRUNCATION_NUMBER: u64 = 16777216; + +fn erf(x: FP8x23W) -> FP8x23W { + // Lookup + // 1. if x.mag < 3.5 { lookup table } + // 2. else{ return 1} + let mut erf_value: u64 = 0; + + if x.mag < MAX_ERF_NUMBER { + erf_value = erf_lut(x.mag); + } else { + erf_value = ONE; + } + FP8x23W { mag: erf_value, sign: x.sign } +} diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math/lut.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math/lut.cairo index 157499b5b..eea11e46a 100644 --- a/src/numbers/fixed_point/implementations/fp8x23wide/math/lut.cairo +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math/lut.cairo @@ -1,3 +1,5 @@ +use orion::numbers::fixed_point::implementations::fp8x23wide::core::ONE; + // Calculates the most significant bit fn msb(whole: u64) -> (u64, u64) { if whole < 256 { @@ -1227,3 +1229,695 @@ fn atan(a: u64) -> (u64, u64, u64) { return (5813305, 5083601, 5123141); } + +fn erf_lut(x: u64) -> u64 { + // Construct the erf lookup table + if x <= 754974 { + if x <= 0 { + return 0; + } + if x <= 83886 { + return 94652; + } + if x <= 167772 { + return 189285; + } + if x <= 251658 { + return 283880; + } + if x <= 335544 { + return 378419; + } + if x <= 419430 { + return 472882; + } + if x <= 503316 { + return 567251; + } + if x <= 587202 { + return 661506; + } + if x <= 671088 { + return 755630; + } + if x <= 754974 { + return 849603; + } + } + if x <= 1593835 { + if x <= 838860 { + return 943407; + } + if x <= 922746 { + return 1037024; + } + if x <= 1006632 { + return 1130434; + } + if x <= 1090519 { + return 1223622; + } + if x <= 1174405 { + return 1316567; + } + if x <= 1258291 { + return 1409252; + } + if x <= 1342177 { + return 1501659; + } + if x <= 1426063 { + return 1593772; + } + if x <= 1509949 { + return 1685571; + } + if x <= 1593835 { + return 1777041; + } + } + if x <= 2432696 { + if x <= 1677721 { + return 1868164; + } + if x <= 1761607 { + return 1958923; + } + if x <= 1845493 { + return 2049302; + } + if x <= 1929379 { + return 2139284; + } + if x <= 2013265 { + return 2228853; + } + if x <= 2097152 { + return 2317993; + } + if x <= 2181038 { + return 2406689; + } + if x <= 2264924 { + return 2494924; + } + if x <= 2348810 { + return 2582685; + } + if x <= 2432696 { + return 2669955; + } + } + if x <= 3271557 { + if x <= 2516582 { + return 2756721; + } + if x <= 2600468 { + return 2842967; + } + if x <= 2684354 { + return 2928681; + } + if x <= 2768240 { + return 3013847; + } + if x <= 2852126 { + return 3098454; + } + if x <= 2936012 { + return 3182487; + } + if x <= 3019898 { + return 3265934; + } + if x <= 3103784 { + return 3348782; + } + if x <= 3187671 { + return 3431019; + } + if x <= 3271557 { + return 3512634; + } + } + if x <= 4110417 { + if x <= 3355443 { + return 3593615; + } + if x <= 3439329 { + return 3673951; + } + if x <= 3523215 { + return 3753630; + } + if x <= 3607101 { + return 3832643; + } + if x <= 3690987 { + return 3910979; + } + if x <= 3774873 { + return 3988629; + } + if x <= 3858759 { + return 4065584; + } + if x <= 3942645 { + return 4141833; + } + if x <= 4026531 { + return 4217369; + } + if x <= 4110417 { + return 4292184; + } + } + if x <= 4949278 { + if x <= 4194304 { + return 4366269; + } + if x <= 4278190 { + return 4439617; + } + if x <= 4362076 { + return 4512220; + } + if x <= 4445962 { + return 4584073; + } + if x <= 4529848 { + return 4655167; + } + if x <= 4613734 { + return 4725498; + } + if x <= 4697620 { + return 4795060; + } + if x <= 4781506 { + return 4863847; + } + if x <= 4865392 { + return 4931854; + } + if x <= 4949278 { + return 4999077; + } + } + if x <= 5788139 { + if x <= 5033164 { + return 5065512; + } + if x <= 5117050 { + return 5131153; + } + if x <= 5200936 { + return 5195999; + } + if x <= 5284823 { + return 5260046; + } + if x <= 5368709 { + return 5323291; + } + if x <= 5452595 { + return 5385732; + } + if x <= 5536481 { + return 5447366; + } + if x <= 5620367 { + return 5508192; + } + if x <= 5704253 { + return 5568208; + } + if x <= 5788139 { + return 5627414; + } + } + if x <= 6627000 { + if x <= 5872025 { + return 5685808; + } + if x <= 5955911 { + return 5743390; + } + if x <= 6039797 { + return 5800161; + } + if x <= 6123683 { + return 5856120; + } + if x <= 6207569 { + return 5911268; + } + if x <= 6291456 { + return 5965605; + } + if x <= 6375342 { + return 6019134; + } + if x <= 6459228 { + return 6071855; + } + if x <= 6543114 { + return 6123771; + } + if x <= 6627000 { + return 6174883; + } + } + if x <= 7465861 { + if x <= 6710886 { + return 6225194; + } + if x <= 6794772 { + return 6274706; + } + if x <= 6878658 { + return 6323422; + } + if x <= 6962544 { + return 6371347; + } + if x <= 7046430 { + return 6418482; + } + if x <= 7130316 { + return 6464832; + } + if x <= 7214202 { + return 6510400; + } + if x <= 7298088 { + return 6555192; + } + if x <= 7381975 { + return 6599211; + } + if x <= 7465861 { + return 6642462; + } + } + if x <= 8304721 { + if x <= 7549747 { + return 6684950; + } + if x <= 7633633 { + return 6726680; + } + if x <= 7717519 { + return 6767658; + } + if x <= 7801405 { + return 6807888; + } + if x <= 7885291 { + return 6847377; + } + if x <= 7969177 { + return 6886131; + } + if x <= 8053063 { + return 6924155; + } + if x <= 8136949 { + return 6961456; + } + if x <= 8220835 { + return 6998041; + } + if x <= 8304721 { + return 7033915; + } + } + if x <= 9143582 { + if x <= 8388608 { + return 7069086; + } + if x <= 8472494 { + return 7103561; + } + if x <= 8556380 { + return 7137346; + } + if x <= 8640266 { + return 7170449; + } + if x <= 8724152 { + return 7202877; + } + if x <= 8808038 { + return 7234638; + } + if x <= 8891924 { + return 7265739; + } + if x <= 8975810 { + return 7296187; + } + if x <= 9059696 { + return 7325990; + } + if x <= 9143582 { + return 7355157; + } + } + if x <= 9982443 { + if x <= 9227468 { + return 7383695; + } + if x <= 9311354 { + return 7411612; + } + if x <= 9395240 { + return 7438915; + } + if x <= 9479127 { + return 7465615; + } + if x <= 9563013 { + return 7491717; + } + if x <= 9646899 { + return 7517231; + } + if x <= 9730785 { + return 7542165; + } + if x <= 9814671 { + return 7566527; + } + if x <= 9898557 { + return 7590326; + } + if x <= 9982443 { + return 7613570; + } + } + if x <= 10821304 { + if x <= 10066329 { + return 7636267; + } + if x <= 10150215 { + return 7658425; + } + if x <= 10234101 { + return 7680054; + } + if x <= 10317987 { + return 7701162; + } + if x <= 10401873 { + return 7721757; + } + if x <= 10485760 { + return 7741847; + } + if x <= 10569646 { + return 7761441; + } + if x <= 10653532 { + return 7780548; + } + if x <= 10737418 { + return 7799175; + } + if x <= 10821304 { + return 7817332; + } + } + if x <= 11660165 { + if x <= 10905190 { + return 7835026; + } + if x <= 10989076 { + return 7852266; + } + if x <= 11072962 { + return 7869060; + } + if x <= 11156848 { + return 7885417; + } + if x <= 11240734 { + return 7901344; + } + if x <= 11324620 { + return 7916851; + } + if x <= 11408506 { + return 7931944; + } + if x <= 11492392 { + return 7946632; + } + if x <= 11576279 { + return 7960923; + } + if x <= 11660165 { + return 7974825; + } + } + if x <= 12499025 { + if x <= 11744051 { + return 7988346; + } + if x <= 11827937 { + return 8001494; + } + if x <= 11911823 { + return 8014276; + } + if x <= 11995709 { + return 8026700; + } + if x <= 12079595 { + return 8038774; + } + if x <= 12163481 { + return 8050505; + } + if x <= 12247367 { + return 8061901; + } + if x <= 12331253 { + return 8072969; + } + if x <= 12415139 { + return 8083716; + } + if x <= 12499025 { + return 8094149; + } + } + if x <= 13337886 { + if x <= 12582912 { + return 8104277; + } + if x <= 12666798 { + return 8114105; + } + if x <= 12750684 { + return 8123641; + } + if x <= 12834570 { + return 8132891; + } + if x <= 12918456 { + return 8141862; + } + if x <= 13002342 { + return 8150562; + } + if x <= 13086228 { + return 8158996; + } + if x <= 13170114 { + return 8167170; + } + if x <= 13254000 { + return 8175092; + } + if x <= 13337886 { + return 8182768; + } + } + if x <= 14176747 { + if x <= 13421772 { + return 8190203; + } + if x <= 13505658 { + return 8197405; + } + if x <= 13589544 { + return 8204378; + } + if x <= 13673431 { + return 8211128; + } + if x <= 13757317 { + return 8217663; + } + if x <= 13841203 { + return 8223986; + } + if x <= 13925089 { + return 8230104; + } + if x <= 14008975 { + return 8236022; + } + if x <= 14092861 { + return 8241746; + } + if x <= 14176747 { + return 8247281; + } + } + if x <= 15015608 { + if x <= 14260633 { + return 8252632; + } + if x <= 14344519 { + return 8257804; + } + if x <= 14428405 { + return 8262802; + } + if x <= 14512291 { + return 8267631; + } + if x <= 14596177 { + return 8272296; + } + if x <= 14680064 { + return 8276801; + } + if x <= 14763950 { + return 8281152; + } + if x <= 14847836 { + return 8285352; + } + if x <= 14931722 { + return 8289405; + } + if x <= 15015608 { + return 8293318; + } + } + if x <= 15854469 { + if x <= 15099494 { + return 8297092; + } + if x <= 15183380 { + return 8300733; + } + if x <= 15267266 { + return 8304245; + } + if x <= 15351152 { + return 8307631; + } + if x <= 15435038 { + return 8310895; + } + if x <= 15518924 { + return 8314041; + } + if x <= 15602810 { + return 8317074; + } + if x <= 15686696 { + return 8319995; + } + if x <= 15770583 { + return 8322809; + } + if x <= 15854469 { + return 8325519; + } + } + if x <= 16693329 { + if x <= 15938355 { + return 8328129; + } + if x <= 16022241 { + return 8330642; + } + if x <= 16106127 { + return 8333060; + } + if x <= 16190013 { + return 8335387; + } + if x <= 16273899 { + return 8337626; + } + if x <= 16357785 { + return 8339780; + } + if x <= 16441671 { + return 8341852; + } + if x <= 16525557 { + return 8343844; + } + if x <= 16609443 { + return 8345758; + } + if x <= 16693329 { + return 8347599; + } + } + if x <= 28521267 { + if x <= 16777216 { + return 8349368; + } + if x <= 17616076 { + return 8363614; + } + if x <= 18454937 { + return 8372981; + } + if x <= 19293798 { + return 8379018; + } + if x <= 20132659 { + return 8382832; + } + if x <= 20971520 { + return 8385194; + } + if x <= 21810380 { + return 8386627; + } + if x <= 22649241 { + return 8387481; + } + if x <= 23488102 { + return 8387978; + } + if x <= 24326963 { + return 8388263; + } + if x <= 25165824 { + return 8388422; + } + if x <= 26004684 { + return 8388510; + } + if x <= 26843545 { + return 8388557; + } + if x <= 27682406 { + return 8388582; + } + if x <= 28521267 { + return 8388595; + } + } + return ONE; +} diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index fd80e2b93..fb5345ce8 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -121,6 +121,8 @@ impl TensorSerde, impl TDrop: Drop> of Serde { /// # tensor.new /// @@ -4853,6 +4855,103 @@ trait TensorTrait { /// ``` /// fn not(self: @Tensor) -> Tensor; + /// ## tensor.reduce_log_sum + /// + /// ```rust + /// fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor; + /// ``` + /// + /// Computes the log sum of the input tensor's elements along the provided axes. + /// ## Args + /// + /// * `self`(`@Tensor`) - The input tensor. + /// * `axis`(`usize`) - The dimension to reduce. + /// * `keepdims`(`bool`) - If true, retains reduced dimensions with length 1. + /// + /// ## Panics + /// + /// * Panics if axis is not in the range of the input tensor's dimensions. + /// + /// ## Returns + /// + /// A new `Tensor` instance with the specified axis reduced by summing its elements. + /// + /// fn reduce_log_sum() -> Tensor { + /// + /// let mut sizes = ArrayTrait::new(); + /// sizes.append(2); + /// sizes.append(2); + /// sizes.append(2); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FixedTrait::new_unscaled(1, false)); + /// data.append(FixedTrait::new_unscaled(2, false)); + /// data.append(FixedTrait::new_unscaled(3, false)); + /// data.append(FixedTrait::new_unscaled(4, false)); + /// data.append(FixedTrait::new_unscaled(5, false)); + /// data.append(FixedTrait::new_unscaled(6, false)); + /// data.append(FixedTrait::new_unscaled(7, false)); + /// data.append(FixedTrait::new_unscaled(8, false)); + /// + /// let tensor = TensorTrait::::new(sizes.span(), data.span()); + /// + /// We can call `reduce_log_sum` function as follows. + /// return tensor.reduce_log_sum(axis: 2, keepdims: false); + /// } + /// >>> [[0x11938, 0x1f203], [0x265d9, 0x2b540]] + /// ``` + /// + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor; + /// ## tensor.erf + /// + /// ```rust + /// fn erf(self: @Tensor) -> Tensor; + /// ``` + /// + /// Computes the mean of the input tensor's elements along the provided axes. + /// + /// ## Args + /// + /// * `self`(`@Tensor`) - The input tensor. + /// + /// ## Returns + /// + /// A new `Tensor` of the same shape as the input tensor with + /// the the error function of the input tensor computed element-wise. + /// + /// ## Type Constraints + /// + /// Constrain input and output types to fixed point tensors. + /// + /// ## Examples + /// + /// ```rust + /// use core::array::{ArrayTrait, SpanTrait}; + /// + /// use orion::operators::tensor::{TensorTrait, Tensor, FP16x16Tensor}; + /// use orion::numbers::{FixedTrait, FP16x16}; + /// + /// fn erf_example() -> Tensor { + /// // The erf inputs is [1.0, 0.134, 0.520, 2.0, 3.5, 5.164] + /// let tensor = TensorTrait::::new( + /// shape: array![6].span(), + /// data: array![ + /// FixedTrait::new_unscaled(65536, false), + /// FixedTrait::new_unscaled(8832, false), + /// FixedTrait::new_unscaled(34079, false), + /// FixedTrait::new_unscaled(131072, false), + /// FixedTrait::new_unscaled(229376, false), + /// FixedTrait::new_unscaled(338428, false), + /// ] + /// .span(), + /// ); + /// + /// return tensor.erf(); + /// } + /// >>> [55227,9560,35252,65229,65536,65536] + /// ``` + /// + fn erf(self: @Tensor) -> Tensor; } /// Cf: TensorTrait::new docstring diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index da17b38a9..5436f2131 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -457,6 +457,14 @@ impl BoolTensor of TensorTrait { ) -> Tensor { panic(array!['not supported!']) } + + fn erf(self: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + panic(array!['not supported!']) + } } /// Implements partial equal for two `Tensor` using the `PartialEq` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index 0d0318cd7..fb1925a26 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -518,6 +518,15 @@ impl FP16x16Tensor of TensorTrait { ) -> Tensor { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } + + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) + } + + + fn erf(self: @Tensor) -> Tensor { + math::erf::erf(*self) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index f5172a006..42deb993c 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -484,6 +484,15 @@ impl FP16x16WTensor of TensorTrait { ) -> Tensor { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } + + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) + } + + + fn erf(self: @Tensor) -> Tensor { + math::erf::erf(*self) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 05fdebecb..a5f70535b 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -519,6 +519,15 @@ impl FP32x32Tensor of TensorTrait { ) -> Tensor { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } + + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) + } + + + fn erf(self: @Tensor) -> Tensor { + math::erf::erf(*self) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index 40b78c7d3..7e802fa6e 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -520,6 +520,15 @@ impl FP64x64Tensor of TensorTrait { ) -> Tensor { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } + + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) + } + + + fn erf(self: @Tensor) -> Tensor { + math::erf::erf(*self) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 62af2b3e4..bda045fa5 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -518,6 +518,14 @@ impl FP8x23Tensor of TensorTrait { ) -> Tensor { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } + + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) + } + + fn erf(self: @Tensor) -> Tensor { + math::erf::erf(*self) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index 85d8aa2fb..c3c756d0b 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -471,6 +471,14 @@ impl FP8x23WTensor of TensorTrait { ) -> Tensor { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } + + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) + } + + fn erf(self: @Tensor) -> Tensor { + math::erf::erf(*self) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index a25cef6ee..707dfce73 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -515,6 +515,14 @@ impl I32Tensor of TensorTrait { ) -> Tensor { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } + + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + panic(array!['not supported!']) + } + + fn erf(self: @Tensor) -> Tensor { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index 94f32f70b..60ec52dcc 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -513,6 +513,14 @@ impl I8Tensor of TensorTrait { ) -> Tensor { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } + + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + panic(array!['not supported!']) + } + + fn erf(self: @Tensor) -> Tensor { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 56abb440a..841b21db6 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -456,6 +456,14 @@ impl U32Tensor of TensorTrait { ) -> Tensor { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } + + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + panic(array!['not supported!']) + } + + fn erf(self: @Tensor) -> Tensor { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/math.cairo b/src/operators/tensor/math.cairo index d5a74322c..8635b60bf 100644 --- a/src/operators/tensor/math.cairo +++ b/src/operators/tensor/math.cairo @@ -62,3 +62,5 @@ mod sequence_insert; mod concat_from_sequence; mod is_nan; mod is_inf; +mod reduce_log_sum; +mod erf; diff --git a/src/operators/tensor/math/erf.cairo b/src/operators/tensor/math/erf.cairo new file mode 100644 index 000000000..44a755c15 --- /dev/null +++ b/src/operators/tensor/math/erf.cairo @@ -0,0 +1,31 @@ +use core::array::ArrayTrait; +use core::array::SpanTrait; +use core::option::OptionTrait; + +use orion::numbers::fixed_point::core::FixedTrait; +use orion::operators::tensor::core::{Tensor, TensorTrait}; +use orion::numbers::NumberTrait; + + +/// Cf: TensorTrait::erf docstring +fn erf< + T, + MAG, + impl TTensor: TensorTrait, + impl TFixed: FixedTrait, + impl TCopy: Copy, + impl TDrop: Drop, +>( + mut z: Tensor +) -> Tensor { + let mut data_result = ArrayTrait::::new(); + + loop { + match z.data.pop_front() { + Option::Some(item) => { data_result.append((*item).erf()); }, + Option::None(_) => { break; } + }; + }; + + return TensorTrait::::new(z.shape, data_result.span()); +} diff --git a/src/operators/tensor/math/reduce_log_sum.cairo b/src/operators/tensor/math/reduce_log_sum.cairo new file mode 100644 index 000000000..d2cccaad4 --- /dev/null +++ b/src/operators/tensor/math/reduce_log_sum.cairo @@ -0,0 +1,28 @@ +use core::option::OptionTrait; +use core::array::ArrayTrait; +use core::array::SpanTrait; +use core::debug::PrintTrait; + +use orion::numbers::NumberTrait; +use orion::operators::tensor::core::{Tensor, TensorTrait, ravel_index, unravel_index}; +use orion::numbers::signed_integer::integer_trait::IntegerTrait; +use orion::numbers::fixed_point::core::FixedTrait; + +/// Cf: TensorTrait::reduce_sum_square docstring +fn reduce_log_sum< + T, + MAG, + impl TTensor: TensorTrait, + impl TNumber: NumberTrait, + impl TMul: Mul, + impl TAddEq: AddEq, + impl TCopy: Copy, + impl TDrop: Drop, +>( + self: @Tensor, axis: usize, keepdims: bool +) -> Tensor { + let tensor_square_sum = self.reduce_sum(axis: axis, keepdims: keepdims); + let tensor_square_sum_log = tensor_square_sum.log(); + + return tensor_square_sum_log; +} diff --git a/tests/nodes.cairo b/tests/nodes.cairo index 12bc31c9b..5ed3f0eb2 100644 --- a/tests/nodes.cairo +++ b/tests/nodes.cairo @@ -484,7 +484,7 @@ mod where_i8; mod where_i8_broadcast; mod where_u32; mod where_u32_broadcast; -mod not_bool; +mod not_bool; mod round_fp16x16; mod round_fp8x23; mod max_fp16x16_three_tensors; @@ -819,4 +819,12 @@ mod is_pos_inf_i32; mod is_neg_inf_i32; mod is_pos_inf_i8; mod is_neg_inf_i8; +mod reduce_log_sum_fp8x23_export_do_not_keepdims; +mod reduce_log_sum_fp8x23_export_keepdims; +mod reduce_log_sum_fp8x23_export_negative_axes_keepdims; +mod reduce_log_sum_fp16x16_export_do_not_keepdims; +mod reduce_log_sum_fp16x16_export_keepdims; +mod reduce_log_sum_fp16x16_export_negative_axes_keepdims; mod and_bool; +mod erf_fp16x16; +mod erf_fp8x23; diff --git a/tests/nodes/concat_from_sequence_fp16x16_new_axis_default.cairo b/tests/nodes/concat_from_sequence_fp16x16_new_axis_default.cairo index c1b0bd466..ce6b56929 100644 --- a/tests/nodes/concat_from_sequence_fp16x16_new_axis_default.cairo +++ b/tests/nodes/concat_from_sequence_fp16x16_new_axis_default.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_fp16x16_new_axis_default() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(())); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::None(()) + ); assert_eq(y, z); } diff --git a/tests/nodes/concat_from_sequence_fp16x16_new_axis_one.cairo b/tests/nodes/concat_from_sequence_fp16x16_new_axis_one.cairo index 01433366d..bbb5d9fc0 100644 --- a/tests/nodes/concat_from_sequence_fp16x16_new_axis_one.cairo +++ b/tests/nodes/concat_from_sequence_fp16x16_new_axis_one.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_fp16x16_new_axis_one() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1)); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::Some(1) + ); assert_eq(y, z); } diff --git a/tests/nodes/concat_from_sequence_fp16x16_new_axis_zero.cairo b/tests/nodes/concat_from_sequence_fp16x16_new_axis_zero.cairo index 95dcbab39..7ae21d053 100644 --- a/tests/nodes/concat_from_sequence_fp16x16_new_axis_zero.cairo +++ b/tests/nodes/concat_from_sequence_fp16x16_new_axis_zero.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_fp16x16_new_axis_zero() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0)); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::Some(0) + ); assert_eq(y, z); } diff --git a/tests/nodes/concat_from_sequence_fp8x23_new_axis_default.cairo b/tests/nodes/concat_from_sequence_fp8x23_new_axis_default.cairo index 3a7a12230..3383ce9f6 100644 --- a/tests/nodes/concat_from_sequence_fp8x23_new_axis_default.cairo +++ b/tests/nodes/concat_from_sequence_fp8x23_new_axis_default.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_fp8x23_new_axis_default() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(())); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::None(()) + ); assert_eq(y, z); } diff --git a/tests/nodes/concat_from_sequence_fp8x23_new_axis_one.cairo b/tests/nodes/concat_from_sequence_fp8x23_new_axis_one.cairo index 237026a61..a032dd956 100644 --- a/tests/nodes/concat_from_sequence_fp8x23_new_axis_one.cairo +++ b/tests/nodes/concat_from_sequence_fp8x23_new_axis_one.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_fp8x23_new_axis_one() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1)); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::Some(1) + ); assert_eq(y, z); } diff --git a/tests/nodes/concat_from_sequence_fp8x23_new_axis_zero.cairo b/tests/nodes/concat_from_sequence_fp8x23_new_axis_zero.cairo index e34c89da5..3696d14af 100644 --- a/tests/nodes/concat_from_sequence_fp8x23_new_axis_zero.cairo +++ b/tests/nodes/concat_from_sequence_fp8x23_new_axis_zero.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_fp8x23_new_axis_zero() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0)); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::Some(0) + ); assert_eq(y, z); } diff --git a/tests/nodes/concat_from_sequence_i32_new_axis_default.cairo b/tests/nodes/concat_from_sequence_i32_new_axis_default.cairo index 1829e386d..8c53ebf67 100644 --- a/tests/nodes/concat_from_sequence_i32_new_axis_default.cairo +++ b/tests/nodes/concat_from_sequence_i32_new_axis_default.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_i32_new_axis_default() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(())); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::None(()) + ); assert_eq(y, z); } diff --git a/tests/nodes/concat_from_sequence_i32_new_axis_one.cairo b/tests/nodes/concat_from_sequence_i32_new_axis_one.cairo index e18d2007a..855321bf4 100644 --- a/tests/nodes/concat_from_sequence_i32_new_axis_one.cairo +++ b/tests/nodes/concat_from_sequence_i32_new_axis_one.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_i32_new_axis_one() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1)); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::Some(1) + ); assert_eq(y, z); } diff --git a/tests/nodes/concat_from_sequence_i32_new_axis_zero.cairo b/tests/nodes/concat_from_sequence_i32_new_axis_zero.cairo index 2abd37613..ca190e242 100644 --- a/tests/nodes/concat_from_sequence_i32_new_axis_zero.cairo +++ b/tests/nodes/concat_from_sequence_i32_new_axis_zero.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_i32_new_axis_zero() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0)); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::Some(0) + ); assert_eq(y, z); } diff --git a/tests/nodes/concat_from_sequence_i8_new_axis_default.cairo b/tests/nodes/concat_from_sequence_i8_new_axis_default.cairo index f9369d2ea..0cab162e7 100644 --- a/tests/nodes/concat_from_sequence_i8_new_axis_default.cairo +++ b/tests/nodes/concat_from_sequence_i8_new_axis_default.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_i8_new_axis_default() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(())); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::None(()) + ); assert_eq(y, z); } diff --git a/tests/nodes/concat_from_sequence_i8_new_axis_one.cairo b/tests/nodes/concat_from_sequence_i8_new_axis_one.cairo index 445102ae2..59d295d3c 100644 --- a/tests/nodes/concat_from_sequence_i8_new_axis_one.cairo +++ b/tests/nodes/concat_from_sequence_i8_new_axis_one.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_i8_new_axis_one() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1)); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::Some(1) + ); assert_eq(y, z); } diff --git a/tests/nodes/concat_from_sequence_i8_new_axis_zero.cairo b/tests/nodes/concat_from_sequence_i8_new_axis_zero.cairo index 7d25b7625..0d4e4daea 100644 --- a/tests/nodes/concat_from_sequence_i8_new_axis_zero.cairo +++ b/tests/nodes/concat_from_sequence_i8_new_axis_zero.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_i8_new_axis_zero() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0)); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::Some(0) + ); assert_eq(y, z); } diff --git a/tests/nodes/concat_from_sequence_u32_new_axis_default.cairo b/tests/nodes/concat_from_sequence_u32_new_axis_default.cairo index 53337ccce..8a787ea18 100644 --- a/tests/nodes/concat_from_sequence_u32_new_axis_default.cairo +++ b/tests/nodes/concat_from_sequence_u32_new_axis_default.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_u32_new_axis_default() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(())); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::None(()) + ); assert_eq(y, z); } diff --git a/tests/nodes/concat_from_sequence_u32_new_axis_one.cairo b/tests/nodes/concat_from_sequence_u32_new_axis_one.cairo index 51b1c5a77..fa1fa4d9c 100644 --- a/tests/nodes/concat_from_sequence_u32_new_axis_one.cairo +++ b/tests/nodes/concat_from_sequence_u32_new_axis_one.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_u32_new_axis_one() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1)); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::Some(1) + ); assert_eq(y, z); } diff --git a/tests/nodes/concat_from_sequence_u32_new_axis_zero.cairo b/tests/nodes/concat_from_sequence_u32_new_axis_zero.cairo index 45550fc53..5a2b1fa30 100644 --- a/tests/nodes/concat_from_sequence_u32_new_axis_zero.cairo +++ b/tests/nodes/concat_from_sequence_u32_new_axis_zero.cairo @@ -15,7 +15,9 @@ fn test_concat_from_sequence_u32_new_axis_zero() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0)); + let y = TensorTrait::concat_from_sequence( + input_0, IntegerTrait::::new(1, false), Option::Some(0) + ); assert_eq(y, z); } diff --git a/tests/nodes/erf_fp16x16.cairo b/tests/nodes/erf_fp16x16.cairo new file mode 100644 index 000000000..2f6cecd8f --- /dev/null +++ b/tests/nodes/erf_fp16x16.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::FP16x16Tensor; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16TensorPartialEq; + +#[test] +#[available_gas(2000000000)] +fn test_erf_fp16x16() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.erf(); + + assert_eq(y, z); +} diff --git a/tests/nodes/erf_fp16x16/input_0.cairo b/tests/nodes/erf_fp16x16/input_0.cairo new file mode 100644 index 000000000..b40c47ea4 --- /dev/null +++ b/tests/nodes/erf_fp16x16/input_0.cairo @@ -0,0 +1,18 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 7864, sign: false }); + data.append(FP16x16 { mag: 108789, sign: true }); + data.append(FP16x16 { mag: 222822, sign: false }); + data.append(FP16x16 { mag: 314572, sign: false }); + data.append(FP16x16 { mag: 176947, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/erf_fp16x16/output_0.cairo b/tests/nodes/erf_fp16x16/output_0.cairo new file mode 100644 index 000000000..3de9849fb --- /dev/null +++ b/tests/nodes/erf_fp16x16/output_0.cairo @@ -0,0 +1,18 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 8831, sign: false }); + data.append(FP16x16 { mag: 64297, sign: true }); + data.append(FP16x16 { mag: 65535, sign: false }); + data.append(FP16x16 { mag: 65535, sign: false }); + data.append(FP16x16 { mag: 65527, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/erf_fp8x23.cairo b/tests/nodes/erf_fp8x23.cairo new file mode 100644 index 000000000..83893a8f6 --- /dev/null +++ b/tests/nodes/erf_fp8x23.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::operators::tensor::FP8x23Tensor; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; + +#[test] +#[available_gas(2000000000)] +fn test_erf_fp8x23() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.erf(); + + assert_eq(y, z); +} diff --git a/tests/nodes/erf_fp8x23/input_0.cairo b/tests/nodes/erf_fp8x23/input_0.cairo new file mode 100644 index 000000000..c0f13ff6b --- /dev/null +++ b/tests/nodes/erf_fp8x23/input_0.cairo @@ -0,0 +1,18 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 1006632, sign: false }); + data.append(FP8x23 { mag: 13925089, sign: true }); + data.append(FP8x23 { mag: 28521267, sign: false }); + data.append(FP8x23 { mag: 40265318, sign: false }); + data.append(FP8x23 { mag: 22649241, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/erf_fp8x23/output_0.cairo b/tests/nodes/erf_fp8x23/output_0.cairo new file mode 100644 index 000000000..0abcbe8a5 --- /dev/null +++ b/tests/nodes/erf_fp8x23/output_0.cairo @@ -0,0 +1,18 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 1130434, sign: false }); + data.append(FP8x23 { mag: 8230104, sign: true }); + data.append(FP8x23 { mag: 8388595, sign: false }); + data.append(FP8x23 { mag: 8388607, sign: false }); + data.append(FP8x23 { mag: 8387481, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/gather_elements_fp16x16_3d_axis1.cairo b/tests/nodes/gather_elements_fp16x16_3d_axis1.cairo index 2afad0443..53757865c 100644 --- a/tests/nodes/gather_elements_fp16x16_3d_axis1.cairo +++ b/tests/nodes/gather_elements_fp16x16_3d_axis1.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_fp16x16_3d_axis1() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(1)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(1)); assert_eq(y, z); } diff --git a/tests/nodes/gather_elements_fp16x16_3d_axis2.cairo b/tests/nodes/gather_elements_fp16x16_3d_axis2.cairo index 1f85eefc7..7952db17a 100644 --- a/tests/nodes/gather_elements_fp16x16_3d_axis2.cairo +++ b/tests/nodes/gather_elements_fp16x16_3d_axis2.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_fp16x16_3d_axis2() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(2)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(2)); assert_eq(y, z); } diff --git a/tests/nodes/gather_elements_fp16x16_3d_default.cairo b/tests/nodes/gather_elements_fp16x16_3d_default.cairo index b8e927081..a1070d7d3 100644 --- a/tests/nodes/gather_elements_fp16x16_3d_default.cairo +++ b/tests/nodes/gather_elements_fp16x16_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_fp16x16_3d_default() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(0)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(0)); assert_eq(y, z); } diff --git a/tests/nodes/gather_elements_fp8x23_3d_axis1.cairo b/tests/nodes/gather_elements_fp8x23_3d_axis1.cairo index f31041ee5..0814a2c6c 100644 --- a/tests/nodes/gather_elements_fp8x23_3d_axis1.cairo +++ b/tests/nodes/gather_elements_fp8x23_3d_axis1.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_fp8x23_3d_axis1() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(1)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(1)); assert_eq(y, z); } diff --git a/tests/nodes/gather_elements_fp8x23_3d_axis2.cairo b/tests/nodes/gather_elements_fp8x23_3d_axis2.cairo index deb41b14e..96bfcc8c4 100644 --- a/tests/nodes/gather_elements_fp8x23_3d_axis2.cairo +++ b/tests/nodes/gather_elements_fp8x23_3d_axis2.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_fp8x23_3d_axis2() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(2)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(2)); assert_eq(y, z); } diff --git a/tests/nodes/gather_elements_fp8x23_3d_default.cairo b/tests/nodes/gather_elements_fp8x23_3d_default.cairo index 75098af27..fee79d361 100644 --- a/tests/nodes/gather_elements_fp8x23_3d_default.cairo +++ b/tests/nodes/gather_elements_fp8x23_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_fp8x23_3d_default() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(0)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(0)); assert_eq(y, z); } diff --git a/tests/nodes/gather_elements_i32_3d_axis1.cairo b/tests/nodes/gather_elements_i32_3d_axis1.cairo index 38562aea0..4282c3e48 100644 --- a/tests/nodes/gather_elements_i32_3d_axis1.cairo +++ b/tests/nodes/gather_elements_i32_3d_axis1.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_i32_3d_axis1() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(1)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(1)); assert_eq(y, z); } diff --git a/tests/nodes/gather_elements_i32_3d_axis2.cairo b/tests/nodes/gather_elements_i32_3d_axis2.cairo index a30803206..a7641f948 100644 --- a/tests/nodes/gather_elements_i32_3d_axis2.cairo +++ b/tests/nodes/gather_elements_i32_3d_axis2.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_i32_3d_axis2() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(2)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(2)); assert_eq(y, z); } diff --git a/tests/nodes/gather_elements_i32_3d_default.cairo b/tests/nodes/gather_elements_i32_3d_default.cairo index b518f78d6..e0d3471bf 100644 --- a/tests/nodes/gather_elements_i32_3d_default.cairo +++ b/tests/nodes/gather_elements_i32_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_i32_3d_default() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(0)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(0)); assert_eq(y, z); } diff --git a/tests/nodes/gather_elements_i8_3d_axis1.cairo b/tests/nodes/gather_elements_i8_3d_axis1.cairo index 283601233..2c9e468dc 100644 --- a/tests/nodes/gather_elements_i8_3d_axis1.cairo +++ b/tests/nodes/gather_elements_i8_3d_axis1.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_i8_3d_axis1() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(1)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(1)); assert_eq(y, z); } diff --git a/tests/nodes/gather_elements_i8_3d_default.cairo b/tests/nodes/gather_elements_i8_3d_default.cairo index ed280e873..7c8787c9e 100644 --- a/tests/nodes/gather_elements_i8_3d_default.cairo +++ b/tests/nodes/gather_elements_i8_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_i8_3d_default() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(0)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(0)); assert_eq(y, z); } diff --git a/tests/nodes/gather_elements_u32_axis1.cairo b/tests/nodes/gather_elements_u32_axis1.cairo index 209834966..523d5b244 100644 --- a/tests/nodes/gather_elements_u32_axis1.cairo +++ b/tests/nodes/gather_elements_u32_axis1.cairo @@ -16,7 +16,7 @@ fn test_gather_elements_u32_axis1() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(1)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(1)); assert_eq(y, z); } diff --git a/tests/nodes/gather_elements_u32_axis2.cairo b/tests/nodes/gather_elements_u32_axis2.cairo index 9334dd164..d9135016c 100644 --- a/tests/nodes/gather_elements_u32_axis2.cairo +++ b/tests/nodes/gather_elements_u32_axis2.cairo @@ -16,7 +16,7 @@ fn test_gather_elements_u32_axis2() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(2)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(2)); assert_eq(y, z); } diff --git a/tests/nodes/gather_elements_u32_axis3.cairo b/tests/nodes/gather_elements_u32_axis3.cairo index 437a90c70..a5abbe543 100644 --- a/tests/nodes/gather_elements_u32_axis3.cairo +++ b/tests/nodes/gather_elements_u32_axis3.cairo @@ -16,7 +16,7 @@ fn test_gather_elements_u32_axis3() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(3)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(3)); assert_eq(y, z); } diff --git a/tests/nodes/gather_elements_u32_default.cairo b/tests/nodes/gather_elements_u32_default.cairo index 2351bdf7a..1f9cc9683 100644 --- a/tests/nodes/gather_elements_u32_default.cairo +++ b/tests/nodes/gather_elements_u32_default.cairo @@ -16,7 +16,7 @@ fn test_gather_elements_u32_default() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = input_0.gather_elements(indices:input_1, axis:Option::Some(0)); + let y = input_0.gather_elements(indices: input_1, axis: Option::Some(0)); assert_eq(y, z); } diff --git a/tests/nodes/not_bool.cairo b/tests/nodes/not_bool.cairo index 65d2cfea6..cc73e1cd4 100644 --- a/tests/nodes/not_bool.cairo +++ b/tests/nodes/not_bool.cairo @@ -1,5 +1,5 @@ -mod input_0; -mod output_0; +mod input_0; +mod output_0; use core::array::{ArrayTrait, SpanTrait}; @@ -17,4 +17,4 @@ fn test_not_bool() { let y = input_0.not(); assert_eq(y, z); -} \ No newline at end of file +} diff --git a/tests/nodes/not_bool/input_0.cairo b/tests/nodes/not_bool/input_0.cairo index 25222167b..eef582ae5 100644 --- a/tests/nodes/not_bool/input_0.cairo +++ b/tests/nodes/not_bool/input_0.cairo @@ -10,4 +10,4 @@ fn input_0() -> Tensor { let mut data = ArrayTrait::new(); data.append(true); TensorTrait::new(shape.span(), data.span()) -} \ No newline at end of file +} diff --git a/tests/nodes/not_bool/output_0.cairo b/tests/nodes/not_bool/output_0.cairo index 24047b93e..43bb7750d 100644 --- a/tests/nodes/not_bool/output_0.cairo +++ b/tests/nodes/not_bool/output_0.cairo @@ -10,4 +10,4 @@ fn output_0() -> Tensor { let mut data = ArrayTrait::new(); data.append(false); TensorTrait::new(shape.span(), data.span()) -} \ No newline at end of file +} diff --git a/tests/nodes/reduce_log_sum_fp16x16_export_do_not_keepdims.cairo b/tests/nodes/reduce_log_sum_fp16x16_export_do_not_keepdims.cairo new file mode 100644 index 000000000..108ef328f --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp16x16_export_do_not_keepdims.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; + +#[test] +#[available_gas(2000000000)] +fn test_reduce_log_sum_fp16x16_export_do_not_keepdims() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.reduce_log_sum(2, false); + + assert_eq(y, z); +} diff --git a/tests/nodes/reduce_log_sum_fp16x16_export_do_not_keepdims/input_0.cairo b/tests/nodes/reduce_log_sum_fp16x16_export_do_not_keepdims/input_0.cairo new file mode 100644 index 000000000..d8f5ac09d --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp16x16_export_do_not_keepdims/input_0.cairo @@ -0,0 +1,26 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 41943040, sign: false }); + data.append(FP8x23 { mag: 50331648, sign: false }); + data.append(FP8x23 { mag: 58720256, sign: false }); + data.append(FP8x23 { mag: 67108864, sign: false }); + data.append(FP8x23 { mag: 75497472, sign: false }); + data.append(FP8x23 { mag: 83886080, sign: false }); + data.append(FP8x23 { mag: 92274688, sign: false }); + data.append(FP8x23 { mag: 100663296, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/reduce_log_sum_fp16x16_export_do_not_keepdims/output_0.cairo b/tests/nodes/reduce_log_sum_fp16x16_export_do_not_keepdims/output_0.cairo new file mode 100644 index 000000000..4c5cd630b --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp16x16_export_do_not_keepdims/output_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 9215827, sign: false }); + data.append(FP8x23 { mag: 16323477, sign: false }); + data.append(FP8x23 { mag: 20115003, sign: false }); + data.append(FP8x23 { mag: 22716771, sign: false }); + data.append(FP8x23 { mag: 24699744, sign: false }); + data.append(FP8x23 { mag: 26302431, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/reduce_log_sum_fp16x16_export_keepdims.cairo b/tests/nodes/reduce_log_sum_fp16x16_export_keepdims.cairo new file mode 100644 index 000000000..5ee464e1c --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp16x16_export_keepdims.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; + +#[test] +#[available_gas(2000000000)] +fn test_reduce_log_sum_fp16x16_export_keepdims() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.reduce_log_sum(2, true); + + assert_eq(y, z); +} diff --git a/tests/nodes/reduce_log_sum_fp16x16_export_keepdims/input_0.cairo b/tests/nodes/reduce_log_sum_fp16x16_export_keepdims/input_0.cairo new file mode 100644 index 000000000..d8f5ac09d --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp16x16_export_keepdims/input_0.cairo @@ -0,0 +1,26 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 41943040, sign: false }); + data.append(FP8x23 { mag: 50331648, sign: false }); + data.append(FP8x23 { mag: 58720256, sign: false }); + data.append(FP8x23 { mag: 67108864, sign: false }); + data.append(FP8x23 { mag: 75497472, sign: false }); + data.append(FP8x23 { mag: 83886080, sign: false }); + data.append(FP8x23 { mag: 92274688, sign: false }); + data.append(FP8x23 { mag: 100663296, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/reduce_log_sum_fp16x16_export_keepdims/output_0.cairo b/tests/nodes/reduce_log_sum_fp16x16_export_keepdims/output_0.cairo new file mode 100644 index 000000000..39127716e --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp16x16_export_keepdims/output_0.cairo @@ -0,0 +1,20 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 9215827, sign: false }); + data.append(FP8x23 { mag: 16323477, sign: false }); + data.append(FP8x23 { mag: 20115003, sign: false }); + data.append(FP8x23 { mag: 22716771, sign: false }); + data.append(FP8x23 { mag: 24699744, sign: false }); + data.append(FP8x23 { mag: 26302431, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/reduce_log_sum_fp16x16_export_negative_axes_keepdims.cairo b/tests/nodes/reduce_log_sum_fp16x16_export_negative_axes_keepdims.cairo new file mode 100644 index 000000000..7f7fc7f98 --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp16x16_export_negative_axes_keepdims.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; + +#[test] +#[available_gas(2000000000)] +fn test_reduce_log_sum_fp16x16_export_negative_axes_keepdims() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.reduce_log_sum(0, true); + + assert_eq(y, z); +} diff --git a/tests/nodes/reduce_log_sum_fp16x16_export_negative_axes_keepdims/input_0.cairo b/tests/nodes/reduce_log_sum_fp16x16_export_negative_axes_keepdims/input_0.cairo new file mode 100644 index 000000000..068fe1a81 --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp16x16_export_negative_axes_keepdims/input_0.cairo @@ -0,0 +1,22 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 41943040, sign: false }); + data.append(FP8x23 { mag: 50331648, sign: false }); + data.append(FP8x23 { mag: 58720256, sign: false }); + data.append(FP8x23 { mag: 67108864, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/reduce_log_sum_fp16x16_export_negative_axes_keepdims/output_0.cairo b/tests/nodes/reduce_log_sum_fp16x16_export_negative_axes_keepdims/output_0.cairo new file mode 100644 index 000000000..0b3cda1be --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp16x16_export_negative_axes_keepdims/output_0.cairo @@ -0,0 +1,18 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 15030367, sign: false }); + data.append(FP8x23 { mag: 17443619, sign: false }); + data.append(FP8x23 { mag: 19315483, sign: false }); + data.append(FP8x23 { mag: 20844907, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/reduce_log_sum_fp8x23_export_do_not_keepdims.cairo b/tests/nodes/reduce_log_sum_fp8x23_export_do_not_keepdims.cairo new file mode 100644 index 000000000..3f0adf3eb --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp8x23_export_do_not_keepdims.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; + +#[test] +#[available_gas(2000000000)] +fn test_reduce_log_sum_fp8x23_export_do_not_keepdims() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.reduce_log_sum(2, false); + + assert_eq(y, z); +} diff --git a/tests/nodes/reduce_log_sum_fp8x23_export_do_not_keepdims/input_0.cairo b/tests/nodes/reduce_log_sum_fp8x23_export_do_not_keepdims/input_0.cairo new file mode 100644 index 000000000..d8f5ac09d --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp8x23_export_do_not_keepdims/input_0.cairo @@ -0,0 +1,26 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 41943040, sign: false }); + data.append(FP8x23 { mag: 50331648, sign: false }); + data.append(FP8x23 { mag: 58720256, sign: false }); + data.append(FP8x23 { mag: 67108864, sign: false }); + data.append(FP8x23 { mag: 75497472, sign: false }); + data.append(FP8x23 { mag: 83886080, sign: false }); + data.append(FP8x23 { mag: 92274688, sign: false }); + data.append(FP8x23 { mag: 100663296, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/reduce_log_sum_fp8x23_export_do_not_keepdims/output_0.cairo b/tests/nodes/reduce_log_sum_fp8x23_export_do_not_keepdims/output_0.cairo new file mode 100644 index 000000000..3ee433b31 --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp8x23_export_do_not_keepdims/output_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 9215828, sign: false }); + data.append(FP8x23 { mag: 16323477, sign: false }); + data.append(FP8x23 { mag: 20115004, sign: false }); + data.append(FP8x23 { mag: 22716772, sign: false }); + data.append(FP8x23 { mag: 24699744, sign: false }); + data.append(FP8x23 { mag: 26302432, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/reduce_log_sum_fp8x23_export_keepdims.cairo b/tests/nodes/reduce_log_sum_fp8x23_export_keepdims.cairo new file mode 100644 index 000000000..5662f1510 --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp8x23_export_keepdims.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; + +#[test] +#[available_gas(2000000000)] +fn test_reduce_log_sum_fp8x23_export_keepdims() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.reduce_log_sum(2, true); + + assert_eq(y, z); +} diff --git a/tests/nodes/reduce_log_sum_fp8x23_export_keepdims/input_0.cairo b/tests/nodes/reduce_log_sum_fp8x23_export_keepdims/input_0.cairo new file mode 100644 index 000000000..d8f5ac09d --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp8x23_export_keepdims/input_0.cairo @@ -0,0 +1,26 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 41943040, sign: false }); + data.append(FP8x23 { mag: 50331648, sign: false }); + data.append(FP8x23 { mag: 58720256, sign: false }); + data.append(FP8x23 { mag: 67108864, sign: false }); + data.append(FP8x23 { mag: 75497472, sign: false }); + data.append(FP8x23 { mag: 83886080, sign: false }); + data.append(FP8x23 { mag: 92274688, sign: false }); + data.append(FP8x23 { mag: 100663296, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/reduce_log_sum_fp8x23_export_keepdims/output_0.cairo b/tests/nodes/reduce_log_sum_fp8x23_export_keepdims/output_0.cairo new file mode 100644 index 000000000..39127716e --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp8x23_export_keepdims/output_0.cairo @@ -0,0 +1,20 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 9215827, sign: false }); + data.append(FP8x23 { mag: 16323477, sign: false }); + data.append(FP8x23 { mag: 20115003, sign: false }); + data.append(FP8x23 { mag: 22716771, sign: false }); + data.append(FP8x23 { mag: 24699744, sign: false }); + data.append(FP8x23 { mag: 26302431, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/reduce_log_sum_fp8x23_export_negative_axes_keepdims.cairo b/tests/nodes/reduce_log_sum_fp8x23_export_negative_axes_keepdims.cairo new file mode 100644 index 000000000..ec295a396 --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp8x23_export_negative_axes_keepdims.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; + +#[test] +#[available_gas(2000000000)] +fn test_reduce_log_sum_fp8x23_export_negative_axes_keepdims() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.reduce_log_sum(0, true); + + assert_eq(y, z); +} diff --git a/tests/nodes/reduce_log_sum_fp8x23_export_negative_axes_keepdims/input_0.cairo b/tests/nodes/reduce_log_sum_fp8x23_export_negative_axes_keepdims/input_0.cairo new file mode 100644 index 000000000..7db081374 --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp8x23_export_negative_axes_keepdims/input_0.cairo @@ -0,0 +1,41 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 41943040, sign: false }); + data.append(FP8x23 { mag: 50331648, sign: false }); + data.append(FP8x23 { mag: 58720256, sign: false }); + data.append(FP8x23 { mag: 67108864, sign: false }); + data.append(FP8x23 { mag: 75497472, sign: false }); + data.append(FP8x23 { mag: 83886080, sign: false }); + data.append(FP8x23 { mag: 92274688, sign: false }); + data.append(FP8x23 { mag: 100663296, sign: false }); + data.append(FP8x23 { mag: 109051904, sign: false }); + data.append(FP8x23 { mag: 117440512, sign: false }); + data.append(FP8x23 { mag: 125829120, sign: false }); + data.append(FP8x23 { mag: 134217728, sign: false }); + data.append(FP8x23 { mag: 142606336, sign: false }); + data.append(FP8x23 { mag: 150994944, sign: false }); + data.append(FP8x23 { mag: 159383552, sign: false }); + data.append(FP8x23 { mag: 167772160, sign: false }); + data.append(FP8x23 { mag: 176160768, sign: false }); + data.append(FP8x23 { mag: 184549376, sign: false }); + data.append(FP8x23 { mag: 192937984, sign: false }); + data.append(FP8x23 { mag: 201326592, sign: false }); + data.append(FP8x23 { mag: 209715200, sign: false }); + data.append(FP8x23 { mag: 218103808, sign: false }); + data.append(FP8x23 { mag: 226492416, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/reduce_log_sum_fp8x23_export_negative_axes_keepdims/output_0.cairo b/tests/nodes/reduce_log_sum_fp8x23_export_negative_axes_keepdims/output_0.cairo new file mode 100644 index 000000000..72c608c25 --- /dev/null +++ b/tests/nodes/reduce_log_sum_fp8x23_export_negative_axes_keepdims/output_0.cairo @@ -0,0 +1,23 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 28531311, sign: false }); + data.append(FP8x23 { mag: 29330831, sign: false }); + data.append(FP8x23 { mag: 30060735, sign: false }); + data.append(FP8x23 { mag: 30732182, sign: false }); + data.append(FP8x23 { mag: 31353845, sign: false }); + data.append(FP8x23 { mag: 31932599, sign: false }); + data.append(FP8x23 { mag: 32473987, sign: false }); + data.append(FP8x23 { mag: 32982543, sign: false }); + data.append(FP8x23 { mag: 33462023, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/numbers.cairo b/tests/numbers.cairo index 91ae91e96..9a65c3b28 100644 --- a/tests/numbers.cairo +++ b/tests/numbers.cairo @@ -1,4 +1,4 @@ -// mod fixed_point; +mod fixed_point_test; mod signed_integer_test; mod complex_number_test; diff --git a/tests/numbers/fixed_point_test.cairo b/tests/numbers/fixed_point_test.cairo new file mode 100644 index 000000000..78cac4259 --- /dev/null +++ b/tests/numbers/fixed_point_test.cairo @@ -0,0 +1,6 @@ +mod erf_fp16x16_test; +mod erf_fp16x16wide_test; +mod erf_fp8x23_test; +mod erf_fp8x23wide_test; +mod erf_fp32x32_test; +mod erf_fp64x64_test; diff --git a/tests/numbers/fixed_point_test/erf_fp16x16_test.cairo b/tests/numbers/fixed_point_test/erf_fp16x16_test.cairo new file mode 100644 index 000000000..373189492 --- /dev/null +++ b/tests/numbers/fixed_point_test/erf_fp16x16_test.cairo @@ -0,0 +1,33 @@ +use orion::numbers::fixed_point::implementations::fp16x16::math::erf::erf; +use orion::numbers::fixed_point::implementations::fp16x16::core::{ONE, FP16x16, FixedTrait}; +use core::debug::PrintTrait; +#[test] +#[available_gas(1000000000)] +fn test_erf() { + // 1.0 + let f1: FP16x16 = FP16x16 { mag: 65536, sign: false }; + // 0.134 + let f2: FP16x16 = FP16x16 { mag: 8832, sign: false }; + // 0.520 + let f3: FP16x16 = FP16x16 { mag: 34078, sign: false }; + // 2.0 + let f4: FP16x16 = FP16x16 { mag: 131072, sign: false }; + // 3.5 + let f5: FP16x16 = FP16x16 { mag: 229376, sign: false }; + // 5.164 + let f6: FP16x16 = FP16x16 { mag: 338428, sign: false }; + + let f1_erf: FP16x16 = erf(f1); + let f2_erf: FP16x16 = erf(f2); + let f3_erf: FP16x16 = erf(f3); + let f4_erf: FP16x16 = erf(f4); + let f5_erf: FP16x16 = erf(f5); + let f6_erf: FP16x16 = erf(f6); + + assert(f1_erf.mag == 55227, 'f1_erf it works!'); + assert(f2_erf.mag == 10285, 'f2_erf it works!'); + assert(f3_erf.mag == 35251, 'f3_erf it works!'); + assert(f4_erf.mag == 65229, 'f4_erf it works!'); + assert(f5_erf.mag == 65536, 'f5_erf it works!'); + assert(f6_erf.mag == 65536, 'f6_erf it works!'); +} diff --git a/tests/numbers/fixed_point_test/erf_fp16x16wide_test.cairo b/tests/numbers/fixed_point_test/erf_fp16x16wide_test.cairo new file mode 100644 index 000000000..917364e33 --- /dev/null +++ b/tests/numbers/fixed_point_test/erf_fp16x16wide_test.cairo @@ -0,0 +1,33 @@ +use orion::numbers::fixed_point::implementations::fp16x16wide::math::erf::erf; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ONE, FP16x16W, FixedTrait}; +use core::debug::PrintTrait; +#[test] +#[available_gas(1000000000)] +fn test_erf() { + // 1.0 + let f1: FP16x16W = FP16x16W { mag: 65536, sign: false }; + // 0.134 + let f2: FP16x16W = FP16x16W { mag: 8832, sign: false }; + // 0.520 + let f3: FP16x16W = FP16x16W { mag: 34078, sign: false }; + // 2.0 + let f4: FP16x16W = FP16x16W { mag: 131072, sign: false }; + // 3.5 + let f5: FP16x16W = FP16x16W { mag: 229376, sign: false }; + // 5.164 + let f6: FP16x16W = FP16x16W { mag: 338428, sign: false }; + + let f1_erf: FP16x16W = erf(f1); + let f2_erf: FP16x16W = erf(f2); + let f3_erf: FP16x16W = erf(f3); + let f4_erf: FP16x16W = erf(f4); + let f5_erf: FP16x16W = erf(f5); + let f6_erf: FP16x16W = erf(f6); + + assert(f1_erf.mag == 55227, 'f1_erf it works!'); + assert(f2_erf.mag == 10285, 'f2_erf it works!'); + assert(f3_erf.mag == 35251, 'f3_erf it works!'); + assert(f4_erf.mag == 65229, 'f4_erf it works!'); + assert(f5_erf.mag == 65536, 'f5_erf it works!'); + assert(f6_erf.mag == 65536, 'f6_erf it works!'); +} diff --git a/tests/numbers/fixed_point_test/erf_fp32x32_test.cairo b/tests/numbers/fixed_point_test/erf_fp32x32_test.cairo new file mode 100644 index 000000000..eb21129f6 --- /dev/null +++ b/tests/numbers/fixed_point_test/erf_fp32x32_test.cairo @@ -0,0 +1,32 @@ +use orion::numbers::fixed_point::implementations::fp32x32::erf::erf; +use orion::numbers::fixed_point::implementations::fp32x32::core::{ONE, FP32x32, FixedTrait}; +use core::debug::PrintTrait; +#[test] +#[available_gas(1000000000)] +fn test_erf() { + // 1.0 + let f1: FP32x32 = FP32x32 { mag: 4294967296, sign: false }; + // 0.134 + let f2: FP32x32 = FP32x32 { mag: 575525618, sign: false }; + // 0.520 + let f3: FP32x32 = FP32x32 { mag: 2233382993, sign: false }; + // 2.0 + let f4: FP32x32 = FP32x32 { mag: 8589934592, sign: false }; + // 3.5 + let f5: FP32x32 = FP32x32 { mag: 15032385536, sign: false }; + // 5.164 + let f6: FP32x32 = FP32x32 { mag: 22179211117, sign: false }; + + let f1_erf: FP32x32 = erf(f1); + let f2_erf: FP32x32 = erf(f2); + let f3_erf: FP32x32 = erf(f3); + let f4_erf: FP32x32 = erf(f4); + let f5_erf: FP32x32 = erf(f5); + let f6_erf: FP32x32 = erf(f6); + assert(f1_erf.mag == 3619372346, 'f1_erf it works!'); + assert(f2_erf.mag == 674082374, 'f2_erf it works!'); + assert(f3_erf.mag == 2310257026, 'f3_erf it works!'); + assert(f4_erf.mag == 4274876577, 'f4_erf it works!'); + assert(f5_erf.mag == 4294967296, 'f5_erf it works!'); + assert(f6_erf.mag == 4294967296, 'f6_erf it works!'); +} diff --git a/tests/numbers/fixed_point_test/erf_fp64x64_test.cairo b/tests/numbers/fixed_point_test/erf_fp64x64_test.cairo new file mode 100644 index 000000000..973fdd953 --- /dev/null +++ b/tests/numbers/fixed_point_test/erf_fp64x64_test.cairo @@ -0,0 +1,32 @@ +use orion::numbers::fixed_point::implementations::fp64x64::erf::erf; +use orion::numbers::fixed_point::implementations::fp64x64::core::{ONE, FP64x64, FixedTrait}; +use core::debug::PrintTrait; +#[test] +#[available_gas(1000000000)] +fn test_erf() { + // 1.0 + let f1: FP64x64 = FP64x64 { mag: 18446744073709551616_u128, sign: false }; + // 0.134 + let f2: FP64x64 = FP64x64 { mag: 2471863705877080064_u128, sign: false }; + // 0.520 + let f3: FP64x64 = FP64x64 { mag: 9592306918328967168_u128, sign: false }; + // 2.0 + let f4: FP64x64 = FP64x64 { mag: 36893488147419103232_u128, sign: false }; + // 3.5 + let f5: FP64x64 = FP64x64 { mag: 64563604257983430656_u128, sign: false }; + // 5.164 + let f6: FP64x64 = FP64x64 { mag: 95258986396636119040_u128, sign: false }; + + let f1_erf: FP64x64 = erf(f1); + let f2_erf: FP64x64 = erf(f2); + let f3_erf: FP64x64 = erf(f3); + let f4_erf: FP64x64 = erf(f4); + let f5_erf: FP64x64 = erf(f5); + let f6_erf: FP64x64 = erf(f6); + assert(f1_erf.mag == 15545085858255493120_u128, 'f1_erf it works!'); + assert(f2_erf.mag == 2895161752038532608_u128, 'f2_erf it works!'); + assert(f3_erf.mag == 9922478374042292224_u128, 'f3_erf it works!'); + assert(f4_erf.mag == 18360455093669533696_u128, 'f4_erf it works!'); + assert(f5_erf.mag == 18446744073709551616_u128, 'f5_erf it works!'); + assert(f6_erf.mag == 18446744073709551616_u128, 'f6_erf it works!'); +} diff --git a/tests/numbers/fixed_point_test/erf_fp8x23_test.cairo b/tests/numbers/fixed_point_test/erf_fp8x23_test.cairo new file mode 100644 index 000000000..02053b35c --- /dev/null +++ b/tests/numbers/fixed_point_test/erf_fp8x23_test.cairo @@ -0,0 +1,33 @@ +use orion::numbers::fixed_point::implementations::fp8x23::math::erf::erf; +use orion::numbers::fixed_point::implementations::fp8x23::core::{ONE, FP8x23, FixedTrait}; +use core::debug::PrintTrait; +#[test] +#[available_gas(1000000000)] +fn test_erf() { + // 1.0 + let f1: FP8x23 = FP8x23 { mag: 8388608, sign: false }; + // 0.134 + let f2: FP8x23 = FP8x23 { mag: 1124073, sign: false }; + // 0.520 + let f3: FP8x23 = FP8x23 { mag: 4362076, sign: false }; + // 2.0 + let f4: FP8x23 = FP8x23 { mag: 16777216, sign: false }; + // 3.5 + let f5: FP8x23 = FP8x23 { mag: 29360128, sign: false }; + // 5.164 + let f6: FP8x23 = FP8x23 { mag: 43318772, sign: false }; + + let f1_erf: FP8x23 = erf(f1); + let f2_erf: FP8x23 = erf(f2); + let f3_erf: FP8x23 = erf(f3); + let f4_erf: FP8x23 = erf(f4); + let f5_erf: FP8x23 = erf(f5); + let f6_erf: FP8x23 = erf(f6); + + assert(f1_erf.mag == 7069086, 'f1_erf it works!'); + assert(f2_erf.mag == 1316567, 'f2_erf it works!'); + assert(f3_erf.mag == 4512220, 'f3_erf it works!'); + assert(f4_erf.mag == 8349368, 'f4_erf it works!'); + assert(f5_erf.mag == 8388608, 'f5_erf it works!'); + assert(f6_erf.mag == 8388608, 'f6_erf it works!'); +} diff --git a/tests/numbers/fixed_point_test/erf_fp8x23wide_test.cairo b/tests/numbers/fixed_point_test/erf_fp8x23wide_test.cairo new file mode 100644 index 000000000..be361b7d6 --- /dev/null +++ b/tests/numbers/fixed_point_test/erf_fp8x23wide_test.cairo @@ -0,0 +1,33 @@ +use orion::numbers::fixed_point::implementations::fp8x23wide::math::erf::erf; +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{ONE, FP8x23W, FixedTrait}; +use core::debug::PrintTrait; +#[test] +#[available_gas(1000000000)] +fn test_erf() { + // 1.0 + let f1: FP8x23W = FP8x23W { mag: 8388608, sign: false }; + // 0.134 + let f2: FP8x23W = FP8x23W { mag: 1124073, sign: false }; + // 0.520 + let f3: FP8x23W = FP8x23W { mag: 4362076, sign: false }; + // 2.0 + let f4: FP8x23W = FP8x23W { mag: 16777216, sign: false }; + // 3.5 + let f5: FP8x23W = FP8x23W { mag: 29360128, sign: false }; + // 5.164 + let f6: FP8x23W = FP8x23W { mag: 43318772, sign: false }; + + let f1_erf: FP8x23W = erf(f1); + let f2_erf: FP8x23W = erf(f2); + let f3_erf: FP8x23W = erf(f3); + let f4_erf: FP8x23W = erf(f4); + let f5_erf: FP8x23W = erf(f5); + let f6_erf: FP8x23W = erf(f6); + + assert(f1_erf.mag == 7069086, 'f1_erf it works!'); + assert(f2_erf.mag == 1316567, 'f2_erf it works!'); + assert(f3_erf.mag == 4512220, 'f3_erf it works!'); + assert(f4_erf.mag == 8349368, 'f4_erf it works!'); + assert(f5_erf.mag == 8388608, 'f5_erf it works!'); + assert(f6_erf.mag == 8388608, 'f6_erf it works!'); +}