Skip to content

Commit

Permalink
bit reverse in rust
Browse files Browse the repository at this point in the history
  • Loading branch information
yshekel committed Jul 1, 2024
1 parent 4d059f9 commit de04bd0
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 130 deletions.
160 changes: 92 additions & 68 deletions wrappers/rust_v3/icicle-core/src/vec_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,16 @@ pub trait VecOps<F> {
cfg: &VecOpsConfig,
) -> Result<(), eIcicleError>;

// TODO Yuval : bit reverse
// fn bit_reverse(
// input: &(impl HostOrDeviceSlice<F> + ?Sized),
// cfg: &BitReverseConfig,
// output: &mut (impl HostOrDeviceSlice<F> + ?Sized),
// ) -> Result<(), eIcicleError>;

// fn bit_reverse_inplace(
// input: &mut (impl HostOrDeviceSlice<F> + ?Sized),
// cfg: &BitReverseConfig,
// ) -> Result<(), eIcicleError>;
fn bit_reverse(
input: &(impl HostOrDeviceSlice<F> + ?Sized),
cfg: &VecOpsConfig,
output: &mut (impl HostOrDeviceSlice<F> + ?Sized),
) -> Result<(), eIcicleError>;

fn bit_reverse_inplace(
input: &mut (impl HostOrDeviceSlice<F> + ?Sized),
cfg: &VecOpsConfig,
) -> Result<(), eIcicleError>;
}

fn check_vec_ops_args<'a, F>(
Expand Down Expand Up @@ -169,6 +168,31 @@ where
<<F as FieldImpl>::Config as VecOps<F>>::transpose(input, nof_rows, nof_cols, output, &cfg)
}

pub fn bit_reverse<F>(
input: &(impl HostOrDeviceSlice<F> + ?Sized),
cfg: &VecOpsConfig,
output: &mut (impl HostOrDeviceSlice<F> + ?Sized),
) -> Result<(), eIcicleError>
where
F: FieldImpl,
<F as FieldImpl>::Config: VecOps<F>,
{
let cfg = check_vec_ops_args(input, input /*dummy*/, output, cfg);
<<F as FieldImpl>::Config as VecOps<F>>::bit_reverse(input, &cfg, output)
}

pub fn bit_reverse_inplace<F>(
input: &mut (impl HostOrDeviceSlice<F> + ?Sized),
cfg: &VecOpsConfig,
) -> Result<(), eIcicleError>
where
F: FieldImpl,
<F as FieldImpl>::Config: VecOps<F>,
{
let cfg = check_vec_ops_args(input, input /*dummy*/, input, cfg);
<<F as FieldImpl>::Config as VecOps<F>>::bit_reverse_inplace(input, &cfg)
}

