diff --git a/src/operators/nn/functional/logsoftmax.cairo b/src/operators/nn/functional/logsoftmax.cairo index 6d19cbb62..bd38d138c 100644 --- a/src/operators/nn/functional/logsoftmax.cairo +++ b/src/operators/nn/functional/logsoftmax.cairo @@ -2,6 +2,8 @@ use array::SpanTrait; use orion::numbers::NumberTrait; use orion::operators::tensor::core::{Tensor, TensorTrait}; +use orion::numbers::fixed_point::core::FixedTrait; +use orion::operators::tensor::math::{exp::exp_upcast, arithmetic::div_downcast}; /// Cf: NNTrait::logsoftmax docstring fn logsoftmax< @@ -16,3 +18,29 @@ fn logsoftmax< return logsoftmax; } + +/// Cf: NNTrait::logsoftmax docstring +fn logsoftmaxWide< + T, + TMAG, + W, + WMAG, + impl TTensor: TensorTrait, + impl WTensor: TensorTrait, + impl TDiv: Div, + impl TIntoW: Into, + impl WTryIntoT: TryInto, + impl TCopy: Copy, + impl TDrop: Drop, + impl WCopy: Copy, + impl WDrop: Drop, + impl TFixed: FixedTrait, + impl WFixed: FixedTrait, +>( + z: @Tensor, axis: usize +) -> Tensor { + let exp_tensor: Tensor = exp_upcast(*z); + let sum = exp_tensor.reduce_sum(axis, true); + let softmax = div_downcast(@exp_tensor, @sum); + softmax.log() +} \ No newline at end of file diff --git a/src/operators/nn/implementations/nn_fp16x16.cairo b/src/operators/nn/implementations/nn_fp16x16.cairo index b940d8742..de81cde6d 100644 --- a/src/operators/nn/implementations/nn_fp16x16.cairo +++ b/src/operators/nn/implementations/nn_fp16x16.cairo @@ -28,7 +28,7 @@ impl FP16x16NN of NNTrait { } fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { - functional::logsoftmax::logsoftmax(tensor, axis) + functional::logsoftmax::logsoftmaxWide::(tensor, axis) } fn softsign(tensor: @Tensor) -> Tensor { diff --git a/src/operators/nn/implementations/nn_fp8x23.cairo b/src/operators/nn/implementations/nn_fp8x23.cairo index 510f8cebd..d837b8fef 100644 --- a/src/operators/nn/implementations/nn_fp8x23.cairo +++ b/src/operators/nn/implementations/nn_fp8x23.cairo @@ -26,7 +26,7 @@ impl FP8x23NN of NNTrait { } fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { - functional::logsoftmax::logsoftmax(tensor, axis) + functional::logsoftmax::logsoftmaxWide::(tensor, axis) } fn softsign(tensor: @Tensor) -> Tensor {