Skip to content

Commit

Permalink
impl tensor_recip
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Aug 15, 2024
1 parent 5c7e732 commit e1bec01
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
2 changes: 1 addition & 1 deletion packages/deep-learning/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub(crate) mod ops;
pub(crate) mod utils;

pub use ops::binary::{BinaryOpMetadata, tensor_add, tensor_mul, tensor_rem, tensor_lt};
pub use ops::unary::{tensor_log2, tensor_exp2, tensor_sin, tensor_sqrt};
pub use ops::unary::{tensor_log2, tensor_exp2, tensor_sin, tensor_sqrt, tensor_recip};

#[derive(Drop, Copy)]
pub struct Tensor<T> {
Expand Down
33 changes: 32 additions & 1 deletion packages/deep-learning/src/ops/unary.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,24 @@ pub(crate) fn tensor_sqrt<T, S, +FixedTrait<T, S>, +Copy<T>, +Drop<T>>(
Tensor { data: result_data.span() }
}

pub(crate) fn tensor_recip<T, S, +FixedTrait<T, S>, +Div<T>, +Copy<T>, +Drop<T>>(
ref self: Tensor<T>
) -> Tensor<T> {
let mut result_data = ArrayTrait::new();

loop {
match self.data.pop_front() {
Option::Some(ele) => { result_data.append(FixedTrait::ONE() / *ele); },
Option::None(_) => { break; }
};
};

Tensor { data: result_data.span() }
}

#[cfg(test)]
mod tests {
use super::{Tensor, tensor_log2, tensor_exp2, tensor_sin, tensor_sqrt};
use super::{Tensor, tensor_log2, tensor_exp2, tensor_sin, tensor_sqrt, tensor_recip};
use orion_numbers::{F64, F64Impl, f64::helpers::assert_precise_span};

#[test]
Expand Down Expand Up @@ -131,4 +146,20 @@ mod tests {

assert_precise_span(result.data, expected, 'Incorrect sin result', Option::None);
}

#[test]
fn test_tensor_recip() {
let self_data: Array<F64> = array![
F64Impl::new_unscaled(1),
F64Impl::new_unscaled(2),
F64Impl::new_unscaled(3),
F64Impl::new_unscaled(4),
];
let mut self = Tensor { data: self_data.span() };

let result = tensor_recip(ref self);
let expected = array![4294967296, 2147483648, 1431655765, 1073741824].span();

assert_precise_span(result.data, expected, 'Incorrect sin result', Option::None);
}
}

0 comments on commit e1bec01

Please sign in to comment.