Skip to content

Commit

Permalink
Merge pull request #389 from gizatechxyz/fix-overflow
Browse files Browse the repository at this point in the history
Bug: Fix overflow with Softmax - Logsoftmax
  • Loading branch information
raphaelDkhn authored Oct 23, 2023
2 parents 7ab41bc + 19d5e99 commit 89d637c
Show file tree
Hide file tree
Showing 29 changed files with 7,271 additions and 7 deletions.
330 changes: 330 additions & 0 deletions src/numbers.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,171 @@ impl FP8x23Number of NumberTrait<FP8x23, u32> {
}
}

use orion::numbers::fixed_point::implementations::fp8x23wide::core::{FP8x23WImpl, FP8x23W};
use orion::numbers::fixed_point::implementations::fp8x23wide::math::core as core_fp8x23wide;
use orion::numbers::fixed_point::implementations::fp8x23wide::math::comp as comp_fp8x23wide;

impl FP8x23WNumber of NumberTrait<FP8x23W, u64> {
fn new(mag: u64, sign: bool) -> FP8x23W {
FP8x23WImpl::new(mag, sign)
}

fn new_unscaled(mag: u64, sign: bool) -> FP8x23W {
FP8x23WImpl::new_unscaled(mag, sign)
}

fn from_felt(val: felt252) -> FP8x23W {
FP8x23WImpl::from_felt(val)
}

fn ceil(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::ceil(self)
}

fn exp(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::exp(self)
}

fn exp2(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::exp2(self)
}

fn floor(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::floor(self)
}

fn ln(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::ln(self)
}

fn log2(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::log2(self)
}

fn log10(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::log10(self)
}

fn pow(self: FP8x23W, b: FP8x23W) -> FP8x23W {
FP8x23WImpl::pow(self, b)
}

fn round(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::round(self)
}

fn sqrt(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::sqrt(self)
}

fn acos(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::acos(self)
}

fn asin(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::asin(self)
}

fn atan(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::atan(self)
}

fn cos(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::cos(self)
}

fn sin(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::sin(self)
}

fn tan(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::tan(self)
}

fn acosh(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::acosh(self)
}

fn asinh(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::asinh(self)
}

fn atanh(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::atanh(self)
}

fn cosh(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::cosh(self)
}

fn sinh(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::sinh(self)
}

fn tanh(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::tanh(self)
}

fn zero() -> FP8x23W {
FP8x23WImpl::ZERO()
}
fn is_zero(self: FP8x23W) -> bool {
core_fp8x23wide::eq(@self, @FP8x23WImpl::ZERO())
}

fn one() -> FP8x23W {
FP8x23WImpl::ONE()
}

fn neg_one() -> FP8x23W {
FP8x23W { mag: core_fp8x23wide::ONE, sign: true }
}

fn is_one(self: FP8x23W) -> bool {
core_fp8x23wide::eq(@self, @FP8x23WImpl::ONE())
}

fn abs(self: FP8x23W) -> FP8x23W {
core_fp8x23wide::abs(self)
}

fn min_value() -> FP8x23W {
FP8x23W { mag: core_fp8x23wide::MAX, sign: true }
}

fn max_value() -> FP8x23W {
FP8x23W { mag: core_fp8x23wide::MAX, sign: false }
}

fn min(self: FP8x23W, other: FP8x23W) -> FP8x23W {
comp_fp8x23wide::min(self, other)
}

fn max(self: FP8x23W, other: FP8x23W) -> FP8x23W {
comp_fp8x23wide::max(self, other)
}

fn mag(self: FP8x23W) -> u64 {
self.mag
}

fn is_neg(self: FP8x23W) -> bool {
self.sign
}

fn xor(lhs: FP8x23W, rhs: FP8x23W) -> bool {
comp_fp8x23wide::xor(lhs, rhs)
}

fn or(lhs: FP8x23W, rhs: FP8x23W) -> bool {
comp_fp8x23wide::or(lhs, rhs)
}

fn sign(self: FP8x23W) -> FP8x23W {
core_fp8x23wide::sign(self)
}
}

use orion::numbers::fixed_point::implementations::fp16x16::core::{FP16x16Impl, FP16x16};
use orion::numbers::fixed_point::implementations::fp16x16::math::core as core_fp16x16;
use orion::numbers::fixed_point::implementations::fp16x16::math::comp as comp_fp16x16;
Expand Down Expand Up @@ -378,6 +543,171 @@ impl FP16x16Number of NumberTrait<FP16x16, u32> {
}
}

