Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Oct 23, 2023
1 parent 5267e35 commit 8b2276b
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool {

// Calculates the natural exponent of x: e^x
fn exp(a: FP16x16W) -> FP16x16W {
// a.sign.print();
return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203
}

Expand Down
30 changes: 2 additions & 28 deletions src/operators/nn/functional/softmax.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ fn softmax<
) -> Tensor<T> {
let exp_tensor = z.exp();
let sum = exp_tensor.reduce_sum(axis, true);
let softmax = exp_tensor / sum;

return softmax;
exp_tensor / sum
}

/// Cf: NNTrait::softmax docstring
Expand All @@ -37,35 +35,11 @@ fn softmaxWide<
impl WDrop: Drop<W>,
impl TFixed: FixedTrait<T, TMAG>,
impl WFixed: FixedTrait<W, WMAG>,
impl TPrint: PrintTrait<T>,
impl WPrint: PrintTrait<W>
>(
z: @Tensor<T>, axis: usize
) -> Tensor<T> {
let exp_tensor: Tensor<W> = exp_upcast(*z);
let sum = exp_tensor.reduce_sum(axis, true);
let softmax: Tensor<T> = div_downcast(@exp_tensor, @sum);

return softmax;
div_downcast(@exp_tensor, @sum)
}

use orion::numbers::{FP16x16, FP16x16W};
use orion::operators::tensor::{
implementations::tensor_fp16x16wide::{FP16x16WTensor, FP16x16WTensorDiv}, FP16x16Tensor
};
use debug::PrintTrait;

/// Cf: NNTrait::softmax docstring
fn softmaxWide2(z: @Tensor<FP16x16>, axis: usize) -> Tensor<FP16x16> {
let exp_tensor: Tensor<FP16x16W> = exp_upcast(*z);

(*(*z).data.at(0)).print();
(*(*z).data.at(1)).print();
(*(*z).data.at(2)).print();

// let sum = exp_tensor.reduce_sum(axis, true);
// (*sum.data.at(0)).print();
// let softmax = exp_tensor / sum;
// return exp_tensor;
*z
}
2 changes: 0 additions & 2 deletions src/operators/tensor/math/exp.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ fn exp_upcast<
impl WCopy: Copy<W>,
impl WDrop: Drop<W>,
impl TIntoW: Into<T, W>,
impl TPrint: PrintTrait<T>,
impl WPrint: PrintTrait<W>
>(
mut self: Tensor<T>
) -> Tensor<W> {
Expand Down

0 comments on commit 8b2276b

Please sign in to comment.