From de04bd0c4422cd3ea994a781a455af60ea55df69 Mon Sep 17 00:00:00 2001 From: Yuval Shekel Date: Mon, 1 Jul 2024 15:14:10 +0300 Subject: [PATCH] bit reverse in rust --- .../rust_v3/icicle-core/src/vec_ops/mod.rs | 160 ++++++++++-------- .../rust_v3/icicle-core/src/vec_ops/tests.rs | 128 +++++++------- 2 files changed, 158 insertions(+), 130 deletions(-) diff --git a/wrappers/rust_v3/icicle-core/src/vec_ops/mod.rs b/wrappers/rust_v3/icicle-core/src/vec_ops/mod.rs index 07a8eda326..f30835eeb2 100644 --- a/wrappers/rust_v3/icicle-core/src/vec_ops/mod.rs +++ b/wrappers/rust_v3/icicle-core/src/vec_ops/mod.rs @@ -60,17 +60,16 @@ pub trait VecOps { cfg: &VecOpsConfig, ) -> Result<(), eIcicleError>; - // TODO Yuval : bit reverse - // fn bit_reverse( - // input: &(impl HostOrDeviceSlice + ?Sized), - // cfg: &BitReverseConfig, - // output: &mut (impl HostOrDeviceSlice + ?Sized), - // ) -> Result<(), eIcicleError>; - - // fn bit_reverse_inplace( - // input: &mut (impl HostOrDeviceSlice + ?Sized), - // cfg: &BitReverseConfig, - // ) -> Result<(), eIcicleError>; + fn bit_reverse( + input: &(impl HostOrDeviceSlice + ?Sized), + cfg: &VecOpsConfig, + output: &mut (impl HostOrDeviceSlice + ?Sized), + ) -> Result<(), eIcicleError>; + + fn bit_reverse_inplace( + input: &mut (impl HostOrDeviceSlice + ?Sized), + cfg: &VecOpsConfig, + ) -> Result<(), eIcicleError>; } fn check_vec_ops_args<'a, F>( @@ -169,6 +168,31 @@ where <::Config as VecOps>::transpose(input, nof_rows, nof_cols, output, &cfg) } +pub fn bit_reverse( + input: &(impl HostOrDeviceSlice + ?Sized), + cfg: &VecOpsConfig, + output: &mut (impl HostOrDeviceSlice + ?Sized), +) -> Result<(), eIcicleError> +where + F: FieldImpl, + ::Config: VecOps, +{ + let cfg = check_vec_ops_args(input, input /*dummy*/, output, cfg); + <::Config as VecOps>::bit_reverse(input, &cfg, output) +} + +pub fn bit_reverse_inplace( + input: &mut (impl HostOrDeviceSlice + ?Sized), + cfg: &VecOpsConfig, +) -> Result<(), eIcicleError> +where + F: FieldImpl, + ::Config: VecOps, +{ + let cfg = check_vec_ops_args(input, input /*dummy*/, input, cfg); + <::Config as VecOps>::bit_reverse_inplace(input, &cfg) +} + #[macro_export] macro_rules! impl_vec_ops_field { ( @@ -179,14 +203,13 @@ macro_rules! impl_vec_ops_field { ) => { mod $field_prefix_ident { - use crate::vec_ops::{$field, HostOrDeviceSlice}; - // use icicle_core::vec_ops::BitReverseConfig; + use crate::vec_ops::{$field, HostOrDeviceSlice}; use icicle_core::vec_ops::VecOpsConfig; use icicle_runtime::errors::eIcicleError; extern "C" { #[link_name = concat!($field_prefix, "_vector_add")] - pub(crate) fn vector_add( + pub(crate) fn vector_add_ffi( a: *const $field, b: *const $field, size: u32, @@ -195,7 +218,7 @@ macro_rules! impl_vec_ops_field { ) -> eIcicleError; #[link_name = concat!($field_prefix, "_vector_sub")] - pub(crate) fn vector_sub( + pub(crate) fn vector_sub_ffi( a: *const $field, b: *const $field, size: u32, @@ -204,7 +227,7 @@ macro_rules! impl_vec_ops_field { ) -> eIcicleError; #[link_name = concat!($field_prefix, "_vector_mul")] - pub(crate) fn vector_mul( + pub(crate) fn vector_mul_ffi( a: *const $field, b: *const $field, size: u32, @@ -213,7 +236,7 @@ macro_rules! impl_vec_ops_field { ) -> eIcicleError; #[link_name = concat!($field_prefix, "_matrix_transpose")] - pub(crate) fn _matrix_transpose( + pub(crate) fn matrix_transpose_ffi( input: *const $field, nof_rows: u32, nof_cols: u32, @@ -221,13 +244,13 @@ macro_rules! impl_vec_ops_field { output: *mut $field, ) -> eIcicleError; - // #[link_name = concat!($field_prefix, "_bit_reverse_cuda")] - // pub(crate) fn bit_reverse_cuda( - // input: *const $field, - // size: u64, - // config: *const BitReverseConfig, - // output: *mut $field, - // ) -> CudaError; + #[link_name = concat!($field_prefix, "_bit_reverse")] + pub(crate) fn bit_reverse_ffi( + input: *const $field, + size: u64, + config: *const VecOpsConfig, + output: *mut $field, + ) -> eIcicleError; } } @@ -239,7 +262,7 @@ macro_rules! impl_vec_ops_field { cfg: &VecOpsConfig, ) -> Result<(), eIcicleError> { unsafe { - $field_prefix_ident::vector_add( + $field_prefix_ident::vector_add_ffi( a.as_ptr(), b.as_ptr(), a.len() as u32, @@ -257,7 +280,7 @@ macro_rules! impl_vec_ops_field { cfg: &VecOpsConfig, ) -> Result<(), eIcicleError> { unsafe { - $field_prefix_ident::vector_sub( + $field_prefix_ident::vector_sub_ffi( a.as_ptr(), b.as_ptr(), a.len() as u32, @@ -275,7 +298,7 @@ macro_rules! impl_vec_ops_field { cfg: &VecOpsConfig, ) -> Result<(), eIcicleError> { unsafe { - $field_prefix_ident::vector_mul( + $field_prefix_ident::vector_mul_ffi( a.as_ptr(), b.as_ptr(), a.len() as u32, @@ -294,7 +317,7 @@ macro_rules! impl_vec_ops_field { cfg: &VecOpsConfig, ) -> Result<(), eIcicleError> { unsafe { - $field_prefix_ident::_matrix_transpose( + $field_prefix_ident::matrix_transpose_ffi( input.as_ptr(), nof_rows, nof_cols, @@ -305,38 +328,36 @@ macro_rules! impl_vec_ops_field { } } - // TODO Yuval : bit reverse - - // fn bit_reverse( - // input: &(impl HostOrDeviceSlice<$field> + ?Sized), - // cfg: &BitReverseConfig, - // output: &mut (impl HostOrDeviceSlice<$field> + ?Sized), - // ) -> Result<(), eIcicleError> { - // unsafe { - // $field_prefix_ident::bit_reverse_cuda( - // input.as_ptr(), - // input.len() as u64, - // cfg as *const BitReverseConfig, - // output.as_mut_ptr(), - // ) - // .wrap() - // } - // } - - // fn bit_reverse_inplace( - // input: &mut (impl HostOrDeviceSlice<$field> + ?Sized), - // cfg: &BitReverseConfig, - // ) -> Result<(), eIcicleError> { - // unsafe { - // $field_prefix_ident::bit_reverse_cuda( - // input.as_ptr(), - // input.len() as u64, - // cfg as *const BitReverseConfig, - // input.as_mut_ptr(), - // ) - // .wrap() - // } - // } + fn bit_reverse( + input: &(impl HostOrDeviceSlice<$field> + ?Sized), + cfg: &VecOpsConfig, + output: &mut (impl HostOrDeviceSlice<$field> + ?Sized), + ) -> Result<(), eIcicleError> { + unsafe { + $field_prefix_ident::bit_reverse_ffi( + input.as_ptr(), + input.len() as u64, + cfg as *const VecOpsConfig, + output.as_mut_ptr(), + ) + .wrap() + } + } + + fn bit_reverse_inplace( + input: &mut (impl HostOrDeviceSlice<$field> + ?Sized), + cfg: &VecOpsConfig, + ) -> Result<(), eIcicleError> { + unsafe { + $field_prefix_ident::bit_reverse_ffi( + input.as_ptr(), + input.len() as u64, + cfg as *const VecOpsConfig, + input.as_mut_ptr(), + ) + .wrap() + } + } } }; } @@ -368,14 +389,17 @@ macro_rules! impl_vec_ops_tests { check_matrix_transpose::<$field>() } - // #[test] - // pub fn test_bit_reverse() { - // check_bit_reverse::<$field>() - // } - // #[test] - // pub fn test_bit_reverse_inplace() { - // check_bit_reverse_inplace::<$field>() - // } + #[test] + pub fn test_bit_reverse() { + initialize(); + check_bit_reverse::<$field>() + } + + #[test] + pub fn test_bit_reverse_inplace() { + initialize(); + check_bit_reverse_inplace::<$field>() + } } }; } diff --git a/wrappers/rust_v3/icicle-core/src/vec_ops/tests.rs b/wrappers/rust_v3/icicle-core/src/vec_ops/tests.rs index e767235f1c..497b0deb0d 100644 --- a/wrappers/rust_v3/icicle-core/src/vec_ops/tests.rs +++ b/wrappers/rust_v3/icicle-core/src/vec_ops/tests.rs @@ -1,7 +1,7 @@ #![allow(unused_imports)] use crate::test_utilities; use crate::traits::GenerateRandom; -use crate::vec_ops::{add_scalars, mul_scalars, sub_scalars, transpose_matrix, FieldImpl, VecOps, VecOpsConfig}; +use crate::vec_ops::{add_scalars, mul_scalars, sub_scalars, bit_reverse, bit_reverse_inplace, transpose_matrix, FieldImpl, VecOps, VecOpsConfig}; use icicle_runtime::device::Device; use icicle_runtime::memory::{DeviceVec, HostSlice}; use icicle_runtime::{runtime, stream::IcicleStream}; @@ -103,64 +103,68 @@ where assert_eq!(result_main, result_ref); } -// pub fn check_bit_reverse() -// where -// ::Config: VecOps + GenerateRandom, -// { -// const LOG_SIZE: u32 = 20; -// const TEST_SIZE: usize = 1 << LOG_SIZE; -// let input_vec = F::Config::generate_random(TEST_SIZE); -// let input = HostSlice::from_slice(&input_vec); -// let mut intermediate = DeviceVec::::cuda_malloc(TEST_SIZE).unwrap(); -// let cfg = BitReverseConfig::default(); -// bit_reverse(input, &cfg, &mut intermediate[..]).unwrap(); - -// let mut intermediate_host = vec![F::one(); TEST_SIZE]; -// intermediate -// .copy_to_host(HostSlice::from_mut_slice(&mut intermediate_host[..])) -// .unwrap(); -// let index_reverser = |i: usize| i.reverse_bits() >> (usize::BITS - LOG_SIZE); -// intermediate_host -// .iter() -// .enumerate() -// .for_each(|(i, val)| assert_eq!(val, &input_vec[index_reverser(i)])); - -// let mut result = vec![F::one(); TEST_SIZE]; -// let result = HostSlice::from_mut_slice(&mut result); -// let cfg = BitReverseConfig::default(); -// bit_reverse(&intermediate[..], &cfg, result).unwrap(); -// assert_eq!(input.as_slice(), result.as_slice()); -// } - -// pub fn check_bit_reverse_inplace() -// where -// ::Config: VecOps + GenerateRandom, -// { -// const LOG_SIZE: u32 = 20; -// const TEST_SIZE: usize = 1 << LOG_SIZE; -// let input_vec = F::Config::generate_random(TEST_SIZE); -// let input = HostSlice::from_slice(&input_vec); -// let mut intermediate = DeviceVec::::cuda_malloc(TEST_SIZE).unwrap(); -// intermediate -// .copy_from_host(&input) -// .unwrap(); -// let cfg = BitReverseConfig::default(); -// bit_reverse_inplace(&mut intermediate[..], &cfg).unwrap(); - -// let mut intermediate_host = vec![F::one(); TEST_SIZE]; -// intermediate -// .copy_to_host(HostSlice::from_mut_slice(&mut intermediate_host[..])) -// .unwrap(); -// let index_reverser = |i: usize| i.reverse_bits() >> (usize::BITS - LOG_SIZE); -// intermediate_host -// .iter() -// .enumerate() -// .for_each(|(i, val)| assert_eq!(val, &input_vec[index_reverser(i)])); - -// bit_reverse_inplace(&mut intermediate[..], &cfg).unwrap(); -// let mut result_host = vec![F::one(); TEST_SIZE]; -// intermediate -// .copy_to_host(HostSlice::from_mut_slice(&mut result_host[..])) -// .unwrap(); -// assert_eq!(input.as_slice(), result_host.as_slice()); -// } +pub fn check_bit_reverse() +where + ::Config: VecOps + GenerateRandom, +{ + test_utilities::test_set_main_device(); + + const LOG_SIZE: u32 = 20; + const TEST_SIZE: usize = 1 << LOG_SIZE; + let input_vec = F::Config::generate_random(TEST_SIZE); + let input = HostSlice::from_slice(&input_vec); + let mut intermediate = DeviceVec::::device_malloc(TEST_SIZE).unwrap(); + let cfg = VecOpsConfig::default(); + bit_reverse(input, &cfg, &mut intermediate[..]).unwrap(); + + let mut intermediate_host = vec![F::one(); TEST_SIZE]; + intermediate + .copy_to_host(HostSlice::from_mut_slice(&mut intermediate_host[..])) + .unwrap(); + let index_reverser = |i: usize| i.reverse_bits() >> (usize::BITS - LOG_SIZE); + intermediate_host + .iter() + .enumerate() + .for_each(|(i, val)| assert_eq!(val, &input_vec[index_reverser(i)])); + + let mut result = vec![F::one(); TEST_SIZE]; + let result = HostSlice::from_mut_slice(&mut result); + let cfg = VecOpsConfig::default(); + bit_reverse(&intermediate[..], &cfg, result).unwrap(); + assert_eq!(input.as_slice(), result.as_slice()); +} + +pub fn check_bit_reverse_inplace() +where + ::Config: VecOps + GenerateRandom, +{ + test_utilities::test_set_main_device(); + + const LOG_SIZE: u32 = 20; + const TEST_SIZE: usize = 1 << LOG_SIZE; + let input_vec = F::Config::generate_random(TEST_SIZE); + let input = HostSlice::from_slice(&input_vec); + let mut intermediate = DeviceVec::::device_malloc(TEST_SIZE).unwrap(); + intermediate + .copy_from_host(&input) + .unwrap(); + let cfg = VecOpsConfig::default(); + bit_reverse_inplace(&mut intermediate[..], &cfg).unwrap(); + + let mut intermediate_host = vec![F::one(); TEST_SIZE]; + intermediate + .copy_to_host(HostSlice::from_mut_slice(&mut intermediate_host[..])) + .unwrap(); + let index_reverser = |i: usize| i.reverse_bits() >> (usize::BITS - LOG_SIZE); + intermediate_host + .iter() + .enumerate() + .for_each(|(i, val)| assert_eq!(val, &input_vec[index_reverser(i)])); + + bit_reverse_inplace(&mut intermediate[..], &cfg).unwrap(); + let mut result_host = vec![F::one(); TEST_SIZE]; + intermediate + .copy_to_host(HostSlice::from_mut_slice(&mut result_host[..])) + .unwrap(); + assert_eq!(input.as_slice(), result_host.as_slice()); +}