Skip to content

Commit

Permalink
implement tensor_fp8x23wide
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Oct 23, 2023
1 parent f79e0f5 commit 8a55e5c
Show file tree
Hide file tree
Showing 5 changed files with 568 additions and 3 deletions.
165 changes: 165 additions & 0 deletions src/numbers.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,171 @@ impl FP8x23Number of NumberTrait<FP8x23, u32> {
}
}

use orion::numbers::fixed_point::implementations::fp8x23wide::core::{FP8x23WImpl, FP8x23W};
use orion::numbers::fixed_point::implementations::fp8x23wide::math::core as core_fp8x23wide;
use orion::numbers::fixed_point::implementations::fp8x23wide::math::comp as comp_fp8x23wide;

impl FP8x23WNumber of NumberTrait<FP8x23W, u64> {
fn new(mag: u64, sign: bool) -> FP8x23W {
FP8x23WImpl::new(mag, sign)
}

fn new_unscaled(mag: u64, sign: bool) -> FP8x23W {
FP8x23WImpl::new_unscaled(mag, sign)
}

fn from_felt(val: felt252) -> FP8x23W {
FP8x23WImpl::from_felt(val)
}

fn ceil(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::ceil(self)
}

fn exp(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::exp(self)
}

fn exp2(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::exp2(self)
}

fn floor(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::floor(self)
}

fn ln(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::ln(self)
}

fn log2(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::log2(self)
}

fn log10(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::log10(self)
}

fn pow(self: FP8x23W, b: FP8x23W) -> FP8x23W {
FP8x23WImpl::pow(self, b)
}

fn round(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::round(self)
}

fn sqrt(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::sqrt(self)
}

fn acos(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::acos(self)
}

fn asin(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::asin(self)
}

fn atan(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::atan(self)
}

fn cos(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::cos(self)
}

fn sin(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::sin(self)
}

fn tan(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::tan(self)
}

fn acosh(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::acosh(self)
}

fn asinh(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::asinh(self)
}

fn atanh(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::atanh(self)
}

fn cosh(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::cosh(self)
}

fn sinh(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::sinh(self)
}

fn tanh(self: FP8x23W) -> FP8x23W {
FP8x23WImpl::tanh(self)
}

fn zero() -> FP8x23W {
FP8x23WImpl::ZERO()
}
fn is_zero(self: FP8x23W) -> bool {
core_fp8x23wide::eq(@self, @FP8x23WImpl::ZERO())
}

fn one() -> FP8x23W {
FP8x23WImpl::ONE()
}

fn neg_one() -> FP8x23W {
FP8x23W { mag: core_fp8x23wide::ONE, sign: true }
}

fn is_one(self: FP8x23W) -> bool {
core_fp8x23wide::eq(@self, @FP8x23WImpl::ONE())
}

fn abs(self: FP8x23W) -> FP8x23W {
core_fp8x23wide::abs(self)
}

fn min_value() -> FP8x23W {
FP8x23W { mag: core_fp8x23wide::MAX, sign: true }
}

fn max_value() -> FP8x23W {
FP8x23W { mag: core_fp8x23wide::MAX, sign: false }
}

fn min(self: FP8x23W, other: FP8x23W) -> FP8x23W {
comp_fp8x23wide::min(self, other)
}

fn max(self: FP8x23W, other: FP8x23W) -> FP8x23W {
comp_fp8x23wide::max(self, other)
}

fn mag(self: FP8x23W) -> u64 {
self.mag
}

fn is_neg(self: FP8x23W) -> bool {
self.sign
}

fn xor(lhs: FP8x23W, rhs: FP8x23W) -> bool {
comp_fp8x23wide::xor(lhs, rhs)
}

fn or(lhs: FP8x23W, rhs: FP8x23W) -> bool {
comp_fp8x23wide::or(lhs, rhs)
}

fn sign(self: FP8x23W) -> FP8x23W {
core_fp8x23wide::sign(self)
}
}

use orion::numbers::fixed_point::implementations::fp16x16::core::{FP16x16Impl, FP16x16};
use orion::numbers::fixed_point::implementations::fp16x16::math::core as core_fp16x16;
use orion::numbers::fixed_point::implementations::fp16x16::math::comp as comp_fp16x16;
Expand Down
21 changes: 20 additions & 1 deletion src/numbers/fixed_point/implementations/fp8x23wide/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use result::{ResultTrait, ResultTraitImpl};
use traits::{TryInto, Into};

use orion::numbers::signed_integer::{i32::i32, i8::i8};
use orion::numbers::fixed_point::core::{FixedTrait};
use orion::numbers::{fixed_point::core::{FixedTrait}, FP8x23};
use orion::numbers::fixed_point::implementations::fp8x23wide::math::{core, trig, hyp};
use orion::numbers::fixed_point::utils;

Expand Down Expand Up @@ -205,6 +205,25 @@ impl FP8x23WIntoFelt252 of Into<FP8x23W, felt252> {
}
}

impl FP8x23IntoFP8x23W of Into<FP8x23, FP8x23W> {
fn into(self: FP8x23) -> FP8x23W {
FP8x23W { mag: self.mag.into(), sign: self.sign }
}
}

impl FP8x23WTryIntoFP8x23 of TryInto<FP8x23W, FP8x23> {
fn try_into(self: FP8x23W) -> Option<FP8x23> {
match self.mag.try_into() {
Option::Some(val) => {
Option::Some(FP8x23 { mag: val, sign: self.sign })
},
Option::None(_) => {
Option::None(())
}
}
}
}

impl FP8x23WTryIntoU128 of TryInto<FP8x23W, u128> {
fn try_into(self: FP8x23W) -> Option<u128> {
if self.sign {
Expand Down
6 changes: 5 additions & 1 deletion src/operators/nn/implementations/nn_fp8x23.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ use orion::numbers::fixed_point::implementations::fp8x23::core::FP8x23;
use orion::operators::tensor::implementations::tensor_fp8x23::{
FP8x23Tensor, FP8x23TensorDiv, FP8x23TensorAdd
};
use orion::numbers::fixed_point::implementations::fp8x23wide::core::{
FP8x23WImpl, FP8x23WTryIntoFP8x23, FP8x23W, FP8x23IntoFP8x23W
};
use orion::operators::tensor::implementations::tensor_fp8x23wide::{FP8x23WTensor};

impl FP8x23NN of NNTrait<FP8x23> {
fn relu(tensor: @Tensor<FP8x23>) -> Tensor<FP8x23> {
Expand All @@ -18,7 +22,7 @@ impl FP8x23NN of NNTrait<FP8x23> {
}

fn softmax(tensor: @Tensor<FP8x23>, axis: usize) -> Tensor<FP8x23> {
functional::softmax::softmax(tensor, axis)
functional::softmax::softmaxWide::<FP8x23, u32, FP8x23W, u64>(tensor, axis)
}

fn logsoftmax(tensor: @Tensor<FP8x23>, axis: usize) -> Tensor<FP8x23> {
Expand Down
3 changes: 2 additions & 1 deletion src/operators/tensor/implementations.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ mod tensor_fp8x23;
mod tensor_fp16x16;
mod tensor_fp64x64;
mod tensor_fp32x32;
mod tensor_fp16x16wide;
mod tensor_fp16x16wide;
mod tensor_fp8x23wide;
Loading

0 comments on commit 8a55e5c

Please sign in to comment.