use orion::numbers::fixed_point::implementations::fp16x16wide::core::{FP16x16WImpl, FP16x16W};
use orion::numbers::fixed_point::implementations::fp16x16wide::math::core as core_fp16x16wide;
use orion::numbers::fixed_point::implementations::fp16x16wide::math::comp as comp_fp16x16wide;

impl FP16x16WNumber of NumberTrait<FP16x16W, u64> {
fn new(mag: u64, sign: bool) -> FP16x16W {
FP16x16WImpl::new(mag, sign)
}

fn new_unscaled(mag: u64, sign: bool) -> FP16x16W {
FP16x16WImpl::new_unscaled(mag, sign)
}

fn from_felt(val: felt252) -> FP16x16W {
FP16x16WImpl::from_felt(val)
}

fn ceil(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::ceil(self)
}

fn exp(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::exp(self)
}

fn exp2(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::exp2(self)
}

fn floor(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::floor(self)
}

fn ln(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::ln(self)
}

fn log2(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::log2(self)
}

fn log10(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::log10(self)
}

fn pow(self: FP16x16W, b: FP16x16W) -> FP16x16W {
FP16x16WImpl::pow(self, b)
}

fn round(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::round(self)
}

fn sqrt(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::sqrt(self)
}

fn acos(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::acos(self)
}

fn asin(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::asin(self)
}

fn atan(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::atan(self)
}

fn cos(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::cos(self)
}

fn sin(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::sin(self)
}

fn tan(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::tan(self)
}

fn acosh(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::acosh(self)
}

fn asinh(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::asinh(self)
}

fn atanh(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::atanh(self)
}

fn cosh(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::cosh(self)
}

fn sinh(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::sinh(self)
}

fn tanh(self: FP16x16W) -> FP16x16W {
FP16x16WImpl::tanh(self)
}

fn zero() -> FP16x16W {
FP16x16WImpl::ZERO()
}
fn is_zero(self: FP16x16W) -> bool {
core_fp16x16wide::eq(@self, @FP16x16WImpl::ZERO())
}

fn one() -> FP16x16W {
FP16x16WImpl::ONE()
}

fn neg_one() -> FP16x16W {
FP16x16W { mag: core_fp16x16wide::ONE, sign: true }
}

fn is_one(self: FP16x16W) -> bool {
core_fp16x16wide::eq(@self, @FP16x16WImpl::ONE())
}

fn abs(self: FP16x16W) -> FP16x16W {
core_fp16x16wide::abs(self)
}

fn min_value() -> FP16x16W {
FP16x16W { mag: core_fp16x16wide::MAX, sign: true }
}

fn max_value() -> FP16x16W {
FP16x16W { mag: core_fp16x16wide::MAX, sign: false }
}

fn min(self: FP16x16W, other: FP16x16W) -> FP16x16W {
comp_fp16x16wide::min(self, other)
}

fn max(self: FP16x16W, other: FP16x16W) -> FP16x16W {
comp_fp16x16wide::max(self, other)
}

fn mag(self: FP16x16W) -> u64 {
self.mag
}

fn is_neg(self: FP16x16W) -> bool {
self.sign
}

fn xor(lhs: FP16x16W, rhs: FP16x16W) -> bool {
comp_fp16x16wide::xor(lhs, rhs)
}

fn or(lhs: FP16x16W, rhs: FP16x16W) -> bool {
comp_fp16x16wide::or(lhs, rhs)
}

fn sign(self: FP16x16W) -> FP16x16W {
core_fp16x16wide::sign(self)
}
}

use orion::numbers::fixed_point::implementations::fp64x64::core::{FP64x64Impl, FP64x64};
use orion::numbers::fixed_point::implementations::fp64x64::core as core_fp64x64;
use orion::numbers::fixed_point::implementations::fp64x64::comp as comp_fp64x64;
Expand Down
2 changes: 2 additions & 0 deletions src/numbers/fixed_point/implementations.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ mod fp8x23;
mod fp16x16;
mod fp64x64;
mod fp32x32;
mod fp16x16wide;
mod fp8x23wide;
3 changes: 3 additions & 0 deletions src/numbers/fixed_point/implementations/fp16x16wide.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod core;
mod math;
mod helpers;
Loading

0 comments on commit 89d637c

Please sign in to comment.