Skip to content

Commit

Permalink
add logsoftmaxwide
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Oct 23, 2023
1 parent 8a55e5c commit 19d5e99
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
28 changes: 28 additions & 0 deletions src/operators/nn/functional/logsoftmax.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use array::SpanTrait;

use orion::numbers::NumberTrait;
use orion::operators::tensor::core::{Tensor, TensorTrait};
use orion::numbers::fixed_point::core::FixedTrait;
use orion::operators::tensor::math::{exp::exp_upcast, arithmetic::div_downcast};

/// Cf: NNTrait::logsoftmax docstring
fn logsoftmax<
Expand All @@ -16,3 +18,29 @@ fn logsoftmax<

return logsoftmax;
}

/// Cf: NNTrait::logsoftmax docstring
fn logsoftmaxWide<
T,
TMAG,
W,
WMAG,
impl TTensor: TensorTrait<T>,
impl WTensor: TensorTrait<W>,
impl TDiv: Div<T>,
impl TIntoW: Into<T, W>,
impl WTryIntoT: TryInto<W, T>,
impl TCopy: Copy<T>,
impl TDrop: Drop<T>,
impl WCopy: Copy<W>,
impl WDrop: Drop<W>,
impl TFixed: FixedTrait<T, TMAG>,
impl WFixed: FixedTrait<W, WMAG>,
>(
z: @Tensor<T>, axis: usize
) -> Tensor<T> {
let exp_tensor: Tensor<W> = exp_upcast(*z);
let sum = exp_tensor.reduce_sum(axis, true);
let softmax = div_downcast(@exp_tensor, @sum);
softmax.log()
}
2 changes: 1 addition & 1 deletion src/operators/nn/implementations/nn_fp16x16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl FP16x16NN of NNTrait<FP16x16> {
}

fn logsoftmax(tensor: @Tensor<FP16x16>, axis: usize) -> Tensor<FP16x16> {
functional::logsoftmax::logsoftmax(tensor, axis)
functional::logsoftmax::logsoftmaxWide::<FP16x16, u32, FP16x16W, u64>(tensor, axis)
}

fn softsign(tensor: @Tensor<FP16x16>) -> Tensor<FP16x16> {
Expand Down
2 changes: 1 addition & 1 deletion src/operators/nn/implementations/nn_fp8x23.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl FP8x23NN of NNTrait<FP8x23> {
}

fn logsoftmax(tensor: @Tensor<FP8x23>, axis: usize) -> Tensor<FP8x23> {
functional::logsoftmax::logsoftmax(tensor, axis)
functional::logsoftmax::logsoftmaxWide::<FP8x23, u32, FP8x23W, u64>(tensor, axis)
}

fn softsign(tensor: @Tensor<FP8x23>) -> Tensor<FP8x23> {
Expand Down

0 comments on commit 19d5e99

Please sign in to comment.