#[macro_export]
macro_rules! impl_vec_ops_field {
(
Expand All @@ -179,14 +203,13 @@ macro_rules! impl_vec_ops_field {
) => {
mod $field_prefix_ident {

use crate::vec_ops::{$field, HostOrDeviceSlice};
// use icicle_core::vec_ops::BitReverseConfig;
use crate::vec_ops::{$field, HostOrDeviceSlice};
use icicle_core::vec_ops::VecOpsConfig;
use icicle_runtime::errors::eIcicleError;

extern "C" {
#[link_name = concat!($field_prefix, "_vector_add")]
pub(crate) fn vector_add(
pub(crate) fn vector_add_ffi(
a: *const $field,
b: *const $field,
size: u32,
Expand All @@ -195,7 +218,7 @@ macro_rules! impl_vec_ops_field {
) -> eIcicleError;

#[link_name = concat!($field_prefix, "_vector_sub")]
pub(crate) fn vector_sub(
pub(crate) fn vector_sub_ffi(
a: *const $field,
b: *const $field,
size: u32,
Expand All @@ -204,7 +227,7 @@ macro_rules! impl_vec_ops_field {
) -> eIcicleError;

#[link_name = concat!($field_prefix, "_vector_mul")]
pub(crate) fn vector_mul(
pub(crate) fn vector_mul_ffi(
a: *const $field,
b: *const $field,
size: u32,
Expand All @@ -213,21 +236,21 @@ macro_rules! impl_vec_ops_field {
) -> eIcicleError;

#[link_name = concat!($field_prefix, "_matrix_transpose")]
pub(crate) fn _matrix_transpose(
pub(crate) fn matrix_transpose_ffi(
input: *const $field,
nof_rows: u32,
nof_cols: u32,
cfg: *const VecOpsConfig,
output: *mut $field,
) -> eIcicleError;

// #[link_name = concat!($field_prefix, "_bit_reverse_cuda")]
// pub(crate) fn bit_reverse_cuda(
// input: *const $field,
// size: u64,
// config: *const BitReverseConfig,
// output: *mut $field,
// ) -> CudaError;
#[link_name = concat!($field_prefix, "_bit_reverse")]
pub(crate) fn bit_reverse_ffi(
input: *const $field,
size: u64,
config: *const VecOpsConfig,
output: *mut $field,
) -> eIcicleError;
}
}

Expand All @@ -239,7 +262,7 @@ macro_rules! impl_vec_ops_field {
cfg: &VecOpsConfig,
) -> Result<(), eIcicleError> {
unsafe {
$field_prefix_ident::vector_add(
$field_prefix_ident::vector_add_ffi(
a.as_ptr(),
b.as_ptr(),
a.len() as u32,
Expand All @@ -257,7 +280,7 @@ macro_rules! impl_vec_ops_field {
cfg: &VecOpsConfig,
) -> Result<(), eIcicleError> {
unsafe {
$field_prefix_ident::vector_sub(
$field_prefix_ident::vector_sub_ffi(
a.as_ptr(),
b.as_ptr(),
a.len() as u32,
Expand All @@ -275,7 +298,7 @@ macro_rules! impl_vec_ops_field {
cfg: &VecOpsConfig,
) -> Result<(), eIcicleError> {
unsafe {
$field_prefix_ident::vector_mul(
$field_prefix_ident::vector_mul_ffi(
a.as_ptr(),
b.as_ptr(),
a.len() as u32,
Expand All @@ -294,7 +317,7 @@ macro_rules! impl_vec_ops_field {
cfg: &VecOpsConfig,
) -> Result<(), eIcicleError> {
unsafe {
$field_prefix_ident::_matrix_transpose(
$field_prefix_ident::matrix_transpose_ffi(
input.as_ptr(),
nof_rows,
nof_cols,
Expand All @@ -305,38 +328,36 @@ macro_rules! impl_vec_ops_field {
}
}

// TODO Yuval : bit reverse

// fn bit_reverse(
// input: &(impl HostOrDeviceSlice<$field> + ?Sized),
// cfg: &BitReverseConfig,
// output: &mut (impl HostOrDeviceSlice<$field> + ?Sized),
// ) -> Result<(), eIcicleError> {
// unsafe {
// $field_prefix_ident::bit_reverse_cuda(
// input.as_ptr(),
// input.len() as u64,
// cfg as *const BitReverseConfig,
// output.as_mut_ptr(),
// )
// .wrap()
// }
// }

// fn bit_reverse_inplace(
// input: &mut (impl HostOrDeviceSlice<$field> + ?Sized),
// cfg: &BitReverseConfig,
// ) -> Result<(), eIcicleError> {
// unsafe {
// $field_prefix_ident::bit_reverse_cuda(
// input.as_ptr(),
// input.len() as u64,
// cfg as *const BitReverseConfig,
// input.as_mut_ptr(),
// )
// .wrap()
// }
// }
fn bit_reverse(
input: &(impl HostOrDeviceSlice<$field> + ?Sized),
cfg: &VecOpsConfig,
output: &mut (impl HostOrDeviceSlice<$field> + ?Sized),
) -> Result<(), eIcicleError> {
unsafe {
$field_prefix_ident::bit_reverse_ffi(
input.as_ptr(),
input.len() as u64,
cfg as *const VecOpsConfig,
output.as_mut_ptr(),
)
.wrap()
}
}

fn bit_reverse_inplace(
input: &mut (impl HostOrDeviceSlice<$field> + ?Sized),
cfg: &VecOpsConfig,
) -> Result<(), eIcicleError> {
unsafe {
$field_prefix_ident::bit_reverse_ffi(
input.as_ptr(),
input.len() as u64,
cfg as *const VecOpsConfig,
input.as_mut_ptr(),
)
.wrap()
}
}
}
};
}
Expand Down Expand Up @@ -368,14 +389,17 @@ macro_rules! impl_vec_ops_tests {
check_matrix_transpose::<$field>()
}

// #[test]
// pub fn test_bit_reverse() {
// check_bit_reverse::<$field>()
// }
// #[test]
// pub fn test_bit_reverse_inplace() {
// check_bit_reverse_inplace::<$field>()
// }
#[test]
pub fn test_bit_reverse() {
initialize();
check_bit_reverse::<$field>()
}

#[test]
pub fn test_bit_reverse_inplace() {
initialize();
check_bit_reverse_inplace::<$field>()
}
}
};
}
128 changes: 66 additions & 62 deletions wrappers/rust_v3/icicle-core/src/vec_ops/tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![allow(unused_imports)]
use crate::test_utilities;
use crate::traits::GenerateRandom;
use crate::vec_ops::{add_scalars, mul_scalars, sub_scalars, transpose_matrix, FieldImpl, VecOps, VecOpsConfig};
use crate::vec_ops::{add_scalars, mul_scalars, sub_scalars, bit_reverse, bit_reverse_inplace, transpose_matrix, FieldImpl, VecOps, VecOpsConfig};
use icicle_runtime::device::Device;
use icicle_runtime::memory::{DeviceVec, HostSlice};
use icicle_runtime::{runtime, stream::IcicleStream};
Expand Down Expand Up @@ -103,64 +103,68 @@ where
assert_eq!(result_main, result_ref);
}

// pub fn check_bit_reverse<F: FieldImpl>()
// where
// <F as FieldImpl>::Config: VecOps<F> + GenerateRandom<F>,
// {
// const LOG_SIZE: u32 = 20;
// const TEST_SIZE: usize = 1 << LOG_SIZE;
// let input_vec = F::Config::generate_random(TEST_SIZE);
// let input = HostSlice::from_slice(&input_vec);
// let mut intermediate = DeviceVec::<F>::cuda_malloc(TEST_SIZE).unwrap();
// let cfg = BitReverseConfig::default();
// bit_reverse(input, &cfg, &mut intermediate[..]).unwrap();

// let mut intermediate_host = vec![F::one(); TEST_SIZE];
// intermediate
// .copy_to_host(HostSlice::from_mut_slice(&mut intermediate_host[..]))
// .unwrap();
// let index_reverser = |i: usize| i.reverse_bits() >> (usize::BITS - LOG_SIZE);
// intermediate_host
// .iter()
// .enumerate()
// .for_each(|(i, val)| assert_eq!(val, &input_vec[index_reverser(i)]));

// let mut result = vec![F::one(); TEST_SIZE];
// let result = HostSlice::from_mut_slice(&mut result);
// let cfg = BitReverseConfig::default();
// bit_reverse(&intermediate[..], &cfg, result).unwrap();
// assert_eq!(input.as_slice(), result.as_slice());
// }

// pub fn check_bit_reverse_inplace<F: FieldImpl>()
// where
// <F as FieldImpl>::Config: VecOps<F> + GenerateRandom<F>,
// {
// const LOG_SIZE: u32 = 20;
// const TEST_SIZE: usize = 1 << LOG_SIZE;
// let input_vec = F::Config::generate_random(TEST_SIZE);
// let input = HostSlice::from_slice(&input_vec);
// let mut intermediate = DeviceVec::<F>::cuda_malloc(TEST_SIZE).unwrap();
// intermediate
// .copy_from_host(&input)
// .unwrap();
// let cfg = BitReverseConfig::default();
// bit_reverse_inplace(&mut intermediate[..], &cfg).unwrap();

// let mut intermediate_host = vec![F::one(); TEST_SIZE];
// intermediate
// .copy_to_host(HostSlice::from_mut_slice(&mut intermediate_host[..]))
// .unwrap();
// let index_reverser = |i: usize| i.reverse_bits() >> (usize::BITS - LOG_SIZE);
// intermediate_host
// .iter()
// .enumerate()
// .for_each(|(i, val)| assert_eq!(val, &input_vec[index_reverser(i)]));

// bit_reverse_inplace(&mut intermediate[..], &cfg).unwrap();
// let mut result_host = vec![F::one(); TEST_SIZE];
// intermediate
// .copy_to_host(HostSlice::from_mut_slice(&mut result_host[..]))
// .unwrap();
// assert_eq!(input.as_slice(), result_host.as_slice());
// }
pub fn check_bit_reverse<F: FieldImpl>()
where
<F as FieldImpl>::Config: VecOps<F> + GenerateRandom<F>,
{
test_utilities::test_set_main_device();

const LOG_SIZE: u32 = 20;
const TEST_SIZE: usize = 1 << LOG_SIZE;
let input_vec = F::Config::generate_random(TEST_SIZE);
let input = HostSlice::from_slice(&input_vec);
let mut intermediate = DeviceVec::<F>::device_malloc(TEST_SIZE).unwrap();
let cfg = VecOpsConfig::default();
bit_reverse(input, &cfg, &mut intermediate[..]).unwrap();

let mut intermediate_host = vec![F::one(); TEST_SIZE];
intermediate
.copy_to_host(HostSlice::from_mut_slice(&mut intermediate_host[..]))
.unwrap();
let index_reverser = |i: usize| i.reverse_bits() >> (usize::BITS - LOG_SIZE);
intermediate_host
.iter()
.enumerate()
.for_each(|(i, val)| assert_eq!(val, &input_vec[index_reverser(i)]));

let mut result = vec![F::one(); TEST_SIZE];
let result = HostSlice::from_mut_slice(&mut result);
let cfg = VecOpsConfig::default();
bit_reverse(&intermediate[..], &cfg, result).unwrap();
assert_eq!(input.as_slice(), result.as_slice());
}

pub fn check_bit_reverse_inplace<F: FieldImpl>()
where
<F as FieldImpl>::Config: VecOps<F> + GenerateRandom<F>,
{
test_utilities::test_set_main_device();

const LOG_SIZE: u32 = 20;
const TEST_SIZE: usize = 1 << LOG_SIZE;
let input_vec = F::Config::generate_random(TEST_SIZE);
let input = HostSlice::from_slice(&input_vec);
let mut intermediate = DeviceVec::<F>::device_malloc(TEST_SIZE).unwrap();
intermediate
.copy_from_host(&input)
.unwrap();
let cfg = VecOpsConfig::default();
bit_reverse_inplace(&mut intermediate[..], &cfg).unwrap();

let mut intermediate_host = vec![F::one(); TEST_SIZE];
intermediate
.copy_to_host(HostSlice::from_mut_slice(&mut intermediate_host[..]))
.unwrap();
let index_reverser = |i: usize| i.reverse_bits() >> (usize::BITS - LOG_SIZE);
intermediate_host
.iter()
.enumerate()
.for_each(|(i, val)| assert_eq!(val, &input_vec[index_reverser(i)]));

bit_reverse_inplace(&mut intermediate[..], &cfg).unwrap();
let mut result_host = vec![F::one(); TEST_SIZE];
intermediate
.copy_to_host(HostSlice::from_mut_slice(&mut result_host[..]))
.unwrap();
assert_eq!(input.as_slice(), result_host.as_slice());
}

0 comments on commit de04bd0

Please sign in to comment.