Skip to content

Commit

Permalink
impl tensor_sqrt
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Aug 15, 2024
1 parent 5fee3ab commit 5c7e732
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
2 changes: 1 addition & 1 deletion packages/deep-learning/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub(crate) mod ops;
pub(crate) mod utils;

pub use ops::binary::{BinaryOpMetadata, tensor_add, tensor_mul, tensor_rem, tensor_lt};
pub use ops::unary::{tensor_log2, tensor_exp2, tensor_sin};
pub use ops::unary::{tensor_log2, tensor_exp2, tensor_sin, tensor_sqrt};

#[derive(Drop, Copy)]
pub struct Tensor<T> {
Expand Down
33 changes: 32 additions & 1 deletion packages/deep-learning/src/ops/unary.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,24 @@ pub(crate) fn tensor_sin<T, S, +FixedTrait<T, S>, +Copy<T>, +Drop<T>>(
Tensor { data: result_data.span() }
}

pub(crate) fn tensor_sqrt<T, S, +FixedTrait<T, S>, +Copy<T>, +Drop<T>>(
ref self: Tensor<T>
) -> Tensor<T> {
let mut result_data = ArrayTrait::new();

loop {
match self.data.pop_front() {
Option::Some(ele) => { result_data.append(FixedTrait::sqrt(*ele)); },
Option::None(_) => { break; }
};
};

Tensor { data: result_data.span() }
}

#[cfg(test)]
mod tests {
use super::{Tensor, tensor_log2, tensor_exp2, tensor_sin};
use super::{Tensor, tensor_log2, tensor_exp2, tensor_sin, tensor_sqrt};
use orion_numbers::{F64, F64Impl, f64::helpers::assert_precise_span};

#[test]
Expand Down Expand Up @@ -100,4 +115,20 @@ mod tests {

assert_precise_span(result.data, expected, 'Incorrect sin result', error);
}

#[test]
fn test_tensor_sqrt() {
let self_data: Array<F64> = array![
F64Impl::new_unscaled(1),
F64Impl::new_unscaled(2),
F64Impl::new_unscaled(3),
F64Impl::new_unscaled(4),
];
let mut self = Tensor { data: self_data.span() };

let result = tensor_sqrt(ref self);
let expected = array![4294967296, 6074000999, 7439101573, 8589934592].span();

assert_precise_span(result.data, expected, 'Incorrect sin result', Option::None);
}
}

0 comments on commit 5c7e732

Please sign in to comment.